"""Dump label-noise suspects: samples where model is confident in a different
class than the true label. Outputs a JSONL review file with scenario_id,
turn_index, text, old label, model prediction, and confidences.

Usage:
    python -m models.microalbert.report_suspects \
        --ckpt checkpoints/klue_teacher_v3_snap_llrd/best.pt \
        --split data/emotion/seed_val.jsonl \
        --out data/emotion/suspects_val.jsonl \
        --min_pred_conf 0.5 --max_true_conf 0.15
"""
from __future__ import annotations

import argparse
import json
from functools import partial
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

from .config import EMOTION_LABELS, EMOTION_TO_ID, MicroAlbertConfig
from .dataset import SeedEmotionDataset, collate_fn
from .teacher import KlueTeacherForEmotionVAD
from .tokenizer import MicroTokenizer


@torch.no_grad()
def collect(model, loader, device):
    model.eval()
    probs_all = []
    for batch in loader:
        out = model(batch["input_ids"].to(device), batch["attention_mask"].to(device))
        probs_all.append(torch.softmax(out["emotion_logits"], dim=-1).cpu().numpy())
    return np.concatenate(probs_all, axis=0)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=Path, required=True)
    ap.add_argument("--split", type=Path, default=Path("data/emotion/seed_val.jsonl"))
    ap.add_argument("--out", type=Path, required=True)
    ap.add_argument("--min_pred_conf", type=float, default=0.5)
    ap.add_argument("--max_true_conf", type=float, default=0.15)
    ap.add_argument("--only_true", nargs="*", default=None,
                    help="Restrict to samples whose *true* label is one of these.")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--context_window", type=int, default=0,
                    help="Must match training config of the checkpoint being evaluated")
    args = ap.parse_args()

    ckpt = torch.load(args.ckpt, map_location=args.device, weights_only=False)
    cfg = MicroAlbertConfig(**ckpt["config"]) if isinstance(ckpt.get("config"), dict) else MicroAlbertConfig()

    tok = MicroTokenizer.build(
        save_dir=args.ckpt.parent / "tokenizer",
        train_jsonl=args.split,
        max_len=cfg.max_seq_len,
        add_speaker_tokens=(args.context_window > 0),
    )
    model = KlueTeacherForEmotionVAD(
        model_name=ckpt.get("model_name", "klue/roberta-base"),
        num_emotions=cfg.num_emotions,
        vad_dim=cfg.vad_dim,
        dropout=cfg.dropout,
        vad_head_hidden=cfg.vad_head_hidden,
        attention_dropout=cfg.attention_dropout,
    ).to(args.device)
    if len(tok) != model.backbone.config.vocab_size:
        model.backbone.resize_token_embeddings(len(tok))
    model.load_state_dict(ckpt["model_state_dict"])

    ds = SeedEmotionDataset(args.split, tok, cfg.max_seq_len, context_window=args.context_window)
    loader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        collate_fn=partial(collate_fn, pad_id=tok.pad_id),
    )
    probs = collect(model, loader, args.device)

    # Re-read raw file to recover scenario_id + turn_index in the exact
    # iteration order SeedEmotionDataset uses.
    sample_meta = []
    with open(args.split, encoding="utf-8") as f:
        for line in f:
            row = json.loads(line)
            sid = row["scenario_id"]
            for ti, t in enumerate(row["turns"]):
                if not t["text"].strip():
                    continue
                sample_meta.append({
                    "scenario_id": sid,
                    "turn_index": ti,
                    "text": t["text"].strip(),
                    "emotion": t["emotion"],
                    "vad": t["vad"],
                })

    assert len(sample_meta) == len(probs), \
        f"mismatch: meta={len(sample_meta)} probs={len(probs)}"

    only_true = set(args.only_true) if args.only_true else None
    suspects = []
    for i, meta in enumerate(sample_meta):
        eid_true = EMOTION_TO_ID[meta["emotion"]]
        if only_true and meta["emotion"] not in only_true:
            continue
        pred = int(probs[i].argmax())
        pred_conf = float(probs[i][pred])
        true_conf = float(probs[i][eid_true])
        if pred == eid_true:
            continue
        if pred_conf < args.min_pred_conf:
            continue
        if true_conf > args.max_true_conf:
            continue
        top3 = np.argsort(-probs[i])[:3].tolist()
        suspects.append({
            "scenario_id": meta["scenario_id"],
            "turn_index": meta["turn_index"],
            "text": meta["text"],
            "old": meta["emotion"],
            "model_pred": EMOTION_LABELS[pred],
            "pred_conf": round(pred_conf, 3),
            "true_conf": round(true_conf, 3),
            "top3": [
                {"emo": EMOTION_LABELS[j], "p": round(float(probs[i][j]), 3)}
                for j in top3
            ],
            "vad": meta["vad"],
        })

    # Sort by descending pred_conf (most confident suspects first)
    suspects.sort(key=lambda x: -x["pred_conf"])

    args.out.parent.mkdir(parents=True, exist_ok=True)
    with args.out.open("w", encoding="utf-8") as f:
        for s in suspects:
            f.write(json.dumps(s, ensure_ascii=False) + "\n")

    print(f"[suspects] total={len(suspects)}  out={args.out}")
    # Summary by true class
    from collections import Counter
    by_old = Counter(s["old"] for s in suspects)
    for emo, n in by_old.most_common():
        print(f"  {emo:<12s} {n}")


if __name__ == "__main__":
    main()
