#!/usr/bin/env python3
"""Convert dataset .npz files (from data_pipeline.py output) into viewer
format so they can be loaded in tools/blendshape-player.html for visual QA.

NEW DEFAULT (2026-05): output goes to `data/viewer_dataset/` (the FULL
archive of every baked .npz). The active player folder `data/viewer/` is
populated separately by `scripts/curate_viewer.py`, which copies a curated
subset (all long_, a sample of solo_/daily-split) so the player dropdown
stays uncluttered.

For each .npz, builds:
  - <viewer_dir>/<sid>_dataset_{A,B,C}.json  — three identical files (player
    expects A/B/C variants; for dataset preview they all carry the same
    target trajectory — dataset has no variant split)
  - <viewer_dir>/<sid>_dataset_ABC.mp3       — concatenated turn audio
  - <viewer_dir>/<sid>_dataset_{A,B,C}.mp3   — symlinks to the above
  - manifest.json entry under base "<sid>_dataset"

NOTE: data_pipeline.py bakes brow + eyeSquint tremor into the .npz targets
(and variant B also bakes V2 prosodic motion). When previewing in
blendshape-player.html, all runtime overlays auto-disable for `_dataset_`
scenarios — what you see IS the .npz.

Usage:
    PYTHONPATH=. python3 scripts/dataset_to_viewer.py             # convert all in v3_training/
    PYTHONPATH=. python3 scripts/dataset_to_viewer.py -s long_001 long_046    # specific
"""
from __future__ import annotations

import argparse
import json
import re
import subprocess
from pathlib import Path

import numpy as np

# Pattern: per-turn dialogue split. Matches "daily_NNN_tK" and also
# "daily_NNN_pM_tK" (paraphrased dialogues split per turn). Captures the
# parent scenario_id and turn index so we can look up the original.
_SPLIT_RE = re.compile(r"^(daily_.+)_t(\d+)$")

PROJECT_ROOT = Path(__file__).resolve().parents[1]


def find_audio_for_turns(scen: dict, audio_dir: Path) -> list[Path]:
    # For per-turn dialogue splits, the audio is named after the original
    # (sid, turn_idx) — reroute lookup so the new scenario_id `{sid}_t{ti}`
    # doesn't try to find a non-existent `{sid}_t{ti}_t0_emotion.mp3` file.
    #
    # CAUTION: paraphrase scenarios (long_001_p0, long_001_p1) also carry a
    # `source_scenario_id` field pointing to the original (long_001), but that
    # field on paraphrases marks LINEAGE only — their audio files are named
    # after the paraphrase id (long_001_p0_t0_*.mp3). We only reroute the
    # audio lookup when `source_turn_indices` is ALSO present, which is the
    # actual split-dialogue marker (paraphrases don't have it).
    src_tis = scen.get("source_turn_indices")
    sid = scen["source_scenario_id"] if src_tis is not None else scen["scenario_id"]
    out = []
    for local_ti, turn in enumerate(scen["turns"]):
        if not turn.get("text", "").strip():
            continue
        actual_ti = src_tis[local_ti] if src_tis is not None else local_ti
        emo = turn.get("emotion", "neutral")
        exact = audio_dir / f"{sid}_t{actual_ti}_{emo}.mp3"
        if exact.exists() and exact.stat().st_size > 1000:
            out.append(exact); continue
        matches = [m for m in audio_dir.glob(f"{sid}_t{actual_ti}_*.mp3")
                   if m.stat().st_size > 1000]
        if matches:
            out.append(matches[0])
    return out


