"""Regenerate specific (scenario_id, turn_index) pairs only.

Usage:
    python -m scripts.regenerate_turns \
        --turns scripts/_pron_turns.txt \
        --out_dir /dataset/AnimaSync-mic-fix/data/audio_preview_pron_fixed

Each line in --turns is "<scenario_id>:<turn_index>", e.g. "daily_157_p1:2".
"""
import argparse
import asyncio
import json
from pathlib import Path

from scripts.compiler.generate_audio import DEFAULT_SOURCES, load_scenarios
from scripts.compiler.tts import (
    dominant_base_for_turns,
    EMOTION_TO_BASE,
    FEMALE_BY_BASE,
    MALE_BY_BASE,
    synth_all,
)


def parse_turn_spec(path: Path) -> list[tuple[str, int, str | None]]:
    """Each line is `<scenario_id>:<turn_index>` (output goes to --out_dir
    with default filename) OR `<scenario_id>:<turn_index>\t<abs_dest_path>`
    to pin a specific destination (e.g. straight into a setNN/prefixed file).
    """
    out: list[tuple[str, int, str | None]] = []
    for line in path.read_text().splitlines():
        line = line.rstrip()
        if not line.strip() or line.lstrip().startswith("#"):
            continue
        if "\t" in line:
            spec, dest = line.split("\t", 1)
            dest = dest.strip()
        else:
            spec, dest = line.strip(), None
        sid, ti = spec.split(":")
        out.append((sid, int(ti), dest))
    return out


async def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--turns", type=Path, required=True,
                    help="File with <scenario_id>:<turn_index> per line")
    ap.add_argument("--sources", nargs="+", type=Path, default=DEFAULT_SOURCES)
    ap.add_argument("--out_dir", type=Path, required=True)
    ap.add_argument("--concurrency", type=int, default=4)
    ap.add_argument("--voice_gender", choices=["male", "female"], default=None,
                    help="Force all turns to a single-gender voice (picked "
                         "per emotion's base from MALE_BY_BASE/FEMALE_BY_BASE).")
    args = ap.parse_args()

    wanted = parse_turn_spec(args.turns)
    print(f"[want] {len(wanted)} specific turns")

    scenarios = load_scenarios(args.sources)
    by_sid = {s["scenario_id"]: s for s in scenarios}

    texts, emotions, vads, out_paths = [], [], [], []
    voice_seeds, voice_pools, voice_ids = [], [], []
    args.out_dir.mkdir(parents=True, exist_ok=True)

    for sid, ti, dest in sorted(wanted):
        scen = by_sid.get(sid)
        if scen is None:
            print(f"[miss] scenario not found: {sid}")
            continue
        if ti >= len(scen["turns"]):
            print(f"[miss] {sid}: turn {ti} out of range")
            continue
        turn = scen["turns"][ti]
        if not turn["text"].strip():
            print(f"[skip] {sid}:{ti} has empty text")
            continue
        is_monologue = sid.startswith("long_") or sid.startswith("solo_")
        scenario_voice_id = None
        if is_monologue:
            scenario_voice_id = FEMALE_BY_BASE[dominant_base_for_turns(scen["turns"])]
        if args.voice_gender:
            base = EMOTION_TO_BASE.get(turn["emotion"] or "neutral", "neutral")
            scenario_voice_id = (MALE_BY_BASE if args.voice_gender == "male"
                                 else FEMALE_BY_BASE)[base]
        if dest:
            out_path = Path(dest)
            out_path.parent.mkdir(parents=True, exist_ok=True)
        else:
            fname = f"{sid}_t{ti}_{turn['emotion']}.mp3"
            out_path = args.out_dir / fname
        # Force overwrite — synth_all skips files that exist & >1000 bytes,
        # which makes rerunning after a scrubber tweak silently no-op.
        if out_path.exists():
            out_path.unlink()
        texts.append(turn["text"])
        emotions.append(turn["emotion"])
        vads.append(turn.get("vad"))
        out_paths.append(out_path)
        voice_seeds.append(sid if is_monologue else None)
        voice_pools.append(None)
        voice_ids.append(scenario_voice_id)

    print(f"[plan] {len(texts)} turns to synthesize -> {args.out_dir}")
    if not texts:
        return
    await synth_all(
        texts, out_paths,
        backend="elevenlabs",
        concurrency=args.concurrency,
        emotions=emotions,
        vads=vads,
        voice_seeds=voice_seeds,
        voice_pools=voice_pools,
        voice_ids=voice_ids,
    )
    print("[done]")


if __name__ == "__main__":
    asyncio.run(main())
