"""Brow + eyeSquint tremor baked into training targets.

Direct Python port of the runtime tremor in
`tools/blendshape-player.html` so dataset `.npz` targets contain the same
micro-motion that we've been judging in the viewer.

Channels:
  - 0, 1: browDownLeft/Right       (silence-gated)
  - 2:    browInnerUp              (silence-gated)
  - 3, 4: browOuterUpLeft/Right    (silence-gated)
  - 18, 19: eyeSquintLeft/Right    (always-on)

Bilateral pairs share a single noise stream so L/R move together
(orbicularis oculi tone, brow micro-furrowing).
"""
from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np
from scipy.ndimage import gaussian_filter1d


# Channel indices in ARKIT_52_NAMES (must stay in sync with constants.py)
BROW_DOWN_L_R = (0, 1)
BROW_INNER_UP = 2
BROW_OUTER_UP_L_R = (3, 4)
EYE_SQUINT_L_R = (18, 19)


def _hash_str(s: str) -> int:
    """FNV-1a-ish 32-bit hash. Mirrors player's `_hashStr` so identical
    scenario IDs produce identical noise streams in Python and JS."""
    h = 2166136261
    for c in s:
        h = ((h ^ ord(c)) * 16777619) & 0xFFFFFFFF
    return h


def _mulberry32(seed: int):
    """Deterministic uniform-[0,1) PRNG matching player's `_mulberry32`."""
    state = seed & 0xFFFFFFFF

    def rand() -> float:
        nonlocal state
        state = (state + 0x6D2B79F5) & 0xFFFFFFFF
        t = state
        # t = imul(t ^ (t >>> 15), 1 | t)
        t = ((t ^ (t >> 15)) * (1 | t)) & 0xFFFFFFFF
        # t = (t + imul(t ^ (t >>> 7), 61 | t)) ^ t
        t = ((t + ((t ^ (t >> 7)) * (61 | t))) & 0xFFFFFFFF) ^ t
        t = t & 0xFFFFFFFF
        # ((t ^ (t >>> 14)) >>> 0) / 4294967296
        return ((t ^ (t >> 14)) & 0xFFFFFFFF) / 4294967296.0

    return rand


def _gauss_rand_stream(rng, n: int) -> np.ndarray:
    """Box-Muller on top of a uniform PRNG (matches player's `_gaussRand`)."""
    out = np.empty(n, dtype=np.float32)
    for i in range(n):
        u1 = 0.0
        while u1 == 0.0:
            u1 = rng()
        u2 = rng()
        out[i] = float(np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2))
    return out


def _smooth_rms_normalize(x: np.ndarray, sigma: float) -> np.ndarray:
    """Player's `_gaussianSmooth(input, sigma, true)`: Gaussian filter with
    edge-clamp boundary + per-stream RMS normalization to unit RMS."""
    out = gaussian_filter1d(x.astype(np.float32), sigma=float(sigma),
                            mode="nearest")
    rms = float(np.sqrt(np.mean(out * out)) + 1e-6)
    if rms > 0:
        out = out / rms
    return out.astype(np.float32)


def generate_noise_streams(T: int, scenario_id: str,
                           sigma: float = 1.5) -> dict:
    """Four bilateral-correlated noise streams seeded from scenario_id.
    Each stream is unit-RMS after Gaussian smoothing.
    Keys: 'down', 'innerUp', 'outerUp', 'eyeSquint'."""
    seed = _hash_str(scenario_id)
    rngs = {
        'down':      _mulberry32(seed ^ 0xA1A1),
        'innerUp':   _mulberry32(seed ^ 0xB2B2),
        'outerUp':   _mulberry32(seed ^ 0xC3C3),
        'eyeSquint': _mulberry32(seed ^ 0xD4D4),
    }
    streams = {}
    for k, rng in rngs.items():
        raw = _gauss_rand_stream(rng, T)
        streams[k] = _smooth_rms_normalize(raw, sigma)
    return streams


def silence_gate_from_wav(wav: np.ndarray, sr: int, T: int,
                          fps: int = 30) -> np.ndarray:
    """Per-frame soft silence indicator (0=speech, 1=silent).

    Mirrors player's `_computeSilenceGate`: RMS per hop converted to dB,
    soft sigmoid around -45 dB ± 12 dB, then σ=6 Gaussian smooth.
    """
    hop = max(1, int(sr / fps))
    raw = np.zeros(T, dtype=np.float32)
    SILENCE_DB = -45.0
    SOFT_RANGE = 12.0
    n = len(wav)
    for i in range(T):
        a = i * hop
        b = min(a + hop, n)
        if b > a:
            seg = wav[a:b]
            rms = float(np.sqrt(np.mean(seg * seg)) + 1e-9)
        else:
            rms = 1e-9
        db = 20.0 * float(np.log10(rms))
        g = (SILENCE_DB + SOFT_RANGE - db) / SOFT_RANGE
        raw[i] = float(np.clip(g, 0.0, 1.0))
    # σ=6 smoothing matches the player exactly (no RMS normalize).
    return gaussian_filter1d(raw, sigma=6.0, mode="nearest").astype(np.float32)


def apply_tremor(target: np.ndarray,
                 silence_gate: np.ndarray,
                 scenario_id: str,
                 amp: float = 0.014,
                 sigma: float = 1.5) -> np.ndarray:
    """Bake tremor into `target` (T, 52) in-place-equivalent (returns new array).

    Brow channels: jitter scaled by amp * silence_gate * filtered_noise.
    EyeSquint channels: always-on, amp * filtered_noise (no gating).
    Clamped to [0, 1].
    """
    if amp <= 0.0 or sigma <= 0.0:
        return target
    T = target.shape[0]
    if silence_gate.shape[0] != T:
        # Pad/truncate gate to match — should never trigger if caller is correct
        if silence_gate.shape[0] < T:
            pad = T - silence_gate.shape[0]
            silence_gate = np.concatenate([
                silence_gate, np.full(pad, silence_gate[-1], dtype=np.float32),
            ])
        else:
            silence_gate = silence_gate[:T]
    streams = generate_noise_streams(T, scenario_id, sigma=sigma)
    out = target.astype(np.float32, copy=True)
    s = amp * silence_gate           # silence-gated scale per frame
    # Brow group (gated)
    j_down  = (s * streams['down']).astype(np.float32)
    j_inner = (s * streams['innerUp']).astype(np.float32)
    j_outer = (s * streams['outerUp']).astype(np.float32)
    # EyeSquint (always-on)
    j_eye = (amp * streams['eyeSquint']).astype(np.float32)
    for ch in BROW_DOWN_L_R:
        out[:, ch] = out[:, ch] + j_down
    out[:, BROW_INNER_UP] = out[:, BROW_INNER_UP] + j_inner
    for ch in BROW_OUTER_UP_L_R:
        out[:, ch] = out[:, ch] + j_outer
    for ch in EYE_SQUINT_L_R:
        out[:, ch] = out[:, ch] + j_eye
    return np.clip(out, 0.0, 1.0).astype(np.float32)
