"""V3 end-to-end pipeline test:

    text  →  KlueTeacher  →  (emotion argmax, VAD)  →  cond
    audio →  mel
                                            ↓
                          V3 face model + locked post-processing
                                            ↓
                            (T, 52) ARKit blendshapes
                                            ↓
                         <sid>_e2e_dataset.json (viewer)

Writes alongside the existing `_dataset` (teacher GT) and `_pred_dataset`
(V3 with GT cond) entries so the viewer dropdown shows all three for
each scenario — lets us A/B which step degrades.

Usage:
    PYTHONPATH=. python3 -m models.v3_face.infer_e2e \
        --ckpt models/v3_face/checkpoints/best_expression.pt \
        -s long_001 long_100 long_046

    PYTHONPATH=. python3 -m models.v3_face.infer_e2e --all-viewer
"""
from __future__ import annotations

import argparse
import json
import re
from pathlib import Path
from typing import Dict, List, Optional

import librosa
import numpy as np
import torch
from scipy.ndimage import gaussian_filter1d
from transformers import AutoTokenizer

from scripts.compiler.constants import ARKIT_52_NAMES
from scripts.compiler.data_pipeline import (
    EMOTION_LABELS, EMOTION_TO_IDX, FPS, mel_features,
    lookup_audio_for_scenario,
)
from models.microalbert.teacher import KlueTeacherForEmotionVAD

from .infer import (
    crisp_mouth, smooth_brows, inject_blinks, load_model,
)

PROJECT_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_TEACHER_CKPT = (PROJECT_ROOT / "checkpoints" /
                        "klue_teacher_clean_ctx2" / "best.pt")
DEFAULT_TOKENIZER_DIR = (PROJECT_ROOT / "checkpoints" /
                         "klue_teacher_clean_ctx2" / "tokenizer")
DEFAULT_AUDIO_DIR = PROJECT_ROOT / "data" / "audio_preview"
DEFAULT_VIEWER_DIR = PROJECT_ROOT / "data" / "viewer_e2e"
DEFAULT_EMOTION_DIR = PROJECT_ROOT / "data" / "emotion"
DEFAULT_CKPT = (PROJECT_ROOT / "models" / "v3_face" /
                "checkpoints" / "best_expression.pt")

SPLIT_RE = re.compile(r"^(daily_.+)_t(\d+)$")


def load_teacher(ckpt_path: Path, tokenizer_dir: Path, device: torch.device):
    """Load the KLUE-RoBERTa emotion+VAD teacher and its tokenizer.

    The teacher was trained with 2 extra special tokens added to the base
    klue/roberta-base vocab (32000 → 32002). We load the tokenizer first to
    determine the actual vocab size, then resize the model's embeddings
    before loading the saved state_dict.
    """
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    cfg = ckpt.get("config", {}) if isinstance(ckpt, dict) else {}
    # Saved tokenizer dir is intentionally empty in this checkpoint set —
    # training rebuilds the tokenizer at runtime by adding [SELF]/[OTHER]
    # as additional_special_tokens to klue/roberta-base (vocab 32000 → 32002).
    # See models/microalbert/tokenizer.py:98.
    tokenizer = AutoTokenizer.from_pretrained(
        cfg.get("model_name", "klue/roberta-base")
    )
    tokenizer.add_special_tokens(
        {"additional_special_tokens": ["[SELF]", "[OTHER]"]}
    )
    teacher = KlueTeacherForEmotionVAD(
        model_name=cfg.get("model_name", "klue/roberta-base"),
        num_emotions=cfg.get("num_emotions", 16),
        vad_dim=cfg.get("vad_dim", 3),
    )
    # Resize embeddings to match tokenizer vocab (training-time addition).
    if len(tokenizer) != teacher.backbone.config.vocab_size:
        teacher.backbone.resize_token_embeddings(len(tokenizer))
    teacher = teacher.to(device).eval()

    state = ckpt.get("model", ckpt.get("model_state_dict", ckpt)) \
        if isinstance(ckpt, dict) else ckpt
    state = {k.replace("module.", "").replace("model.", "", 1): v
             for k, v in state.items()}
    missing, unexpected = teacher.load_state_dict(state, strict=False)
    if missing or unexpected:
        print(f"[teacher] state_dict load: {len(missing)} missing, "
              f"{len(unexpected)} unexpected (first 3 missing: {missing[:3]})")
    print(f"[teacher] loaded {teacher.num_params()/1e6:.1f}M params from {ckpt_path} "
          f"(vocab={len(tokenizer)})")
    return teacher, tokenizer


