from __future__ import annotations

import argparse
import json
import math
import time
from dataclasses import asdict, replace
from functools import partial
from pathlib import Path
from typing import Dict, List

import torch
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset

from .config import MicroAlbertConfig
from .model import MicroAlbertBackbone
from .tokenizer import MicroTokenizer
from .train import lr_lambda_factory, seed_worker, set_seed

KLUE_MODEL = "klue/roberta-base"
WIKI_CACHE = Path("/dataset/AnimaSync-mic-fix/data/wikipedia_ko")
KOTE_TRAIN = Path("/dataset/KOTE/train.tsv")
SEED_TRAIN = Path("/dataset/AnimaSync-mic-fix/data/emotion/seed_train_final.jsonl")


def _encode_and_append(
    tok: MicroTokenizer, texts: List[str], samples: List[List[int]], max_len: int
) -> int:
    if not texts:
        return 0
    batch_ids = tok.batch_encode(texts, max_len)
    added = 0
    for ids in batch_ids:
        if len(ids) >= 4:
            samples.append(ids)
            added += 1
    return added


def build_corpus(
    tok: MicroTokenizer,
    max_len: int,
    wiki_truncate_chars: int = 1500,
    wiki_limit: int = 0,
    encode_batch_size: int = 5000,
) -> List[List[int]]:
    samples: List[List[int]] = []

    print("[corpus] loading Korean Wikipedia...", flush=True)
    from datasets import load_dataset

    ds_wiki = load_dataset(
        "wikimedia/wikipedia",
        "20231101.ko",
        cache_dir=str(WIKI_CACHE),
        split="train",
    )
    total_wiki = len(ds_wiki) if not wiki_limit else min(len(ds_wiki), wiki_limit)
    print(f"[corpus] wiki articles: {len(ds_wiki):,} (processing {total_wiki:,})", flush=True)
    n_wiki = 0
    buf: List[str] = []
    t0 = time.time()
    for i, row in enumerate(ds_wiki):
        if wiki_limit and i >= wiki_limit:
            break
        text = row["text"][:wiki_truncate_chars].strip()
        if text:
            buf.append(text)
        if len(buf) >= encode_batch_size:
            n_wiki += _encode_and_append(tok, buf, samples, max_len)
            buf = []
        if (i + 1) % 50000 == 0:
            elapsed = time.time() - t0
            rate = (i + 1) / max(elapsed, 1e-6)
            eta = max(0, (total_wiki - (i + 1)) / max(rate, 1e-6))
            print(
                f"[corpus] wiki {i+1:,}/{total_wiki:,} | samples {n_wiki:,} | "
                f"{elapsed:.0f}s ({rate:.0f}/s) eta {eta:.0f}s",
                flush=True,
            )
    n_wiki += _encode_and_append(tok, buf, samples, max_len)
    print(f"[corpus] wiki samples: {n_wiki:,}", flush=True)

    print("[corpus] loading KOTE...", flush=True)
    buf = []
    n_kote = 0
    with KOTE_TRAIN.open(encoding="utf-8") as f:
        for line in f:
            parts = line.split("\t")
            if len(parts) >= 2:
                text = parts[1].strip()
                if text:
                    buf.append(text)
                    if len(buf) >= encode_batch_size:
                        n_kote += _encode_and_append(tok, buf, samples, max_len)
                        buf = []
    n_kote += _encode_and_append(tok, buf, samples, max_len)
    print(f"[corpus] kote samples: {n_kote:,}", flush=True)

    print("[corpus] loading seed...", flush=True)
    buf = []
    n_seed = 0
    with SEED_TRAIN.open(encoding="utf-8") as f:
        for line in f:
            row = json.loads(line)
            for t in row["turns"]:
                text = t["text"].strip()
                if text:
                    buf.append(text)
                    if len(buf) >= encode_batch_size:
                        n_seed += _encode_and_append(tok, buf, samples, max_len)
                        buf = []
    n_seed += _encode_and_append(tok, buf, samples, max_len)
    print(f"[corpus] seed samples: {n_seed:,}", flush=True)

    print(f"[corpus] total: {len(samples):,}", flush=True)
    return samples


class CorpusDataset(Dataset):
    def __init__(self, samples: List[List[int]]):
        self.samples = samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict:
        return {"input_ids": self.samples[idx]}


def collate_corpus(batch: List[Dict], pad_id: int) -> Dict[str, torch.Tensor]:
    max_len = max(len(b["input_ids"]) for b in batch)
    input_ids: List[List[int]] = []
    attn_mask: List[List[int]] = []
    for b in batch:
        ids = b["input_ids"]
        n_pad = max_len - len(ids)
        input_ids.append(ids + [pad_id] * n_pad)
        attn_mask.append([1] * len(ids) + [0] * n_pad)
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
    }


