#!/usr/bin/env python3
"""Post-process .npz target files: apply Gaussian smoothing to brow channels.

Useful when you want a smoother teacher target without re-running the full
data_pipeline.py (which re-invokes LAM + V2 ONNX per scenario). This script
only touches the saved `target` array on brow channels — it doesn't recompute
audio, cond, or any non-brow channel.

Usage:
    # Smooth in place with σ=2 (mild smoothing)
    PYTHONPATH=. python3 scripts/postprocess_brow_smooth.py --sigma 2.0 --in-place

    # Write to a new dir so you can A/B (keeps the originals intact)
    PYTHONPATH=. python3 scripts/postprocess_brow_smooth.py --sigma 4.0 \\
        --out_dir data/v3_training_brow_s4

    # Smooth ALL expression channels, not just brows
    PYTHONPATH=. python3 scripts/postprocess_brow_smooth.py --sigma 2.0 --in-place \\
        --channels brow,cheek,nose,eyesquint

σ guide:
    1.0 ≈ 33 ms — barely visible smoothing
    2.0 ≈ 67 ms — gentle, kills high-frequency jitter
    4.0 ≈ 133 ms — moderate, may damp V2 syllable-rate prosody
    8.0 ≈ 267 ms — heavy, brows become nearly static

After running, re-convert to viewer:
    python3 scripts/dataset_to_viewer.py
    python3 scripts/curate_viewer.py --clean
"""
from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
from scipy.ndimage import gaussian_filter1d

PROJECT_ROOT = Path(__file__).resolve().parents[1]

# Channel groups (ARKit-52 indices, must match scripts/compiler/constants.py)
CHANNEL_GROUPS = {
    "brow":      [0, 1, 2, 3, 4],   # browDownL/R, browInnerUp, browOuterUpL/R
    "cheek":     [10, 11],           # cheekSquintL/R
    "nose":      [16, 17],           # noseSneerL/R
    "eyesquint": [18, 19],           # eyeSquintL/R
    "eyewide":   [20, 21],           # eyeWideL/R
    "all_v2":    [0, 1, 2, 3, 4, 10, 11, 16, 17, 18, 19],  # variant A's mask
}


def resolve_channels(spec: str) -> list[int]:
    out: set[int] = set()
    for part in spec.split(","):
        part = part.strip().lower()
        if not part:
            continue
        if part in CHANNEL_GROUPS:
            out.update(CHANNEL_GROUPS[part])
        else:
            try:
                out.add(int(part))
            except ValueError:
                raise SystemExit(f"unknown channel: {part!r} "
                                  f"(known: {list(CHANNEL_GROUPS)} or integer)")
    return sorted(out)


def main():
    ap = argparse.ArgumentParser(description=__doc__,
                                  formatter_class=argparse.RawDescriptionHelpFormatter)
    ap.add_argument("--npz_dir", type=Path,
                    default=PROJECT_ROOT / "data" / "v3_training")
    ap.add_argument("--sigma", type=float, default=2.0,
                    help="Gaussian σ in frames (30fps). Default 2.0 (≈67ms).")
    ap.add_argument("--channels", default="brow",
                    help="Comma-separated channel groups or indices. "
                         f"Groups: {list(CHANNEL_GROUPS.keys())}. Default 'brow'.")
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument("--in-place", action="store_true",
                    help="Overwrite .npz files in --npz_dir.")
    g.add_argument("--out_dir", type=Path, default=None,
                    help="Write smoothed copies here (originals untouched).")
    args = ap.parse_args()

    channels = resolve_channels(args.channels)
    target_dir = args.npz_dir if args.in_place else args.out_dir
    target_dir.mkdir(parents=True, exist_ok=True)

    files = sorted(args.npz_dir.glob("*.npz"))
    if not files:
        raise SystemExit(f"no .npz found in {args.npz_dir}")

    print(f"smoothing channels {channels} with σ={args.sigma} "
          f"({args.sigma * 1000 / 30:.0f} ms equivalent)")
    print(f"input:  {args.npz_dir}  ({len(files)} files)")
    print(f"output: {target_dir}  {'(in-place)' if args.in_place else ''}")
    print()

    for i, p in enumerate(files, 1):
        d = np.load(p)
        target = d["target"].copy()
        for ch in channels:
            target[:, ch] = gaussian_filter1d(target[:, ch],
                                              sigma=args.sigma, mode="nearest")
        target = np.clip(target, 0.0, 1.0).astype(np.float32)
        np.savez_compressed(target_dir / p.name,
                            audio=d["audio"], cond=d["cond"], target=target)
        if i % 200 == 0 or i == len(files):
            print(f"  {i}/{len(files)}  {p.name}")

    print(f"\ndone — {len(files)} files written to {target_dir}")
    print("re-run dataset_to_viewer.py to update viewer previews.")


if __name__ == "__main__":
    main()
