"""Parametric layer: emotion + VAD → 52 blendshapes via closed-form functions.

Implementation of docs/research/vad-to-arkit-blendshape-mapping.md §4 & §5.

VAD convention: [-1, 1] scale (native anchor convention).
  V⁺ = max(0, V), V⁻ = max(0, -V), same for A, D.
"""
from __future__ import annotations

import numpy as np

from .constants import (
    LIPSYNC_ONLY,
    GAMMA_MOUTH_SMILE_FROWN,
    GAMMA_EYE,
    GAMMA_BROW,
    GAMMA_NOSE,
    VALENCE_EXCLUSION_THRESHOLD,
    DOMINANCE_BROW_CONFLICT_HIGH,
    DOMINANCE_BROW_CONFLICT_LOW,
    AROUSAL_VISIBILITY_GATE,
    AROUSAL_EXTREME_AMP_THRESHOLD,
    AROUSAL_EXTREME_AMP_COEFF,
)


def _soft_gate(x: float, threshold: float, width: float = 0.1) -> float:
    """Smooth step from 0 → 1 around `threshold` with `width` transition."""
    return float(np.clip((x - threshold) / width + 0.5, 0.0, 1.0))


def _pow_safe(x: float, gamma: float) -> float:
    """Raise to power only if positive, return 0 otherwise."""
    return float(np.power(max(0.0, x), gamma)) if x > 0 else 0.0


def parametric_layer(vad: np.ndarray) -> np.ndarray:
    """Compute 52 blendshape values from VAD via parametric rules.

    vad: np.ndarray shape (3,) in [-1, 1]
    Returns: np.ndarray shape (52,) in [0, 1]
    """
    V = float(vad[0])
    A = float(vad[1])
    D = float(vad[2])

    Vp = max(0.0, V)
    Vn = max(0.0, -V)
    Ap = max(0.0, A)
    An = max(0.0, -A)
    Dp = max(0.0, D)
    Dn = max(0.0, -D)

    out = np.zeros(52, dtype=np.float32)

    # ── BROW (0-4) ── [EXPRESSION_ONLY]
    # browDown: primary dominance driver, secondary negative valence
    brow_down = (
        0.55 * _pow_safe(Dp, GAMMA_BROW) * _soft_gate(D, 0.2)
        + 0.20 * _pow_safe(Vn, GAMMA_BROW) * (1.0 if D >= 0.3 else 0.0)
    )
    out[0] = brow_down  # browDownLeft
    out[1] = brow_down  # browDownRight

    # browInnerUp: multi-driver (sadness + surprise + submission)
    brow_inner_up = (
        0.55 * _pow_safe(Vn, GAMMA_BROW)
        + 0.45 * _pow_safe(Ap, GAMMA_EYE) * (1.0 - 0.5 * Dp)
        + 0.60 * _pow_safe(Dn, GAMMA_BROW)
    )
    out[2] = min(1.0, brow_inner_up)

    # browOuterUp: arousal + vulnerability
    brow_outer_up = 0.40 * _pow_safe(Ap, GAMMA_EYE) + 0.35 * _pow_safe(Dn, GAMMA_EYE)
    out[3] = brow_outer_up  # browOuterUpLeft
    out[4] = brow_outer_up  # browOuterUpRight

    # ── CHEEK (5-7) ──
    # cheekPuff (5, SHARED): contempt/disdain marker
    out[5] = 0.10 * _pow_safe(Dp, 2.0) * _pow_safe(Vn, 2.0) * (
        1.0 if (D > 0.6 and V < -0.4) else 0.0
    )

    # cheekSquint (6,7, EXPRESSION_ONLY): Duchenne marker — gated
    cheek_squint = (
        0.65 * _pow_safe(Vp, GAMMA_MOUTH_SMILE_FROWN) * _soft_gate(V, 0.3)
    )
    out[6] = cheek_squint
    out[7] = cheek_squint

    # ── EYE (8-21) ──
    # eyeBlink (8,9): low-arousal baseline droop (NOT pulsed)
    eye_blink = 0.25 * _pow_safe(An, 1.0)
    out[8] = eye_blink
    out[9] = eye_blink

    # eyeLookDown (10,11): gaze aversion in submission/shame
    eye_look_down = 0.30 * _pow_safe(Dn, GAMMA_EYE) * _soft_gate(-D, 0.4)
    out[10] = eye_look_down
    out[11] = eye_look_down

    # eyeLookIn/Out/Up (12-17): gaze targeting, not VAD-driven
    # stays 0

    # eyeSquint (18,19): multi-driver (Duchenne + dominance + relax)
    eye_squint_duchenne = 0.40 * _pow_safe(Vp, GAMMA_EYE) * _soft_gate(V, 0.2)
    eye_squint_dominance = 0.30 * _pow_safe(Dp, GAMMA_EYE)
    eye_squint_relax = 0.20 * _pow_safe(An, 1.0) * (1.0 if abs(V) < 0.3 else 0.0)
    eye_squint = max(eye_squint_duchenne, eye_squint_dominance, eye_squint_relax)
    out[18] = eye_squint
    out[19] = eye_squint

    # eyeWide (20,21): multi-driver (arousal + fear)
    eye_wide_arousal = 0.70 * _pow_safe(Ap, 1.5)
    eye_wide_fear = 0.30 * _pow_safe(Dn, GAMMA_EYE) * _pow_safe(Ap, 0.5)
    eye_wide = min(1.0, eye_wide_arousal + eye_wide_fear)
    out[20] = eye_wide
    out[21] = eye_wide

    # ── JAW (22-25) ──
    # jawForward, jawLeft, jawRight: LIPSYNC_ONLY, stay 0

    # jawOpen (24, SHARED): baseline + arousal (LAM overrides at runtime)
    out[24] = 0.02 + 0.43 * _pow_safe(Ap, 1.0) * (1.0 if A > 0.3 else 0.0)

    # ── MOUTH (26-48) ──
    # mouthClose, mouthFunnel, mouthLeft, mouthRight, mouthLowerDownL/R,
    # mouthPucker, mouthRollLower, mouthRollUpper,
    # mouthStretchL/R, mouthUpperUpL/R: LIPSYNC_ONLY, stay 0

    # mouthDimple (27,28, SHARED): positive valence
    mouth_dimple = 0.25 * _pow_safe(Vp, GAMMA_MOUTH_SMILE_FROWN)
    out[27] = mouth_dimple
    out[28] = mouth_dimple

    # mouthFrown (29,30, SHARED): negative valence
    mouth_frown = 0.75 * _pow_safe(Vn, GAMMA_MOUTH_SMILE_FROWN)
    out[29] = mouth_frown
    out[30] = mouth_frown

    # mouthPress (35,36, SHARED): multi-driver (dominance + suppressed negative)
    mouth_press = min(
        1.0,
        0.25 * _pow_safe(Dp, GAMMA_BROW)
        + 0.30 * _pow_safe(Vn, GAMMA_BROW) * (1.0 if D >= 0.2 else 0.0),
    )
    out[35] = mouth_press
    out[36] = mouth_press

    # mouthShrug (41,42, SHARED)
    mouth_shrug = 0.20 * _pow_safe(Ap, GAMMA_EYE) + 0.15 * _pow_safe(Dp, GAMMA_EYE)
    out[41] = mouth_shrug  # lower
    out[42] = mouth_shrug  # upper

    # mouthSmile (43,44, SHARED): primary positive valence
    mouth_smile = 0.85 * _pow_safe(Vp, GAMMA_MOUTH_SMILE_FROWN)
    out[43] = mouth_smile
    out[44] = mouth_smile

    # ── NOSE (49,50, SHARED) ──
    # noseSneer: multi-driver (disgust + dominance + arousal)
    nose_disgust = 0.60 * _pow_safe(Vn, GAMMA_NOSE) * (1.0 if V < -0.3 else 0.0)
    nose_dominance = 0.35 * _pow_safe(Dp, GAMMA_NOSE)
    nose_arousal = 0.15 * _pow_safe(Ap, GAMMA_NOSE) * (1.0 if A > 0.75 else 0.0)
    nose_sneer = min(1.0, max(nose_disgust, nose_dominance) + nose_arousal)
    out[49] = nose_sneer
    out[50] = nose_sneer

    # ── TONGUE (51) ── LIPSYNC_ONLY, stays 0

    return np.clip(out, 0.0, 1.0).astype(np.float32)


