#!/usr/bin/env python3
"""Post-process: add audio-gated bilateral brow jitter to an existing render.

Loads a rendered scenario from data/viewer/, adds filtered Gaussian noise on
the 5 brow channels (gated by audio silence so V2 dominates during speech),
writes new variant JSONs and updates manifest.

Usage:
    python3 scripts/post_brow_jitter.py --scenario long_046_pose50 \
        --amp 0.014 --freq-sigma 1.5 --out-suffix _jit14

Iterate fast:
    python3 scripts/post_brow_jitter.py -s long_046_pose50 -a 0.012 -f 1.2 -o _t1
    python3 scripts/post_brow_jitter.py -s long_046_pose50 -a 0.018 -f 0.9 -o _t2

Frequency reference (σ=Gaussian sigma in frames at 30 fps):
    σ=0.5  → ~7 Hz dominant   (fast tremor)
    σ=0.7  → ~5 Hz            (mid)
    σ=1.0  → ~3 Hz
    σ=1.5  → ~2 Hz            (slow drift)
    σ=2.0  → ~1.5 Hz          (very slow drift)
"""
from __future__ import annotations

import argparse
import hashlib
import json
from pathlib import Path

import librosa
import numpy as np
from scipy.ndimage import gaussian_filter1d

PROJECT_ROOT = Path(__file__).resolve().parents[1]
VIEWER = PROJECT_ROOT / "data" / "viewer"
FPS = 30

# Bilateral correlation: L/R brow channels share noise stream so they don't
# look "drunk" (independent left/right twitching). Three muscle groups.
GROUPS = {
    'down':    [0, 1],   # browDownLeft, browDownRight
    'innerUp': [2],      # browInnerUp (single)
    'outerUp': [3, 4],   # browOuterUpLeft, browOuterUpRight
}


def make_noise(rng: np.random.Generator, n: int, sigma: float) -> np.ndarray:
    """Filtered Gaussian noise with unit RMS."""
    raw = rng.standard_normal(n).astype(np.float32)
    s = gaussian_filter1d(raw, sigma=sigma, mode='nearest')
    return s / (np.sqrt(np.mean(s ** 2)) + 1e-6)


def silence_gate(mp3_path: Path, n_frames: int, fps: int = FPS,
                 silence_db: float = -45.0, soft_range: float = 12.0,
                 smooth_sigma: float = 6.0) -> np.ndarray:
    """Per-frame gate ∈ [0, 1]: 1 = silent, 0 = active speech."""
    wav, sr = librosa.load(str(mp3_path), sr=22050, mono=True)
    hop = int(sr / fps)
    rms = np.array([
        np.sqrt(np.mean(wav[i*hop:min((i+1)*hop, len(wav))] ** 2))
        if i*hop < len(wav) else 0.0
        for i in range(n_frames)
    ], dtype=np.float32) + 1e-9
    db = 20.0 * np.log10(rms)
    gate = np.clip((silence_db + soft_range - db) / soft_range, 0.0, 1.0)
    return gaussian_filter1d(gate, sigma=smooth_sigma, mode='nearest')


def find_source_mp3(scenario: str, variant: str) -> Path:
    """Return the actual MP3 file (resolves symlink)."""
    direct = VIEWER / f"{scenario}_{variant}.mp3"
    if direct.is_symlink():
        return VIEWER / direct.readlink()
    if direct.exists():
        return direct
    abc = VIEWER / f"{scenario}_ABC.mp3"
    if abc.exists():
        return abc
    raise FileNotFoundError(f"no mp3 found for {scenario}_{variant}")


