#!/usr/bin/env python3
"""A/B/C experiment: compare V2-dynamics mask strategies on one scenario.

Generates 3 variants of daily_003 viewer bundles:
    A: strict 11-channel mask (brow+cheekSquint+noseSneer+eyeSquint)
    B: tiered ~19-channel mask (A + smile/frown/eyeWide/dimple at reduced α)
    C: no V2 — current baseline (compiler + LAM only)

Usage:
    python3 scripts/compiler/abc_experiment.py [--scenario daily_003]
"""
from __future__ import annotations

import argparse
import json
import subprocess
import sys
from pathlib import Path

import librosa
import numpy as np
import onnxruntime as ort

sys.path.insert(0, '/dataset/text-to-face-se/LAM_Audio2Expression')
from distillation.student_model import AudioFeatureExtractor

from scripts.compiler.expressive import compile_expressive_batch, MOUTH_CHEEK_OVERLAY_CHANNELS
from scripts.compiler.parametric import parametric_layer
from scripts.compiler.constants import (
    ARKIT_52_NAMES, LAM_WEIGHTS_SHARED, LIPSYNC_ONLY, EXPRESSION_ONLY,
    SHARED_CHANNELS,
)
from scripts.compiler.data_pipeline import merge_lam_compiler, speech_gate
from scripts.compiler.eye_motion import apply_eye_motion
from scripts.compiler.lam_wrapper import LAMWrapper
from scripts.compiler.utils import load_presets_from_json, build_synthetic_presets

PROJECT_ROOT = Path(__file__).resolve().parents[2]
SCENARIOS = [
    PROJECT_ROOT / "data" / "emotion" / "seed_train_final.jsonl",
    PROJECT_ROOT / "data" / "emotion" / "seed_val.jsonl",
    PROJECT_ROOT / "data" / "emotion" / "seed_test.jsonl",
]
AUDIO_DIR = PROJECT_ROOT / "data" / "audio_preview"
PRESETS = PROJECT_ROOT / "expression_presets.json"
ONNX_V2 = '/dataset/mead-expression-training/e2f/distill/emotion_face_int8.onnx'
OUT_DIR = PROJECT_ROOT / "data" / "viewer"
FPS = 30

# ── V2's native ARKit ordering (MEAD config) ──
V2_NAMES = [
    'eyeBlinkLeft', 'eyeBlinkRight',
    'eyeLookDownLeft', 'eyeLookDownRight',
    'eyeLookInLeft', 'eyeLookInRight',
    'eyeLookOutLeft', 'eyeLookOutRight',
    'eyeLookUpLeft', 'eyeLookUpRight',
    'eyeSquintLeft', 'eyeSquintRight',
    'eyeWideLeft', 'eyeWideRight',
    'jawForward', 'jawLeft', 'jawRight', 'jawOpen',
    'mouthClose', 'mouthFunnel', 'mouthPucker',
    'mouthLeft', 'mouthRight',
    'mouthSmileLeft', 'mouthSmileRight',
    'mouthFrownLeft', 'mouthFrownRight',
    'mouthDimpleLeft', 'mouthDimpleRight',
    'mouthStretchLeft', 'mouthStretchRight',
    'mouthRollLower', 'mouthRollUpper',
    'mouthShrugLower', 'mouthShrugUpper',
    'mouthPressLeft', 'mouthPressRight',
    'mouthLowerDownLeft', 'mouthLowerDownRight',
    'mouthUpperUpLeft', 'mouthUpperUpRight',
    'browDownLeft', 'browDownRight',
    'browInnerUp',
    'browOuterUpLeft', 'browOuterUpRight',
    'cheekPuff',
    'cheekSquintLeft', 'cheekSquintRight',
    'noseSneerLeft', 'noseSneerRight',
    'tongueOut',
]
assert len(V2_NAMES) == 52

# Remap V2 → AnimaSync: out[anima_idx] = v2[v2_idx]
V2_TO_ANIMA = np.array(
    [V2_NAMES.index(n) for n in ARKIT_52_NAMES], dtype=np.int64
)

# ── Sub-emotion → 5-base mapping (for V2 conditioning) ──
EMO_TO_IDX = {'neutral': 0, 'joy': 1, 'anger': 2, 'sadness': 3, 'surprise': 4}
SUB_TO_BASE = {
    'neutral': 'neutral',
    'joy': 'joy', 'laughter': 'joy', 'excitement': 'joy',
    'agreement': 'joy', 'gratitude': 'joy',
    'sadness': 'sadness', 'crying': 'sadness', 'sulk': 'sadness',
    'apology': 'sadness', 'struggle': 'sadness',
    'anger': 'anger', 'refusal': 'anger',
    'surprise': 'surprise', 'fluster': 'surprise', 'shy': 'surprise',
}

