from __future__ import annotations

import argparse
import json
import os
import random
import time
from dataclasses import asdict, replace
from functools import partial
from pathlib import Path
from typing import List

import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from .config import EMOTION_LABELS, MicroAlbertConfig, NUM_EMOTIONS
from .dataset import SeedEmotionDataset, collate_fn
from .eval import evaluate
from .losses import MultitaskLoss, build_data_centroid_anchors, compute_class_weights
from .model import MicroAlbertForEmotionVAD
from .tokenizer import MicroTokenizer


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # SDPA has no deterministic CUDA kernel in torch 2.4; accept non-bitwise reproducibility.
    torch.use_deterministic_algorithms(True, warn_only=True)
    os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")


def seed_worker(worker_id: int) -> None:
    s = torch.initial_seed() % (2 ** 32)
    np.random.seed(s)
    random.seed(s)


def build_param_groups(model: MicroAlbertForEmotionVAD, cfg: MicroAlbertConfig) -> List[dict]:
    head_keys = ("emotion_head", "vad_head")
    decay = {"backbone": [], "head": []}
    no_decay = {"backbone": [], "head": []}
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        bucket = "head" if any(k in name for k in head_keys) else "backbone"
        # Weight decay excluded for 1-D params (LayerNorm weights, biases).
        if param.ndim <= 1:
            no_decay[bucket].append(param)
        else:
            decay[bucket].append(param)
    return [
        {"params": decay["backbone"], "lr": cfg.backbone_lr, "weight_decay": cfg.weight_decay},
        {"params": no_decay["backbone"], "lr": cfg.backbone_lr, "weight_decay": 0.0},
        {"params": decay["head"], "lr": cfg.head_lr, "weight_decay": cfg.weight_decay},
        {"params": no_decay["head"], "lr": cfg.head_lr, "weight_decay": 0.0},
    ]


def lr_lambda_factory(num_warmup: int, num_total: int):
    def fn(step: int) -> float:
        if step < num_warmup:
            return max(1e-8, step / max(1, num_warmup))
        remaining = max(0, num_total - step)
        return max(0.0, remaining / max(1, num_total - num_warmup))
    return fn


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument("--train_jsonl", type=Path, default=Path("data/emotion/seed_train_final.jsonl"))
    ap.add_argument("--val_jsonl", type=Path, default=Path("data/emotion/seed_val.jsonl"))
    ap.add_argument("--test_jsonl", type=Path, default=Path("data/emotion/seed_test.jsonl"))
    ap.add_argument("--output_dir", type=Path, default=Path("checkpoints/microalbert_smoke"))
    ap.add_argument("--epochs", type=int, default=12)
    ap.add_argument("--batch_size", type=int, default=128)
    ap.add_argument("--max_seq_len", type=int, default=128)
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--prefer_hf_tokenizer", dest="prefer_hf_tokenizer", action="store_true")
    ap.add_argument("--no_hf_tokenizer", dest="prefer_hf_tokenizer", action="store_false")
    ap.set_defaults(prefer_hf_tokenizer=True)
    ap.add_argument("--log_every", type=int, default=25)
    ap.add_argument("--init_from", type=Path, default=None, help="Distilled backbone checkpoint to warm-start from")
    ap.add_argument("--wandb_project", type=str, default=None, help="W&B project name (enables logging)")
    ap.add_argument("--wandb_run_name", type=str, default=None)
    ap.add_argument("--wandb_entity", type=str, default=None)
    return ap.parse_args()


def assert_no_leakage(train_ds: SeedEmotionDataset, val_ds: SeedEmotionDataset, test_ds: SeedEmotionDataset) -> None:
    val_leak = train_ds.source_ids & val_ds.source_ids
    test_leak = train_ds.source_ids & test_ds.source_ids
    assert not val_leak, f"paraphrase leak train↔val: {sorted(val_leak)[:5]}..."
    assert not test_leak, f"paraphrase leak train↔test: {sorted(test_leak)[:5]}..."


