"""Standalone audio generation for quality evaluation.

Generates ElevenLabs audio for a sample of seed scenario turns, so you can
listen and evaluate quality BEFORE running the full data pipeline.

Usage:
    # Set env vars first:
    export ELEVENLABS_API_KEY=...
    export ELEVENLABS_VOICE_ID=...    # optional
    export ELEVENLABS_MODEL=eleven_multilingual_v2   # or eleven_v3

    # Generate first 10 turns (in order):
    python -m scripts.compiler.generate_audio --limit 10

    # Generate 10 random turns:
    python -m scripts.compiler.generate_audio --limit 10 --random

    # One turn per emotion (16 total, good for quality matrix):
    python -m scripts.compiler.generate_audio --one-per-emotion

    # Specific scenarios:
    python -m scripts.compiler.generate_audio --scenarios long_125,daily_001
"""
from __future__ import annotations

import argparse
import asyncio
import json
import random
from pathlib import Path
from typing import List

from .tts import synth_all

PROJECT_ROOT = Path(__file__).resolve().parents[2]
EMO_DIR = PROJECT_ROOT / "data" / "emotion"
DEFAULT_SOURCES = [
    EMO_DIR / "seed_train_final.jsonl",
    EMO_DIR / "seed_val.jsonl",
    EMO_DIR / "seed_test.jsonl",
]
AUDIO_OUT = PROJECT_ROOT / "data" / "audio_preview"


def load_scenarios(source_paths: list) -> list:
    """Load all scenarios from multiple jsonl files, preserving order within each."""
    scenarios = []
    for p in source_paths:
        if not p.exists():
            continue
        with p.open() as f:
            for line in f:
                scenarios.append(json.loads(line))
    return scenarios


def scenario_already_done(out_dir: Path, scenario_id: str, num_turns: int) -> bool:
    """Check if all turns of this scenario already have valid mp3 files.

    Filters out *.raw.mp3 orphans — these are written mid-synthesis and
    deleted on success; if a process is killed they remain on disk and
    would otherwise make a partially-failed scenario look complete.
    """
    for ti in range(num_turns):
        matches = [
            m for m in out_dir.glob(f"{scenario_id}_t{ti}_*.mp3")
            if not m.name.endswith(".raw.mp3")
        ]
        if not matches or matches[0].stat().st_size < 1000:
            return False
    return True


def flatten_turns_for_scenarios(scenarios: list):
    """Return list of (scenario_id, turn_index, turn_dict)."""
    items = []
    for scen in scenarios:
        sid = scen["scenario_id"]
        for ti, turn in enumerate(scen["turns"]):
            if turn["text"].strip():
                items.append((sid, ti, turn))
    return items


def pick_scenarios_batch(scenarios: list, out_dir: Path, batch_size: int,
                          one_per_emotion: bool = False,
                          specific: list = None):
    """Pick the next N scenarios to process, skipping already-done ones.

    - specific: list of scenario_ids to target (overrides batch logic)
    - one_per_emotion: diagnostic mode, one turn per emotion
    - default: next `batch_size` unfinished scenarios in order
    """
    if specific:
        requested = set(specific)
        return [s for s in scenarios if s["scenario_id"] in requested]

    if one_per_emotion:
        items = flatten_turns_for_scenarios(scenarios)
        seen = set()
        picked_items = []
        for x in items:
            emo = x[2]["emotion"]
            if emo not in seen:
                seen.add(emo)
                picked_items.append(x)
        # Wrap each selected turn into a synthetic "scenario" of 1 turn
        synth = []
        for sid, ti, turn in picked_items:
            synth.append({
                "scenario_id": f"{sid}_singleTurn{ti}_{turn['emotion']}",
                "turns": [turn],
            })
        return synth

    # Default: next unfinished batch
    unfinished = []
    for scen in scenarios:
        num_turns = sum(1 for t in scen["turns"] if t["text"].strip())
        if num_turns == 0:
            continue
        if scenario_already_done(out_dir, scen["scenario_id"], num_turns):
            continue
        unfinished.append(scen)
        if len(unfinished) >= batch_size:
            break
    return unfinished