# ── Channel masks by NAME (AnimaSync ordering) ──
MASK_A_STRICT = [
    'browInnerUp', 'browDownLeft', 'browDownRight',
    'browOuterUpLeft', 'browOuterUpRight',
    'cheekSquintLeft', 'cheekSquintRight',
    'noseSneerLeft', 'noseSneerRight',
    'eyeSquintLeft', 'eyeSquintRight',
]

MASK_B_TIERED = {
    # name: alpha  (alpha 0.5 = tier-1, 0.25 = tier-2)
    'browInnerUp': 0.5,
    'browDownLeft': 0.5, 'browDownRight': 0.5,
    'browOuterUpLeft': 0.5, 'browOuterUpRight': 0.5,
    'cheekSquintLeft': 0.5, 'cheekSquintRight': 0.5,
    'noseSneerLeft': 0.5, 'noseSneerRight': 0.5,
    'eyeSquintLeft': 0.5, 'eyeSquintRight': 0.5,
    # tier 2:
    'mouthSmileLeft': 0.25, 'mouthSmileRight': 0.25,
    'mouthFrownLeft': 0.25, 'mouthFrownRight': 0.25,
    'eyeWideLeft': 0.25, 'eyeWideRight': 0.25,
    'mouthDimpleLeft': 0.25, 'mouthDimpleRight': 0.25,
}

ALPHA_STRICT = 0.5

# Per-channel alpha — expression channels get α=1.0 to pass V2's natural
# motion amplitude through on top of user's preset anchor. Envelope cap
# (preset L1..L5 range) still bounds the peaks and troughs.
ALPHA_BOOST = {
    'browInnerUp': 1.0,
    'browDownLeft': 1.0, 'browDownRight': 1.0,
    'browOuterUpLeft': 1.0, 'browOuterUpRight': 1.0,
    'eyeSquintLeft': 1.5, 'eyeSquintRight': 1.5,
    'cheekSquintLeft': 1.5, 'cheekSquintRight': 1.5,
    'noseSneerLeft': 1.5, 'noseSneerRight': 1.5,
}


def get_preset_envelope(emotion, presets):
    """Per-channel (min, max) across L1/L3/L5 of this emotion family.

    Returns (lo, hi) each (52,). Used to cap V2 motion so it never exceeds
    the user-authored intensity range for that emotion.

    Option E correction: on the mouth/cheek channels that the parametric
    overlay can drive, expand `hi` to accommodate parametric's max output at
    extreme V (positive or negative). Otherwise the cap clamps Option E's
    contribution back to 0 on emotions where the authored mouth shape is 0
    (e.g. surprise).
    """
    lo = np.ones(52, dtype=np.float32)
    hi = np.zeros(52, dtype=np.float32)
    found = False
    for L in (1, 3, 5):
        key = f"{emotion}_L{L}"
        if key in presets:
            bs = np.asarray(presets[key]['bs'], dtype=np.float32)
            lo = np.minimum(lo, bs)
            hi = np.maximum(hi, bs)
            found = True
    if not found:
        return np.zeros(52, dtype=np.float32), np.ones(52, dtype=np.float32)

    # Expand envelope on Option E channels to allow parametric overlay through.
    # parametric_layer's mouth/cheek terms are V-driven; max happens at V=±1.
    p_pos = parametric_layer(np.array([+1.0, 1.0, 1.0], dtype=np.float32))
    p_neg = parametric_layer(np.array([-1.0, 1.0, 1.0], dtype=np.float32))
    p_max = np.maximum(p_pos, p_neg)
    for ch in MOUTH_CHEEK_OVERLAY_CHANNELS:
        hi[ch] = max(hi[ch], float(p_max[ch]))

    return lo, hi

# ── Helpers ──

def name_to_idx(names, ordering=ARKIT_52_NAMES):
    return np.array([ordering.index(n) for n in names], dtype=np.int64)


def cross_emotion_compile(vads: np.ndarray, presets: dict,
                          sigma: float = 0.4) -> np.ndarray:
    """RBF over ALL preset anchors based on VAD distance (cross-emotion blend).

    Unlike `compile_expressive_batch` which interpolates only within an
    emotion family using L1/L3/L5 anchors, this blends across every preset
    in the dataset. Frames whose VAD lies between, e.g., crying and sulk
    anchors get a weighted mix of both shapes — adds emotional cross-
    pollination from VAD proximity rather than the categorical label.

    vads: (T, 3) per-frame VAD trajectory.
    presets: dict from utils.load_presets_from_json. Each value: 'vad' + 'bs'.
    sigma: RBF kernel width. Smaller = nearest anchor dominates, larger = broader.

    Returns (T, 52) blendshape pose array.
    """
    anchor_vads = np.asarray([p['vad'] for p in presets.values()],
                             dtype=np.float32)  # (N, 3)
    anchor_bs = np.asarray([p['bs'] for p in presets.values()],
                           dtype=np.float32)    # (N, 52)
    T = vads.shape[0]
    out = np.zeros((T, 52), dtype=np.float32)
    inv_2sig2 = 1.0 / (2.0 * sigma ** 2)
    for t in range(T):
        d2 = np.sum((anchor_vads - vads[t]) ** 2, axis=1)
        w = np.exp(-d2 * inv_2sig2)
        s = w.sum()
        if s > 1e-9:
            w /= s
        out[t] = w @ anchor_bs
    return np.clip(out, 0.0, 1.0).astype(np.float32)


