#!/usr/bin/env python3
"""Curate a subset from the full dataset viewer archive into the active
player folder.

Reads from `data/viewer_dataset/` (full archive written by
dataset_to_viewer.py) and copies a curated selection of scenarios into
`data/viewer/` so the player dropdown stays focused. Rebuilds
`data/viewer/manifest.json` with only the curated entries.

Default selection (deterministic with --seed):
  - longs:   ALL    (multi-turn monologues — main signal for emotion arcs)
  - solos:   30     (sampled, diversified across emotion classes)
  - dailies: 20     (per-turn dialogue splits, diversified across emotions)

Usage:
    python3 scripts/curate_viewer.py                  # default selection
    python3 scripts/curate_viewer.py --solos 50 --dailies 0
    python3 scripts/curate_viewer.py --clean          # also remove old *_dataset_* from viewer/
"""
from __future__ import annotations

import argparse
import json
import random
import shutil
from collections import defaultdict
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]


def _classify(sid: str) -> str | None:
    if sid.startswith("long_"):
        return "long_"
    if sid.startswith("solo_"):
        return "solo_"
    if sid.startswith("daily_") and "_t" in sid:
        return "daily_-split"
    return None


def _sample_diverse(entries: list[dict], n: int, rng: random.Random) -> list[dict]:
    """Round-robin sample n entries across distinct primary emotions."""
    if n <= 0 or not entries:
        return []
    if n >= len(entries):
        return list(entries)
    by_emo: dict[str, list[dict]] = defaultdict(list)
    for e in entries:
        emos = e.get("emotions") or ["neutral"]
        by_emo[emos[0]].append(e)
    keys = list(by_emo)
    rng.shuffle(keys)
    for k in by_emo:
        rng.shuffle(by_emo[k])
    picks: list[dict] = []
    while len(picks) < n:
        progressed = False
        for k in keys:
            if by_emo[k]:
                picks.append(by_emo[k].pop())
                progressed = True
                if len(picks) >= n:
                    break
        if not progressed:
            break
    return picks


def _copy_scenario(s: dict, archive: Path, viewer: Path) -> int:
    """Copy one scenario's JSON + MP3 from archive → viewer. Returns count
    of files written. One JSON and one MP3 per scenario — no A/B/C duplication."""
    base = s["scenario_id"]   # e.g. "long_001_dataset"
    written = 0
    src_json = archive / f"{base}.json"
    if src_json.exists():
        shutil.copy2(src_json, viewer / src_json.name)
        written += 1
    src_mp3 = archive / f"{base}.mp3"
    if src_mp3.exists():
        shutil.copy2(src_mp3, viewer / src_mp3.name)
        written += 1
    return written


def main() -> None:
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument("--archive", type=Path,
                    default=PROJECT_ROOT / "data" / "viewer_dataset",
                    help="Full archive folder (dataset_to_viewer.py output).")
    ap.add_argument("--viewer", type=Path,
                    default=PROJECT_ROOT / "data" / "viewer",
                    help="Active player folder (curated subset).")
    ap.add_argument("--longs", type=int, default=0,
                    help="How many long_ samples (0 = all). Default 0.")
    ap.add_argument("--solos", type=int, default=30,
                    help="How many solo_ samples. Default 30.")
    ap.add_argument("--dailies", type=int, default=20,
                    help="How many daily_-split samples. Default 20.")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--clean", action="store_true",
                    help="Remove existing *_dataset_* files from viewer "
                         "BEFORE curation (recommended after re-baking).")
    args = ap.parse_args()

    rng = random.Random(args.seed)

    archive_manifest_fp = args.archive / "manifest.json"
    if not archive_manifest_fp.exists():
        raise SystemExit(f"Archive manifest not found: {archive_manifest_fp}\n"
                          "Run dataset_to_viewer.py first.")
    archive_manifest = json.loads(archive_manifest_fp.read_text())

    buckets: dict[str, list[dict]] = {"long_": [], "solo_": [], "daily_-split": []}
    for s in archive_manifest.get("scenarios", []):
        # Strip the "_dataset" suffix to inspect the underlying scenario_id type.
        sid = s.get("scenario_id", "")
        base_for_class = sid.replace("_dataset", "")
        cat = _classify(base_for_class)
        if cat:
            buckets[cat].append(s)

    selected: list[dict] = []
    if args.longs <= 0 or args.longs >= len(buckets["long_"]):
        selected.extend(buckets["long_"])
    else:
        selected.extend(_sample_diverse(buckets["long_"], args.longs, rng))
    selected.extend(_sample_diverse(buckets["solo_"], args.solos, rng))
    selected.extend(_sample_diverse(buckets["daily_-split"], args.dailies, rng))

    if not selected:
        raise SystemExit(f"No scenarios matched. Archive has: "
                          f"long_={len(buckets['long_'])}, "
                          f"solo_={len(buckets['solo_'])}, "
                          f"daily_-split={len(buckets['daily_-split'])}")

    args.viewer.mkdir(parents=True, exist_ok=True)

    if args.clean:
        n_removed = 0
        for p in args.viewer.glob("*_dataset_*"):
            try:
                p.unlink()
                n_removed += 1
            except OSError:
                pass
        # Also remove stale *_dataset_*.concat.txt artifacts if any
        for p in args.viewer.glob("*_dataset_*.concat.txt"):
            try:
                p.unlink()
                n_removed += 1
            except OSError:
                pass
        print(f"--clean: removed {n_removed} *_dataset_* files from {args.viewer}")

    total_written = 0
    for s in selected:
        total_written += _copy_scenario(s, args.archive, args.viewer)

    # Curated manifest contains ONLY the selected entries. Old non-dataset
    # entries in the existing manifest are preserved (e.g. _d65x20 references).
    existing_manifest_fp = args.viewer / "manifest.json"
    if existing_manifest_fp.exists():
        existing = json.loads(existing_manifest_fp.read_text())
        kept = [e for e in existing.get("scenarios", [])
                if "_dataset" not in (e.get("base") or "")]
    else:
        kept = []
    kept.extend(selected)
    existing_manifest_fp.write_text(
        json.dumps({"scenarios": kept}, ensure_ascii=False, indent=2)
    )

    n_long = sum(1 for s in selected if s["scenario_id"].startswith("long_"))
    n_solo = sum(1 for s in selected if s["scenario_id"].startswith("solo_"))
    n_daily = sum(1 for s in selected
                  if s["scenario_id"].startswith("daily_") and "_t" in s["scenario_id"])
    print(f"curated: {n_long} long, {n_solo} solo, {n_daily} daily-split "
          f"= {len(selected)} scenarios, {total_written} files written")
    print(f"viewer: {args.viewer}")
    print(f"archive: {args.archive}")


if __name__ == "__main__":
    main()