def apply_conflict_resolution(bs: np.ndarray, vad: np.ndarray) -> np.ndarray:
    """Apply §5.4 conflict resolution rules to a 52-dim blendshape vector."""
    V = float(vad[0])
    A = float(vad[1])
    D = float(vad[2])

    out = bs.copy()

    # Rule 2: Valence polarity exclusion (smile ↔ frown)
    if V >= VALENCE_EXCLUSION_THRESHOLD:
        frown_scale = max(0.0, 1.0 - V / VALENCE_EXCLUSION_THRESHOLD)
        out[29] *= frown_scale
        out[30] *= frown_scale
    elif V <= -VALENCE_EXCLUSION_THRESHOLD:
        smile_scale = max(0.0, 1.0 - (-V) / VALENCE_EXCLUSION_THRESHOLD)
        out[43] *= smile_scale
        out[44] *= smile_scale

    # Rule 3: Brow conflict (browDown vs browInnerUp)
    if D > DOMINANCE_BROW_CONFLICT_HIGH:
        blend = min(1.0, (D - DOMINANCE_BROW_CONFLICT_HIGH) / 0.2)
        out[2] *= (1.0 - blend)  # suppress browInnerUp
    elif D < DOMINANCE_BROW_CONFLICT_LOW:
        out[0] = 0.0  # force browDown to 0
        out[1] = 0.0

    # Rule 4: Arousal visibility gate (attenuate everything when extremely low arousal)
    if A < AROUSAL_VISIBILITY_GATE:
        scale = max(0.0, (A + 1.0) / (1.0 + AROUSAL_VISIBILITY_GATE))
        out *= scale

    # Rule 1: Extreme arousal amplification
    if A > AROUSAL_EXTREME_AMP_THRESHOLD:
        amp = 1.0 + (A - AROUSAL_EXTREME_AMP_THRESHOLD) * AROUSAL_EXTREME_AMP_COEFF
        out *= amp

    return np.clip(out, 0.0, 1.0).astype(np.float32)


def apply_channel_mask(bs: np.ndarray) -> np.ndarray:
    """Zero out LIPSYNC_ONLY channels (compiler doesn't own them)."""
    out = bs.copy()
    for idx in LIPSYNC_ONLY:
        out[idx] = 0.0
    return out