def find_scenario(sid: str, emotion_dir: Path) -> Optional[dict]:
    """Locate a scenario across the three split JSONL files. Returns the
    parsed dict; for daily_*_t<i> splits, synthesizes a single-turn scenario
    so audio lookup reroutes via _source_* fields."""
    for fname in ("seed_train_final.jsonl", "seed_val.jsonl", "seed_test.jsonl"):
        p = emotion_dir / fname
        if not p.exists():
            continue
        with p.open() as f:
            for line in f:
                row = json.loads(line)
                if row["scenario_id"] == sid:
                    return row
        m = SPLIT_RE.match(sid)
        if m:
            with p.open() as f:
                for line in f:
                    row = json.loads(line)
                    if row["scenario_id"] == m.group(1):
                        ti = int(m.group(2))
                        return {
                            "scenario_id": sid,
                            "_source_scenario_id": m.group(1),
                            "_source_turn_indices": [ti],
                            "turns": [row["turns"][ti]],
                        }
    return None


@torch.no_grad()
def teacher_predict(teacher, tokenizer, texts: List[str], device, max_length=128):
    """Run KlueTeacher on a batch of turn texts. Returns (argmax_emotions,
    full_probs (N,16), vads (N,3))."""
    enc = tokenizer(texts, padding=True, truncation=True,
                    max_length=max_length, return_tensors="pt").to(device)
    out = teacher(enc["input_ids"], enc["attention_mask"])
    logits = out["emotion_logits"]
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    pred_idx = logits.argmax(dim=-1).cpu().numpy()
    pred_emotions = [EMOTION_LABELS[i] for i in pred_idx]
    pred_vads = out["vad"].cpu().numpy()
    return pred_emotions, probs, pred_vads


def build_cond_smoothed(per_turn_emotions: List[str],
                        per_turn_vads: np.ndarray,
                        turn_Ts: List[int],
                        vad_smooth_sigma: float = 30.0) -> np.ndarray:
    """Build per-frame (T, 19) cond from teacher per-turn predictions.
    Mirrors data_pipeline.py: hard one-hot emotion + 3-D VAD per turn, then
    cross-turn Gaussian smoothing (σ=30 frames) on both. V3 was trained on
    this exact shape of cond signal, so feeding it the same shape (just
    with teacher predictions instead of GT labels) is the apples-to-apples
    pipeline test.
    """
    n_total = sum(turn_Ts)
    all_emos = np.zeros((n_total, 16), dtype=np.float32)
    all_vads = np.zeros((n_total, 3), dtype=np.float32)
    cursor = 0
    for emo, vad, T in zip(per_turn_emotions, per_turn_vads, turn_Ts):
        idx = EMOTION_TO_IDX.get(emo, 0)
        all_emos[cursor:cursor + T, idx] = 1.0
        all_vads[cursor:cursor + T] = np.asarray(vad, dtype=np.float32)
        cursor += T
    if vad_smooth_sigma > 0 and len(turn_Ts) > 1:
        all_vads = gaussian_filter1d(all_vads, sigma=vad_smooth_sigma,
                                     axis=0, mode="nearest").astype(np.float32)
        all_emos = gaussian_filter1d(all_emos, sigma=vad_smooth_sigma,
                                     axis=0, mode="nearest").astype(np.float32)
    return np.concatenate([all_emos, all_vads], axis=-1).astype(np.float32)