def relation_loss(
    q_t: torch.Tensor,
    k_t: torch.Tensor,
    v_t: torch.Tensor,
    q_s: torch.Tensor,
    k_s: torch.Tensor,
    v_s: torch.Tensor,
    attention_mask: torch.Tensor,
    num_relation_heads: int,
) -> torch.Tensor:
    B, L, H_t = q_t.shape
    H_s = q_s.shape[-1]
    R = num_relation_heads
    assert H_t % R == 0 and H_s % R == 0, f"H_t={H_t}, H_s={H_s} must be divisible by R={R}"
    dh_t = H_t // R
    dh_s = H_s // R

    def to_rel(x: torch.Tensor, dh: int) -> torch.Tensor:
        return x.reshape(B, L, R, dh).permute(0, 2, 1, 3).contiguous()

    def compute_rel(x: torch.Tensor, dh: int) -> torch.Tensor:
        scores = x @ x.transpose(-2, -1) / math.sqrt(dh)
        m = attention_mask.bool()[:, None, None, :]
        scores = scores.masked_fill(~m, -1e9)
        return scores.softmax(dim=-1)

    losses = []
    for x_t, x_s in [(q_t, q_s), (k_t, k_s), (v_t, v_s)]:
        rel_t = compute_rel(to_rel(x_t, dh_t), dh_t)
        rel_s = compute_rel(to_rel(x_s, dh_s), dh_s)
        log_rel_s = (rel_s + 1e-9).log()
        kl = F.kl_div(log_rel_s, rel_t, reduction="none").sum(dim=-1)
        row_mask = attention_mask.bool()[:, None, :].float()
        n_valid = (row_mask.sum() * R).clamp(min=1.0)
        loss_term = (kl * row_mask).sum() / n_valid
        losses.append(loss_term)
    return sum(losses) / len(losses)


class TeacherWithQKV:
    def __init__(self, model_name: str, device: str):
        from transformers import AutoModel

        self.model = AutoModel.from_pretrained(model_name).to(device).eval()
        for p in self.model.parameters():
            p.requires_grad = False
        self.q = self.k = self.v = None
        last = self.model.encoder.layer[-1].attention.self
        last.query.register_forward_hook(self._cap("q"))
        last.key.register_forward_hook(self._cap("k"))
        last.value.register_forward_hook(self._cap("v"))

    def _cap(self, name: str):
        def hook(module, inputs, output):
            setattr(self, name, output.detach())
        return hook

    @torch.no_grad()
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        self.model(input_ids=input_ids, attention_mask=attention_mask)
        return self.q, self.k, self.v


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument("--output_dir", type=Path, default=Path("checkpoints/microalbert_distill"))
    ap.add_argument("--max_steps", type=int, default=30000)
    ap.add_argument("--batch_size", type=int, default=128)
    ap.add_argument("--max_seq_len", type=int, default=128)
    ap.add_argument("--lr", type=float, default=2e-4)
    ap.add_argument("--weight_decay", type=float, default=0.01)
    ap.add_argument("--warmup_ratio", type=float, default=0.10)
    ap.add_argument("--num_relation_heads", type=int, default=48)
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--log_every", type=int, default=50)
    ap.add_argument("--save_every", type=int, default=5000)
    ap.add_argument("--wiki_limit", type=int, default=0, help="Cap wiki articles for quick runs (0=all)")
    ap.add_argument("--wandb_project", type=str, default=None, help="W&B project name (enables logging)")
    ap.add_argument("--wandb_run_name", type=str, default=None)
    ap.add_argument("--wandb_entity", type=str, default=None)
    return ap.parse_args()


