"""V3 face training — distill teacher's .npz targets into a causal TCN.

Usage:
    PYTHONPATH=. python3 -m models.v3_face.train               # full train
    PYTHONPATH=. python3 -m models.v3_face.train --smoke       # 10 scenarios × 5 epochs
    PYTHONPATH=. python3 -m models.v3_face.train --device cuda:1 --epochs 80
"""
from __future__ import annotations

import argparse
import math
import time
from pathlib import Path
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, WeightedRandomSampler

from scripts.compiler.constants import LIPSYNC_ONLY, EXPRESSION_ONLY, SHARED_CHANNELS

from .config import V3FaceConfig
from .dataset import BlendshapeDataset
from .model import V3FaceModel

try:
    import wandb
    _WANDB_AVAILABLE = True
except ImportError:
    _WANDB_AVAILABLE = False

PROJECT_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_NPZ_DIR = PROJECT_ROOT / "data" / "v3_training"
DEFAULT_EMOTION_DIR = PROJECT_ROOT / "data" / "emotion"
DEFAULT_OUT_DIR = PROJECT_ROOT / "models" / "v3_face" / "checkpoints"


# ── Loss helpers ──────────────────────────────────────────────────────

def make_channel_weights(cfg: V3FaceConfig) -> torch.Tensor:
    """Per-channel L1 weight vector, shape (52,).

    Lipsync channels get cfg.lipsync_weight (audio-sync matters most).
    Eye-blink (ch 8, 9) get cfg.eye_blink_weight (sparse, hard to learn).
    Everything else gets cfg.expression_weight.
    """
    w = torch.full((cfg.output_dim,), cfg.expression_weight, dtype=torch.float32)
    for ch in LIPSYNC_ONLY:
        w[ch] = cfg.lipsync_weight
    for ch in SHARED_CHANNELS:
        # Shared channels (jawOpen, mouth corners) blend LAM + compiler;
        # they sit between lipsync and expression in importance.
        w[ch] = 0.5 * (cfg.lipsync_weight + cfg.expression_weight)
    # Eye blink — sparse / sharp, weight down so model doesn't try to memorize
    w[8] = cfg.eye_blink_weight
    w[9] = cfg.eye_blink_weight
    return w


def masked_l1(pred: torch.Tensor, target: torch.Tensor,
              valid_length: torch.Tensor, ch_weights: torch.Tensor) -> torch.Tensor:
    """L1 loss masked to valid frames, weighted per channel.

    pred, target:    (B, T, C)
    valid_length:    (B,) int  — number of valid frames per sample
    ch_weights:      (C,)
    """
    B, T, C = pred.shape
    # Frame mask: (B, T) with 1.0 inside valid_length, 0.0 in pad
    frame_idx = torch.arange(T, device=pred.device).unsqueeze(0)  # (1, T)
    mask = (frame_idx < valid_length.unsqueeze(1)).float()        # (B, T)
    diff = (pred - target).abs() * ch_weights[None, None, :]
    diff = diff.mean(dim=-1)                                       # (B, T)
    diff = diff * mask
    denom = mask.sum().clamp(min=1.0)
    return diff.sum() / denom


def make_velocity_weights(cfg: V3FaceConfig) -> torch.Tensor:
    """Per-channel velocity penalty weights, shape (52,).

    Lipsync channels get heavy smoothing; brows + cheek get light smoothing
    (preserve V2 prosody motion); eye-blink channels get near-zero so the
    sharp 5-frame blink kernel survives training.
    """
    w = torch.full((cfg.output_dim,), cfg.velocity_expression_weight,
                    dtype=torch.float32)
    for ch in LIPSYNC_ONLY:
        w[ch] = cfg.velocity_lipsync_weight
    for ch in SHARED_CHANNELS:
        w[ch] = cfg.velocity_shared_weight
    # Eye blink — minimal smoothing so the blink kernel stays crisp
    w[8] = cfg.velocity_eye_blink_weight
    w[9] = cfg.velocity_eye_blink_weight
    # Optional eyeSquint override (v18b) — orbicularis oculi is naturally
    # slow + sustained; bumping velocity weight here kills the jitter that
    # appears when expression_target_gain amplifies frame-to-frame target
    # noise. None → keeps the expression default.
    if cfg.velocity_eye_squint_weight is not None:
        w[18] = cfg.velocity_eye_squint_weight
        w[19] = cfg.velocity_eye_squint_weight
    # Optional brow override (v18c). Same story as eyeSquint — heavy gain on
    # brows amplifies target noise; this knob suppresses the visible jitter
    # without flattening the V2-style prosody motion entirely.
    if cfg.velocity_brow_weight is not None:
        for ch in (0, 1, 2, 3, 4):
            w[ch] = cfg.velocity_brow_weight
    return w


