"""Generate V2 blendshapes for the curated scenarios in viewer_e2e/.

For each scenario, for each turn:
    audio → V2 ONNX (LAM + E2F int8) with GT emotion → (T, 52) AnimaSync ordering

Concatenates per-turn outputs, writes `<sid>_v2_dataset.json` next to the
existing `_dataset` and `_pred_dataset` entries so the comparison viewer
can A/B V2 vs V3 on the same audio.

Usage:
    PYTHONPATH=. python3 -m models.v3_face.infer_v2_compare --all-viewer
"""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import List, Optional

import librosa
import numpy as np
import onnxruntime as ort

from scripts.compiler.constants import ARKIT_52_NAMES
from scripts.compiler.data_pipeline import lookup_audio_for_scenario

from .infer_e2e import find_scenario

# V2 ONNX + audio feature extractor live in repos outside this package.
sys.path.insert(0, "/dataset/text-to-face-se/LAM_Audio2Expression")
from distillation.student_model import AudioFeatureExtractor  # noqa: E402
from scripts.compiler.abc_experiment import run_v2  # noqa: E402

# Override the V2 ONNX path used in V3's data pipeline (older int8 distill) with
# the streaming variant — closer in time to the missing v8_brow09 production
# build. Pass --onnx to override at the CLI.
DEFAULT_V2_ONNX = "/dataset/mead-expression-training/e2f/distill/emotion_face_int8.onnx"

PROJECT_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_VIEWER_DIR = PROJECT_ROOT / "data" / "viewer_e2e"
DEFAULT_AUDIO_DIR = PROJECT_ROOT / "data" / "audio_preview"
DEFAULT_EMOTION_DIR = PROJECT_ROOT / "data" / "emotion"


def predict_v2(sid: str, sess, feat, args) -> Optional[dict]:
    scen = find_scenario(sid, args.emotion_dir)
    if scen is None:
        print(f"  ✗ {sid}: scenario not found"); return None

    audio_paths = lookup_audio_for_scenario(scen, args.audio_dir)

    bs_concat: List[np.ndarray] = []
    valid_turns: List[dict] = []
    for ti, (turn, ap) in enumerate(zip(scen["turns"], audio_paths)):
        text = turn.get("text", "").strip()
        if not text or ap is None or not ap.exists():
            continue
        wav, sr = librosa.load(str(ap), sr=16000, mono=True)
        if len(wav) < 16000 * 0.1:
            continue
        v2_bs = run_v2(sess, feat, wav, turn.get("emotion", "neutral"))
        bs_concat.append(v2_bs.astype(np.float32))
        valid_turns.append(turn)
    if not bs_concat:
        print(f"  ✗ {sid}: no valid turns"); return None

    bs = np.concatenate(bs_concat, axis=0)

    new_base = f"{sid}_v2_dataset"
    turns_meta = []
    for ti, t in enumerate(valid_turns):
        turns_meta.append({
            "turn_idx": ti,
            "emotion": t.get("emotion", "neutral"),
            "vad": list(t.get("vad", [0, 0, 0])),
            "text": t.get("text", ""),
            "speaker": t.get("speaker", ""),
        })
    viewer_json = {
        "scenario_id": new_base,
        "fps": 30,
        "num_frames": int(bs.shape[0]),
        "names": ARKIT_52_NAMES,
        "turns": turns_meta,
        "blendshapes": np.round(bs, 4).tolist(),
    }
    out_json = args.viewer_dir / f"{new_base}.json"
    out_json.write_text(json.dumps(viewer_json, ensure_ascii=False))

    teacher_mp3 = args.viewer_dir / f"{sid}_dataset.mp3"
    pred_mp3 = args.viewer_dir / f"{new_base}.mp3"
    if pred_mp3.exists() or pred_mp3.is_symlink():
        pred_mp3.unlink()
    if teacher_mp3.exists():
        pred_mp3.symlink_to(teacher_mp3.name)

    print(f"  ✓ {sid}  frames={bs.shape[0]}  turns={len(valid_turns)}")
    return {"sid": sid, "new_base": new_base}


def update_manifest(viewer_dir: Path, predictions: List[dict]):
    manifest_path = viewer_dir / "manifest.json"
    manifest = (json.loads(manifest_path.read_text())
                if manifest_path.exists() else {"scenarios": []})
    seen = {s.get("base"): s for s in manifest["scenarios"]}
    for p in predictions:
        new_base = p["new_base"]
        teacher_entry = seen.get(f"{p['sid']}_dataset", {})
        seen[new_base] = {
            "base": new_base,
            "scenario_id": new_base,
            "variants": [],
            "n_turns": teacher_entry.get("n_turns", 1),
            "emotions": teacher_entry.get("emotions", []),
            "text_preview": "[V2] " + teacher_entry.get("text_preview", ""),
        }
    manifest["scenarios"] = list(seen.values())
    manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--audio_dir", type=Path, default=DEFAULT_AUDIO_DIR)
    ap.add_argument("--viewer_dir", type=Path, default=DEFAULT_VIEWER_DIR)
    ap.add_argument("--emotion_dir", type=Path, default=DEFAULT_EMOTION_DIR)
    ap.add_argument("-s", "--scenarios", nargs="+", default=None)
    ap.add_argument("--all-viewer", action="store_true")
    ap.add_argument("--onnx", type=str, default=DEFAULT_V2_ONNX,
                    help=f"V2 ONNX path. Default: {DEFAULT_V2_ONNX}")
    args = ap.parse_args()

    if not args.scenarios and not args.all_viewer:
        ap.error("provide --scenarios or --all-viewer")

    if args.all_viewer:
        sids = []
        for p in sorted(args.viewer_dir.glob("*_dataset.json")):
            stem = p.stem
            if any(x in stem for x in ("_pred_dataset", "_e2e_dataset",
                                        "_v2_dataset")):
                continue
            sids.append(stem[: -len("_dataset")])
        print(f"all-viewer: {len(sids)} scenarios")
    else:
        sids = args.scenarios

    print(f"Loading V2 ONNX: {args.onnx}")
    sess = ort.InferenceSession(args.onnx, providers=["CPUExecutionProvider"])
    feat = AudioFeatureExtractor()
    print(f"V2 ready ✓\n")

    predictions = []
    for sid in sids:
        r = predict_v2(sid, sess, feat, args)
        if r:
            predictions.append(r)

    if predictions:
        update_manifest(args.viewer_dir, predictions)
    print(f"\nDone. {len(predictions)} V2 outputs written to {args.viewer_dir}")


if __name__ == "__main__":
    main()
