from __future__ import annotations

import argparse
import json
import time
from dataclasses import asdict, replace
from functools import partial
from pathlib import Path
from typing import List

import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

import re

from .config import EMOTION_LABELS, MicroAlbertConfig, NUM_EMOTIONS
from .dataset import SeedEmotionDataset, collate_fn
from .eval import evaluate
from .losses import MultitaskLoss, build_data_centroid_anchors, compute_class_weights
from .teacher import KlueTeacherForEmotionVAD
from .tokenizer import MicroTokenizer
from .train import lr_lambda_factory, seed_worker, set_seed

KLUE_MODEL = "klue/roberta-base"


def build_teacher_param_groups(
    model: KlueTeacherForEmotionVAD, cfg: MicroAlbertConfig
) -> List[dict]:
    head_keys = ("emotion_head", "vad_head", "pooler")
    decay = {"backbone": [], "head": []}
    no_decay = {"backbone": [], "head": []}
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        bucket = "head" if any(k in name for k in head_keys) else "backbone"
        if param.ndim <= 1:
            no_decay[bucket].append(param)
        else:
            decay[bucket].append(param)
    return [
        {"params": decay["backbone"], "lr": cfg.backbone_lr, "weight_decay": cfg.weight_decay},
        {"params": no_decay["backbone"], "lr": cfg.backbone_lr, "weight_decay": 0.0},
        {"params": decay["head"], "lr": cfg.head_lr, "weight_decay": cfg.weight_decay},
        {"params": no_decay["head"], "lr": cfg.head_lr, "weight_decay": 0.0},
    ]


def build_llrd_param_groups(
    model: KlueTeacherForEmotionVAD, cfg: MicroAlbertConfig, decay: float = 0.9
) -> List[dict]:
    """Layer-wise LR Decay for KLUE-RoBERTa fine-tuning.

    Lower layers learn slowly (preserve pretrained linguistic knowledge),
    upper layers learn faster (task-specific adaptation). Standard recipe
    for fine-tuning pretrained transformers on small datasets.
    """
    head_keys = ("emotion_head", "vad_head", "pooler")
    n_layers = model.backbone.config.num_hidden_layers  # 12 for KLUE-RoBERTa-base
    layer_re = re.compile(r"\.encoder\.layer\.(\d+)\.")

    groups: List[dict] = []
    layer_lrs: Dict[int, float] = {}  # for logging
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        is_head = any(k in name for k in head_keys)
        wd = 0.0 if param.ndim <= 1 else cfg.weight_decay
        if is_head:
            lr = cfg.head_lr
        else:
            m = layer_re.search(name)
            if m is not None:
                depth = int(m.group(1))  # 0..11
            elif "embeddings" in name:
                depth = -1  # treat embeddings as layer "below" 0
            else:
                # backbone bits not in encoder.layer or embeddings (rare):
                # treat as topmost backbone layer (highest LR)
                depth = n_layers - 1
            lr = cfg.backbone_lr * (decay ** (n_layers - depth))
            layer_lrs[depth] = lr
        groups.append({"params": [param], "lr": lr, "weight_decay": wd})

    print(f"[llrd] decay={decay} num_layers={n_layers}")
    for depth in sorted(layer_lrs):
        label = "embeddings" if depth == -1 else f"layer{depth:>2}"
        print(f"[llrd]   {label}: lr={layer_lrs[depth]:.3e}")
    print(f"[llrd]   head: lr={cfg.head_lr:.3e}")
    return groups


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument("--train_jsonl", type=Path, default=Path("data/emotion/seed_train_final.jsonl"))
    ap.add_argument("--val_jsonl", type=Path, default=Path("data/emotion/seed_val.jsonl"))
    ap.add_argument("--test_jsonl", type=Path, default=Path("data/emotion/seed_test.jsonl"))
    ap.add_argument("--output_dir", type=Path, default=Path("checkpoints/klue_teacher"))
    ap.add_argument("--epochs", type=int, default=12)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--max_seq_len", type=int, default=128)
    ap.add_argument("--backbone_lr", type=float, default=2e-5)
    ap.add_argument("--head_lr", type=float, default=1e-4)
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--log_every", type=int, default=25)
    ap.add_argument("--wandb_project", type=str, default=None)
    ap.add_argument("--wandb_run_name", type=str, default=None)
    ap.add_argument("--wandb_entity", type=str, default=None)
    ap.add_argument(
        "--context_window",
        type=int,
        default=0,
        help="Prev turns to concat with current (0=baseline; 2=recommended MVP)",
    )
    return ap.parse_args()