def predict_and_save(sid: str, model, cfg, teacher, tokenizer, device,
                     args) -> Optional[dict]:
    scen = find_scenario(sid, args.emotion_dir)
    if scen is None:
        print(f"  ✗ {sid}: scenario not found"); return None

    audio_paths = lookup_audio_for_scenario(scen, args.audio_dir)

    turn_mels: List[np.ndarray] = []
    turn_Ts: List[int] = []
    valid_turns: List[dict] = []
    for ti, (turn, ap) in enumerate(zip(scen["turns"], audio_paths)):
        text = turn.get("text", "").strip()
        if not text or ap is None or not ap.exists():
            continue
        wav, sr = librosa.load(str(ap), sr=16000, mono=True)
        if len(wav) < 16000 * 0.1:
            continue
        mel = mel_features(wav, sr=sr, fps=FPS)
        turn_mels.append(mel)
        turn_Ts.append(mel.shape[0])
        valid_turns.append(turn)
    if not valid_turns:
        print(f"  ✗ {sid}: no valid turns"); return None

    texts = [t["text"] for t in valid_turns]
    pred_emos, pred_probs, pred_vads = teacher_predict(
        teacher, tokenizer, texts, device
    )
    cond = build_cond_smoothed(pred_emos, pred_vads, turn_Ts,
                               vad_smooth_sigma=args.vad_smooth_sigma)
    audio = np.concatenate(turn_mels, axis=0).astype(np.float32)

    audio_t = torch.from_numpy(audio).unsqueeze(0).to(device)
    cond_t = torch.from_numpy(cond).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(audio_t, cond_t).squeeze(0).cpu().numpy().astype(np.float32)

    if args.crisp:
        pred = crisp_mouth(pred, threshold=args.crisp_threshold,
                           scale=args.crisp_scale,
                           pre_smooth_sigma=args.crisp_sigma,
                           mouth_close_sigma=args.crisp_mouthclose_sigma)
    if args.brow_smooth:
        pred = smooth_brows(pred, min_cutoff=args.brow_min_cutoff,
                            beta=args.brow_beta, d_cutoff=args.brow_d_cutoff)
    if args.add_blinks:
        pred = inject_blinks(pred, scenario_id=sid,
                             mean_interval_s=args.blink_interval,
                             expressive_cap=args.blink_expressive_cap)

    # Per-turn meta — record both teacher prediction and GT so the viewer
    # (or any postmortem script) can show the disagreement.
    turns_meta = []
    for ti, (t, emo, vad, probs) in enumerate(
            zip(valid_turns, pred_emos, pred_vads, pred_probs)):
        top3 = [
            {"label": EMOTION_LABELS[i], "prob": float(probs[i])}
            for i in np.argsort(probs)[-3:][::-1]
        ]
        turns_meta.append({
            "turn_idx": ti,
            "emotion": emo,                         # teacher's argmax
            "gt_emotion": t.get("emotion", "neutral"),
            "vad": vad.tolist(),                    # teacher's VAD head
            "gt_vad": list(t.get("vad", [0, 0, 0])),
            "text": t.get("text", ""),
            "speaker": t.get("speaker", ""),
            "top3_emotions": top3,
        })

    new_base = f"{sid}_e2e_dataset"
    viewer_json = {
        "scenario_id": new_base,
        "fps": 30,
        "num_frames": int(pred.shape[0]),
        "names": ARKIT_52_NAMES,
        "turns": turns_meta,
        "blendshapes": np.round(pred, 4).tolist(),
    }
    out_json = args.viewer_dir / f"{new_base}.json"
    out_json.write_text(json.dumps(viewer_json, ensure_ascii=False))

    teacher_mp3 = args.viewer_dir / f"{sid}_dataset.mp3"
    pred_mp3 = args.viewer_dir / f"{new_base}.mp3"
    if pred_mp3.exists() or pred_mp3.is_symlink():
        pred_mp3.unlink()
    if teacher_mp3.exists():
        pred_mp3.symlink_to(teacher_mp3.name)
    else:
        print(f"  ! {sid}: teacher mp3 missing; e2e will play silent.")

    matches = sum(1 for t, emo in zip(valid_turns, pred_emos)
                  if t.get("emotion", "neutral") == emo)
    print(f"  ✓ {sid}  emotion match {matches}/{len(valid_turns)}  "
          f"frames={pred.shape[0]}")
    return {"sid": sid, "new_base": new_base,
            "matches": matches, "total": len(valid_turns)}