def concat_audio(paths: list[Path], out_mp3: Path) -> None:
    list_file = out_mp3.with_suffix(".concat.txt")
    with list_file.open("w") as f:
        for p in paths:
            f.write(f"file '{p.resolve()}'\n")
    subprocess.run([
        "ffmpeg", "-y", "-loglevel", "error",
        "-f", "concat", "-safe", "0", "-i", str(list_file),
        "-c:a", "libmp3lame", "-b:a", "128k", str(out_mp3),
    ], check=True)
    list_file.unlink()


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--npz_dir", type=Path,
                    default=PROJECT_ROOT / "data/v3_training")
    ap.add_argument("--scenarios_jsonl", type=Path,
                    default=PROJECT_ROOT / "data/emotion/seed_train_final.jsonl")
    ap.add_argument("--audio_dir", type=Path,
                    default=PROJECT_ROOT / "data/audio_preview")
    ap.add_argument("--viewer_dir", type=Path,
                    default=PROJECT_ROOT / "data/viewer_dataset",
                    help="Output dir for the full archive. Curate into "
                         "data/viewer/ via scripts/curate_viewer.py.")
    ap.add_argument("--suffix", default="_dataset")
    ap.add_argument("-s", "--scenarios", nargs="*", default=None,
                    help="Specific scenario IDs to convert. Default: all .npz in dir.")
    args = ap.parse_args()

    from scripts.compiler.constants import ARKIT_52_NAMES

    # Build scenario lookup from train/val/test splits + dialogue-split sidecar
    # (written by data_pipeline.py --split-dialogues).
    scenario_lookup = {}
    for jsonl_name in ("seed_train_final.jsonl", "seed_val.jsonl",
                        "seed_test.jsonl", "seed_split_dialogues.jsonl"):
        p = args.scenarios_jsonl.parent / jsonl_name
        if not p.exists():
            continue
        with p.open() as f:
            for line in f:
                s = json.loads(line)
                scenario_lookup[s["scenario_id"]] = s

    # Determine scenarios to process
    if args.scenarios:
        sids = args.scenarios
    else:
        sids = sorted(p.stem for p in args.npz_dir.glob("*.npz"))
    print(f"converting {len(sids)} scenarios")

    # Ensure output dir exists (new default `data/viewer_dataset/` won't on
    # first run).
    args.viewer_dir.mkdir(parents=True, exist_ok=True)

    # Load manifest
    manifest_fp = args.viewer_dir / "manifest.json"
    manifest = json.loads(manifest_fp.read_text()) if manifest_fp.exists() \
        else {"scenarios": []}

    def resolve_scenario(sid: str) -> dict | None:
        """Resolve a .npz stem to its scenario metadata. Direct hit in the
        loaded JSONLs first; on miss, try to synthesize a per-turn dialogue
        split by stripping a `_tK` suffix and looking up the parent.

        Robust to the sidecar `seed_split_dialogues.jsonl` being missing or
        only containing entries from the most recent --split-dialogues run.
        """
        if sid in scenario_lookup:
            return scenario_lookup[sid]
        m = _SPLIT_RE.match(sid)
        if not m:
            return None
        src_sid, src_ti_str = m.group(1), m.group(2)
        parent = scenario_lookup.get(src_sid)
        if parent is None:
            return None
        try:
            turn = parent["turns"][int(src_ti_str)]
        except (KeyError, IndexError, ValueError):
            return None
        return {
            "scenario_id": sid,
            "source_scenario_id": src_sid,
            "source_turn_indices": [int(src_ti_str)],
            "turns": [turn],
        }

    converted = 0
    for sid in sids:
        npz_path = args.npz_dir / f"{sid}.npz"
        if not npz_path.exists():
            print(f"  ✗ {sid}: no .npz found")
            continue
        scen = resolve_scenario(sid)
        if scen is None:
            print(f"  ✗ {sid}: not in scenario JSONL")
            continue
        data = np.load(npz_path)
        target = data["target"]   # (T, 52)
        T = target.shape[0]
        new_base = sid + args.suffix

        # Concat per-turn audio (only turns with text — same logic as data_pipeline).
        # One mp3 per scenario, named simply <new_base>.mp3 (no A/B/C copies —
        # the V2 variant is baked into the .npz; the viewer doesn't need to
        # compare variants side-by-side anymore).
        audio_paths = find_audio_for_turns(scen, args.audio_dir)
        out_mp3 = args.viewer_dir / f"{new_base}.mp3"
        if audio_paths:
            concat_audio(audio_paths, out_mp3)

        # Build viewer JSON
        turns_meta = [
            {
                "turn_idx": ti,
                "emotion": t.get("emotion", "neutral"),
                "vad": t.get("vad", [0, 0, 0]),
                "text": t.get("text", ""),
                "speaker": t.get("speaker", ""),
            }
            for ti, t in enumerate(scen["turns"])
            if t.get("text", "").strip()
        ]
        viewer_json = {
            "scenario_id": new_base,
            "fps": 30,
            "num_frames": int(T),
            "names": ARKIT_52_NAMES,
            "turns": turns_meta,
            "blendshapes": np.round(target, 4).tolist(),
        }
        # One JSON per scenario — no A/B/C duplication.
        json_out = args.viewer_dir / f"{new_base}.json"
        json_out.write_text(json.dumps(viewer_json, ensure_ascii=False))

        # Manifest entry
        manifest["scenarios"] = [s for s in manifest["scenarios"]
                                 if s.get("base") != new_base]
        manifest["scenarios"].append({
            "base": new_base,
            "scenario_id": new_base,
            "variants": [],
            "n_turns": len(turns_meta),
            "emotions": [t["emotion"] for t in turns_meta],
            "text_preview": (turns_meta[0].get("text", "") if turns_meta else "")[:50],
        })
        converted += 1
        print(f"  ✓ {new_base}  ({T} frames, {len(turns_meta)} turns)")

    manifest_fp.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
    print(f"\nmanifest updated. converted {converted}/{len(sids)} scenarios.")


if __name__ == "__main__":
    main()