def main() -> None:
    args = parse_args()
    set_seed(args.seed)
    device = args.device
    output_dir: Path = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    use_ctx = args.context_window > 0
    tok = MicroTokenizer.build(
        save_dir=output_dir / "tokenizer",
        train_jsonl=args.train_jsonl,
        max_len=args.max_seq_len,
        prefer_hf=True,
        add_speaker_tokens=use_ctx,
    )
    assert tok.mode == "hf", "teacher training requires HF tokenizer to match KLUE-RoBERTa"
    print(f"[tokenizer] mode={tok.mode} vocab={len(tok)} pad_id={tok.pad_id} "
          f"context_window={args.context_window}")

    cfg = MicroAlbertConfig()
    cfg = replace(
        cfg,
        vocab_size=len(tok),
        pad_token_id=tok.pad_id,
        max_seq_len=args.max_seq_len,
        backbone_lr=args.backbone_lr,
        head_lr=args.head_lr,
        seed=args.seed,
    )

    train_ds = SeedEmotionDataset(args.train_jsonl, tok, args.max_seq_len, context_window=args.context_window)
    val_ds = SeedEmotionDataset(args.val_jsonl, tok, args.max_seq_len, context_window=args.context_window)
    test_ds = SeedEmotionDataset(args.test_jsonl, tok, args.max_seq_len, context_window=args.context_window)
    assert len(train_ds) > 0 and len(val_ds) > 0, "empty dataset after loading"
    print(f"[data] train={len(train_ds)} val={len(val_ds)} test={len(test_ds)}")

    cw = compute_class_weights(
        np.array(train_ds.emotion_ids()), NUM_EMOTIONS, cfg.class_weight_clip
    )
    print(
        f"[class_weights] min={float(cw.min()):.3f} max={float(cw.max()):.3f} "
        f"mean={float(cw.mean()):.3f}"
    )

    g = torch.Generator()
    g.manual_seed(args.seed)
    collate = partial(collate_fn, pad_id=tok.pad_id)
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate,
        num_workers=args.num_workers,
        worker_init_fn=seed_worker,
        generator=g,
        pin_memory=(device == "cuda"),
        drop_last=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate,
        num_workers=args.num_workers,
        pin_memory=(device == "cuda"),
    )

    print(f"[teacher] loading {KLUE_MODEL}...")
    model = KlueTeacherForEmotionVAD(
        model_name=KLUE_MODEL,
        num_emotions=NUM_EMOTIONS,
        vad_dim=3,
        dropout=cfg.dropout,
        vad_head_hidden=cfg.vad_head_hidden,
        attention_dropout=cfg.attention_dropout,
    ).to(device)
    if len(tok) != model.backbone.config.vocab_size:
        old = model.backbone.config.vocab_size
        model.backbone.resize_token_embeddings(len(tok))
        print(f"[teacher] resized token embeddings {old} -> {len(tok)} for speaker tokens")
    n_params = model.num_params()
    print(f"[teacher] params={n_params:,}")

    # Use empirical class centroids for snap loss (replaces hand-curated JSON anchors,
    # which were 0.25-0.40 from data means → snap loss was pulling in wrong direction).
    anchors = build_data_centroid_anchors(args.train_jsonl, NUM_EMOTIONS)
    print("[anchors] using data centroids from train split:")
    for i, name in enumerate(EMOTION_LABELS):
        v, a, d = anchors[i].tolist()
        print(f"  {name:11s}: V={v:+.3f} A={a:+.3f} D={d:+.3f}")
    loss_fn = MultitaskLoss(cfg, class_weights=cw, anchors=anchors).to(device)

    param_groups = build_llrd_param_groups(model, cfg, decay=0.9)
    optimizer = AdamW(param_groups, betas=cfg.adam_betas, eps=cfg.adam_eps)
    num_total = args.epochs * len(train_loader)
    num_warmup = max(1, int(num_total * cfg.warmup_ratio))
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_factory(num_warmup, num_total))
    print(f"[train] total_steps={num_total} warmup_steps={num_warmup}")

    wandb_run = None
    if args.wandb_project:
        import wandb
        wandb_run = wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.wandb_run_name,
            config={
                **asdict(cfg),
                **{k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
                "stage": "train_teacher",
                "n_params": n_params,
            },
        )
        print(f"[wandb] logging to {wandb_run.url}")

    history = []
    best_score = -float("inf")
    best_epoch = -1
    global_step = 0

    for epoch in range(args.epochs):
        model.train()
        t0 = time.time()
        running_ce = 0.0
        running_total = 0.0
        running_emo_acc = 0.0
        running_vad_mae = 0.0
        running_vad_mae_v = 0.0
        running_vad_mae_a = 0.0
        running_vad_mae_d = 0.0
        n_steps = 0
        for step, batch in enumerate(train_loader):
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attn = batch["attention_mask"].to(device, non_blocking=True)
            emo = batch["emotion_id"].to(device, non_blocking=True)
            vad = batch["vad"].to(device, non_blocking=True)

            out = model(input_ids, attn)
            total, comps, aux = loss_fn(
                out["emotion_logits"], out["vad"], emo, vad, epoch=epoch
            )

            optimizer.zero_grad(set_to_none=True)
            total.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            optimizer.step()
            scheduler.step()

            with torch.no_grad():
                emo_pred = out["emotion_logits"].argmax(dim=-1)
                batch_emo_acc = (emo_pred == emo).float().mean().item()
                vad_diff = (out["vad"].detach() - vad).abs()
                batch_vad_mae_v = vad_diff[:, 0].mean().item()
                batch_vad_mae_a = vad_diff[:, 1].mean().item()
                batch_vad_mae_d = vad_diff[:, 2].mean().item()
                batch_vad_mae = vad_diff.mean().item()

            running_ce += comps["ce"]
            running_total += float(total.detach())
            running_emo_acc += batch_emo_acc
            running_vad_mae += batch_vad_mae
            running_vad_mae_v += batch_vad_mae_v
            running_vad_mae_a += batch_vad_mae_a
            running_vad_mae_d += batch_vad_mae_d
            n_steps += 1
            global_step += 1
            if step % args.log_every == 0:
                lrs = scheduler.get_last_lr()
                lr_bb = lrs[0]
                lr_hd = lrs[2] if len(lrs) > 2 else lrs[0]
                print(
                    f"epoch {epoch+1}/{args.epochs} | step {step+1}/{len(train_loader)} | "
                    f"loss {float(total.detach()):.4f} "
                    f"(ce={comps['ce']:.4f} vad={comps['vad']:.4f} "
                    f"snap={comps['snap']:.4f} rng={comps['range']:.4f}) | "
                    f"emo_acc={batch_emo_acc:.3f} "
                    f"mae=({batch_vad_mae_v:.3f},{batch_vad_mae_a:.3f},{batch_vad_mae_d:.3f}) | "
                    f"w_vad={aux['w_vad']:.2f} w_snap={aux['w_snap']:.2f} "
                    f"gate={aux['gate_rate']:.2f} | lr_bb={lr_bb:.2e} lr_hd={lr_hd:.2e}"
                )
                if wandb_run is not None:
                    wandb_run.log(
                        {
                            "train/loss": float(total.detach()),
                            "train/ce": comps["ce"],
                            "train/vad": comps["vad"],
                            "train/snap": comps["snap"],
                            "train/range": comps["range"],
                            "train/emo_acc": batch_emo_acc,
                            "train/vad_mae": batch_vad_mae,
                            "train/vad_mae_v": batch_vad_mae_v,
                            "train/vad_mae_a": batch_vad_mae_a,
                            "train/vad_mae_d": batch_vad_mae_d,
                            "train/w_vad": aux["w_vad"],
                            "train/w_snap": aux["w_snap"],
                            "train/gate_rate": aux["gate_rate"],
                            "train/lr_backbone": lr_bb,
                            "train/lr_head": lr_hd,
                            "train/epoch": epoch + 1,
                        },
                        step=global_step,
                    )

        avg_ce = running_ce / max(1, n_steps)
        avg_total = running_total / max(1, n_steps)
        avg_emo_acc = running_emo_acc / max(1, n_steps)
        avg_vad_mae = running_vad_mae / max(1, n_steps)
        avg_vad_mae_v = running_vad_mae_v / max(1, n_steps)
        avg_vad_mae_a = running_vad_mae_a / max(1, n_steps)
        avg_vad_mae_d = running_vad_mae_d / max(1, n_steps)

        val_metrics = evaluate(model, val_loader, device)
        elapsed = time.time() - t0
        score = (
            0.5 * val_metrics["macro_f1"]
            + 0.5 * (1.0 - val_metrics["vad_mae_mean"] / 2.0)
        )

        mae = val_metrics["vad_mae"]
        r = val_metrics["vad_pearson_r"]
        print(
            f"[ep{epoch+1} train] emo_acc={avg_emo_acc:.4f} "
            f"vad_mae=({avg_vad_mae_v:.3f},{avg_vad_mae_a:.3f},{avg_vad_mae_d:.3f}) "
            f"ce={avg_ce:.4f} total={avg_total:.4f}"
        )
        print(
            f"[ep{epoch+1} val]   macro_f1={val_metrics['macro_f1']:.4f} "
            f"vad_mae=({mae[0]:.3f},{mae[1]:.3f},{mae[2]:.3f}) "
            f"vad_r=({r[0]:.3f},{r[1]:.3f},{r[2]:.3f}) "
            f"pred_classes={val_metrics['pred_class_count']}/16 "
            f"dom={val_metrics['dominant_ratio']:.2f} "
            f"score={score:.4f} time={elapsed:.1f}s"
        )

        history.append(
            {
                "epoch": epoch + 1,
                "train_ce": avg_ce,
                "train_total": avg_total,
                "train_emo_acc": avg_emo_acc,
                "train_vad_mae": avg_vad_mae,
                "train_vad_mae_v": avg_vad_mae_v,
                "train_vad_mae_a": avg_vad_mae_a,
                "train_vad_mae_d": avg_vad_mae_d,
                "score": score,
                **val_metrics,
            }
        )
        (output_dir / "metrics.json").write_text(
            json.dumps(history, indent=2, ensure_ascii=False)
        )

        if wandb_run is not None:
            wandb_run.log(
                {
                    "val/macro_f1": val_metrics["macro_f1"],
                    "val/vad_mae_mean": val_metrics["vad_mae_mean"],
                    "val/vad_mae_v": val_metrics["vad_mae"][0],
                    "val/vad_mae_a": val_metrics["vad_mae"][1],
                    "val/vad_mae_d": val_metrics["vad_mae"][2],
                    "val/vad_r_v": val_metrics["vad_pearson_r"][0],
                    "val/vad_r_a": val_metrics["vad_pearson_r"][1],
                    "val/vad_r_d": val_metrics["vad_pearson_r"][2],
                    "val/pred_class_count": val_metrics["pred_class_count"],
                    "val/dominant_ratio": val_metrics["dominant_ratio"],
                    "val/score": score,
                    "val/epoch": epoch + 1,
                    "epoch_avg/train_ce": avg_ce,
                    "epoch_avg/train_total": avg_total,
                    "epoch_avg/train_emo_acc": avg_emo_acc,
                    "epoch_avg/train_vad_mae": avg_vad_mae,
                    "epoch_avg/train_vad_mae_v": avg_vad_mae_v,
                    "epoch_avg/train_vad_mae_a": avg_vad_mae_a,
                    "epoch_avg/train_vad_mae_d": avg_vad_mae_d,
                },
                step=global_step,
            )

        ckpt = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch": epoch + 1,
            "config": asdict(cfg),
            "class_weights": cw.tolist(),
            "tokenizer_mode": tok.mode,
            "tokenizer_vocab_size": len(tok),
            "tokenizer_pad_id": tok.pad_id,
            "best_score": best_score,
            "metrics": val_metrics,
            "seed": args.seed,
            "model_name": KLUE_MODEL,
        }
        torch.save(ckpt, output_dir / "latest.pt")
        if score > best_score:
            best_score = score
            best_epoch = epoch + 1
            torch.save(ckpt, output_dir / "best.pt")

    print(f"done. best score={best_score:.4f} @ epoch {best_epoch}")

    if wandb_run is not None:
        wandb_run.summary["best_score"] = best_score
        wandb_run.summary["best_epoch"] = best_epoch
        wandb_run.finish()


if __name__ == "__main__":
    main()