def main() -> None:
    args = parse_args()
    set_seed(args.seed)
    device = args.device
    output_dir: Path = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    tok = MicroTokenizer.build(
        save_dir=output_dir / "tokenizer",
        train_jsonl=args.train_jsonl,
        max_len=args.max_seq_len,
        prefer_hf=args.prefer_hf_tokenizer,
    )
    print(
        f"[tokenizer] mode={tok.mode} vocab={len(tok)} pad_id={tok.pad_id} "
        f"cls_id={tok.cls_id} sep_id={tok.sep_id}"
    )

    cfg = MicroAlbertConfig()
    cfg = replace(
        cfg,
        vocab_size=len(tok),
        pad_token_id=tok.pad_id,
        max_seq_len=args.max_seq_len,
        seed=args.seed,
    )

    train_ds = SeedEmotionDataset(args.train_jsonl, tok, args.max_seq_len)
    val_ds = SeedEmotionDataset(args.val_jsonl, tok, args.max_seq_len)
    test_ds = SeedEmotionDataset(args.test_jsonl, tok, args.max_seq_len)
    assert len(train_ds) > 0 and len(val_ds) > 0, "empty dataset after loading"
    assert_no_leakage(train_ds, val_ds, test_ds)
    print(f"[data] train={len(train_ds)} val={len(val_ds)} test={len(test_ds)}")

    cw = compute_class_weights(
        np.array(train_ds.emotion_ids()), NUM_EMOTIONS, cfg.class_weight_clip
    )
    print(
        f"[class_weights] min={float(cw.min()):.3f} max={float(cw.max()):.3f} "
        f"mean={float(cw.mean()):.3f}"
    )

    g = torch.Generator()
    g.manual_seed(args.seed)
    collate = partial(collate_fn, pad_id=tok.pad_id)
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate,
        num_workers=args.num_workers,
        worker_init_fn=seed_worker,
        generator=g,
        pin_memory=(device == "cuda"),
        drop_last=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate,
        num_workers=args.num_workers,
        pin_memory=(device == "cuda"),
    )

    model = MicroAlbertForEmotionVAD(cfg).to(device)

    if args.init_from is not None:
        ckpt = torch.load(args.init_from, map_location=device)
        state = ckpt.get("model_state_dict", ckpt)
        if not any(k.startswith("backbone.") for k in state.keys()):
            state = {f"backbone.{k}": v for k, v in state.items()}
        missing, unexpected = model.load_state_dict(state, strict=False)
        loaded = len(state) - len(unexpected)
        print(f"[init_from] loaded {loaded} keys from {args.init_from}; missing_in_ckpt={len(missing)} unexpected={len(unexpected)}")

    n_params = model.num_params()
    print(f"[model] params={n_params:,}")
    assert 3_500_000 <= n_params <= 12_000_000, (
        f"param count {n_params:,} outside [3.5M, 12M]; arch drift"
    )

    anchors = build_data_centroid_anchors(args.train_jsonl, NUM_EMOTIONS)
    print("[anchors] using data centroids from train split:")
    for i, name in enumerate(EMOTION_LABELS):
        v, a, d = anchors[i].tolist()
        print(f"  {name:11s}: V={v:+.3f} A={a:+.3f} D={d:+.3f}")
    loss_fn = MultitaskLoss(cfg, class_weights=cw, anchors=anchors).to(device)

    param_groups = build_param_groups(model, cfg)
    optimizer = AdamW(param_groups, betas=cfg.adam_betas, eps=cfg.adam_eps)
    num_total = args.epochs * len(train_loader)
    num_warmup = max(1, int(num_total * cfg.warmup_ratio))
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_factory(num_warmup, num_total))
    print(f"[train] total_steps={num_total} warmup_steps={num_warmup}")

    wandb_run = None
    if args.wandb_project:
        import wandb
        wandb_run = wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.wandb_run_name,
            config={**asdict(cfg), **{k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}, "stage": "finetune", "n_params": n_params},
        )
        print(f"[wandb] logging to {wandb_run.url}")

    history = []
    best_score = -float("inf")
    best_epoch = -1
    collapse_streak = 0
    global_step = 0

    for epoch in range(args.epochs):
        model.train()
        t0 = time.time()
        running_ce = 0.0
        running_total = 0.0
        running_emo_acc = 0.0
        running_vad_mae = 0.0
        running_vad_mae_v = 0.0
        running_vad_mae_a = 0.0
        running_vad_mae_d = 0.0
        n_steps = 0
        for step, batch in enumerate(train_loader):
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attn = batch["attention_mask"].to(device, non_blocking=True)
            emo = batch["emotion_id"].to(device, non_blocking=True)
            vad = batch["vad"].to(device, non_blocking=True)

            out = model(input_ids, attn)
            total, comps, aux = loss_fn(
                out["emotion_logits"], out["vad"], emo, vad, epoch=epoch
            )

            optimizer.zero_grad(set_to_none=True)
            total.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            optimizer.step()
            scheduler.step()

            with torch.no_grad():
                emo_pred = out["emotion_logits"].argmax(dim=-1)
                batch_emo_acc = (emo_pred == emo).float().mean().item()
                vad_diff = (out["vad"].detach() - vad).abs()
                batch_vad_mae_v = vad_diff[:, 0].mean().item()
                batch_vad_mae_a = vad_diff[:, 1].mean().item()
                batch_vad_mae_d = vad_diff[:, 2].mean().item()
                batch_vad_mae = vad_diff.mean().item()

            running_ce += comps["ce"]
            running_total += float(total.detach())
            running_emo_acc += batch_emo_acc
            running_vad_mae += batch_vad_mae
            running_vad_mae_v += batch_vad_mae_v
            running_vad_mae_a += batch_vad_mae_a
            running_vad_mae_d += batch_vad_mae_d
            n_steps += 1
            global_step += 1
            if step % args.log_every == 0:
                lr0 = scheduler.get_last_lr()[0]
                print(
                    f"epoch {epoch+1}/{args.epochs} | step {step+1}/{len(train_loader)} | "
                    f"loss {float(total.detach()):.4f} "
                    f"(ce={comps['ce']:.4f} vad={comps['vad']:.4f} "
                    f"snap={comps['snap']:.4f} rng={comps['range']:.4f}) | "
                    f"emo_acc={batch_emo_acc:.3f} "
                    f"mae=({batch_vad_mae_v:.3f},{batch_vad_mae_a:.3f},{batch_vad_mae_d:.3f}) | "
                    f"w_vad={aux['w_vad']:.2f} w_snap={aux['w_snap']:.2f} "
                    f"gate={aux['gate_rate']:.2f} | lr={lr0:.2e}"
                )
                if wandb_run is not None:
                    wandb_run.log(
                        {
                            "train/loss": float(total.detach()),
                            "train/ce": comps["ce"],
                            "train/vad": comps["vad"],
                            "train/snap": comps["snap"],
                            "train/range": comps["range"],
                            "train/emo_acc": batch_emo_acc,
                            "train/vad_mae": batch_vad_mae,
                            "train/vad_mae_v": batch_vad_mae_v,
                            "train/vad_mae_a": batch_vad_mae_a,
                            "train/vad_mae_d": batch_vad_mae_d,
                            "train/w_vad": aux["w_vad"],
                            "train/w_snap": aux["w_snap"],
                            "train/gate_rate": aux["gate_rate"],
                            "train/lr": lr0,
                            "train/epoch": epoch + 1,
                        },
                        step=global_step,
                    )
        avg_ce = running_ce / max(1, n_steps)
        avg_total = running_total / max(1, n_steps)
        avg_emo_acc = running_emo_acc / max(1, n_steps)
        avg_vad_mae = running_vad_mae / max(1, n_steps)
        avg_vad_mae_v = running_vad_mae_v / max(1, n_steps)
        avg_vad_mae_a = running_vad_mae_a / max(1, n_steps)
        avg_vad_mae_d = running_vad_mae_d / max(1, n_steps)

        val_metrics = evaluate(model, val_loader, device)
        elapsed = time.time() - t0
        score = (
            0.5 * val_metrics["macro_f1"]
            + 0.5 * (1.0 - val_metrics["vad_mae_mean"] / 2.0)
        )

        dominant_ratio = val_metrics["dominant_ratio"]
        vad_std = val_metrics["vad_std"]
        collapsed = dominant_ratio > 0.50 or any(s < 0.10 for s in vad_std)
        post_warmup = epoch >= cfg.snap_start_epoch + 2
        if collapsed and post_warmup:
            collapse_streak += 1
        else:
            collapse_streak = 0

        warn = " [COLLAPSE_WARN]" if collapsed else ""
        mae = val_metrics["vad_mae"]
        r = val_metrics["vad_pearson_r"]
        print(
            f"[ep{epoch+1} train] emo_acc={avg_emo_acc:.4f} "
            f"vad_mae=({avg_vad_mae_v:.3f},{avg_vad_mae_a:.3f},{avg_vad_mae_d:.3f}) "
            f"ce={avg_ce:.4f} total={avg_total:.4f}"
        )
        print(
            f"[ep{epoch+1} val]   macro_f1={val_metrics['macro_f1']:.4f} "
            f"vad_mae=({mae[0]:.3f},{mae[1]:.3f},{mae[2]:.3f}) "
            f"vad_r=({r[0]:.3f},{r[1]:.3f},{r[2]:.3f}) "
            f"pred_classes={val_metrics['pred_class_count']}/16 "
            f"dom={dominant_ratio:.2f} "
            f"vad_std=({vad_std[0]:.2f},{vad_std[1]:.2f},{vad_std[2]:.2f}) "
            f"score={score:.4f} time={elapsed:.1f}s{warn}"
        )

        history.append(
            {
                "epoch": epoch + 1,
                "train_ce": avg_ce,
                "train_total": avg_total,
                "train_emo_acc": avg_emo_acc,
                "train_vad_mae": avg_vad_mae,
                "train_vad_mae_v": avg_vad_mae_v,
                "train_vad_mae_a": avg_vad_mae_a,
                "train_vad_mae_d": avg_vad_mae_d,
                "score": score,
                "collapsed_flag": collapsed,
                **val_metrics,
            }
        )
        (output_dir / "metrics.json").write_text(
            json.dumps(history, indent=2, ensure_ascii=False)
        )

        if wandb_run is not None:
            wandb_run.log(
                {
                    "val/macro_f1": val_metrics["macro_f1"],
                    "val/vad_mae_mean": val_metrics["vad_mae_mean"],
                    "val/vad_mae_v": val_metrics["vad_mae"][0],
                    "val/vad_mae_a": val_metrics["vad_mae"][1],
                    "val/vad_mae_d": val_metrics["vad_mae"][2],
                    "val/vad_r_v": val_metrics["vad_pearson_r"][0],
                    "val/vad_r_a": val_metrics["vad_pearson_r"][1],
                    "val/vad_r_d": val_metrics["vad_pearson_r"][2],
                    "val/pred_class_count": val_metrics["pred_class_count"],
                    "val/dominant_ratio": val_metrics["dominant_ratio"],
                    "val/score": score,
                    "val/epoch": epoch + 1,
                    "val/collapsed_flag": int(collapsed),
                    "epoch_avg/train_ce": avg_ce,
                    "epoch_avg/train_total": avg_total,
                    "epoch_avg/train_emo_acc": avg_emo_acc,
                    "epoch_avg/train_vad_mae": avg_vad_mae,
                    "epoch_avg/train_vad_mae_v": avg_vad_mae_v,
                    "epoch_avg/train_vad_mae_a": avg_vad_mae_a,
                    "epoch_avg/train_vad_mae_d": avg_vad_mae_d,
                },
                step=global_step,
            )

        ckpt = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch": epoch + 1,
            "config": asdict(cfg),
            "class_weights": cw.tolist(),
            "tokenizer_mode": tok.mode,
            "tokenizer_vocab_size": len(tok),
            "tokenizer_pad_id": tok.pad_id,
            "best_score": best_score,
            "metrics": val_metrics,
            "seed": args.seed,
        }
        torch.save(ckpt, output_dir / "latest.pt")
        if score > best_score:
            best_score = score
            best_epoch = epoch + 1
            torch.save(ckpt, output_dir / "best.pt")

        if collapse_streak >= 2:
            print("[abort] 2 consecutive collapse epochs after warmup — stopping early")
            break

    print(f"done. best score={best_score:.4f} @ epoch {best_epoch}")

    if history:
        final = history[-1]
        gates = {
            "emotion_macro_f1>=0.35": final["macro_f1"] >= 0.35,
            "vad_mae_per_dim<=0.25": all(m <= 0.25 for m in final["vad_mae"]),
        }
        all_pass = all(gates.values())
        print(f"[gates] {gates} -> {'PASS' if all_pass else 'FAIL'}")

    if wandb_run is not None:
        wandb_run.summary["best_score"] = best_score
        wandb_run.summary["best_epoch"] = best_epoch
        wandb_run.finish()


if __name__ == "__main__":
    main()
