"""TTS backends for data pipeline.

Two implementations:
    - ElevenLabs (paid, emotional prosody) — recommended
    - edge-tts (free fallback)

Env vars:
    ELEVENLABS_API_KEY  (required for elevenlabs)
    ELEVENLABS_VOICE_ID (optional, default to multilingual-friendly voice)
"""
from __future__ import annotations

import asyncio
import hashlib
import logging
import os
from pathlib import Path
from typing import List

LOG = logging.getLogger("tts")


class QuotaExhausted(RuntimeError):
    """Raised when ElevenLabs returns 401/402/quota — non-retryable.
    synth_all catches this once and aborts the whole batch so the user
    can see a single clear message and re-run with a different key."""

# Default ElevenLabs model. v3 went GA on 2026-02-02; supports audio tags
# (e.g. [excited], [whispers]) and is the intended target for this pipeline.
# Override via ELEVENLABS_MODEL env var if you need to fall back to v2.
DEFAULT_EL_MODEL = os.getenv("ELEVENLABS_MODEL", "eleven_v3")

# ── Voice pool: 5 base emotions × {female, male} → ElevenLabs voice IDs ──
# 16 sub-emotions map to their parent base emotion (per emotion_labels.json).
# Pool is expanded with both genders to prevent voice-emotion shortcut memorization.
VOICE_POOLS_BY_BASE = {
    "anger": [
        "FCdKzv68Ofr4VUDcZXIy",  # female
        "KsAmBSHXsuxsZ1lZZXlG",  # male
        "oHB9Xhox1bqMl1Tvkmel",
        "ZubHeGTOAkECknc02Zmo",
    ],
    "sadness": [
        "m3yAHyFEFKtbCIM5n7GF",  # female
        "k9073AMdU5sAUtPMH1il",  # male
        "hWXqitL3DEOLD49pgNWR",
        "z2P4oCxSHhXan3ew4COv",
    ],
    "joy": [
        "TbMNBJ27fH2U0VgpSNko",  # female
        "1W00IGEmNmwmsDeYy7ag",  # male
        "OEaq3WGNtNvFJ5co9mJE",
        "xi3rF0t7dg7uN2M0WUhr",
    ],
    "neutral": [
        "8tsLeAV5vPVuzCCvqbbU",  # female
        "s0XGIcqmceN2l7kjsqoZ",  # male
        "iWLjl1zCuqXRkW6494ve",
        "8QclKarwLctvzN2plke3",
    ],
    "surprise": [
        "XfNU2rGpBa01ckF309OY",  # female
        "93nuHbke4dTER9x2pDwE",  # male
        "YDDaC9XKjODs7hY78qEW",
        "BNr4zvrC1bGIdIstzjFQ",
    ],
}

# Monologue pool: only known-female voices, drawn from idx 0 of each base pool.
# Used so long_*/solo_* scenarios stay gender-consistent end-to-end (within a
# scenario the seed=sid lock already pins one voice; this guarantees that voice
# is always female across scenarios too).
FEMALE_MONOLOGUE_POOL = [
    "FCdKzv68Ofr4VUDcZXIy",  # anger-female
    "m3yAHyFEFKtbCIM5n7GF",  # sadness-female
    "TbMNBJ27fH2U0VgpSNko",  # joy-female
    "8tsLeAV5vPVuzCCvqbbU",  # neutral-female
    "XfNU2rGpBa01ckF309OY",  # surprise-female
]

# Direct base-emotion → female-voice lookup (used for "dominant-emotion"
# voice picking in monologues — see dominant_base_for_turns below).
FEMALE_BY_BASE = {
    "anger":    "FCdKzv68Ofr4VUDcZXIy",
    "sadness":  "m3yAHyFEFKtbCIM5n7GF",
    "joy":      "TbMNBJ27fH2U0VgpSNko",
    "neutral":  "8tsLeAV5vPVuzCCvqbbU",
    "surprise": "XfNU2rGpBa01ckF309OY",
}