def highpass_per_channel(x: np.ndarray, sigma_frames: float = 15.0) -> np.ndarray:
    """Remove slow drift via subtracting a Gaussian-smoothed version.

    sigma_frames=15 at 30fps → 0.5s cutoff → keeps sub-second prosody.
    """
    from scipy.ndimage import gaussian_filter1d
    smoothed = gaussian_filter1d(x, sigma=sigma_frames, axis=0, mode='nearest')
    return x - smoothed


def bandpass_prosody(x: np.ndarray, hp_sigma: float = 15.0,
                     lp_sigma: float = 4.0) -> np.ndarray:
    """Bandpass filter for prosody: HPF removes slow drift, LPF removes jitter.

    lp_sigma=4 at 30fps ≈ 133ms window → smooths phoneme-rate jitter but
    preserves syllable-rate prosody (~150-300ms).
    """
    from scipy.ndimage import gaussian_filter1d
    hp = highpass_per_channel(x, sigma_frames=hp_sigma)
    return gaussian_filter1d(hp, sigma=lp_sigma, axis=0, mode='nearest')


def _one_euro_filter(signal: np.ndarray, fps: float = 30.0,
                     min_cutoff: float = 1.5, beta: float = 0.5,
                     d_cutoff: float = 1.0) -> np.ndarray:
    """One-Euro filter (ported from mead training). Peak-preserving low-pass."""
    def sf(te, cutoff):
        r = 2.0 * np.pi * cutoff * te
        return r / (r + 1.0)
    T = len(signal)
    out = np.zeros(T, dtype=np.float32)
    out[0] = signal[0]
    dx_prev = 0.0
    te = 1.0 / fps
    for i in range(1, T):
        a_d = sf(te, d_cutoff)
        dx = (signal[i] - out[i - 1]) / te
        dx_hat = a_d * dx + (1.0 - a_d) * dx_prev
        dx_prev = dx_hat
        cutoff = min_cutoff + beta * abs(dx_hat)
        a = sf(te, cutoff)
        out[i] = a * signal[i] + (1.0 - a) * out[i - 1]
    return out


# Brow channels — high-magnitude swings between extreme emotions (sad↔anger,
# surprise↔anger) look unnatural under straight-line LERP. We route them
# through a brief neutral pause when a big swing is detected.
BROW_CHANNELS = [0, 1, 2, 3, 4]  # browDownL, browDownR, browInnerUp, browOuterUpL, browOuterUpR
# Only route through neutral when a brow channel CHANGES a lot across the
# boundary (e.g., sad→anger flips browInnerUp from 0.7→0.0 AND browDown from
# 0.0→0.5). Surprise→sad has both endpoints raised, so |delta| is small and
# we just LERP normally.
BROW_SWING_DELTA = 0.40
NEUTRAL_PAUSE_FRACTION = 0.20  # shorter pause — was 0.35, looked artificial


def _brow_pass_through_zero(prev_v: float, next_v: float, t: float) -> float:
    """For a single brow channel value over the crossfade.
    t ∈ [0, 1] across the fade window.
    Profile:
      [0, 0.5−PAUSE/2]:  prev → 0 (cosine ramp-down)
      [0.5−PAUSE/2, 0.5+PAUSE/2]:  hold at 0 (neutral pause)
      [0.5+PAUSE/2, 1]:  0 → next (cosine ramp-up)
    """
    half_pause = NEUTRAL_PAUSE_FRACTION / 2
    if t < 0.5 - half_pause:
        # Ramp prev → 0
        local_t = t / (0.5 - half_pause) if (0.5 - half_pause) > 0 else 1.0
        eased = 0.5 * (1.0 - np.cos(local_t * np.pi))
        return (1.0 - eased) * prev_v
    elif t < 0.5 + half_pause:
        return 0.0  # neutral pause
    else:
        # Ramp 0 → next
        denom = (0.5 - half_pause)
        local_t = (t - (0.5 + half_pause)) / denom if denom > 0 else 1.0
        eased = 0.5 * (1.0 - np.cos(local_t * np.pi))
        return eased * next_v