def masked_velocity(pred: torch.Tensor, target: torch.Tensor,
                    valid_length: torch.Tensor,
                    vel_weights: torch.Tensor) -> torch.Tensor:
    """Per-channel-weighted L1 on per-frame difference (smoothness penalty)."""
    B, T, _ = pred.shape
    pred_v = pred[:, 1:] - pred[:, :-1]
    target_v = target[:, 1:] - target[:, :-1]
    # Per-channel weighted; (B, T-1, C) → weighted → (B, T-1)
    diff = (pred_v - target_v).abs() * vel_weights[None, None, :]
    diff = diff.mean(dim=-1)
    frame_idx = torch.arange(T - 1, device=pred.device).unsqueeze(0)
    mask = (frame_idx < (valid_length - 1).unsqueeze(1)).float()
    diff = diff * mask
    denom = mask.sum().clamp(min=1.0)
    return diff.sum() / denom


# ── Trainer ───────────────────────────────────────────────────────────

def run_epoch(model, dl, optimizer, scheduler, ch_weights, vel_weights, cfg,
              device, train: bool, velocity_scale: float = 1.0) -> Dict[str, float]:
    model.train(train)
    sums = {"loss": 0.0, "l1": 0.0, "vel": 0.0, "n": 0}
    grad_ctx = torch.enable_grad() if train else torch.no_grad()
    with grad_ctx:
        for batch in dl:
            audio = batch["audio"].to(device, non_blocking=True)
            cond = batch["cond"].to(device, non_blocking=True)
            target = batch["target"].to(device, non_blocking=True)
            valid_length = batch["valid_length"].to(device, non_blocking=True)

            pred = model(audio, cond)
            loss_l1 = masked_l1(pred, target, valid_length, ch_weights)
            # Velocity loss with warmup scale: during warmup epochs, scale<1
            # so L1 dominates and model learns to predict varied output. Once
            # scale=1, velocity penalty refines the (already-varied) output.
            loss_vel = masked_velocity(pred, target, valid_length,
                                        vel_weights * velocity_scale)
            loss = loss_l1 + loss_vel

            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                optimizer.step()
                scheduler.step()

            sums["loss"] += float(loss.item())
            sums["l1"] += float(loss_l1.item())
            sums["vel"] += float(loss_vel.item()) if cfg.velocity_weight > 0 else 0.0
            sums["n"] += 1

    n = max(1, sums["n"])
    return {"loss": sums["loss"] / n, "l1": sums["l1"] / n, "vel": sums["vel"] / n}


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--npz_dir", type=Path, default=DEFAULT_NPZ_DIR)
    ap.add_argument("--emotion_dir", type=Path, default=DEFAULT_EMOTION_DIR)
    ap.add_argument("--out_dir", type=Path, default=DEFAULT_OUT_DIR)
    ap.add_argument("--epochs", type=int, default=None)
    ap.add_argument("--batch_size", type=int, default=None)
    ap.add_argument("--lr", type=float, default=None)
    ap.add_argument("--device", default="cuda:0")
    ap.add_argument("--num_workers", type=int, default=4)
    ap.add_argument("--smoke", action="store_true",
                    help="Smoke test: 10 train scenarios, 5 epochs.")
    ap.add_argument("--focus", choices=["all", "lipsync", "expression"],
                    default="all",
                    help="Channel focus. "
                         "'all' = train all 52 channels (default). "
                         "'lipsync' = only LIPSYNC + SHARED channels have "
                         "loss (expression branch sees zero gradient). "
                         "'expression' = only EXPRESSION_ONLY channels have "
                         "loss (lipsync branch sees zero gradient).")
    ap.add_argument("--resume", type=Path, default=None,
                    help="Load model weights from a checkpoint .pt before "
                         "training. Use with --freeze_lipsync for phase-2 "
                         "expression retraining.")
    ap.add_argument("--freeze_lipsync", action="store_true",
                    help="Freeze shared backbone + lipsync branch + lipsync "
                         "head. Only the expression branch + head will train. "
                         "Lipsync output stays bit-for-bit identical to what "
                         "was loaded via --resume.")
    ap.add_argument("--wandb", action="store_true",
                    help="Log to Weights & Biases. Requires `wandb login` once.")
    ap.add_argument("--wandb_project", default="animasync-v3-face",
                    help="W&B project name.")
    ap.add_argument("--wandb_run_name", default=None,
                    help="W&B run name. Defaults to auto-generated.")
    ap.add_argument("--wandb_entity", default=None,
                    help="W&B entity (team or user). Defaults to your default.")
    ap.add_argument("--lipsync-target-gain", type=float, default=1.0,
                    help="Multiply PURE_LIPSYNC target channels (jaw, mouth "
                         "mechanics, tongue, cheekPuff) by this factor at "
                         "load time, then clamp [0,1]. Default 1.0 = no "
                         "change.")
    ap.add_argument("--expression-target-gain", type=float, default=1.0,
                    help="Multiply EMOTIONAL target channels (brows incl. "
                         "innerUp, cheekSquint, eyeSquint, eyeWide, mouth "
                         "Dimple/Frown/Smile, noseSneer) by this factor at "
                         "load time, then clamp [0,1]. Default 1.0 = no "
                         "change. If --emotional-mouth-target-gain is also "
                         "set, THIS knob covers only the pure-expression "
                         "subset (brows + eyeSquint + eyeWide + cheekSquint).")
    ap.add_argument("--emotional-mouth-target-gain", type=float, default=None,
                    help="Optional separate gain for emotional-mouth target "
                         "channels (mouthDimple, mouthFrown, mouthSmile, "
                         "noseSneer). When set, decouples from --expression-"
                         "target-gain so brows/eyes can go higher than mouth. "
                         "Used by v18b to avoid pushing shared mouth channels "
                         "past the point where crisp_mouth normalization "
                         "destabilizes lipsync. None → same as --expression-"
                         "target-gain (backward-compat with v14/v18).")
    ap.add_argument("--velocity-eye-squint-weight", type=float, default=None,
                    help="Per-channel velocity penalty for eyeSquint L/R "
                         "(ch 18, 19). Higher than the default expression "
                         "velocity weight suppresses jitter at high gain. "
                         "None → use velocity_expression_weight (backward-"
                         "compat). Suggested 0.8 for v18b.")
    ap.add_argument("--velocity-brow-weight", type=float, default=None,
                    help="Per-channel velocity penalty for the 5 brow "
                         "channels (ch 0-4). Same mechanism as the eyeSquint "
                         "weight — suppresses brow jitter at high expression "
                         "gain. None → use velocity_expression_weight "
                         "(backward-compat). Suggested 0.7 for v18c.")
    ap.add_argument("--plosive-damp-target", type=float, default=0.0,
                    help="Bake the runtime plosive damper into training "
                         "targets. When mouthClose > 0.4 on a frame, "
                         "mouthPress/Roll/Shrug (ch 35,36,39,40,41,42) get "
                         "multiplied by (1 - this_value * smoothstep). "
                         "0 = off (default). 0.30 = matches the production "
                         "main-viewer setting that prevents 'lips swallowed' "
                         "on m/b/p plosives.")
    ap.add_argument("--smooth-target-sigma-brow", type=float, default=0.0,
                    help="Gaussian σ (frames @ 30 fps) for pre-smoothing the "
                         "5 brow target channels BEFORE the gain. The proper "
                         "fix for brow flicker at high gain — smooths the "
                         "input the model is asked to fit so jitter never "
                         "enters the training signal. 0 = off (default). "
                         "Suggested 2.0 (~67ms) for v18e.")
    ap.add_argument("--smooth-target-sigma-eye-wide", type=float, default=0.0,
                    help="Gaussian σ (frames @ 30 fps) on the eyeWide target "
                         "channels (20, 21) before gain. Pair with eyeSquint "
                         "smoothing to fix surprise → other-emotion "
                         "transitions where eyeWide↓ and eyeSquint↑ cross "
                         "discontinuously. Typical: same value as eyeSquint.")
    ap.add_argument("--smooth-target-sigma-eye-squint", type=float, default=0.0,
                    help="Gaussian σ (frames @ 30 fps) for pre-smoothing the "
                         "eyeSquint L/R target channels BEFORE the gain. "
                         "Real orbicularis oculi is slow + sustained, so we "
                         "can smooth heavier than brows without losing "
                         "useful motion. 0 = off (default). Suggested 3.0 "
                         "(~100ms) for v18e.")
    ap.add_argument("--brow-innerup-happy-gain", type=float, default=None,
                    help="Override the gain on browInnerUp (ch 2) for happy "
                         "emotions (joy / laughter / excitement / gratitude). "
                         "Set to 1.0 with expression-target-gain=2.2 to keep "
                         "browInnerUp un-amplified on happy frames so the "
                         "avatar doesn't look concerned/apologetic when "
                         "saying happy things, while leaving the 2.2× boost "
                         "everywhere else. None → no override (backward "
                         "compat with v14..v18f).")
    ap.add_argument("--brow-happy-gain", type=float, default=None,
                    help="Per-emotion gain override for ALL brow channels "
                         "(ch 0–4) on happy frames (joy / laughter / "
                         "excitement / gratitude). Use when brows are pushed "
                         "very high via --brow-target-gain but should stay "
                         "calm (~v14 level 1.4) on smiling frames. None → "
                         "no override (backward compat).")
    ap.add_argument("--brow-target-gain", type=float, default=None,
                    help="Override target gain for ALL brow channels (ch 0–4) "
                         "independently of --expression-target-gain. Use to "
                         "keep brows at a v14-style mild gain (e.g. 1.4) "
                         "while pushing the rest of expression (cheekSquint / "
                         "eyeSquint / eyeWide) much higher. None → brows ride "
                         "the global expression gain (backward compat).")
    ap.add_argument("--soft-clip", action="store_true",
                    help="Replace the dataset's hard upper-clamp at 1.0 with "
                         "a smooth knee-then-asymptote curve. Preserves "
                         "emotional cross-fade tails that would otherwise be "
                         "chopped off when high target gains push values "
                         "well past 1.0. Linear below --soft-clip-knee, "
                         "asymptotic above. Lower bound stays hard 0.")
    ap.add_argument("--smooth-cond-sigma-emotion", type=float, default=0.0,
                    help="Gaussian σ (in frames @ 30 fps) for smoothing the "
                         "emotion one-hot in cond[:, :16] along time. Smooths "
                         "turn-boundary step changes so the model learns "
                         "gradual emotion cross-fades. Must also be applied "
                         "at inference time. Typical: 5–10 (~170–330 ms). "
                         "0 = no smoothing (backward compat).")
    ap.add_argument("--soft-clip-knee", type=float, default=0.7,
                    help="Knee point for --soft-clip. Below this value the "
                         "curve is exactly linear (no distortion). Above, "
                         "it bends toward 1.0. Default 0.7 = first 70%% of "
                         "the dynamic range untouched. Range (0, 1).")
    ap.add_argument("--eye-wide-max", type=float, default=None,
                    help="Hard cap on eyeWide (ch 20, 21) AFTER gain, before "
                         "the [0,1] clamp. Prevents bug-eyed saturation at "
                         "high expression gain. Typical: 0.5–0.7. None → no "
                         "cap (backward compat).")
    ap.add_argument("--brow-surprise-gain", type=float, default=None,
                    help="Per-frame brow scale-down on surprise/fluster proxy "
                         "(high eyeWide). Mirrors the viewer 'brow cap' "
                         "slider: brows (ch 0–4) get multiplied by "
                         "(1 - wide_ramp * value), where wide_ramp ramps in "
                         "over [0.10, 0.30] of post-gain eyeWide. Typical: "
                         "0.3–0.5. None → no scale-down (backward compat).")
    ap.add_argument("--variant-tag", type=str, default=None,
                    help="Optional suffix appended to checkpoint filenames "
                         "after --focus, so different gain runs don't "
                         "clobber each other. E.g. 'v14' → "
                         "best_lipsync_v14.pt / best_expression_v14.pt.")
    args = ap.parse_args()

    cfg = V3FaceConfig()
    if args.epochs: cfg.n_epochs = args.epochs
    if args.batch_size: cfg.batch_size = args.batch_size
    if args.lr: cfg.learning_rate = args.lr
    cfg.lipsync_target_gain = args.lipsync_target_gain
    cfg.expression_target_gain = args.expression_target_gain
    cfg.emotional_mouth_target_gain = args.emotional_mouth_target_gain
    cfg.velocity_eye_squint_weight = args.velocity_eye_squint_weight
    cfg.velocity_brow_weight = args.velocity_brow_weight
    cfg.plosive_damp_target = args.plosive_damp_target
    cfg.smooth_target_sigma_brow = args.smooth_target_sigma_brow
    cfg.smooth_target_sigma_eye_squint = args.smooth_target_sigma_eye_squint
    cfg.smooth_target_sigma_eye_wide = args.smooth_target_sigma_eye_wide
    cfg.brow_innerup_happy_gain = args.brow_innerup_happy_gain
    cfg.brow_happy_gain = args.brow_happy_gain
    cfg.brow_target_gain = args.brow_target_gain
    cfg.brow_surprise_gain = args.brow_surprise_gain
    cfg.eye_wide_max = args.eye_wide_max
    cfg.soft_clip = args.soft_clip
    cfg.soft_clip_knee = args.soft_clip_knee
    cfg.smooth_cond_sigma_emotion = args.smooth_cond_sigma_emotion

    args.out_dir.mkdir(parents=True, exist_ok=True)
    device = torch.device(args.device)

    if (cfg.lipsync_target_gain != 1.0
            or cfg.expression_target_gain != 1.0
            or cfg.emotional_mouth_target_gain is not None):
        mouth_g = (cfg.emotional_mouth_target_gain
                   if cfg.emotional_mouth_target_gain is not None
                   else cfg.expression_target_gain)
        print(f"[target-gain] lipsync×{cfg.lipsync_target_gain:.2f}  "
              f"expression×{cfg.expression_target_gain:.2f}  "
              f"emotional-mouth×{mouth_g:.2f}  "
              f"(applied per-channel in dataset, then clamped [0,1])")
    if cfg.velocity_eye_squint_weight is not None:
        print(f"[velocity] eyeSquint(ch 18,19)×{cfg.velocity_eye_squint_weight:.2f}  "
              f"(vs default expression {cfg.velocity_expression_weight:.2f})")
    if cfg.velocity_brow_weight is not None:
        print(f"[velocity] brows(ch 0-4)×{cfg.velocity_brow_weight:.2f}  "
              f"(vs default expression {cfg.velocity_expression_weight:.2f})")
    if cfg.plosive_damp_target > 0.0:
        print(f"[plosive-damper] target damp={cfg.plosive_damp_target:.2f}  "
              f"(applied in dataset when mouthClose > 0.4)")
    if cfg.smooth_target_sigma_brow > 0.0:
        print(f"[target-smooth] brow(ch 0-4) σ={cfg.smooth_target_sigma_brow:.2f} frames "
              f"(~{cfg.smooth_target_sigma_brow * 1000 / cfg.fps:.0f} ms @ {cfg.fps} fps)")
    if cfg.smooth_target_sigma_eye_squint > 0.0:
        print(f"[target-smooth] eyeSquint(ch 18,19) σ={cfg.smooth_target_sigma_eye_squint:.2f} frames "
              f"(~{cfg.smooth_target_sigma_eye_squint * 1000 / cfg.fps:.0f} ms @ {cfg.fps} fps)")
    if cfg.smooth_target_sigma_eye_wide > 0.0:
        print(f"[target-smooth] eyeWide(ch 20,21) σ={cfg.smooth_target_sigma_eye_wide:.2f} frames "
              f"(~{cfg.smooth_target_sigma_eye_wide * 1000 / cfg.fps:.0f} ms @ {cfg.fps} fps)")
    if cfg.brow_innerup_happy_gain is not None:
        print(f"[per-emotion] browInnerUp(ch 2) on happy frames uses gain "
              f"{cfg.brow_innerup_happy_gain:.2f} (vs {cfg.expression_target_gain:.2f} elsewhere)")
    if cfg.brow_target_gain is not None:
        print(f"[target-gain] brows(ch 0-4)×{cfg.brow_target_gain:.2f}  "
              f"(overrides expression×{cfg.expression_target_gain:.2f} for brow group only)")
    if cfg.brow_happy_gain is not None:
        print(f"[per-emotion] brows(ch 0-4) on happy frames use gain "
              f"{cfg.brow_happy_gain:.2f} (vs {cfg.brow_target_gain if cfg.brow_target_gain is not None else cfg.expression_target_gain:.2f} elsewhere)")
    if cfg.brow_surprise_gain is not None:
        print(f"[per-frame] brows(ch 0-4) scaled by (1 - wide_ramp × "
              f"{cfg.brow_surprise_gain:.2f}) on surprise/fluster frames "
              f"(eyeWide proxy, ramp [0.10, 0.30])")
    if cfg.eye_wide_max is not None:
        print(f"[cap] eyeWide(ch 20,21) capped at {cfg.eye_wide_max:.2f} after gain")
    if cfg.soft_clip:
        print(f"[soft-clip] saturation knee at {cfg.soft_clip_knee:.2f} "
              f"(linear below, asymptotic above)")
    if cfg.smooth_cond_sigma_emotion > 0.0:
        print(f"[cond-smooth] emotion one-hot σ={cfg.smooth_cond_sigma_emotion:.1f} "
              f"frames (~{cfg.smooth_cond_sigma_emotion * 1000 / cfg.fps:.0f} ms @ {cfg.fps} fps)")

    # ─── Datasets ──────────────────────────────────────────────────
    train_ds = BlendshapeDataset(args.npz_dir,
                                  args.emotion_dir / "seed_train_final.jsonl",
                                  crop_frames=cfg.crop_frames,
                                  lipsync_target_gain=cfg.lipsync_target_gain,
                                  expression_target_gain=cfg.expression_target_gain,
                                  emotional_mouth_target_gain=cfg.emotional_mouth_target_gain,
                                  plosive_damp_target=cfg.plosive_damp_target,
                                  smooth_target_sigma_brow=cfg.smooth_target_sigma_brow,
                                  smooth_target_sigma_eye_squint=cfg.smooth_target_sigma_eye_squint,
                                  smooth_target_sigma_eye_wide=cfg.smooth_target_sigma_eye_wide,
                                  brow_innerup_happy_gain=cfg.brow_innerup_happy_gain,
                                  brow_happy_gain=cfg.brow_happy_gain,
                                  brow_target_gain=cfg.brow_target_gain,
                                  brow_surprise_gain=cfg.brow_surprise_gain,
                                  eye_wide_max=cfg.eye_wide_max,
                                  soft_clip=cfg.soft_clip,
                                  soft_clip_knee=cfg.soft_clip_knee,
                                  smooth_cond_sigma_emotion=cfg.smooth_cond_sigma_emotion)
    val_ds = BlendshapeDataset(args.npz_dir,
                                args.emotion_dir / "seed_val.jsonl",
                                crop_frames=cfg.crop_frames,
                                lipsync_target_gain=cfg.lipsync_target_gain,
                                expression_target_gain=cfg.expression_target_gain,
                                emotional_mouth_target_gain=cfg.emotional_mouth_target_gain,
                                plosive_damp_target=cfg.plosive_damp_target,
                                smooth_target_sigma_brow=cfg.smooth_target_sigma_brow,
                                smooth_target_sigma_eye_squint=cfg.smooth_target_sigma_eye_squint,
                                smooth_target_sigma_eye_wide=cfg.smooth_target_sigma_eye_wide,
                                brow_happy_gain=cfg.brow_happy_gain,
                                brow_target_gain=cfg.brow_target_gain,
                                brow_surprise_gain=cfg.brow_surprise_gain,
                                eye_wide_max=cfg.eye_wide_max,
                                soft_clip=cfg.soft_clip,
                                soft_clip_knee=cfg.soft_clip_knee,
                                smooth_cond_sigma_emotion=cfg.smooth_cond_sigma_emotion)

    if args.smoke:
        # Pick a diverse smoke set: 4 long_ + 3 solo_ + 3 daily-split if possible
        smoke_picks = []
        for want_cat, n in (("long_", 4), ("solo_", 3), ("daily_-split", 3)):
            picks = [e for e in train_ds.entries if e[1] == want_cat][:n]
            smoke_picks.extend(picks)
        if not smoke_picks:
            smoke_picks = train_ds.entries[:10]
        train_ds.entries = smoke_picks
        val_ds.entries = val_ds.entries[:5]
        cfg.n_epochs = 5
        cfg.warmup_steps = 5
        cfg.batch_size = 4   # so 10 samples → 3 batches/epoch
        print(f"[smoke] train={len(train_ds)} val={len(val_ds)} "
              f"epochs={cfg.n_epochs} batch={cfg.batch_size}")

    print(f"train: {len(train_ds)} scenarios — counts: {train_ds.category_counts()}")
    print(f"val:   {len(val_ds)} scenarios — counts: {val_ds.category_counts()}")

    # Weighted sampler boosts long_ (multi-turn) to balance the daily-split flood
    sample_weights = train_ds.get_sample_weights(cfg.long_oversample_weight)
    sampler = WeightedRandomSampler(sample_weights,
                                     num_samples=len(train_ds),
                                     replacement=True)

    train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, sampler=sampler,
                          num_workers=args.num_workers, pin_memory=True,
                          drop_last=False)
    val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                        num_workers=args.num_workers, pin_memory=True)

    # ─── Model ─────────────────────────────────────────────────────
    model = V3FaceModel(cfg).to(device)
    print(f"V3FaceModel (split-branch): {model.n_params/1e6:.2f}M params, "
          f"~{model.size_mb:.1f} MB fp32 (~{model.size_mb/4:.1f} MB int8)")

    # ─── Resume from checkpoint (optional) ─────────────────────────
    if args.resume is not None:
        if not args.resume.exists():
            raise SystemExit(f"--resume checkpoint not found: {args.resume}")
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model"])
        print(f"[resume] loaded weights from {args.resume}  "
              f"(prev epoch={ckpt.get('epoch', '?')}, "
              f"val_l1={ckpt.get('val_l1', float('nan')):.4f})")

    # ─── Freeze lipsync (phase 2) ──────────────────────────────────
    if args.freeze_lipsync:
        frozen_n = model.freeze_lipsync()
        print(f"[freeze_lipsync] froze {frozen_n/1e6:.2f}M params "
              f"(shared backbone + lipsync branch + lipsync head). "
              f"Trainable: {model.n_trainable/1e6:.2f}M (expression branch + head).")
        # Phase-2 loss focus: zero out lipsync-branch channels in L1 + velocity
        # weights so the metric numbers reflect only the actively-training half.
        # (Gradients are already zero on frozen params; this just cleans metrics.)
        from .model import LIPSYNC_BRANCH_CHANNELS
        for ch in LIPSYNC_BRANCH_CHANNELS:
            pass   # weights are zeroed below after they're built

    # ─── W&B init (opt-in) ─────────────────────────────────────────
    use_wandb = args.wandb and _WANDB_AVAILABLE
    if args.wandb and not _WANDB_AVAILABLE:
        print("[wandb] requested but `wandb` not installed — skipping. "
              "Install with: pip install wandb")
    if use_wandb:
        run_name = args.wandb_run_name or (
            f"v3face_h{cfg.hidden_dim}_b{len(cfg.shared_dilations) + len(cfg.branch_dilations)}"
            f"_lr{cfg.learning_rate:.0e}"
            + ("_smoke" if args.smoke else "")
        )
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=run_name,
            config={
                **cfg.__dict__,
                "n_params": model.n_params,
                "size_mb_fp32": model.size_mb,
                "size_mb_int8": model.size_mb / 4,
                "train_scenarios": len(train_ds),
                "val_scenarios": len(val_ds),
                "train_category_counts": train_ds.category_counts(),
                "val_category_counts": val_ds.category_counts(),
                "device": str(device),
                "smoke": args.smoke,
            },
        )
        wandb.watch(model, log="gradients", log_freq=100)
        print(f"[wandb] logging to {wandb.run.get_url()}")

    # ─── Optim + LR schedule (warmup + cosine) ─────────────────────
    # Filter to trainable params only — when --freeze_lipsync is set,
    # the optimizer should ignore the frozen shared backbone + lipsync side.
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params,
                      lr=cfg.learning_rate,
                      betas=(0.9, 0.95),
                      weight_decay=cfg.weight_decay)

    total_steps = max(1, len(train_dl) * cfg.n_epochs)

    def lr_lambda(step: int) -> float:
        if step < cfg.warmup_steps:
            return step / max(1, cfg.warmup_steps)
        prog = (step - cfg.warmup_steps) / max(1, total_steps - cfg.warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * min(1.0, prog)))

    scheduler = LambdaLR(optimizer, lr_lambda)
    ch_weights = make_channel_weights(cfg).to(device)
    vel_weights = make_velocity_weights(cfg).to(device)

    # Apply focus mask: zero out non-focus channels in L1 + velocity weights.
    # `--focus lipsync`    → expression channels masked (only lipsync trains)
    # `--focus expression` → lipsync channels masked (only expression trains)
    # `--freeze_lipsync` implies lipsync channels are also masked (gradients
    # on those would be wasted since producing params are no_grad).
    from .model import LIPSYNC_BRANCH_CHANNELS, EXPRESSION_BRANCH_CHANNELS
    if args.focus == "lipsync":
        for ch in EXPRESSION_BRANCH_CHANNELS:
            ch_weights[ch] = 0.0
            vel_weights[ch] = 0.0
        print(f"[focus=lipsync] masked {len(EXPRESSION_BRANCH_CHANNELS)} expression channels")
    elif args.focus == "expression" or args.freeze_lipsync:
        for ch in LIPSYNC_BRANCH_CHANNELS:
            ch_weights[ch] = 0.0
            vel_weights[ch] = 0.0
        print(f"[focus=expression{'/freeze_lipsync' if args.freeze_lipsync else ''}] "
              f"masked {len(LIPSYNC_BRANCH_CHANNELS)} lipsync channels")

    # ─── Train loop ────────────────────────────────────────────────
    best_val = float("inf")
    for epoch in range(cfg.n_epochs):
        t0 = time.time()
        # Linear velocity warmup: 0 at epoch 0 → 1.0 at velocity_warmup_epochs.
        velocity_scale = (
            min(1.0, epoch / cfg.velocity_warmup_epochs)
            if cfg.velocity_warmup_epochs > 0 else 1.0
        )
        tr = run_epoch(model, train_dl, optimizer, scheduler,
                       ch_weights, vel_weights, cfg, device, train=True,
                       velocity_scale=velocity_scale)
        va = run_epoch(model, val_dl, optimizer, scheduler,
                       ch_weights, vel_weights, cfg, device, train=False,
                       velocity_scale=velocity_scale)
        dt = time.time() - t0
        print(f"epoch {epoch:3d}  "
              f"train l1={tr['l1']:.4f} vel={tr['vel']:.4f}  "
              f"val l1={va['l1']:.4f} vel={va['vel']:.4f}  "
              f"lr={scheduler.get_last_lr()[0]:.2e}  "
              f"vw={velocity_scale:.2f}  "
              f"{dt:.1f}s")

        is_best = va["l1"] < best_val
        if is_best:
            best_val = va["l1"]
        ckpt = {
            "model": model.state_dict(),
            "config": cfg.__dict__,
            "epoch": epoch,
            "val_l1": va["l1"],
            "val_vel": va["vel"],
            "train_l1": tr["l1"],
            "train_vel": tr["vel"],
        }
        # Suffix checkpoint with focus + optional variant tag so phases AND
        # gain variants don't clobber each other. Examples:
        #   focus=all                          → best.pt
        #   focus=lipsync                      → best_lipsync.pt
        #   focus=expression                   → best_expression.pt
        #   focus=lipsync,    variant=v14      → best_lipsync_v14.pt
        #   focus=expression, variant=v18      → best_expression_v18.pt
        suffix = "" if args.focus == "all" else f"_{args.focus}"
        if args.variant_tag:
            suffix = f"{suffix}_{args.variant_tag}"
        if is_best:
            torch.save(ckpt, args.out_dir / f"best{suffix}.pt")
            print(f"  → saved best{suffix} (val l1={best_val:.4f})")
        torch.save(ckpt, args.out_dir / f"latest{suffix}.pt")

        if use_wandb:
            wandb.log({
                "epoch": epoch,
                "train/l1": tr["l1"],
                "train/velocity": tr["vel"],
                "train/loss": tr["loss"],
                "val/l1": va["l1"],
                "val/velocity": va["vel"],
                "val/loss": va["loss"],
                "val/best_l1": best_val,
                "lr": scheduler.get_last_lr()[0],
                "epoch_seconds": dt,
            }, step=epoch)

    print(f"\nDone. best val l1: {best_val:.4f}")
    print(f"checkpoints: {args.out_dir}")
    if use_wandb:
        wandb.summary["best_val_l1"] = best_val
        wandb.finish()


if __name__ == "__main__":
    main()