# Same shape but male — idx 1 from each pool. Used when a regen needs to
# force a male voice (e.g. fixing pronunciation on dialogue turns whose
# A/B speakers are both male).
MALE_BY_BASE = {
    "anger":    "KsAmBSHXsuxsZ1lZZXlG",
    "sadness":  "k9073AMdU5sAUtPMH1il",
    "joy":      "1W00IGEmNmwmsDeYy7ag",
    "neutral":  "s0XGIcqmceN2l7kjsqoZ",
    "surprise": "93nuHbke4dTER9x2pDwE",
}


def dominant_base_for_turns(turns):
    """Pick the most-frequent base emotion across a scenario's turns.

    `turns` is a list of dicts with 'emotion' (and optionally 'vad').
    Returns the base name ('joy', 'sadness', etc.). Tie-break: highest
    summed |arousal| of the tied bases. Empty input → 'neutral'.
    """
    # Skip turns with no real content — empty or whitespace-only text
    # would otherwise vote 'neutral' and bias all-empty-padded scenarios.
    real = [t for t in turns if (t.get("text") or "").strip()]
    if not real:
        return "neutral"
    from collections import Counter
    counts = Counter()
    arousal_sum = {}
    for t in real:
        emo = t.get("emotion") or "neutral"
        base = EMOTION_TO_BASE.get(emo, "neutral")
        counts[base] += 1
        vad = t.get("vad") or [0.0, 0.0, 0.0]
        a = abs(float(vad[1])) if len(vad) >= 2 else 0.0
        arousal_sum[base] = arousal_sum.get(base, 0.0) + a
    top_count = max(counts.values())
    tied = [b for b, c in counts.items() if c == top_count]
    if len(tied) == 1:
        return tied[0]
    # tie-break by total |A|
    return max(tied, key=lambda b: arousal_sum.get(b, 0.0))

# 16 emotions → base emotion (via parent field in emotion_labels.json)
EMOTION_TO_BASE = {
    "neutral": "neutral",
    "joy": "joy",        "laughter": "joy",     "excitement": "joy",
    "agreement": "joy",  "gratitude": "joy",
    "sadness": "sadness","crying": "sadness",   "sulk": "sadness",
    "apology": "sadness","struggle": "sadness",
    "anger": "anger",    "refusal": "anger",
    "surprise": "surprise", "fluster": "surprise", "shy": "surprise",
}

# Legacy override: user may still set ELEVENLABS_VOICE_ID for a single voice
SINGLE_VOICE_OVERRIDE = os.getenv("ELEVENLABS_VOICE_ID")


def voice_id_for_emotion(emotion: str | None, override: str | None = None,
                          seed: str | None = None,
                          pool_override: str | None = None) -> str:
    """Pick voice ID from pool keyed by emotion → base emotion.

    seed: stable key (e.g. text or scenario_id+turn_idx) to deterministically
    pick a voice from the pool. Without seed, picks first voice in pool.
    Deterministic selection ensures same turn always gets same voice on re-run.

    pool_override: force voice selection from a specific base-emotion pool
    (e.g. 'neutral' for monologues — same voice across all emotions in turn).
    """
    if override:
        return override
    if SINGLE_VOICE_OVERRIDE:
        return SINGLE_VOICE_OVERRIDE
    if pool_override == "female_monologue":
        pool = FEMALE_MONOLOGUE_POOL
    elif pool_override:
        pool = VOICE_POOLS_BY_BASE[pool_override]
    else:
        base = EMOTION_TO_BASE.get(emotion or "neutral", "neutral")
        pool = VOICE_POOLS_BY_BASE[base]
    if seed is None:
        return pool[0]
    # Stable hash → voice index, so resume/skip logic stays consistent
    h = int(hashlib.md5(seed.encode("utf-8")).hexdigest(), 16)
    return pool[h % len(pool)]