def crossfade_turn_boundaries(comp_stack, turn_lengths, fade_frames: int = 8):
    """Cosine-eased blend of static comp_bs across turn boundaries.

    Cosine ease (smoothstep) reads as more natural than linear because real
    facial-expression shifts have ramp-up/ramp-down rather than constant
    velocity. Same fade_frames duration as linear, but the *feel* is smoother.

    For BROW channels with large opposite-sign swings (e.g. raised brows for
    sadness/surprise transitioning to lowered brows for anger), the pose-LERP
    approach looks unnatural — straight-line trajectory between extremes lacks
    the brief muscular relaxation real human faces show. We route those brow
    channels through a brief neutral pause (Option B).

    comp_stack: list of (T_i, 52) arrays, one per turn (static-per-turn).
    turn_lengths: parallel list of T_i.
    Returns concatenated (total_T, 52) with smooth boundary transitions.
    """
    concat = np.concatenate(comp_stack, axis=0).astype(np.float32)
    half = fade_frames // 2
    cursor = 0
    for i, Ti in enumerate(turn_lengths[:-1]):
        cursor += Ti
        prev_pose = comp_stack[i][-1]      # last frame of turn i
        next_pose = comp_stack[i+1][0]     # first frame of turn i+1
        fade_start = max(0, cursor - half)
        fade_end = min(concat.shape[0], cursor + half)
        L = fade_end - fade_start
        if L <= 1:
            continue

        # Decide which brow channels qualify for pass-through-neutral.
        # Trigger only on a LARGE CHANGE across the boundary — i.e. the
        # channel inverts or drops/rises dramatically. Cases like
        # surprise→sad have similar raised-brow values, so |delta| is small
        # and they LERP normally (preserving "brows stayed up").
        brow_pass_channels = []
        for ch in BROW_CHANNELS:
            if abs(float(prev_pose[ch]) - float(next_pose[ch])) > BROW_SWING_DELTA:
                brow_pass_channels.append(ch)

        for f in range(fade_start, fade_end):
            t = (f - fade_start) / (L - 1)
            alpha = 0.5 * (1.0 - np.cos(t * np.pi))
            # Default cosine LERP for all channels
            concat[f] = (1.0 - alpha) * prev_pose + alpha * next_pose
            # Override brow channels that qualify with pass-through-zero profile
            for ch in brow_pass_channels:
                concat[f, ch] = _brow_pass_through_zero(prev_pose[ch], next_pose[ch], t)
    return concat


def smooth_expression_channels(target: np.ndarray,
                                min_cutoff: float = 1.5,
                                beta: float = 0.5) -> np.ndarray:
    """Apply One-Euro filter to expression channels (not lipsync-critical)."""
    result = target.copy()
    # Channels to smooth: EXPRESSION_ONLY + SHARED minus jawOpen (ch 24)
    smooth_ch = sorted(set(EXPRESSION_ONLY) | (set(SHARED_CHANNELS) - {24}))
    for ch in smooth_ch:
        result[:, ch] = _one_euro_filter(result[:, ch],
                                         min_cutoff=min_cutoff, beta=beta)
    return np.clip(result, 0.0, 1.0).astype(np.float32)


def find_scenario(sid):
    for p in SCENARIOS:
        if not p.exists():
            continue
        with p.open() as f:
            for line in f:
                s = json.loads(line)
                if s['scenario_id'] == sid:
                    return s
    return None


def find_audio(sid, ti, emo):
    exact = AUDIO_DIR / f"{sid}_t{ti}_{emo}.mp3"
    if exact.exists() and exact.stat().st_size > 1000:
        return exact
    matches = [m for m in AUDIO_DIR.glob(f"{sid}_t{ti}_*.mp3")
               if m.stat().st_size > 1000]
    return matches[0] if matches else None


def run_v2(sess, feat_extractor, wav, emotion_name):
    """Run V2 ONNX. Returns (T, 52) in AnimaSync order, float32."""
    feats = feat_extractor.extract(wav).astype(np.float32)
    T = feats.shape[0]
    base = SUB_TO_BASE.get(emotion_name, 'neutral')
    emotion = np.zeros((1, T, 5), dtype=np.float32)
    if base != 'neutral':
        emotion[0, :, EMO_TO_IDX[base]] = 1.0
    out = sess.run(None, {
        'features': feats[None, ...],
        'emotion': emotion,
    })[0][0]  # (T, 52) in V2 ordering
    # Remap V2 ordering → AnimaSync ordering
    return out[:, V2_TO_ANIMA].astype(np.float32)


