"""Generate sklearn classification_report + confusion matrix for a trained
teacher checkpoint on a given split (val by default).

Usage:
    python -m models.microalbert.report \
        --ckpt checkpoints/klue_teacher_v3_snap_llrd/best.pt \
        --split data/emotion/seed_val.jsonl
"""
from __future__ import annotations

import argparse
from functools import partial
from pathlib import Path

import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader

from .config import EMOTION_LABELS, MicroAlbertConfig
from .dataset import SeedEmotionDataset, collate_fn
from .teacher import KlueTeacherForEmotionVAD
from .tokenizer import MicroTokenizer


@torch.no_grad()
def collect_preds(model, loader, device):
    model.eval()
    preds, trues = [], []
    for batch in loader:
        out = model(batch["input_ids"].to(device), batch["attention_mask"].to(device))
        preds.append(out["emotion_logits"].argmax(dim=-1).cpu().numpy())
        trues.append(batch["emotion_id"].numpy())
    return np.concatenate(trues), np.concatenate(preds)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=Path, required=True)
    ap.add_argument("--split", type=Path, default=Path("data/emotion/seed_val.jsonl"))
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--model_name", default="klue/roberta-base")
    ap.add_argument("--context_window", type=int, default=0,
                    help="Must match training config of the checkpoint being evaluated")
    args = ap.parse_args()

    ckpt = torch.load(args.ckpt, map_location=args.device, weights_only=False)
    cfg = MicroAlbertConfig(**ckpt["config"]) if isinstance(ckpt.get("config"), dict) else MicroAlbertConfig()

    tok = MicroTokenizer.build(
        save_dir=args.ckpt.parent / "tokenizer",
        train_jsonl=args.split,
        max_len=cfg.max_seq_len,
        add_speaker_tokens=(args.context_window > 0),
    )

    model = KlueTeacherForEmotionVAD(
        model_name=ckpt.get("model_name", args.model_name),
        num_emotions=cfg.num_emotions,
        vad_dim=cfg.vad_dim,
        dropout=cfg.dropout,
        vad_head_hidden=cfg.vad_head_hidden,
        attention_dropout=cfg.attention_dropout,
    ).to(args.device)
    if len(tok) != model.backbone.config.vocab_size:
        model.backbone.resize_token_embeddings(len(tok))
    model.load_state_dict(ckpt["model_state_dict"])

    ds = SeedEmotionDataset(args.split, tok, cfg.max_seq_len, context_window=args.context_window)
    loader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        collate_fn=partial(collate_fn, pad_id=tok.pad_id),
    )

    y_true, y_pred = collect_preds(model, loader, args.device)
    labels = list(range(cfg.num_emotions))
    target_names = list(EMOTION_LABELS)

    print(f"\n=== classification_report ({args.split.name}, n={len(y_true)}) ===")
    print(classification_report(
        y_true, y_pred,
        labels=labels, target_names=target_names,
        digits=3, zero_division=0,
    ))

    print("\n=== confusion matrix (rows=true, cols=pred) ===")
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    name_w = max(len(n) for n in target_names)
    header = " " * (name_w + 2) + " ".join(f"{i:>4d}" for i in labels)
    print(header)
    for i, row in enumerate(cm):
        print(f"{target_names[i]:<{name_w}} |" + " ".join(f"{v:>4d}" for v in row))

    print("\n=== top confusions (true → pred, count) ===")
    off = []
    for i in labels:
        for j in labels:
            if i != j and cm[i, j] > 0:
                off.append((cm[i, j], target_names[i], target_names[j]))
    off.sort(reverse=True)
    for cnt, t, p in off[:15]:
        print(f"  {t:<12s} → {p:<12s}  {cnt}")


if __name__ == "__main__":
    main()