def _build_voice_settings(vad: list | None):
    """Map VAD → ElevenLabs voice_settings for expressive prosody.

    For v3 + audio tags: tags carry the emotion, voice_settings carry
    consistency. Stacking high `style` on top of expressive tags
    triple-amplifies expression and produces distortion (raspy / accent
    drift / foreigner-sounding) — that was the failure mode observed on
    2026-05-07. We keep style modest and stability solid.

    Intensity is |A| magnitude — both high-positive (excited) and
    low-negative (sad) arousal scale prosody equally.
    """
    if vad is None:
        return {"stability": 0.45, "similarity_boost": 0.75, "style": 0.30, "use_speaker_boost": True}
    V, A, D = vad
    intensity = abs(float(A))
    # Brighter style ramp (0.20 baseline → ~0.45 max) so v3 has room to express
    # tag-driven emotion — too-low style was muting joy/excitement on 2026-05-07.
    style = float(max(0.0, min(0.50, 0.20 + 0.25 * intensity)))
    # Stability slightly lower than before (0.40-0.50) to allow more variation
    stability = float(max(0.40, min(0.50, 0.50 - 0.10 * intensity)))
    similarity_boost = 0.75
    return {
        "stability": stability,
        "similarity_boost": similarity_boost,
        "style": style,
        "use_speaker_boost": True,
    }


def _ensure_trailing_punct(text: str) -> str:
    """Ensure text ends with sentence-ending punctuation so v3 doesn't cut the tail.

    Note: `、` (U+3001) is a Japanese reading-comma, not a Korean sentence
    terminator — it was here by mistake. Korean uses ASCII `.?!` plus
    sometimes `…`. CJK terminators `。？！` are kept as a safety net for
    pasted text.
    """
    text = text.strip()
    if text and text[-1] not in ".?!。？！…":
        return text + "."
    return text


_KOREAN_EMOTICON_RE = __import__("re").compile(
    r"[ㅠㅜ]{2,}|[ㅋㅎ]{2,}|[ㅏㅓㅗㅡㅣ]{2,}|[ㅡ]{2,}"
)


