"""BlendshapeDataset — loads (audio, cond, target) from data_pipeline.py .npz.

Each .npz contains:
  audio:  (T, 80)  log-mel features at 30 fps
  cond:   (T, 19)  16-dim emotion one-hot + 3-dim VAD per frame
  target: (T, 52)  ARKit blendshape values

Split membership is derived from the JSONL files in data/emotion/
(seed_train_final.jsonl / seed_val.jsonl / seed_test.jsonl).

For daily-split scenarios (per-turn pseudo-scenarios like `daily_007_t2`),
parent membership is checked: if the parent dialogue is in the split's JSONL,
all its per-turn splits belong to that split.

Crops:
  - Sequences longer than `crop_frames`  → random window crop
  - Sequences shorter than `crop_frames` → edge-pad to crop_frames + return
    `valid_length` so the trainer can mask loss to real frames only.
"""
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import List, Set, Tuple

import numpy as np
import torch
from scipy.ndimage import gaussian_filter1d
from torch.utils.data import Dataset

# Same pattern as scripts/dataset_to_viewer.py — daily_NNN_tK or daily_NNN_pM_tK
_SPLIT_RE = re.compile(r"^(daily_.+)_t(\d+)$")

# Channels boosted by the target-side gain knobs. See V3FaceConfig docstring
# for the rationale. These are AppLIED at __getitem__ time so the model
# learns to produce stronger output natively — no inference-time scaling.
PURE_LIPSYNC_CHANNELS = (
    5,                                  # cheekPuff
    22, 23, 24, 25,                     # jawForward, jawLeft, jawOpen, jawRight
    26,                                 # mouthClose
    31,                                 # mouthFunnel
    32,                                 # mouthLeft
    33, 34,                             # mouthLowerDownLeft, mouthLowerDownRight
    35, 36,                             # mouthPressLeft, mouthPressRight
    37,                                 # mouthPucker
    38,                                 # mouthRight
    39, 40,                             # mouthRollLower, mouthRollUpper
    41, 42,                             # mouthShrugLower, mouthShrugUpper
    45, 46,                             # mouthStretchLeft, mouthStretchRight
    47, 48,                             # mouthUpperUpLeft, mouthUpperUpRight
    51,                                 # tongueOut
)
EMOTIONAL_CHANNELS = (
    0, 1, 2, 3, 4,                      # brows (incl. browInnerUp)
    6, 7,                               # cheekSquintLeft, cheekSquintRight
    18, 19,                             # eyeSquintLeft, eyeSquintRight
    20, 21,                             # eyeWideLeft, eyeWideRight
    27, 28,                             # mouthDimpleLeft, mouthDimpleRight
    29, 30,                             # mouthFrownLeft, mouthFrownRight
    43, 44,                             # mouthSmileLeft, mouthSmileRight
    49, 50,                             # noseSneerLeft, noseSneerRight
)

# v18b split — separate gain knob for emotional mouth channels (smile/frown/
# dimple/sneer). These are SHARED channels output by the lipsync branch; at
# high gain they make crisp_mouth's per-channel-max normalization push too
# many frames through the gate, which makes lipsync look over-active.
# Keeping them at moderate gain (~1.4) while pure-expression channels go
# higher (~1.8) gets us strong brows/eyes without the lipsync side-effect.
EMOTIONAL_PURE_CHANNELS = (
    0, 1, 2, 3, 4,                      # brows (incl. browInnerUp)
    6, 7,                               # cheekSquintLeft, cheekSquintRight
    18, 19,                             # eyeSquintLeft, eyeSquintRight
    20, 21,                             # eyeWideLeft, eyeWideRight
)
EMOTIONAL_MOUTH_CHANNELS = (
    27, 28,                             # mouthDimpleLeft, mouthDimpleRight
    29, 30,                             # mouthFrownLeft, mouthFrownRight
    43, 44,                             # mouthSmileLeft, mouthSmileRight
    49, 50,                             # noseSneerLeft, noseSneerRight
)
assert tuple(sorted(EMOTIONAL_PURE_CHANNELS + EMOTIONAL_MOUTH_CHANNELS)) == \
       tuple(sorted(EMOTIONAL_CHANNELS)), "split must partition EMOTIONAL"