async def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--sources", nargs="+", type=Path, default=DEFAULT_SOURCES,
                    help="Scenario jsonl files to combine (default: train+val+test)")
    ap.add_argument("--out_dir", type=Path, default=AUDIO_OUT)
    ap.add_argument("--batch_size", type=int, default=10,
                    help="Number of scenarios per batch. Default 10.")
    ap.add_argument("--one-per-emotion", action="store_true", dest="one_per_emotion",
                    help="Diagnostic mode: 16 turns (one per emotion) instead of scenarios")
    ap.add_argument("--scenarios",
                    type=lambda s: [x.strip() for x in s.split(",") if x.strip()],
                    default=None,
                    help="Specific scenario IDs (comma-separated), overrides batch logic")
    ap.add_argument("--concurrency", type=int, default=4)
    ap.add_argument("--voice_id", default=None, help="Override auto emotion→voice mapping")
    ap.add_argument("--backend", choices=["elevenlabs", "edge"], default="elevenlabs")
    ap.add_argument("--dry-run", action="store_true", dest="dry_run",
                    help="Print char count + voice plan, then exit without API calls")
    args = ap.parse_args()

    args.out_dir.mkdir(parents=True, exist_ok=True)

    scenarios = load_scenarios(args.sources)
    total_scenarios = len(scenarios)
    total_turns = sum(
        sum(1 for t in s["turns"] if t["text"].strip()) for s in scenarios
    )
    print(f"[pool] {total_scenarios} scenarios, {total_turns} turns total")

    # Count already-done
    done_count = sum(
        1 for s in scenarios
        if scenario_already_done(
            args.out_dir, s["scenario_id"],
            sum(1 for t in s["turns"] if t["text"].strip())
        )
    )
    print(f"[progress] {done_count}/{total_scenarios} scenarios already have audio")

    picked_scenarios = pick_scenarios_batch(
        scenarios, args.out_dir, args.batch_size,
        one_per_emotion=args.one_per_emotion,
        specific=args.scenarios,
    )
    print(f"[select] {len(picked_scenarios)} scenarios this batch")

    if not picked_scenarios:
        print("[info] Nothing to do — all selected scenarios already have audio.")
        return

    # Flatten to per-turn work items.
    # For monologues (long_*, solo_*): lock voice to the SCENARIO'S DOMINANT
    # base emotion (most-frequent across turns). This keeps the voice
    # character matched to the overall emotional arc — a sad-themed monologue
    # uses sadness-female, a happy-themed one uses joy-female, etc.
    # For dialogues: keep the per-turn pool/seed mechanism.
    from .tts import dominant_base_for_turns, FEMALE_BY_BASE
    texts, emotions, vads, out_paths, manifest = [], [], [], [], []
    voice_seeds, voice_pools, voice_ids = [], [], []
    for scen in picked_scenarios:
        sid = scen["scenario_id"]
        is_monologue = sid.startswith("long_") or sid.startswith("solo_")
        scenario_voice_id = None
        if is_monologue:
            dominant_base = dominant_base_for_turns(scen["turns"])
            scenario_voice_id = FEMALE_BY_BASE[dominant_base]
        for ti, turn in enumerate(scen["turns"]):
            if not turn["text"].strip():
                continue
            fname = f"{sid}_t{ti}_{turn['emotion']}.mp3"
            path = args.out_dir / fname
            texts.append(turn["text"])
            emotions.append(turn["emotion"])
            vads.append(turn.get("vad"))
            out_paths.append(path)
            voice_seeds.append(sid if is_monologue else None)
            voice_pools.append(None)  # monologues bypass pool via voice_ids
            voice_ids.append(scenario_voice_id)  # None for dialogues
            manifest.append({
                "file": fname, "scenario_id": sid,
                "turn_index": ti, "emotion": turn["emotion"],
                "vad": turn.get("vad"), "text": turn["text"],
                "monologue": is_monologue,
                "voice_id": scenario_voice_id,
            })

    tts_kwargs = {}
    if args.voice_id:
        tts_kwargs["voice_id"] = args.voice_id

    if any(v is not None for v in voice_ids):
        n_dom = sum(1 for v in voice_ids if v is not None)
        print(f"[voice] {n_dom} turns locked to dominant-emotion voice "
              f"({len(set(v for v in voice_ids if v))} unique voices)")

    total_chars = sum(len(t) for t in texts)
    will_skip = sum(
        1 for p in out_paths if p.exists() and p.stat().st_size > 1000
    )
    will_synth = len(texts) - will_skip
    will_synth_chars = sum(
        len(t) for t, p in zip(texts, out_paths)
        if not (p.exists() and p.stat().st_size > 1000)
    )
    print(f"[chars] {total_chars} characters across {len(texts)} turns")
    print(f"[plan]  {will_skip} turns will be skipped (already on disk)")
    print(f"[plan]  {will_synth} turns will be synthesized "
          f"(~{will_synth_chars} chars / ~${will_synth_chars/1000*0.30:.2f})")
    if args.dry_run:
        print("[dry-run] exiting without API calls")
        return

    print(f"[tts] {len(texts)} turns, backend={args.backend}, concurrency={args.concurrency}")
    ok_flags = await synth_all(
        texts, out_paths,
        backend=args.backend,
        concurrency=args.concurrency,
        emotions=emotions,
        vads=vads,
        voice_seeds=voice_seeds,
        voice_pools=voice_pools,
        voice_ids=voice_ids,
        **tts_kwargs,
    )

    # Rebuild manifest from disk: keep all rows whose audio file still
    # exists, then merge in this batch's rows. Authoritative — survives any
    # prior dupe accumulation. Earlier "append-with-dedupe" approach broke
    # when older runs had stray rows the dedupe set picked up.
    manifest_path = args.out_dir / "manifest.jsonl"
    on_disk = {p.name for p in args.out_dir.glob("*.mp3")
               if not p.name.endswith(".raw.mp3")}
    kept = {}  # file -> row dict
    if manifest_path.exists():
        with manifest_path.open(encoding="utf-8") as f:
            for line in f:
                try:
                    row = json.loads(line)
                except Exception:
                    continue
                fn = row.get("file")
                if fn in on_disk:
                    kept[fn] = row  # last-wins on duplicates
    # Overlay this batch's results
    new_count = 0
    for m, ok in zip(manifest, ok_flags):
        if m["file"] not in kept:
            new_count += 1
        m["ok"] = ok
        kept[m["file"]] = m
    # Write fresh manifest
    with manifest_path.open("w", encoding="utf-8") as f:
        for fn in sorted(kept):
            f.write(json.dumps(kept[fn], ensure_ascii=False) + "\n")
    print(f"[manifest] {len(kept)} rows total (+{new_count} new this batch, "
          f"{len(manifest) - new_count} updated)")

    successes = sum(ok_flags)
    print(f"\n[done] {successes}/{len(texts)} turns generated ({len(picked_scenarios)} scenarios)")
    print(f"  audio:    {args.out_dir}")
    print(f"  manifest: {manifest_path}")
    print(f"\n[next batch] just re-run the same command — already-done scenarios will be skipped")


if __name__ == "__main__":
    asyncio.run(main())