def _scrub_korean_emoticons(text: str, emotion: str | None) -> tuple[str, bool, bool]:
    """v3 reads bare Korean jamo runs (ㅠㅠ, ㅋㅋ, ㅎㅎ) literally — comes out
    as "어어/유유/크크" instead of sobbing/laughing. Replace with v3 audio
    tags so they render as real reactions.

    ㅠㅠ / ㅜㅜ → [crying]    (canonical v3 demo tag — fires reliably
                              where [sniffles] was getting ignored
                              mid-Korean-sentence)
    ㅋㅋ / ㅎㅎ → [laughs]    (canonical v3 demo tag — fires reliably
                              where [chuckles] was getting ignored)
    bare vowel runs (ㅏㅏ, ㅡㅡ…) → stripped (garbage)

    No leading period — that produced a hard sentence break before the
    reaction (audible "speak…long pause…laugh" instead of "speak laugh").
    The canonical tags fire mid-clause without it.

    Returns (cleaned_text, had_sob, had_chuckle) — flags reserved but
    head-tag suppression is OFF by default now. Head tag carries the
    emotional prosody for the whole line; the body tag just stamps the
    reaction sound. Suppressing the head was making the voiced emotion
    drop out entirely.
    """
    import re
    had_sob = bool(re.search(r"[ㅠㅜ]{2,}", text))
    had_chuckle = bool(re.search(r"[ㅋㅎ]{2,}", text))
    text = re.sub(r"[ㅠㅜ]{2,}", " [crying] ", text)
    text = re.sub(r"[ㅋㅎ]{2,}", " [laughs] ", text)
    text = re.sub(r"[ㅏㅓㅗㅡㅣ]{2,}", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text, had_sob, had_chuckle


def _pad_short_text(text: str, min_chars: int = 12) -> str:
    """v3 fails / hallucinates on very short inputs (e.g. "응.", "네.").

    Pad with neutral filler so the model has enough context. We add an
    ellipsis-style trailer that v3 reads as a brief pause, not extra speech.
    """
    if len(text.strip()) < min_chars:
        return text.rstrip(".") + "..."
    return text


def _format_text_with_emotion(text: str, emotion: str | None, use_v3_tags: bool) -> str:
    """Prepend emotion tag for v3 models (supports bracketed audio tags).

    Adds a leading period buffer before the tag — v3 sometimes eats the first
    token after an audio tag, which truncates the first word of real content.
    Also ensures trailing punctuation so the model doesn't cut the final word.
    Short texts are padded — v3 hallucinates on <12-char inputs.
    """
    # Strip any pre-existing [bracket] sequences so they don't conflict
    # with the audio tag we're about to inject. v3 reads any [...] as a tag.
    import re
    text = re.sub(r"\[[^\]]*\]", " ", text)
    text, had_sob, had_chuckle = _scrub_korean_emoticons(text, emotion)
    text = _pad_short_text(text)
    text = _ensure_trailing_punct(text)
    if not use_v3_tags or not emotion or emotion == "neutral":
        return text
    # v3 audio tags. Cross-checked against ElevenLabs official tag categories
    # (Emotional States / Reactions / Cognitive Beats / Tone Cues) on the
    # v3 docs page. Tags marked OFFICIAL appear in the listed examples;
    # ADJ-VARIANT tags are emotion-name adjectives that follow the same
    # pattern as listed ones (e.g. [angry] mirrors [sorrowful], [frustrated]).
    # `None` means "no tag — let text + voice_settings carry the emotion".
    tag_map = {
        "joy":        "[happily]",         # OFFICIAL (v3 launch examples) — more energetic than [cheerfully]
        "laughter":   "[laughs]",          # OFFICIAL — Reactions
        "excitement": "[excited]",         # OFFICIAL — Emotional States
        "agreement":  None,                # no tag (low arousal positive)
        "gratitude":  None,                # no tag (Korean words carry it)
        "sadness":    "[sorrowful]",       # OFFICIAL — Emotional States
        "crying":     "[crying]",          # ADJ-VARIANT (reaction-family)
        "sulk":       "[sighs]",           # OFFICIAL — Reactions
        "apology":    "[resigned tone]",   # OFFICIAL — Cognitive Beats
        "struggle":   "[frustrated]",      # OFFICIAL — Emotional States
        "anger":      "[angry]",           # ADJ-VARIANT
        "refusal":    "[flatly]",          # OFFICIAL — Tone Cues
        "surprise":   "[surprised]",       # ADJ-VARIANT
        "fluster":    "[stammers]",        # OFFICIAL — Cognitive Beats
        "shy":        "[whispers]",        # OFFICIAL — Reactions
    }
    tag = tag_map.get(emotion, "") or ""
    # When the body already has an inline [laughs]/[crying] from the ㅋㅋ/ㅠㅠ
    # scrub, drop the matching head tag so the reaction fires exactly where
    # the original emoticon was — not also at the start of the line.
    if had_chuckle and emotion in ("laughter", "joy", "excitement"):
        tag = ""
    if had_sob and emotion in ("sadness", "crying", "sulk", "apology"):
        tag = ""
    # Leading period absorbs v3's tendency to swallow the first token after a tag
    return f". {tag} {text}".strip() if tag else text


async def synth_one_elevenlabs(text: str, out_path: Path, voice_id: str = None,
                                 model: str = None, api_key: str = None,
                                 emotion: str = None, vad: list = None,
                                 voice_seed: str = None,
                                 voice_pool: str = None):
    """Synthesize one clip with ElevenLabs, using emotion+VAD for prosody.

    If voice_id is None, picks voice automatically based on emotion → base mapping.
    voice_seed: stable key for deterministic voice pick from the pool.
    voice_pool: override base-emotion pool (e.g. 'neutral' for monologues).
    """
    from elevenlabs.client import ElevenLabs

    # Auto-select voice by emotion + stable seed if not explicitly overridden
    voice_id = voice_id_for_emotion(emotion, override=voice_id,
                                     seed=voice_seed or out_path.stem,
                                     pool_override=voice_pool)
    model = model or DEFAULT_EL_MODEL
    api_key = api_key or os.getenv("ELEVENLABS_API_KEY")
    if not api_key:
        raise RuntimeError("ELEVENLABS_API_KEY env var required")

    # Detect v3 model for audio-tag support
    use_v3_tags = "v3" in model.lower()
    formatted_text = _format_text_with_emotion(text, emotion, use_v3_tags)
    voice_settings = _build_voice_settings(vad)

    loop = asyncio.get_event_loop()

    def _sync():
        import subprocess

        client = ElevenLabs(api_key=api_key)
        audio = client.text_to_speech.convert(
            voice_id=voice_id,
            text=formatted_text,
            model_id=model,
            # Higher quality to avoid boundary clipping at low bitrates
            output_format="mp3_44100_128",
            voice_settings=voice_settings,
        )

        # Write raw response to a sibling file (same filesystem as out_path)
        # to avoid cross-device link errors when falling back.
        raw_path = out_path.with_suffix(".raw.mp3")
        with raw_path.open("wb") as f:
            for chunk in audio:
                f.write(chunk)

        # Post-process: pad 180ms silence at start + end so nothing is clipped
        # adelay = leading silence, apad pad_dur = trailing silence
        try:
            subprocess.run(
                [
                    "ffmpeg", "-y", "-loglevel", "error",
                    "-i", str(raw_path),
                    "-af", "adelay=180|180,apad=pad_dur=0.18",
                    "-c:a", "libmp3lame", "-b:a", "128k",
                    str(out_path),
                ],
                check=True,
            )
            raw_path.unlink(missing_ok=True)
        except (FileNotFoundError, subprocess.CalledProcessError) as e:
            # Fallback: keep the raw audio if ffmpeg is unavailable/failed
            LOG.warning(f"[tts] ffmpeg pad failed ({e}); saving raw audio")
            raw_path.replace(out_path)

    # Retry transient errors:
    #   409 already_running   (voice library add race on first concurrent use)
    #   429 rate limited      (expected at scale across long batches)
    #   5xx upstream          (ElevenLabs gateway hiccups)
    #   timeout / connection  (network blips)
    # Hard-fail (no retry) on auth/quota errors — caller raises a sentinel
    # that aborts the whole batch so the user sees a clear "out of quota"
    # message instead of buried per-task warnings.
    import random
    max_retries = 6
    for attempt in range(max_retries):
        try:
            await loop.run_in_executor(None, _sync)
            return
        except Exception as e:
            msg = str(e)
            low = msg.lower()
            is_quota = ("401" in msg or "402" in msg or "403" in msg or
                        "quota" in low or "unauthorized" in low or
                        "payment_required" in low or "billing" in low)
            if is_quota:
                # Persist FULL exception details to a file for diagnosis
                # (truncated msg in the raised exception loses the body).
                try:
                    Path("/tmp/tts_quota_error.txt").write_text(
                        f"=== full exception ===\n{msg}\n\n"
                        f"=== type ===\n{type(e).__name__}\n\n"
                        f"=== status_code attr ===\n{getattr(e, 'status_code', 'n/a')}\n\n"
                        f"=== body attr ===\n{getattr(e, 'body', 'n/a')}\n",
                        encoding='utf-8',
                    )
                except Exception:
                    pass
                raise QuotaExhausted(f"ElevenLabs auth/quota error: {msg[:200]}") from e
            is_409 = "409" in msg or "already_running" in low
            is_429 = "429" in msg or "rate_limit" in low or "too_many_requests" in low
            is_5xx = any(c in msg for c in (" 500 ", " 502 ", " 503 ", " 504 "))
            is_net = any(c in low for c in ("timeout", "connection", "remote disconnected"))
            retryable = is_409 or is_429 or is_5xx or is_net
            if retryable and attempt < max_retries - 1:
                base = 8.0 if is_429 else 2.0
                delay = min(60.0, base * (2 ** attempt) + random.uniform(0, 1.0))
                kind = "429" if is_429 else "409" if is_409 else "5xx" if is_5xx else "net"
                LOG.info(f"[tts] {kind} on voice {voice_id[:8]}…, retry {attempt+1}/{max_retries} in {delay:.1f}s")
                await asyncio.sleep(delay)
                continue
            raise


async def synth_one_edge(text: str, out_path: Path, voice: str = "ko-KR-SunHiNeural"):
    """Free fallback via edge-tts CLI."""
    proc = await asyncio.create_subprocess_exec(
        "edge-tts", "--voice", voice, "--text", text, "--write-media", str(out_path),
        stdout=asyncio.subprocess.DEVNULL,
        stderr=asyncio.subprocess.PIPE,
    )
    _, err = await proc.communicate()
    if proc.returncode != 0:
        raise RuntimeError(f"edge-tts failed: {err.decode()[:200]}")


async def synth_all(
    texts: List[str],
    out_paths: List[Path],
    backend: str = "elevenlabs",
    concurrency: int = 4,
    emotions: List[str] = None,   # per-item emotion labels
    vads: List[list] = None,       # per-item VAD triples
    voice_seeds: List[str] = None, # per-item stable seed for voice pick
    voice_pools: List[str] = None, # per-item pool override (e.g. 'neutral')
    voice_ids: List[str] = None,   # per-item explicit voice ID (overrides pool/seed)
    **kwargs,
) -> List[bool]:
    """Parallel TTS for a batch. Returns success flag per item.

    emotions/vads: per-item lists aligned with texts. If provided,
    passed into the TTS call for emotional prosody.
    """
    assert len(texts) == len(out_paths)
    semaphore = asyncio.Semaphore(concurrency)
    ok = [False] * len(texts)
    skipped = 0
    aborted = {"flag": False, "reason": ""}

    synth_fn = synth_one_elevenlabs if backend == "elevenlabs" else synth_one_edge

    async def one(i: int):
        nonlocal skipped
        async with semaphore:
            if aborted["flag"]:
                return
            path = out_paths[i]
            if path.exists() and path.stat().st_size > 1000:
                ok[i] = True
                skipped += 1
                return
            try:
                per_kwargs = dict(kwargs)
                if backend == "elevenlabs":
                    if emotions is not None:
                        per_kwargs["emotion"] = emotions[i]
                    if vads is not None:
                        per_kwargs["vad"] = vads[i]
                    if voice_seeds is not None and voice_seeds[i] is not None:
                        per_kwargs["voice_seed"] = voice_seeds[i]
                    if voice_pools is not None and voice_pools[i] is not None:
                        per_kwargs["voice_pool"] = voice_pools[i]
                    if voice_ids is not None and voice_ids[i] is not None:
                        per_kwargs["voice_id"] = voice_ids[i]
                await synth_fn(texts[i], path, **per_kwargs)
                ok[i] = True
            except QuotaExhausted as e:
                # Latch the abort so all queued workers stop firing 401s.
                aborted["flag"] = True
                aborted["reason"] = str(e)
                LOG.error(f"[tts] QUOTA EXHAUSTED — aborting batch: {e}")
            except Exception as e:
                LOG.warning(f"[tts] fail {i}: {e}")

    await asyncio.gather(*[one(i) for i in range(len(texts))])
    if skipped:
        print(f"[tts] {skipped}/{len(texts)} turns skipped (already on disk)")
    if aborted["flag"]:
        print(f"\n[tts] ⚠ batch aborted on quota/auth error.\n"
              f"      reason: {aborted['reason'][:200]}\n"
              f"      partial results saved. Re-run with a different "
              f"ELEVENLABS_API_KEY to resume — already-done turns will be skipped.")
    return ok