def process_variant(scenario: str, variant: str, new_base: str,
                    amp: float, freq_sigma: float) -> tuple[int, float]:
    """Returns (silent_frame_count, max_gate_value)."""
    src_fp = VIEWER / f"{scenario}_{variant}.json"
    if not src_fp.exists():
        raise FileNotFoundError(f"source JSON not found: {src_fp}")

    data = json.loads(src_fp.read_text())
    bs = np.array(data["blendshapes"], dtype=np.float32)
    T = bs.shape[0]

    mp3 = find_source_mp3(scenario, variant)
    gate = silence_gate(mp3, T)

    # Deterministic seed per scenario+variant
    h = hashlib.md5(f"{new_base}::{variant}".encode("utf-8")).hexdigest()
    rng = np.random.default_rng(int(h[:16], 16))

    streams = {g: make_noise(rng, T, freq_sigma) for g in GROUPS}
    for grp, channels in GROUPS.items():
        for ch in channels:
            bs[:, ch] += amp * gate * streams[grp]

    bs = np.clip(bs, 0.0, 1.0).astype(np.float32)
    data["blendshapes"] = np.round(bs, 4).tolist()
    data["scenario_id"] = new_base

    out_fp = VIEWER / f"{new_base}_{variant}.json"
    out_fp.write_text(json.dumps(data, ensure_ascii=False))

    out_mp3 = VIEWER / f"{new_base}_{variant}.mp3"
    if out_mp3.exists() or out_mp3.is_symlink():
        out_mp3.unlink()
    # Symlink to the same MP3 the source uses
    src_mp3_name = (VIEWER / f"{scenario}_{variant}.mp3").readlink() \
        if (VIEWER / f"{scenario}_{variant}.mp3").is_symlink() \
        else f"{scenario}_{variant}.mp3"
    out_mp3.symlink_to(src_mp3_name)

    return int((gate > 0.5).sum()), float(gate.max())


def update_manifest(scenario: str, new_base: str, variants: list[str]) -> None:
    mfp = VIEWER / "manifest.json"
    manifest = json.loads(mfp.read_text())
    manifest["scenarios"] = [
        s for s in manifest["scenarios"] if s.get("base") != new_base
    ]
    src_entry = next(
        (s for s in manifest["scenarios"] if s.get("base") == scenario), None
    )
    if src_entry is not None:
        manifest["scenarios"].append({
            **src_entry,
            "base": new_base,
            "scenario_id": new_base,
            "variants": variants,
        })
        mfp.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))


def main():
    ap = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                 description=__doc__)
    ap.add_argument("-s", "--scenario", required=True,
                    help="Source scenario base, e.g. 'long_046_pose50'")
    ap.add_argument("-a", "--amp", type=float, default=0.014,
                    help="Jitter amplitude in blendshape units. Default 0.014.")
    ap.add_argument("-f", "--freq-sigma", type=float, default=1.5,
                    help="Gaussian sigma in frames (lower=faster). "
                         "0.5≈7Hz, 1.0≈3Hz, 1.5≈2Hz, 2.0≈1.5Hz. Default 1.5.")
    ap.add_argument("-o", "--out-suffix", default="_jit",
                    help="Appended to scenario for output base name. Default '_jit'.")
    ap.add_argument("-v", "--variants", default="A,B",
                    help="Comma-separated variants to process. Default 'A,B'.")
    args = ap.parse_args()

    new_base = args.scenario + args.out_suffix
    variants = [v.strip() for v in args.variants.split(",") if v.strip()]
    dom_hz = 0.7 * FPS / (2 * np.pi * args.freq_sigma)
    print(f"[params] amp={args.amp}  σ={args.freq_sigma}f → ~{dom_hz:.1f}Hz dominant")
    print(f"[i/o]    {args.scenario}  →  {new_base}  (variants: {variants})")

    for v in variants:
        sf, mg = process_variant(args.scenario, v, new_base,
                                 args.amp, args.freq_sigma)
        print(f"  {v}:  silent frames {sf}  max gate {mg:.2f}")

    update_manifest(args.scenario, new_base, variants)
    print(f"\nmanifest updated: '{new_base}' added")
    print(f"open: http://localhost:8890/tools/blendshape-player.html"
          f"?scenario={new_base}&variant={variants[0]}")


if __name__ == "__main__":
    main()
