from __future__ import annotations

import json
from pathlib import Path
from typing import List

from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel as ByteLevelPreTok
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import BpeTrainer

_HF_MODEL = "klue/roberta-base"
_SPECIAL = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"]
_BPE_VOCAB_SIZE = 16000


class MicroTokenizer:
    def __init__(
        self,
        backend,
        pad_id: int,
        cls_id: int,
        sep_id: int,
        vocab_size: int,
        mode: str,
    ):
        self.backend = backend
        self.pad_id = pad_id
        self.cls_id = cls_id
        self.sep_id = sep_id
        self.vocab_size = vocab_size
        self.mode = mode

    def __len__(self) -> int:
        return self.vocab_size

    def encode(self, text: str, max_len: int) -> List[int]:
        if self.mode == "hf":
            out = self.backend(
                text,
                add_special_tokens=True,
                truncation=True,
                max_length=max_len,
                return_attention_mask=False,
            )
            return list(out["input_ids"])
        enc = self.backend.encode(text)
        return list(enc.ids)[:max_len]

    def batch_encode(self, texts: List[str], max_len: int) -> List[List[int]]:
        if self.mode == "hf":
            out = self.backend(
                texts,
                add_special_tokens=True,
                truncation=True,
                max_length=max_len,
                return_attention_mask=False,
            )
            return [list(ids) for ids in out["input_ids"]]
        encodings = self.backend.encode_batch(texts)
        return [list(e.ids)[:max_len] for e in encodings]

    @classmethod
    def build(
        cls,
        save_dir: Path,
        train_jsonl: Path,
        max_len: int,
        prefer_hf: bool = True,
        add_speaker_tokens: bool = False,
    ) -> "MicroTokenizer":
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        existing = save_dir / "tokenizer.json"
        if existing.exists() and not add_speaker_tokens:
            # Skip cached BPE if speaker tokens requested; HF path is needed to add them.
            try:
                return cls._load_bpe(existing)
            except Exception as e:
                print(f"[tokenizer] local tokenizer.json load failed: {e}")
        if prefer_hf:
            try:
                return cls._load_hf(max_len, add_speaker_tokens=add_speaker_tokens)
            except Exception as e:
                print(f"[tokenizer] HF path failed: {e}; training fresh BPE on seed")
        return cls._train_bpe(train_jsonl, save_dir, max_len)

    @classmethod
    def _load_hf(cls, max_len: int, add_speaker_tokens: bool = False) -> "MicroTokenizer":
        from transformers import AutoTokenizer

        tok = AutoTokenizer.from_pretrained(_HF_MODEL)
        if add_speaker_tokens:
            # Register speaker-relative markers for dialogue context concat.
            # After this, caller must resize model token embeddings (vocab grew).
            tok.add_special_tokens({"additional_special_tokens": ["[SELF]", "[OTHER]"]})
        probe = tok(
            "안녕하세요",
            add_special_tokens=True,
            truncation=True,
            max_length=max_len,
            return_attention_mask=False,
        )
        assert len(probe["input_ids"]) >= 3, "HF tokenizer smoke test failed"
        assert tok.pad_token_id is not None and tok.cls_token_id is not None
        return cls(
            backend=tok,
            pad_id=int(tok.pad_token_id),
            cls_id=int(tok.cls_token_id),
            sep_id=int(tok.sep_token_id),
            vocab_size=len(tok),
            mode="hf",
        )

    @classmethod
    def _train_bpe(cls, jsonl_path: Path, save_dir: Path, max_len: int) -> "MicroTokenizer":
        texts: List[str] = []
        with Path(jsonl_path).open(encoding="utf-8") as f:
            for line in f:
                row = json.loads(line)
                for t in row["turns"]:
                    txt = t["text"].strip()
                    if txt:
                        texts.append(txt)
        tok = Tokenizer(BPE(unk_token="[UNK]"))
        tok.pre_tokenizer = ByteLevelPreTok(add_prefix_space=False)
        tok.decoder = ByteLevelDecoder()
        trainer = BpeTrainer(
            vocab_size=_BPE_VOCAB_SIZE,
            special_tokens=_SPECIAL,
            initial_alphabet=ByteLevelPreTok.alphabet(),
        )
        tok.train_from_iterator(texts, trainer=trainer)
        cls_id = tok.token_to_id("[CLS]")
        sep_id = tok.token_to_id("[SEP]")
        pad_id = tok.token_to_id("[PAD]")
        tok.post_processor = TemplateProcessing(
            single="[CLS] $A [SEP]",
            special_tokens=[("[CLS]", cls_id), ("[SEP]", sep_id)],
        )
        tok.enable_truncation(max_length=max_len)
        out_path = save_dir / "tokenizer.json"
        tok.save(str(out_path))
        return cls(
            backend=tok,
            pad_id=pad_id,
            cls_id=cls_id,
            sep_id=sep_id,
            vocab_size=tok.get_vocab_size(),
            mode="bpe",
        )

    @classmethod
    def _load_bpe(cls, path: Path) -> "MicroTokenizer":
        tok = Tokenizer.from_file(str(path))
        return cls(
            backend=tok,
            pad_id=tok.token_to_id("[PAD]"),
            cls_id=tok.token_to_id("[CLS]"),
            sep_id=tok.token_to_id("[SEP]"),
            vocab_size=tok.get_vocab_size(),
            mode="bpe",
        )