def apply_v2_dynamics(comp_bs, v2_bs, variant, envelope_lo=None, envelope_hi=None):
    """Add V2 dynamics to comp_bs per mask variant.

    Returns modified comp_bs (same shape, AnimaSync ordering).

    Brow channels get ALPHA_BOOST × normal α for expressivity.
    Output is capped to the emotion's preset envelope (lo, hi) if provided —
    peaks never exceed what the user authored for that emotion.
    """
    if variant == 'C':
        return comp_bs  # no V2

    v2_dyn = bandpass_prosody(v2_bs, hp_sigma=15.0, lp_sigma=4.0)  # (T, 52)
    out = comp_bs.copy()

    if variant == 'A':
        for name in MASK_A_STRICT:
            ch = ARKIT_52_NAMES.index(name)
            alpha = ALPHA_BOOST.get(name, ALPHA_STRICT)
            out[:, ch] = comp_bs[:, ch] + alpha * v2_dyn[:, ch]
    elif variant == 'B':
        for name, base_alpha in MASK_B_TIERED.items():
            ch = ARKIT_52_NAMES.index(name)
            alpha = ALPHA_BOOST.get(name, base_alpha)
            out[:, ch] = comp_bs[:, ch] + alpha * v2_dyn[:, ch]
    else:
        raise ValueError(f"unknown variant: {variant}")

    # Cap peaks at preset L5 (user's authored max per emotion).
    # Only clamp channels where the static comp_bs anchor has headroom below
    # envelope_hi. Where comp_bs is already AT the ceiling (e.g. anger's
    # browDown at preset L5 ≈ envelope_hi), unconditional `np.minimum`
    # chopped the positive excursions of zero-mean v2_dyn while keeping the
    # negatives — DC-shifting the channel downward and weakening the emotion
    # (this is why anger looked under-furrowed). On saturated channels we
    # skip the cap and let np.clip[0,1] below bound the absolute range, so
    # v2_dyn stays symmetric around the anchor.
    if envelope_hi is not None:
        headroom = envelope_hi[None, :] - comp_bs
        cap_mask = headroom > 1e-4
        out = np.where(cap_mask, np.minimum(out, envelope_hi[None, :]), out)
    return np.clip(out, 0.0, 1.0).astype(np.float32)


def concat_audio(paths, out_mp3):
    list_file = out_mp3.with_suffix(".concat.txt")
    with list_file.open("w") as f:
        for p in paths:
            f.write(f"file '{p.resolve()}'\n")
    subprocess.run([
        "ffmpeg", "-y", "-loglevel", "error",
        "-f", "concat", "-safe", "0", "-i", str(list_file),
        "-c:a", "libmp3lame", "-b:a", "128k", str(out_mp3),
    ], check=True)
    list_file.unlink()


def export_viewer_json(target, scen, turns_meta, sid_out, num_frames):
    json_out = OUT_DIR / f"{sid_out}.json"
    payload = {
        "scenario_id": sid_out,
        "fps": FPS,
        "num_frames": int(num_frames),
        "names": ARKIT_52_NAMES,
        "turns": turns_meta,
        "blendshapes": np.round(target, 4).tolist(),
    }
    with json_out.open("w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False)
    return json_out


def update_manifest(base_id: str, variants: list, turns_meta: list, scen: dict):
    """Append/update manifest.json so the player can list scenarios."""
    manifest_path = OUT_DIR / "manifest.json"
    manifest = {"scenarios": []}
    if manifest_path.exists():
        try:
            manifest = json.loads(manifest_path.read_text())
        except Exception:
            pass
    # Drop existing entry with same base_id, then append fresh
    manifest["scenarios"] = [
        s for s in manifest.get("scenarios", []) if s.get("base") != base_id
    ]
    # Build a short label from first turn
    first_turn = turns_meta[0] if turns_meta else {}
    text_preview = first_turn.get("text", "")[:50]
    emotions_in_turn = [t.get("emotion", "?") for t in turns_meta]
    manifest["scenarios"].append({
        "base": base_id,
        "variants": variants,
        "scenario_id": scen.get("scenario_id", base_id),
        "n_turns": len(turns_meta),
        "emotions": emotions_in_turn,
        "text_preview": text_preview,
    })
    # Sort newest-first by insertion (just keep order — newest at end)
    manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))


