"""V3 face inference — run the trained student on baked scenarios and write
the prediction as a viewer JSON so you can A/B against the teacher target.

Usage:
    # Predict a few scenarios, write viewer files, also print per-channel L1
    PYTHONPATH=. python3 -m models.v3_face.infer -s long_001 long_046 long_001_p0

    # Use latest.pt instead of best.pt
    PYTHONPATH=. python3 -m models.v3_face.infer --ckpt models/v3_face/checkpoints/latest.pt -s long_001

    # All curated viewer scenarios (anything currently in data/viewer/)
    PYTHONPATH=. python3 -m models.v3_face.infer --all-viewer

Output:
    data/viewer/<sid>_pred_dataset.json   (V3 student prediction)
    data/viewer/<sid>_pred_dataset.mp3    (symlink to teacher's mp3 — same audio)
    manifest.json updated with the _pred entries so the player dropdown shows them
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
from scipy.ndimage import gaussian_filter1d

from scripts.compiler.constants import (
    ARKIT_52_NAMES, LIPSYNC_ONLY, EXPRESSION_ONLY, SHARED_CHANNELS,
)
from scripts.compiler.eye_motion import add_blinks as _add_blinks_proc

from .config import V3FaceConfig
from .model import V3FaceModel

PROJECT_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_NPZ_DIR = PROJECT_ROOT / "data" / "v3_training"
DEFAULT_VIEWER_DIR = PROJECT_ROOT / "data" / "viewer"
DEFAULT_EMOTION_DIR = PROJECT_ROOT / "data" / "emotion"
DEFAULT_CKPT = PROJECT_ROOT / "models" / "v3_face" / "checkpoints" / "best.pt"

EYE_BLINK = [8, 9]

# mouthClose drives m/b/p phoneme closures — kept in the crisp pass but with
# its own (typically lower) sigma so it stays sharp while the rest of the
# lipsync channels can carry a tinsy bit more smoothing.
MOUTH_CLOSE_CH = 26

# V1's crisp_mouth domain — mouth + jaw channels. We apply only to the lipsync
# branch's outputs so expression branch channels are untouched.
CRISP_CHANNELS = sorted(set(LIPSYNC_ONLY) | set(SHARED_CHANNELS))

# Brow channels (ARKit 0-4). Smoothed at inference with an adaptive One-Euro
# filter: large per-frame deltas (expression transitions) pass through
# essentially untouched, while small deltas (intra-turn micro-jitter) get
# more smoothing than large ones. Gentle defaults — barely visible but the
# intra-turn calm comes through. Matches V1's brow-deployment cutoff (1.5).
BROW_CHANNELS = [0, 1, 2, 3, 4]
EYE_SQUINT_CHANNELS = [18, 19]
EYE_WIDE_CHANNELS = [20, 21]
EYE_LOOK_CHANNELS = [10, 11, 12, 13, 14, 15, 16, 17]   # gaze direction (pupils)


def _one_euro(signal: np.ndarray, fps: float = 30.0,
              min_cutoff: float = 1.5, beta: float = 0.01,
              d_cutoff: float = 1.0) -> np.ndarray:
    """Adaptive low-pass. Cutoff rises with |dx|, so fast changes pass through
    with minimal smoothing and slow drift gets aggressively smoothed.
    Reference: animasync-face-v1/training/inference.py::OneEuroFilter.
    """
    import math
    T = len(signal)
    if T <= 1:
        return signal.astype(np.float32, copy=True)
    out = np.zeros(T, dtype=np.float32)
    out[0] = float(signal[0])
    dx_prev = 0.0
    te = 1.0 / fps

    def sf(t_e, cutoff):
        r = 2.0 * math.pi * cutoff * t_e
        return r / (r + 1.0)

    for i in range(1, T):
        a_d = sf(te, d_cutoff)
        dx = (float(signal[i]) - out[i - 1]) / te
        dx_hat = a_d * dx + (1.0 - a_d) * dx_prev
        dx_prev = dx_hat
        cutoff = min_cutoff + beta * abs(dx_hat)
        a = sf(te, cutoff)
        out[i] = a * float(signal[i]) + (1.0 - a) * out[i - 1]
    return out


def inject_blinks(blendshapes: np.ndarray, scenario_id: str,
                  mean_interval_s: float = 4.0,
                  expressive_cap: float = 0.5,
                  fps: int = 30) -> np.ndarray:
    """Inject procedurally generated Poisson-distributed blinks on channels
    8 and 9. Blink timing is decoupled from audio in real speech, so V3 can't
    learn it; we use the same deterministic generator the teacher uses
    (scripts.compiler.eye_motion.add_blinks) seeded by scenario_id.

    The model's existing values on channels 8 and 9 are preserved up to
    `expressive_cap` so legitimate sustained partial closure (squinty smile,
    sleepy, crying — typically 0.2–0.4) survives. add_blinks composes via
    max(), so the procedural 0.70 blink peaks always show through. Anything
    the model produced above the cap (e.g., the solo failure mode where it
    locks at 0.99) is clamped — that's a model failure, not expression.

    Also caps eyeSquint (18, 19) and eyeWide (20, 21) at 0.84 (just under the
    0.85 suppression threshold in add_blinks) so a misprediction in those
    channels doesn't block blinks entirely.
    """
    result = blendshapes.copy().astype(np.float32)
    np.minimum(result[:, 8], expressive_cap, out=result[:, 8])
    np.minimum(result[:, 9], expressive_cap, out=result[:, 9])
    for ch in (18, 19, 20, 21):
        np.minimum(result[:, ch], 0.84, out=result[:, ch])
    return _add_blinks_proc(result, seed_str=scenario_id,
                            fps=fps, mean_interval_s=mean_interval_s)


def smooth_brows(blendshapes: np.ndarray,
                 min_cutoff: float = 1.5,
                 beta: float = 0.01,
                 d_cutoff: float = 1.0,
                 fps: float = 30.0) -> np.ndarray:
    """Apply One-Euro to the 5 brow channels. Transition spikes preserved,
    intra-turn jitter smoothed.
    """
    result = blendshapes.copy().astype(np.float32)
    for ch in BROW_CHANNELS:
        result[:, ch] = _one_euro(result[:, ch], fps=fps,
                                  min_cutoff=min_cutoff,
                                  beta=beta, d_cutoff=d_cutoff)
    return np.clip(result, 0.0, 1.0).astype(np.float32)


def smooth_eye_squint(blendshapes: np.ndarray,
                       min_cutoff: float = 0.8,
                       beta: float = 0.01,
                       d_cutoff: float = 1.0,
                       fps: float = 30.0) -> np.ndarray:
    """Apply One-Euro to eyeSquint L/R (ch 18, 19). Heavier than brows by
    default (min_cutoff 0.8 vs 1.5) — orbicularis oculi is naturally slow
    and sustained, so we can attenuate sub-Hz wiggle without losing real
    squints.
    """
    result = blendshapes.copy().astype(np.float32)
    for ch in EYE_SQUINT_CHANNELS:
        result[:, ch] = _one_euro(result[:, ch], fps=fps,
                                  min_cutoff=min_cutoff,
                                  beta=beta, d_cutoff=d_cutoff)
    return np.clip(result, 0.0, 1.0).astype(np.float32)


def smooth_eye_wide(blendshapes: np.ndarray,
                     min_cutoff: float = 0.8,
                     beta: float = 0.01,
                     d_cutoff: float = 1.0,
                     fps: float = 30.0) -> np.ndarray:
    """Apply One-Euro to eyeWide L/R (ch 20, 21). Same shape as eyeSquint —
    paired anti-correlated channel that transitions from surprise to other
    emotions, where smoothing helps the cross-fade look continuous.
    """
    result = blendshapes.copy().astype(np.float32)
    for ch in EYE_WIDE_CHANNELS:
        result[:, ch] = _one_euro(result[:, ch], fps=fps,
                                  min_cutoff=min_cutoff,
                                  beta=beta, d_cutoff=d_cutoff)
    return np.clip(result, 0.0, 1.0).astype(np.float32)


def smooth_eye_look(blendshapes: np.ndarray,
                     min_cutoff: float = 1.0,
                     beta: float = 0.01,
                     d_cutoff: float = 1.0,
                     fps: float = 30.0) -> np.ndarray:
    """Apply One-Euro to the 8 eyeLook channels (10–17 = gaze direction /
    pupils). The compiler's iris drift is already smooth; this filter kills
    the model's residual high-frequency reproduction noise. Slightly higher
    cutoff than eyeSquint so saccade-like rapid gaze shifts still pass.
    """
    result = blendshapes.copy().astype(np.float32)
    for ch in EYE_LOOK_CHANNELS:
        result[:, ch] = _one_euro(result[:, ch], fps=fps,
                                  min_cutoff=min_cutoff,
                                  beta=beta, d_cutoff=d_cutoff)
    return np.clip(result, 0.0, 1.0).astype(np.float32)


def crisp_mouth(blendshapes: np.ndarray,
                threshold: float = 0.3,
                scale: float = 1.0,
                pre_smooth_sigma: float = 1.0,
                mouth_close_sigma: float | None = None) -> np.ndarray:
    """V1's crisp_mouth — soft-threshold gate on mouth/jaw channels.

    For each lipsync-branch channel:
      1. Pre-smooth with a small Gaussian (kills HF noise without distorting
         legitimate phoneme onsets)
      2. Normalize to [0, 1] by channel max
      3. Apply smoothstep gate with edges (0.3·threshold, 1.2·threshold) —
         values below 0.3·threshold fade to 0, values above 1.2·threshold
         pass through fully
      4. Scale and clip

    Args:
        pre_smooth_sigma: Gaussian σ for every lipsync channel EXCEPT
            mouthClose.
        mouth_close_sigma: Gaussian σ for mouthClose (ch 26). If None,
            falls back to `pre_smooth_sigma`. Set lower than the main sigma
            to keep m/b/p closures crisp while smoothing the rest.

    Reference: animasync-face-v1/deployment/lipsync_distilled.py::crisp_mouth.
    The smoothstep gate is the key: it kills sub-threshold jitter (our 13%
    HF residual) WITHOUT being a low-pass filter that smears phoneme attacks.
    """
    if mouth_close_sigma is None:
        mouth_close_sigma = pre_smooth_sigma
    result = blendshapes.copy().astype(np.float32)
    edge0 = threshold * 0.3
    edge1 = threshold * 1.2
    for ch in CRISP_CHANNELS:
        sigma_ch = mouth_close_sigma if ch == MOUTH_CLOSE_CH else pre_smooth_sigma
        vals = result[:, ch].copy()
        if sigma_ch > 0:
            vals = gaussian_filter1d(vals, sigma=sigma_ch)
        max_val = float(vals.max()) if vals.max() > 0 else 1.0
        normalized = vals / max_val
        t = np.clip((normalized - edge0) / (edge1 - edge0), 0.0, 1.0)
        gate = t * t * (3.0 - 2.0 * t)
        result[:, ch] = np.clip(vals * gate * scale, 0.0, 1.0)
    return result


def load_model(ckpt_path: Path, device: torch.device):
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    cfg_dict = ckpt["config"]
    # Reconstruct config (handles fields that might be tuples-from-list)
    cfg = V3FaceConfig()
    for k, v in cfg_dict.items():
        if hasattr(cfg, k):
            if k == "dilations" and isinstance(v, list):
                v = tuple(v)
            setattr(cfg, k, v)
    model = V3FaceModel(cfg).to(device)
    model.load_state_dict(ckpt["model"])
    model.eval()
    print(f"[ckpt] {ckpt_path}  epoch={ckpt.get('epoch', '?')}  "
          f"val_l1={ckpt.get('val_l1', float('nan')):.4f}  "
          f"params={model.n_params/1e6:.2f}M")
    return model, cfg


def per_channel_l1(pred: np.ndarray, target: np.ndarray) -> Dict[str, float]:
    """Average L1 grouped by channel category."""
    if pred.shape != target.shape:
        T = min(pred.shape[0], target.shape[0])
        pred, target = pred[:T], target[:T]
    diff = np.abs(pred - target).mean(axis=0)  # (52,)
    lip = float(diff[LIPSYNC_ONLY].mean())
    exp = float(diff[EXPRESSION_ONLY].mean())
    shr = float(diff[SHARED_CHANNELS].mean())
    blink = float(diff[EYE_BLINK].mean())
    overall = float(diff.mean())
    return {"overall": overall, "lipsync": lip, "expression": exp,
            "shared": shr, "blink": blink}


def load_turns_meta(sid: str, emotion_dir: Path) -> List[dict]:
    """Find turns metadata for a scenario across the three split JSONLs."""
    import re
    SPLIT_RE = re.compile(r"^(daily_.+)_t(\d+)$")
    for fname in ("seed_train_final.jsonl", "seed_val.jsonl", "seed_test.jsonl"):
        p = emotion_dir / fname
        if not p.exists(): continue
        with p.open() as f:
            for line in f:
                row = json.loads(line)
                if row["scenario_id"] == sid:
                    return [
                        {"turn_idx": ti, "emotion": t.get("emotion", "neutral"),
                         "vad": t.get("vad", [0, 0, 0]), "text": t.get("text", ""),
                         "speaker": t.get("speaker", "")}
                        for ti, t in enumerate(row["turns"])
                        if t.get("text", "").strip()
                    ]
        # daily-split lookup via parent
        m = SPLIT_RE.match(sid)
        if m:
            with p.open() as f:
                for line in f:
                    row = json.loads(line)
                    if row["scenario_id"] == m.group(1):
                        ti = int(m.group(2))
                        t = row["turns"][ti]
                        return [{"turn_idx": ti,
                                 "emotion": t.get("emotion", "neutral"),
                                 "vad": t.get("vad", [0, 0, 0]),
                                 "text": t.get("text", ""),
                                 "speaker": t.get("speaker", "")}]
    return []


def predict_and_save(sid: str, model, device, args) -> Dict[str, float]:
    """Run model on one scenario, write viewer JSON, return per-channel L1."""
    npz_path = args.npz_dir / f"{sid}.npz"
    if not npz_path.exists():
        print(f"  ✗ {sid}: no npz found at {npz_path}")
        return {}
    data = np.load(npz_path)
    audio = torch.from_numpy(data["audio"].astype(np.float32)).unsqueeze(0).to(device)
    cond_np = data["cond"].astype(np.float32)
    # Mirror the training-time emotion one-hot smoothing so the model
    # sees the same cond distribution it was trained on. Must match the
    # `--smooth-cond-sigma-emotion` value used during training of this
    # checkpoint.
    smooth_sigma = getattr(args, "smooth_cond_sigma_emotion", 0.0) or 0.0
    if smooth_sigma > 0.0:
        cond_np[:, :16] = gaussian_filter1d(cond_np[:, :16],
                                             sigma=smooth_sigma,
                                             axis=0, mode="nearest")
    cond = torch.from_numpy(cond_np).unsqueeze(0).to(device)
    target = data["target"].astype(np.float32)

    with torch.no_grad():
        pred = model(audio, cond).squeeze(0).cpu().numpy().astype(np.float32)

    if getattr(args, "crisp", False):
        pred = crisp_mouth(
            pred,
            threshold=args.crisp_threshold,
            scale=args.crisp_scale,
            pre_smooth_sigma=args.crisp_sigma,
            mouth_close_sigma=args.crisp_mouthclose_sigma,
        )

    if getattr(args, "brow_smooth", False):
        pred = smooth_brows(
            pred,
            min_cutoff=args.brow_min_cutoff,
            beta=args.brow_beta,
            d_cutoff=args.brow_d_cutoff,
        )

    if getattr(args, "eye_squint_smooth", False):
        pred = smooth_eye_squint(
            pred,
            min_cutoff=args.eye_squint_min_cutoff,
            beta=args.eye_squint_beta,
            d_cutoff=args.eye_squint_d_cutoff,
        )

    if getattr(args, "eye_wide_smooth", False):
        pred = smooth_eye_wide(
            pred,
            min_cutoff=args.eye_wide_min_cutoff,
            beta=args.eye_wide_beta,
            d_cutoff=args.eye_wide_d_cutoff,
        )

    if getattr(args, "eye_look_smooth", False):
        pred = smooth_eye_look(
            pred,
            min_cutoff=args.eye_look_min_cutoff,
            beta=args.eye_look_beta,
            d_cutoff=args.eye_look_d_cutoff,
        )

    if getattr(args, "add_blinks", False):
        pred = inject_blinks(pred, scenario_id=sid,
                             mean_interval_s=args.blink_interval,
                             expressive_cap=args.blink_expressive_cap)

    metrics = per_channel_l1(pred, target)

    # Write viewer JSON with _pred[_variant]_dataset suffix (the `_dataset`
    # tail keeps it visible in the player's dataset-filtered dropdown; the
    # optional `_variant` between _pred and _dataset lets multiple gain
    # variants coexist without clobbering each other).
    variant_part = f"_{args.variant_tag}" if args.variant_tag else ""
    new_base = f"{sid}_pred{variant_part}_dataset"
    viewer_json = {
        "scenario_id": new_base,
        "fps": 30,
        "num_frames": int(pred.shape[0]),
        "names": ARKIT_52_NAMES,
        "turns": load_turns_meta(sid, args.emotion_dir),
        "blendshapes": np.round(pred, 4).tolist(),
    }
    out_json = args.viewer_dir / f"{new_base}.json"
    out_json.write_text(json.dumps(viewer_json, ensure_ascii=False))

    # Audio: symlink to the teacher's mp3 — same audio, no need to re-concat
    teacher_mp3 = args.viewer_dir / f"{sid}_dataset.mp3"
    pred_mp3 = args.viewer_dir / f"{new_base}.mp3"
    if pred_mp3.exists() or pred_mp3.is_symlink():
        pred_mp3.unlink()
    if teacher_mp3.exists():
        pred_mp3.symlink_to(teacher_mp3.name)
    else:
        print(f"  ! {sid}: teacher mp3 not found at {teacher_mp3} — "
              "the prediction will play silently. Run dataset_to_viewer.py first.")

    print(f"  ✓ {sid}  L1: overall={metrics['overall']:.4f}  "
          f"lip={metrics['lipsync']:.4f}  exp={metrics['expression']:.4f}  "
          f"shr={metrics['shared']:.4f}  blink={metrics['blink']:.4f}  "
          f"({pred.shape[0]} frames)")
    return {"sid": sid, "new_base": new_base, **metrics}


def update_manifest(viewer_dir: Path, predictions: List[dict]):
    """Add _pred_dataset entries to the viewer's manifest.json."""
    manifest_path = viewer_dir / "manifest.json"
    manifest = (json.loads(manifest_path.read_text())
                if manifest_path.exists() else {"scenarios": []})
    seen = {s.get("base"): s for s in manifest["scenarios"]}
    for p in predictions:
        new_base = p["new_base"]
        # Look up teacher entry to copy turns/emotions metadata
        teacher_entry = seen.get(f"{p['sid']}_dataset", {})
        seen[new_base] = {
            "base": new_base,
            "scenario_id": new_base,
            "variants": [],
            "n_turns": teacher_entry.get("n_turns", 1),
            "emotions": teacher_entry.get("emotions", []),
            "text_preview": "[V3 pred] " + teacher_entry.get("text_preview", ""),
        }
    manifest["scenarios"] = list(seen.values())
    manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=Path, default=DEFAULT_CKPT)
    ap.add_argument("--npz_dir", type=Path, default=DEFAULT_NPZ_DIR)
    ap.add_argument("--viewer_dir", type=Path, default=DEFAULT_VIEWER_DIR)
    ap.add_argument("--emotion_dir", type=Path, default=DEFAULT_EMOTION_DIR)
    ap.add_argument("--device", default="cuda:0")
    ap.add_argument("--smooth-cond-sigma-emotion", type=float, default=0.0,
                    help="Gaussian σ (frames @ 30 fps) for smoothing the "
                         "emotion one-hot in cond[:, :16]. MUST match the "
                         "value used during training of the loaded "
                         "checkpoint. 0 = no smoothing (default).")
    ap.add_argument("--variant-tag", type=str, default=None,
                    help="Optional tag baked into output filenames "
                         "({sid}_pred_{tag}_dataset.json). Use to keep "
                         "predictions from different gain variants from "
                         "clobbering each other in data/viewer/.")
    ap.add_argument("-s", "--scenarios", nargs="+", default=None,
                    help="Specific scenario IDs to predict.")
    ap.add_argument("--all-viewer", action="store_true",
                    help="Predict every scenario currently curated in data/viewer/ "
                         "(i.e. every *_dataset.json without _pred).")
    ap.add_argument("--crisp", action="store_true",
                    help="Apply V1-style crisp_mouth post-processing to lipsync "
                         "channels (Gaussian pre-smooth + smoothstep soft-"
                         "threshold gate). Kills sub-threshold mouth jitter "
                         "without smearing phoneme attacks. Reference: "
                         "animasync-face-v1/deployment/lipsync_distilled.py.")
    ap.add_argument("--crisp-threshold", type=float, default=0.3,
                    help="Smoothstep center for crisp_mouth gate. Default 0.3.")
    ap.add_argument("--crisp-scale", type=float, default=1.0,
                    help="Amplification factor after the gate. Default 1.0 "
                         "(V1 used 1.2 to compensate for the gate's attenuation).")
    ap.add_argument("--crisp-sigma", type=float, default=1.0,
                    help="Gaussian σ (frames) for pre-smoothing before the "
                         "gate, applied to every lipsync channel EXCEPT "
                         "mouthClose. Default 1.0 (V1's smooth_frames=2 → σ=1).")
    ap.add_argument("--crisp-mouthclose-sigma", type=float, default=1.0,
                    help="Separate Gaussian σ for mouthClose (ch 26). Default "
                         "1.0. Keep at or below --crisp-sigma so m/b/p phoneme "
                         "closures stay sharp while the rest of the mouth can "
                         "be slightly more smoothed.")
    ap.add_argument("--brow-smooth", action="store_true",
                    help="Apply gentle One-Euro adaptive filter to brow "
                         "channels (0-4). Transition spikes pass through; "
                         "intra-turn micro-jitter is softened. Matches V1's "
                         "brow deployment.")
    ap.add_argument("--brow-min-cutoff", type=float, default=1.5,
                    help="One-Euro min_cutoff for brows. V1 used 1.5 in "
                         "deployment, 2.0 in training inference. Default 1.5 "
                         "(barely visible smoothing). Raise toward 2.0 for "
                         "even subtler effect.")
    ap.add_argument("--brow-beta", type=float, default=0.01,
                    help="One-Euro β — how fast the cutoff rises with |dx|. "
                         "Higher = transitions pass through faster. Default "
                         "0.01 (V1's brow-deployment value).")
    ap.add_argument("--brow-d-cutoff", type=float, default=1.0,
                    help="One-Euro derivative cutoff. Default 1.0.")
    ap.add_argument("--eye-squint-smooth", action="store_true",
                    help="Apply One-Euro adaptive filter to eyeSquint L/R "
                         "(ch 18, 19). Mirrors --brow-smooth but heavier — "
                         "orbicularis oculi is naturally slow so we can "
                         "smooth sub-Hz wiggle hard without killing real "
                         "squints. Use this when training-time target "
                         "smoothing alone doesn't kill the visible jitter.")
    ap.add_argument("--eye-squint-min-cutoff", type=float, default=0.8,
                    help="One-Euro min_cutoff for eyeSquint. Lower than "
                         "brows (default 0.8 vs 1.5) for heavier smoothing.")
    ap.add_argument("--eye-squint-beta", type=float, default=0.01,
                    help="One-Euro β for eyeSquint. Default 0.01.")
    ap.add_argument("--eye-squint-d-cutoff", type=float, default=1.0,
                    help="One-Euro derivative cutoff for eyeSquint. Default 1.0.")
    ap.add_argument("--eye-wide-smooth", action="store_true",
                    help="Apply One-Euro adaptive filter to eyeWide L/R "
                         "(ch 20, 21). Paired with eyeSquint for surprise → "
                         "other-emotion transitions where eyeWide↓ and "
                         "eyeSquint↑ cross discontinuously.")
    ap.add_argument("--eye-wide-min-cutoff", type=float, default=0.8,
                    help="One-Euro min_cutoff for eyeWide. Default 0.8.")
    ap.add_argument("--eye-wide-beta", type=float, default=0.01,
                    help="One-Euro β for eyeWide. Default 0.01.")
    ap.add_argument("--eye-wide-d-cutoff", type=float, default=1.0,
                    help="One-Euro derivative cutoff for eyeWide. Default 1.0.")
    ap.add_argument("--eye-look-smooth", action="store_true",
                    help="Apply One-Euro adaptive filter to the 8 eyeLook "
                         "channels (10–17, gaze direction / pupils). The "
                         "compiler's iris drift is already smooth; this "
                         "kills the model's residual reproduction noise. "
                         "Cutoff slightly higher than eyeSquint so saccade-"
                         "like rapid shifts still pass.")
    ap.add_argument("--eye-look-min-cutoff", type=float, default=1.0,
                    help="One-Euro min_cutoff for eyeLook (pupils). Default 1.0.")
    ap.add_argument("--eye-look-beta", type=float, default=0.01,
                    help="One-Euro β for eyeLook. Default 0.01.")
    ap.add_argument("--eye-look-d-cutoff", type=float, default=1.0,
                    help="One-Euro derivative cutoff for eyeLook. Default 1.0.")
    ap.add_argument("--add-blinks", action="store_true",
                    help="Replace V3's blink channels (8, 9) with procedural "
                         "Poisson-distributed blinks. Uses the same generator "
                         "as the training teacher (eye_motion.add_blinks), "
                         "seeded deterministically by scenario_id. Use this "
                         "because blink timing isn't audio-determined and "
                         "V3 cannot learn it from audio alone.")
    ap.add_argument("--blink-interval", type=float, default=4.0,
                    help="Mean blink interval in seconds. Default 4.0 "
                         "(matches the teacher's bake-time value). Real human "
                         "blink rate is 15-20/min ≈ 3-4 s.")
    ap.add_argument("--blink-expressive-cap", type=float, default=0.5,
                    help="Cap for the model's pre-existing eyeBlink (ch 8, 9) "
                         "values. Below the cap, expressive sustained eye "
                         "closure (squinty smile, sleepy) survives. Above, "
                         "broken predictions (locked-shut eyes) are clamped. "
                         "Procedural blinks always poke through via max(). "
                         "Default 0.5.")
    args = ap.parse_args()

    if not args.scenarios and not args.all_viewer:
        ap.error("provide --scenarios or --all-viewer")

    if args.all_viewer:
        # Use whatever's currently in the viewer dropdown (teacher entries)
        sids = []
        for p in sorted(args.viewer_dir.glob("*_dataset.json")):
            stem = p.stem  # e.g. "long_001_dataset"
            if "_pred_dataset" in stem: continue
            sids.append(stem[: -len("_dataset")])
        print(f"all-viewer: {len(sids)} scenarios")
    else:
        sids = args.scenarios

    device = torch.device(args.device)
    model, cfg = load_model(args.ckpt, device)

    print(f"\nPredicting {len(sids)} scenarios → {args.viewer_dir}")
    predictions = []
    for sid in sids:
        result = predict_and_save(sid, model, device, args)
        if result:
            predictions.append(result)

    if predictions:
        update_manifest(args.viewer_dir, predictions)
        # Aggregate metrics
        overall = np.mean([p["overall"] for p in predictions])
        lip = np.mean([p["lipsync"] for p in predictions])
        exp = np.mean([p["expression"] for p in predictions])
        blink = np.mean([p["blink"] for p in predictions])
        print(f"\n── aggregate over {len(predictions)} scenarios ──")
        print(f"  overall L1: {overall:.4f}")
        print(f"  lipsync L1: {lip:.4f}")
        print(f"  expression L1: {exp:.4f}")
        print(f"  blink L1: {blink:.4f}")
    print(f"\nDone. Hard-reload viewer. Each scenario now appears twice:")
    print(f"  <sid>_dataset      = teacher target")
    print(f"  <sid>_pred_dataset = V3 student prediction")


if __name__ == "__main__":
    main()