# Channels damped by the plosive damper at training time. Same set as the
# viewer's PLOSIVE_DAMP_MASK: mouthPress L/R, mouthRoll Lower/Upper,
# mouthShrug Lower/Upper. The damper kicks in when mouthClose (ch 26) on
# the same frame exceeds PLOSIVE_TRIGGER (0.4).
PLOSIVE_DAMP_CHANNELS = (35, 36, 39, 40, 41, 42)
PLOSIVE_TRIGGER = 0.4
MOUTH_CLOSE_CH = 26

# Emotion-one-hot indices that count as "happy" for the browInnerUp
# per-emotion gain override. These four are positive-valence emotions where
# the inner-brow raise should stay near-zero (real psychology — happy faces
# raise OUTER brow + cheek, not inner brow). Order matches the 16-emotion
# convention in data/emotion/emotion_labels.json and inference.py.
HAPPY_EMOTION_INDICES = (1, 2, 3, 5)    # joy, laughter, excitement, gratitude
BROW_INNERUP_CH = 2

# All brow channels — used when `brow_target_gain` overrides the global
# expression gain for the whole brow group (not just browInnerUp).
BROW_CHANNELS = (0, 1, 2, 3, 4)
# eyeWide channels — used when `eye_wide_max` hard-caps the wide-eye look.
EYE_WIDE_CHANNELS = (20, 21)


def _classify(sid: str) -> str:
    if sid.startswith("long_"):
        return "long_"
    if sid.startswith("solo_"):
        return "solo_"
    if sid.startswith("daily_") and "_t" in sid:
        return "daily_-split"
    return "other"


def _parent_sid(sid: str) -> str:
    """Daily-split id → parent dialogue id. Pass through for long_/solo_."""
    m = _SPLIT_RE.match(sid)
    return m.group(1) if m else sid