def update_manifest(viewer_dir: Path, predictions: List[dict]):
    manifest_path = viewer_dir / "manifest.json"
    manifest = (json.loads(manifest_path.read_text())
                if manifest_path.exists() else {"scenarios": []})
    seen = {s.get("base"): s for s in manifest["scenarios"]}
    for p in predictions:
        new_base = p["new_base"]
        teacher_entry = seen.get(f"{p['sid']}_dataset", {})
        seen[new_base] = {
            "base": new_base,
            "scenario_id": new_base,
            "variants": [],
            "n_turns": teacher_entry.get("n_turns", 1),
            "emotions": teacher_entry.get("emotions", []),
            "text_preview": "[V3 e2e] " + teacher_entry.get("text_preview", ""),
        }
    manifest["scenarios"] = list(seen.values())
    manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=Path, default=DEFAULT_CKPT)
    ap.add_argument("--teacher_ckpt", type=Path, default=DEFAULT_TEACHER_CKPT)
    ap.add_argument("--tokenizer_dir", type=Path, default=DEFAULT_TOKENIZER_DIR)
    ap.add_argument("--audio_dir", type=Path, default=DEFAULT_AUDIO_DIR)
    ap.add_argument("--viewer_dir", type=Path, default=DEFAULT_VIEWER_DIR)
    ap.add_argument("--emotion_dir", type=Path, default=DEFAULT_EMOTION_DIR)
    ap.add_argument("--device", default="cuda:0")
    ap.add_argument("-s", "--scenarios", nargs="+", default=None)
    ap.add_argument("--all-viewer", action="store_true",
                    help="Run E2E on every scenario currently curated in "
                         "data/viewer/. Preserves existing _dataset and "
                         "_pred_dataset entries; only adds _e2e_dataset.")
    ap.add_argument("--vad-smooth-sigma", type=float, default=30.0,
                    help="Cross-turn Gaussian σ. Default 30 (matches data_pipeline).")
    # Post-processing — same locked-in defaults as the user's preferred config.
    ap.add_argument("--crisp", action=argparse.BooleanOptionalAction, default=True)
    ap.add_argument("--crisp-threshold", type=float, default=0.3)
    ap.add_argument("--crisp-scale", type=float, default=1.0)
    ap.add_argument("--crisp-sigma", type=float, default=1.3)
    ap.add_argument("--crisp-mouthclose-sigma", type=float, default=1.0)
    ap.add_argument("--brow-smooth", action=argparse.BooleanOptionalAction, default=True)
    ap.add_argument("--brow-min-cutoff", type=float, default=2.0)
    ap.add_argument("--brow-beta", type=float, default=0.01)
    ap.add_argument("--brow-d-cutoff", type=float, default=1.0)
    ap.add_argument("--add-blinks", action=argparse.BooleanOptionalAction, default=True)
    ap.add_argument("--blink-interval", type=float, default=6.0)
    ap.add_argument("--blink-expressive-cap", type=float, default=0.5)
    args = ap.parse_args()

    if not args.scenarios and not args.all_viewer:
        ap.error("provide --scenarios or --all-viewer")

    if args.all_viewer:
        sids = []
        for p in sorted(args.viewer_dir.glob("*_dataset.json")):
            stem = p.stem
            if "_pred_dataset" in stem or "_e2e_dataset" in stem:
                continue
            sids.append(stem[: -len("_dataset")])
        print(f"all-viewer: {len(sids)} scenarios curated")
    else:
        sids = args.scenarios

    device = torch.device(args.device)
    print(f"Loading V3 from {args.ckpt}")
    model, cfg = load_model(args.ckpt, device)
    print(f"Loading KlueTeacher from {args.teacher_ckpt}")
    teacher, tokenizer = load_teacher(args.teacher_ckpt, args.tokenizer_dir, device)

    print(f"\nPipeline: text → KlueTeacher → V3 → blendshapes")
    print(f"Predicting {len(sids)} scenarios → {args.viewer_dir}\n")
    predictions = []
    total_m = 0
    total_t = 0
    for sid in sids:
        r = predict_and_save(sid, model, cfg, teacher, tokenizer, device, args)
        if r:
            predictions.append(r)
            total_m += r["matches"]
            total_t += r["total"]

    if predictions:
        update_manifest(args.viewer_dir, predictions)
        print(f"\n── teacher argmax accuracy ──")
        print(f"  matches GT label: {total_m}/{total_t}  "
              f"({100 * total_m / max(total_t, 1):.1f}%)")
    print(f"\nDone. Each scenario now appears up to 3× in viewer dropdown:")
    print(f"  <sid>_dataset       teacher GT")
    print(f"  <sid>_pred_dataset  V3 with GT cond (argmax of teacher's training labels)")
    print(f"  <sid>_e2e_dataset   V3 with KlueTeacher-predicted cond")


if __name__ == "__main__":
    main()