def main() -> None:
    args = parse_args()
    set_seed(args.seed)
    device = args.device
    output_dir: Path = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    tok = MicroTokenizer.build(
        save_dir=output_dir / "tokenizer",
        train_jsonl=SEED_TRAIN,
        max_len=args.max_seq_len,
        prefer_hf=True,
    )
    print(f"[tokenizer] mode={tok.mode} vocab={len(tok)} pad_id={tok.pad_id}")

    samples = build_corpus(tok, args.max_seq_len, wiki_limit=args.wiki_limit)
    dataset = CorpusDataset(samples)

    cfg = MicroAlbertConfig()
    cfg = replace(
        cfg,
        vocab_size=len(tok),
        pad_token_id=tok.pad_id,
        max_seq_len=args.max_seq_len,
        seed=args.seed,
    )
    student = MicroAlbertBackbone(cfg).to(device)
    student.train()
    n_student = sum(p.numel() for p in student.parameters())
    print(f"[student] backbone params={n_student:,}")

    print("[teacher] loading KLUE-RoBERTa-base...")
    teacher = TeacherWithQKV(KLUE_MODEL, device=device)
    teacher_hidden = teacher.model.config.hidden_size
    teacher_vocab = teacher.model.config.vocab_size
    print(f"[teacher] hidden={teacher_hidden} vocab={teacher_vocab} student_hidden={cfg.hidden_size}")
    assert tok.mode == "hf", "distill requires the HF tokenizer to match the KLUE teacher"
    assert len(tok) == teacher_vocab, (
        f"tokenizer vocab {len(tok)} != teacher vocab {teacher_vocab}"
    )
    assert teacher_hidden % args.num_relation_heads == 0, (
        f"teacher hidden {teacher_hidden} not divisible by R={args.num_relation_heads}"
    )
    assert cfg.hidden_size % args.num_relation_heads == 0, (
        f"student hidden {cfg.hidden_size} not divisible by R={args.num_relation_heads}"
    )

    g = torch.Generator()
    g.manual_seed(args.seed)
    collate = partial(collate_corpus, pad_id=tok.pad_id)
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate,
        num_workers=args.num_workers,
        worker_init_fn=seed_worker,
        generator=g,
        pin_memory=(device == "cuda"),
        drop_last=True,
    )
    print(f"[loader] {len(loader)} steps per epoch (batch={args.batch_size})")

    no_decay = [p for n, p in student.named_parameters() if p.ndim <= 1]
    decay = [p for n, p in student.named_parameters() if p.ndim > 1]
    param_groups = [
        {"params": decay, "lr": args.lr, "weight_decay": args.weight_decay},
        {"params": no_decay, "lr": args.lr, "weight_decay": 0.0},
    ]
    optimizer = AdamW(param_groups, betas=cfg.adam_betas, eps=cfg.adam_eps)
    num_warmup = max(1, int(args.max_steps * args.warmup_ratio))
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_factory(num_warmup, args.max_steps))
    print(f"[optim] lr={args.lr} warmup={num_warmup} total={args.max_steps}")

    wandb_run = None
    if args.wandb_project:
        import wandb
        wandb_run = wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.wandb_run_name,
            config={**asdict(cfg), **vars(args), "stage": "distill", "student_params": n_student},
        )
        print(f"[wandb] logging to {wandb_run.url}")

    history = []
    step = 0
    t0 = time.time()
    running_loss = 0.0
    n_running = 0

    while step < args.max_steps:
        for batch in loader:
            if step >= args.max_steps:
                break
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attn = batch["attention_mask"].to(device, non_blocking=True)

            q_t, k_t, v_t = teacher.forward(input_ids, attn)
            _, q_s, k_s, v_s = student(input_ids, attn, return_last_qkv=True)

            loss = relation_loss(
                q_t, k_t, v_t, q_s, k_s, v_s, attn,
                num_relation_heads=args.num_relation_heads,
            )

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), cfg.grad_clip)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            n_running += 1
            step += 1

            if step % args.log_every == 0:
                elapsed = time.time() - t0
                avg = running_loss / max(1, n_running)
                lr0 = scheduler.get_last_lr()[0]
                rate = step / max(elapsed, 1e-6)
                print(
                    f"step {step}/{args.max_steps} | loss {loss.item():.4f} avg {avg:.4f} | "
                    f"lr {lr0:.2e} | {elapsed:.0f}s ({rate:.1f} step/s)"
                )
                history.append({"step": step, "loss": loss.item(), "avg_loss": avg, "lr": lr0})
                if wandb_run is not None:
                    wandb_run.log(
                        {"train/loss": loss.item(), "train/avg_loss": avg, "train/lr": lr0, "train/step_per_sec": rate},
                        step=step,
                    )
                running_loss = 0.0
                n_running = 0
                (output_dir / "metrics.json").write_text(
                    json.dumps(history, indent=2, ensure_ascii=False)
                )

            if step % args.save_every == 0 or step == args.max_steps:
                ckpt = {
                    "model_state_dict": student.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "step": step,
                    "config": asdict(cfg),
                    "tokenizer_mode": tok.mode,
                    "tokenizer_vocab_size": len(tok),
                    "tokenizer_pad_id": tok.pad_id,
                }
                torch.save(ckpt, output_dir / "latest.pt")
                print(f"[save] latest.pt @ step {step}")

    final_ckpt = {
        "model_state_dict": student.state_dict(),
        "step": step,
        "config": asdict(cfg),
        "tokenizer_mode": tok.mode,
        "tokenizer_vocab_size": len(tok),
        "tokenizer_pad_id": tok.pad_id,
    }
    torch.save(final_ckpt, output_dir / "final.pt")
    print(f"[done] saved final.pt to {output_dir / 'final.pt'}")

    if wandb_run is not None:
        wandb_run.finish()


if __name__ == "__main__":
    main()
