"""Dump N actual text examples per target emotion with true vs predicted labels.

Usage:
    python -m models.microalbert.report_examples \
        --ckpt checkpoints/klue_teacher_v3_snap_llrd/best.pt \
        --split data/emotion/seed_val.jsonl \
        --emotions sadness agreement struggle fluster shy \
        --n 10
"""
from __future__ import annotations

import argparse
import json
from functools import partial
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

from .config import EMOTION_LABELS, EMOTION_TO_ID, MicroAlbertConfig
from .dataset import SeedEmotionDataset, collate_fn
from .teacher import KlueTeacherForEmotionVAD
from .tokenizer import MicroTokenizer


@torch.no_grad()
def collect(model, loader, device):
    model.eval()
    preds, probs = [], []
    for batch in loader:
        out = model(batch["input_ids"].to(device), batch["attention_mask"].to(device))
        logits = out["emotion_logits"]
        p = torch.softmax(logits, dim=-1).cpu().numpy()
        preds.append(logits.argmax(dim=-1).cpu().numpy())
        probs.append(p)
    return np.concatenate(preds), np.concatenate(probs, axis=0)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=Path, required=True)
    ap.add_argument("--split", type=Path, default=Path("data/emotion/seed_val.jsonl"))
    ap.add_argument("--emotions", nargs="+", required=True)
    ap.add_argument("--n", type=int, default=10)
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--batch_size", type=int, default=32)
    args = ap.parse_args()

    ckpt = torch.load(args.ckpt, map_location=args.device, weights_only=False)
    cfg = MicroAlbertConfig(**ckpt["config"]) if isinstance(ckpt.get("config"), dict) else MicroAlbertConfig()

    tok = MicroTokenizer.build(
        save_dir=args.ckpt.parent / "tokenizer",
        train_jsonl=args.split,
        max_len=cfg.max_seq_len,
    )
    model = KlueTeacherForEmotionVAD(
        model_name=ckpt.get("model_name", "klue/roberta-base"),
        num_emotions=cfg.num_emotions,
        vad_dim=cfg.vad_dim,
        dropout=cfg.dropout,
        vad_head_hidden=cfg.vad_head_hidden,
        attention_dropout=cfg.attention_dropout,
    ).to(args.device)
    model.load_state_dict(ckpt["model_state_dict"])

    ds = SeedEmotionDataset(args.split, tok, cfg.max_seq_len)
    loader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        collate_fn=partial(collate_fn, pad_id=tok.pad_id),
    )
    preds, probs = collect(model, loader, args.device)

    # samples are (text, emo_id, vad_tuple)
    texts = [s[0] for s in ds.samples]
    trues = np.array([s[1] for s in ds.samples])

    for emo in args.emotions:
        if emo not in EMOTION_TO_ID:
            print(f"[skip] unknown emotion: {emo}")
            continue
        eid = EMOTION_TO_ID[emo]
        idx = np.where(trues == eid)[0]
        # stable order, first N
        idx = idx[: args.n]
        n_corr = int((preds[idx] == eid).sum())
        print(f"\n=== {emo} (true=id{eid}, support={len(np.where(trues == eid)[0])}, "
              f"showing {len(idx)}, correct={n_corr}/{len(idx)}) ===")
        for i in idx:
            p_id = int(preds[i])
            p_name = EMOTION_LABELS[p_id]
            mark = "OK " if p_id == eid else "MIS"
            conf = probs[i, p_id]
            true_conf = probs[i, eid]
            text = texts[i].replace("\n", " ")
            if len(text) > 120:
                text = text[:117] + "..."
            print(f"  [{mark}] pred={p_name:<12s} p={conf:.2f}  "
                  f"(true_p={true_conf:.2f})  | {text}")


if __name__ == "__main__":
    main()