class BlendshapeDataset(Dataset):
    def __init__(
        self,
        npz_dir: Path,
        split_jsonl: Path,
        crop_frames: int = 240,
        lipsync_target_gain: float = 1.0,
        expression_target_gain: float = 1.0,
        emotional_mouth_target_gain: float | None = None,
        plosive_damp_target: float = 0.0,
        smooth_target_sigma_brow: float = 0.0,
        smooth_target_sigma_eye_squint: float = 0.0,
        smooth_target_sigma_eye_wide: float = 0.0,
        brow_innerup_happy_gain: float | None = None,
        brow_happy_gain: float | None = None,
        brow_target_gain: float | None = None,
        brow_surprise_gain: float | None = None,
        eye_wide_max: float | None = None,
        soft_clip: bool = False,
        soft_clip_knee: float = 0.7,
        smooth_cond_sigma_emotion: float = 0.0,
    ):
        self.npz_dir = Path(npz_dir)
        self.crop_frames = int(crop_frames)
        self.lipsync_target_gain = float(lipsync_target_gain)
        self.expression_target_gain = float(expression_target_gain)
        # If emotional_mouth_target_gain is None, mouth gets the same gain as
        # the rest of expression — backward-compat with v14/v18 training.
        # If a value is provided, the pure-expression channels (brows, eye
        # squint/wide, cheekSquint) get `expression_target_gain` while the
        # emotional-mouth channels (smile/frown/dimple/sneer) get the
        # separate `emotional_mouth_target_gain`. Used by v18b to keep
        # strong brow/eye gain without over-pushing mouth shapes through
        # crisp_mouth's per-channel-max normalization.
        if emotional_mouth_target_gain is None:
            emotional_mouth_target_gain = expression_target_gain
        self.emotional_mouth_target_gain = float(emotional_mouth_target_gain)

        # Optional brow-group gain split. None → brows ride the global
        # expression_target_gain (backward compat). Otherwise brows (ch 0–4)
        # use this value while the rest of EMOTIONAL_PURE keeps the global
        # gain — lets us bake a strong cheek/eye look while keeping brows
        # at a v14-style mild gain.
        self.brow_target_gain = brow_target_gain

        # Build a per-channel gain vector once. Channels not in either set
        # stay at 1.0 (eyeBlink, eyeLook, anything not listed).
        self._target_gain = np.ones(52, dtype=np.float32)
        for ch in PURE_LIPSYNC_CHANNELS:
            self._target_gain[ch] = self.lipsync_target_gain
        for ch in EMOTIONAL_PURE_CHANNELS:
            self._target_gain[ch] = self.expression_target_gain
        for ch in EMOTIONAL_MOUTH_CHANNELS:
            self._target_gain[ch] = self.emotional_mouth_target_gain
        # Brow override comes AFTER the EMOTIONAL_PURE loop so it wins.
        if self.brow_target_gain is not None:
            for ch in BROW_CHANNELS:
                self._target_gain[ch] = float(self.brow_target_gain)
        self._gain_active = (self.lipsync_target_gain != 1.0
                              or self.expression_target_gain != 1.0
                              or self.emotional_mouth_target_gain != 1.0
                              or (self.brow_target_gain is not None
                                  and float(self.brow_target_gain) != 1.0))

        # Hard cap on eyeWide (ch 20, 21) after gain, before final clamp.
        # Prevents the bug-eyed look on highly-amplified surprise frames.
        # None or >= 1.0 → effectively no cap.
        self.eye_wide_max = (None if eye_wide_max is None
                             else float(eye_wide_max))

        # Soft-clip toggle + knee. Preserves cross-fade tails by replacing
        # the hard upper-clamp with a smooth knee-then-asymptote curve.
        self.soft_clip = bool(soft_clip)
        self.soft_clip_knee = float(soft_clip_knee)
        self._soft_clip_one_minus_knee = 1.0 - self.soft_clip_knee

        # Conditioning-side Gaussian smoothing of the emotion one-hot
        # (cond[:, :16]). Smooths turn-boundary step changes into ramps so
        # the model learns gradual cross-fades.
        self.smooth_cond_sigma_emotion = float(smooth_cond_sigma_emotion)
        self.plosive_damp_target = float(plosive_damp_target)
        # Precompute the channel mask for the plosive damper so the hot path
        # is a single boolean array lookup.
        self._plosive_damp_mask = np.zeros(52, dtype=bool)
        for ch in PLOSIVE_DAMP_CHANNELS:
            self._plosive_damp_mask[ch] = True

        # Per-channel target pre-smoothing (applied BEFORE gain). Gaussian
        # σ in frames. Built as a list of (channel-tuple, sigma) pairs so
        # the hot path is a tight loop over only the channels actually
        # being smoothed.
        self.smooth_target_sigma_brow = float(smooth_target_sigma_brow)
        self.smooth_target_sigma_eye_squint = float(smooth_target_sigma_eye_squint)
        self.smooth_target_sigma_eye_wide = float(smooth_target_sigma_eye_wide)
        self._smooth_groups = []
        if self.smooth_target_sigma_brow > 0.0:
            self._smooth_groups.append(((0, 1, 2, 3, 4), self.smooth_target_sigma_brow))
        if self.smooth_target_sigma_eye_squint > 0.0:
            self._smooth_groups.append(((18, 19), self.smooth_target_sigma_eye_squint))
        if self.smooth_target_sigma_eye_wide > 0.0:
            self._smooth_groups.append(((20, 21), self.smooth_target_sigma_eye_wide))

        # Per-emotion gain override for browInnerUp.
        self.brow_innerup_happy_gain = brow_innerup_happy_gain
        # Per-emotion gain override for ALL brow channels (full ch 0–4).
        self.brow_happy_gain = brow_happy_gain
        # Per-frame brow scale-down on surprise/fluster proxy (high eyeWide).
        # Mirrors the viewer "brow cap" knob — baked into the training target
        # so the model learns to lower brows on surprise/fluster frames.
        self.brow_surprise_gain = (None if brow_surprise_gain is None
                                    else float(brow_surprise_gain))

        # Collect scenario IDs from the split JSONL (parent ids for daily_,
        # canonical ids for long_/solo_).
        split_ids: Set[str] = set()
        with Path(split_jsonl).open() as f:
            for line in f:
                row = json.loads(line)
                split_ids.add(row["scenario_id"])

        # Walk the .npz directory and keep only files whose parent scenario
        # is in this split.
        self.entries: List[Tuple[str, str]] = []  # (sid, category)
        for p in sorted(self.npz_dir.glob("*.npz")):
            sid = p.stem
            if _parent_sid(sid) in split_ids:
                self.entries.append((sid, _classify(sid)))

    def __len__(self) -> int:
        return len(self.entries)

    def __getitem__(self, idx: int):
        sid, cat = self.entries[idx]
        data = np.load(self.npz_dir / f"{sid}.npz")
        audio = data["audio"].astype(np.float32)
        cond = data["cond"].astype(np.float32)
        target = data["target"].astype(np.float32)

        # Smooth the emotion one-hot (first 16 channels of cond) along
        # the time axis. Converts hard turn-boundary step changes into
        # gradient ramps the model can learn smooth cross-fades from.
        # VAD (cond[:, 16:19]) is left untouched — it's already a
        # continuous signal from the source data.
        if self.smooth_cond_sigma_emotion > 0.0:
            cond[:, :16] = gaussian_filter1d(cond[:, :16],
                                              sigma=self.smooth_cond_sigma_emotion,
                                              axis=0, mode="nearest")

        # Target pre-smoothing (per-channel Gaussian, applied BEFORE gain).
        # Smooths out compiler-side noise on flicker-prone channels so the
        # gain doesn't amplify it. Filter operates along the time axis;
        # 'nearest' mode mirrors edges so we don't bleed zeros in at the
        # start/end of the sequence.
        if self._smooth_groups:
            for chs, sigma in self._smooth_groups:
                for ch in chs:
                    target[:, ch] = gaussian_filter1d(target[:, ch], sigma=sigma,
                                                       mode="nearest")

        # Target-side gain (config-controlled). Multiply selected channels
        # then clamp [0, 1]. Inactive channels keep gain=1.0 (a no-op).
        if self._gain_active:
            target = target * self._target_gain[None, :]

        # browInnerUp (ch 2) per-emotion gain override. The frame's emotion
        # one-hot in `cond[:, :16]` tells us if it's a "happy" frame; if so
        # we re-scale ch 2 so its effective gain is `brow_innerup_happy_gain`
        # instead of `expression_target_gain`. Mask is the sum of
        # joy+laughter+excitement+gratitude probabilities, clipped [0,1] —
        # this naturally blends across mixed-emotion frames (e.g. partial
        # joy gets partial cap). Smoothed by the brow Gaussian sigma to
        # avoid creating high-frequency step changes at emotion boundaries.
        # full_gain is whatever gain was actually applied to ch 2 above —
        # `brow_target_gain` if set, else `expression_target_gain`. Using
        # the wrong base would over- or under-correct the divide.
        full_gain = float(self._target_gain[BROW_INNERUP_CH])
        if (self.brow_innerup_happy_gain is not None
                and self._gain_active
                and full_gain != 1.0):
            happy = cond[:, list(HAPPY_EMOTION_INDICES)].sum(axis=1)
            happy = np.clip(happy, 0.0, 1.0)
            if self.smooth_target_sigma_brow > 0.0:
                happy = gaussian_filter1d(happy,
                                           sigma=self.smooth_target_sigma_brow,
                                           mode="nearest")
            # Effective gain on ch 2: full_gain when happy=0,
            # brow_innerup_happy_gain when happy=1, linear blend in between.
            # We've already applied full_gain via _target_gain, so the
            # corrective multiplier is `effective_gain / full_gain`.
            adjust = (1.0 - happy) + happy * (self.brow_innerup_happy_gain / full_gain)
            target[:, BROW_INNERUP_CH] *= adjust

        # Whole-brow per-emotion gain override (ch 0–4). Same shape as the
        # browInnerUp override above but covers the full brow group. Useful
        # when `brow_target_gain` is set high (e.g. 10) so brows max out on
        # sad/angry frames, but you want them held at v14 (~1.4) on joy /
        # laughter / excitement / gratitude so smiles don't look concerned.
        if (self.brow_happy_gain is not None
                and self._gain_active):
            happy = cond[:, list(HAPPY_EMOTION_INDICES)].sum(axis=1)
            happy = np.clip(happy, 0.0, 1.0)
            if self.smooth_target_sigma_brow > 0.0:
                happy = gaussian_filter1d(happy,
                                           sigma=self.smooth_target_sigma_brow,
                                           mode="nearest")
            for ch in BROW_CHANNELS:
                full_gain_ch = float(self._target_gain[ch])
                if full_gain_ch == 0.0:
                    continue
                adjust = (1.0 - happy) + happy * (self.brow_happy_gain / full_gain_ch)
                target[:, ch] *= adjust

        # Per-frame brow scale-down on surprise/fluster (high-eyeWide proxy).
        # Mirrors the runtime "brow cap" slider in tools/blendshape-player-
        # compare.html exactly: uses POST-gain eyeWide as the proxy (same as
        # the viewer reads fb[20]/fb[21] post-gain), ramps brow attenuation
        # in over [0.10, 0.30] of eyeWide, and scales ch 0–4 by
        # (1 - wide_ramp * brow_surprise_gain). Applied AFTER brow_happy_gain
        # so happy + surprise (rare but possible) get both adjustments.
        if (self.brow_surprise_gain is not None
                and self.brow_surprise_gain > 0.0):
            eye_wide = np.maximum(target[:, 20], target[:, 21])
            wide_ramp = np.clip((eye_wide - 0.10) / 0.20, 0.0, 1.0)
            brow_factor = 1.0 - wide_ramp * float(self.brow_surprise_gain)
            for ch in BROW_CHANNELS:
                target[:, ch] *= brow_factor

        # Plosive damper baked into target. For each frame where the RAW
        # (post-gain) mouthClose exceeds the trigger, attenuate the
        # press/roll/shrug channels by a smoothstep'd factor. Mirrors the
        # runtime damper in tools/blendshape-player.html so the model can
        # learn to never produce the "lips swallowed" stack on m/b/p.
        if self.plosive_damp_target > 0.0:
            mc = target[:, MOUTH_CLOSE_CH]            # (T,)
            t = np.clip((mc - PLOSIVE_TRIGGER) / (1.0 - PLOSIVE_TRIGGER), 0.0, 1.0)
            atten = np.maximum(0.0, 1.0 - self.plosive_damp_target * t)   # (T,)
            # Broadcast atten across only the masked channels.
            target[:, self._plosive_damp_mask] *= atten[:, None]

        # eyeWide hard cap (ch 20, 21). Applied AFTER gain, BEFORE the
        # final [0,1] clamp. Caps the wide-eye amplitude so very-amplified
        # excitement/surprise frames don't all saturate to 1.0 → 1.0,
        # giving the avatar a perma-bug-eyed look. Plain np.minimum, no
        # smoothstep — the dataset clamp itself is a hard ceiling so a
        # second hard ceiling at `eye_wide_max` is in-style.
        if self.eye_wide_max is not None and self.eye_wide_max < 1.0:
            for ch in EYE_WIDE_CHANNELS:
                target[:, ch] = np.minimum(target[:, ch], self.eye_wide_max)

        # Final saturation to [0, 1]. Hard clamp by default; soft clip if
        # enabled (preserves cross-fade ramps that would otherwise have
        # their tails chopped off by hard clamp when gain pushes values
        # well past 1.0). Lower bound is hard 0 in both modes.
        if (self._gain_active or self.plosive_damp_target > 0.0
                or self.eye_wide_max is not None or self.soft_clip):
            if self.soft_clip:
                # Soft knee:  y = x                              if x ≤ k
                #             y = 1 - (1-k) * exp(-(x-k)/(1-k))  if x > k
                k = self.soft_clip_knee
                one_minus_k = self._soft_clip_one_minus_knee
                target = np.where(
                    target > k,
                    1.0 - one_minus_k * np.exp(-(target - k) / one_minus_k),
                    target,
                )
                target = np.maximum(target, 0.0).astype(np.float32)
            else:
                target = np.clip(target, 0.0, 1.0).astype(np.float32)

        T = target.shape[0]
        if T >= self.crop_frames:
            start = int(np.random.randint(0, T - self.crop_frames + 1))
            audio = audio[start : start + self.crop_frames]
            cond = cond[start : start + self.crop_frames]
            target = target[start : start + self.crop_frames]
            valid_length = self.crop_frames
        else:
            pad = self.crop_frames - T
            audio = np.pad(audio, ((0, pad), (0, 0)), mode="edge")
            cond = np.pad(cond, ((0, pad), (0, 0)), mode="edge")
            target = np.pad(target, ((0, pad), (0, 0)), mode="edge")
            valid_length = T

        return {
            "audio": torch.from_numpy(audio),
            "cond": torch.from_numpy(cond),
            "target": torch.from_numpy(target),
            "valid_length": torch.tensor(valid_length, dtype=torch.long),
            "scenario_id": sid,
            "category": cat,
        }

    def get_sample_weights(self, long_weight: float = 5.0) -> np.ndarray:
        """Sample weights for WeightedRandomSampler.

        Long_ scenarios get `long_weight`, everything else gets 1.0.
        Used to compensate for the 5715 daily-split + 253 solo_ vs only 363
        long_ files — without this, the model sees ~6% multi-emotion gradient.
        """
        weights = np.ones(len(self.entries), dtype=np.float32)
        for i, (_, cat) in enumerate(self.entries):
            if cat == "long_":
                weights[i] = long_weight
        return weights

    def category_counts(self) -> dict:
        counts: dict = {}
        for _, cat in self.entries:
            counts[cat] = counts.get(cat, 0) + 1
        return counts