# ── Main ──

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--scenario", default="daily_003")
    ap.add_argument("--turn", type=int, default=None,
                    help="If set, process only this single turn index")
    ap.add_argument("--option-e-intensity", type=float, default=1.0,
                    help="α scalar for Option E parametric mouth/cheek overlay "
                         "(1.0=full, 0.7=−30%%, 0.0=disabled)")
    ap.add_argument("--fade-frames", type=int, default=None,
                    help="Override turn-boundary crossfade duration in frames "
                         "(default: 48 for monologues ≥3 turns, 8 otherwise)")
    ap.add_argument("--vad-smooth-sigma", type=float, default=15.0,
                    help="Gaussian sigma (in frames) for cross-turn VAD trajectory "
                         "smoothing. 0 = disabled (per-turn step). 15 ≈ 0.5s ramp "
                         "around boundaries. NOTE: symmetric kernel — offline use only.")
    ap.add_argument("--vad-damp-gamma", type=float, default=0.0,
                    help="Causal monologue VAD damping. Each turn's VAD becomes "
                         "γ·raw + (1−γ)·running_mean(past_turns). 0=disabled, "
                         "0.6=moderate, 0.4=strong. First turn always unchanged. "
                         "Single-turn scenarios bypassed entirely. Real-time safe.")
    ap.add_argument("--vad-damp-beta", type=float, default=0.7,
                    help="Running-mean update rate for causal damping. "
                         "running_mean = β·old + (1−β)·new_turn. 0.7 = slow drift.")
    ap.add_argument("--persistence-damping", type=float, default=1.0,
                    help="Pose-level scale toward zero for multi-turn fleeting "
                         "emotions. Default 1.0 (off). When <1.0, isolated "
                         "turns have comp_bs · ps; paired turns midpoint; "
                         "sustained/single-turn unaffected. Pair with "
                         "--brow-v2-floor to keep V2 prosody jitter intact "
                         "(otherwise clipping at 0 squashes V2 by ~50%%). "
                         "Recommended 0.5.")
    ap.add_argument("--brow-v2-floor", type=float, default=0.0,
                    help="Minimum value held on brow channels (0-4) AFTER "
                         "pose damping, before V2 dynamics is added. Gives V2 "
                         "negative prosody excursions room to oscillate "
                         "without clipping at zero. Recommended 0.06-0.10 "
                         "when --persistence-damping is on. Adds a small "
                         "constant brow lift even when authored value is 0.")
    ap.add_argument("--out-suffix", default="",
                    help="Suffix appended to output sid so experimental "
                         "renders don't overwrite the canonical scenario.")
    ap.add_argument("--blink-interval", type=float, default=4.0,
                    help="Mean seconds between blinks (Poisson). Default 4.0. "
                         "Lower = more frequent. Real conversational blink rate "
                         "is ~3-4s.")
    ap.add_argument("--cross-emotion-weight", type=float, default=0.0,
                    help="Weight (0-1) for cross-emotion VAD-driven pose. "
                         "Default 0.0 (pure within-emotion authored anchor). "
                         "When >0, blends a VAD-distance-weighted RBF across "
                         "ALL preset anchors (regardless of emotion family) "
                         "with the current authored within-emotion pose. "
                         "Final = (1-w)·authored + w·cross_emotion. Recommended "
                         "0.3 for 'authored with a bit of cross-emotion bleed'.")
    ap.add_argument("--cross-emotion-sigma", type=float, default=0.4,
                    help="Gaussian RBF sigma for cross-emotion VAD distance "
                         "weighting. Smaller = tighter (only nearest anchors "
                         "matter), larger = broader blend. Default 0.4.")
    args = ap.parse_args()
    sid = args.scenario

    OUT_DIR.mkdir(parents=True, exist_ok=True)

    scen = find_scenario(sid)
    if scen is None:
        raise SystemExit(f"scenario not found: {sid}")

    # Load models
    print("[setup] loading LAM...")
    lam = LAMWrapper()
    print("[setup] loading V2 ONNX...")
    v2_sess = ort.InferenceSession(ONNX_V2, providers=['CPUExecutionProvider'])
    v2_feat = AudioFeatureExtractor()
    print(f"[setup] loading presets: {PRESETS}")
    if PRESETS.exists():
        presets = load_presets_from_json(PRESETS)
        print(f"  loaded {len(presets)} presets")
    else:
        presets = build_synthetic_presets()
        print(f"  WARN: no presets JSON, using synthetic ({len(presets)})")

    # Per-variant accumulator
    variants = ['A', 'B', 'C']
    targets = {v: [] for v in variants}
    comp_stacks = {v: [] for v in variants}  # per-turn comp_bs (for crossfade)
    lam_stack = []
    gate_stack = []
    turn_lengths = []
    audio_paths = []
    turns_meta = []

    # ── Pass 1: collect per-turn audio + LAM + V2 + gate (defer compile) ──
    collected = []
    for ti, turn in enumerate(scen['turns']):
        if args.turn is not None and ti != args.turn:
            continue
        if not turn.get('text', '').strip():
            continue
        emo = turn.get('emotion', 'neutral')
        vad = turn.get('vad', [0, 0, 0])
        ap_ = find_audio(sid, ti, emo)
        if ap_ is None:
            print(f"  [t{ti}] no audio, skip")
            continue

        wav, sr = librosa.load(str(ap_), sr=16000, mono=True)
        if len(wav) < 16000 * 0.1:
            continue

        print(f"  [t{ti}] {emo:10s}  wav={len(wav)/sr:.2f}s")

        lam_bs = lam.infer_audio(ap_)
        T = lam_bs.shape[0]
        v2_bs = run_v2(v2_sess, v2_feat, wav, emo)
        if v2_bs.shape[0] > T:
            v2_bs = v2_bs[:T]
        elif v2_bs.shape[0] < T:
            pad = T - v2_bs.shape[0]
            v2_bs = np.concatenate([v2_bs, np.tile(v2_bs[-1:], (pad, 1))], axis=0)
        gate = speech_gate(lam_bs)

        collected.append({
            'turn_idx': ti, 'emotion': emo, 'vad': vad, 'ap': ap_,
            'lam_bs': lam_bs, 'v2_bs': v2_bs, 'gate': gate, 'T': T,
            'text': turn.get('text'), 'speaker': turn.get('speaker'),
            'persist_scale': 1.0,
        })

    if not collected:
        raise SystemExit("no turns processed")

    # ── Persistence-based VAD scaling (multi-turn fleeting-emotion damping).
    # Reduces VAD magnitude for emotions that don't persist across adjacent
    # turns. Real performance does this naturally — a fleeting mid-narrative
    # emotion is gentler than a sustained one. Single-turn scenarios are
    # untouched. Same-base-throughout monologues are also untouched
    # (the emotion owns the scenario → full commitment).
    if args.persistence_damping < 1.0 and len(collected) > 1:
        fleeting = float(args.persistence_damping)
        paired = 0.5 * (fleeting + 1.0)
        bases = [SUB_TO_BASE.get(c['emotion'], 'neutral') for c in collected]
        if len(set(bases)) == 1:
            print(f"[persist-damp] all turns share base '{bases[0]}'; skipped")
        else:
            n = len(bases)
            run_len = [0] * n
            i = 0
            while i < n:
                j = i
                while j < n and bases[j] == bases[i]:
                    j += 1
                for k in range(i, j):
                    run_len[k] = j - i
                i = j
            scales = []
            for k in range(n):
                p = run_len[k]
                s = fleeting if p == 1 else (paired if p == 2 else 1.0)
                scales.append(s)
                collected[k]['persist_scale'] = s
            print(f"[persist-damp] bases={bases}  persistence={run_len}  "
                  f"scales={[round(s,2) for s in scales]}  (pose-level)")

    # ── Causal monologue VAD damping (Step 1 — past-only, real-time safe).
    # Pulls each turn's VAD toward a running mean of *past* turns. First turn
    # is always left untouched so single-turn expressiveness is preserved
    # (single-turn scenarios bypass this block entirely via len>1 guard).
    if args.vad_damp_gamma > 0 and len(collected) > 1:
        γ = float(args.vad_damp_gamma)
        β = float(args.vad_damp_beta)
        running_mean = np.array(collected[0]['vad'], dtype=np.float32)
        # First turn unchanged; running_mean seeded from it.
        for c in collected[1:]:
            raw = np.array(c['vad'], dtype=np.float32)
            # Damp using past-only running_mean FIRST, then update mean.
            # (Updating first would fold current turn into "past" and weaken the pull.)
            damped = γ * raw + (1.0 - γ) * running_mean
            running_mean = β * running_mean + (1.0 - β) * raw
            c['vad'] = damped.tolist()
        print(f"[vad-damp] γ={γ} β={β} applied to {len(collected)-1} non-first turns "
              f"(causal, real-time safe)")

    # ── Build per-frame VAD trajectory across the whole scenario, then smooth.
    # This is "Option B": authored per-turn VAD becomes the *peak* of that
    # turn's emotion; the trajectory between peaks ramps via Gaussian filter
    # with sigma = args.vad_smooth_sigma frames. sigma=0 → step function (off).
    all_vads = np.concatenate([
        np.tile(np.asarray(c['vad'], dtype=np.float32), (c['T'], 1))
        for c in collected
    ], axis=0)
    if args.vad_smooth_sigma > 0 and len(collected) > 1:
        from scipy.ndimage import gaussian_filter1d
        all_vads = gaussian_filter1d(
            all_vads, sigma=args.vad_smooth_sigma, axis=0, mode='nearest'
        ).astype(np.float32)
        print(f"[vad-smooth] σ={args.vad_smooth_sigma}f applied across "
              f"{len(collected)} turns")

    # ── Pass 2: compile per-turn using smoothed VAD slice + apply V2 dynamics
    cursor = 0
    for c in collected:
        T = c['T']
        emo = c['emotion']
        vad_slice = all_vads[cursor:cursor + T]
        cursor += T

        comp_bs = compile_expressive_batch(
            emotions=[emo] * T,
            vads=vad_slice,
            presets=presets,
            parametric_overlay_intensity=args.option_e_intensity,
        )

        # Optional cross-emotion VAD blending. Mixes a VAD-distance-RBF over
        # ALL preset anchors with the within-emotion authored pose. Default
        # weight 0 = pure within-emotion (legacy). >0 = blend.
        if args.cross_emotion_weight > 0.0:
            xemo = cross_emotion_compile(vad_slice, presets,
                                         sigma=args.cross_emotion_sigma)
            w = float(args.cross_emotion_weight)
            comp_bs = ((1.0 - w) * comp_bs + w * xemo).astype(np.float32)

        # Pose-level persistence damping (scale toward zero). Uniformly
        # shrinks every channel — produces the "subdued overall" feel.
        # Pair with --brow-v2-floor to give V2 prosody headroom so its
        # negative excursions don't clip when comp_bs is small.
        ps = float(c.get('persist_scale', 1.0))
        if ps < 1.0:
            comp_bs = (comp_bs * ps).astype(np.float32)
            if args.brow_v2_floor > 0.0:
                # Lift brow channels (0-4) to at least the floor so V2's
                # bandpass-prosody negative excursions have room before clip.
                for ch in BROW_CHANNELS:
                    comp_bs[:, ch] = np.maximum(comp_bs[:, ch],
                                                args.brow_v2_floor)

        env_lo, env_hi = get_preset_envelope(emo, presets)
        for v in variants:
            comp_modified = apply_v2_dynamics(
                comp_bs, c['v2_bs'], v, envelope_lo=env_lo, envelope_hi=env_hi
            )
            comp_stacks[v].append(comp_modified)
        lam_stack.append(c['lam_bs'])
        gate_stack.append(c['gate'])
        turn_lengths.append(T)
        audio_paths.append(c['ap'])
        turns_meta.append({
            'turn_idx': c['turn_idx'], 'emotion': c['emotion'], 'vad': c['vad'],
            'text': c['text'], 'speaker': c['speaker'],
        })

    if not audio_paths:
        raise SystemExit("no turns processed")

    # Crossfade comp_bs at turn boundaries, then merge with concatenated LAM
    # Longer fade + lower min_cutoff for multi-turn scenarios with emotion shifts
    is_monologue = len(turn_lengths) >= 3
    # Default 48 frames (1.6s) for monologues — cosine-eased crossfade reads
    # naturally at this duration. CLI --fade-frames overrides.
    fade_frames = args.fade_frames if args.fade_frames is not None else (
        48 if is_monologue else 8
    )
    min_cutoff = 0.8 if is_monologue else 1.5     # heavier smoothing for monologues
    lam_cat = np.concatenate(lam_stack, axis=0)
    gate_cat = np.concatenate(gate_stack, axis=0)
    print(f"[smooth] fade_frames={fade_frames}  min_cutoff={min_cutoff}  "
          f"(monologue={is_monologue})")
    for v in variants:
        comp_cat = crossfade_turn_boundaries(comp_stacks[v], turn_lengths,
                                             fade_frames=fade_frames)
        merged = merge_lam_compiler(lam_cat, comp_cat, gate_cat)
        merged = smooth_expression_channels(merged, min_cutoff=min_cutoff, beta=0.5)
        # Final pass: natural blinks + subtle iris tremor (deterministic per scenario)
        merged = apply_eye_motion(merged, seed_str=f"{sid}::{v}", fps=30,
                                  blink_interval_s=args.blink_interval)
        targets[v] = merged

    # Concat audio once (shared across variants)
    turn_suffix = f"_t{args.turn}" if args.turn is not None else ""
    shared_mp3 = OUT_DIR / f"{sid}{turn_suffix}{args.out_suffix}_ABC.mp3"
    concat_audio(audio_paths, shared_mp3)
    print(f"[audio] {shared_mp3}")

    # Export 3 JSONs, each referencing the shared mp3
    print("\n[export]")
    suffix = f"{turn_suffix}{args.out_suffix}"
    for v in variants:
        concat_target = targets[v]
        sid_out = f"{sid}{suffix}_{v}"
        json_out = export_viewer_json(
            concat_target, scen, turns_meta, sid_out, concat_target.shape[0]
        )
        # Also copy/symlink the shared mp3 so the player finds it via sid_out
        variant_mp3 = OUT_DIR / f"{sid_out}.mp3"
        if variant_mp3.exists() or variant_mp3.is_symlink():
            variant_mp3.unlink()
        variant_mp3.symlink_to(shared_mp3.name)
        size_kb = json_out.stat().st_size // 1024
        print(f"  {v}: {json_out.name} ({size_kb} KB)  shape={concat_target.shape}")

    # Update manifest so the player can list this scenario
    base_id = f"{sid}{suffix}"
    update_manifest(base_id, variants, turns_meta, scen)
    print(f"\n[manifest] updated: {OUT_DIR / 'manifest.json'}")
    print(f"\nOpen in browser:")
    print(f"  http://localhost:8890/tools/blendshape-player.html?scenario={base_id}")


if __name__ == '__main__':
    main()
