"""Expressive compile: blend user-authored L1/L3/L5 presets within emotion family.

Philosophy: when a turn has an emotion label, the user's authored presets ARE
the ground truth. VAD determines intensity level (L1 mild, L5 extreme) via RBF
distance weighting on anchor VAD — NOT by blending with parametric rules or
other emotion families.

Result: full expressivity of authored L5 poses at extreme VAD, faithful fall-off
to L3/L1 at milder VAD. Emotion label is dominant.
"""
from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np

from .constants import LIPSYNC_ONLY
from .parametric import parametric_layer
from .utils import validate_vad

_ANCHOR_PATH = Path(__file__).resolve().parents[2] / "data" / "emotion" / "emotion_vad_anchors.json"

# RBF bandwidth for within-emotion L1/L3/L5 blending.
# Small σ → sharp level selection; larger σ → smoother transitions.
EXPRESSIVE_SIGMA = 0.35

# Option E: channel-masked parametric overlay.
# These are "valence-coloring" channels — mouth/cheek shapes that should
# follow VAD continuously even when the authored preset is silent on them.
# "Surprise-defining" channels (eyeWide, browInnerUp, browOuterUp, jawOpen)
# are deliberately NOT in this mask so the authored emotion structure is
# preserved.
MOUTH_CHEEK_OVERLAY_CHANNELS = [
    6, 7,    # cheekSquintLeft, cheekSquintRight
    27, 28,  # mouthDimpleLeft, mouthDimpleRight
    29, 30,  # mouthFrownLeft,  mouthFrownRight
    43, 44,  # mouthSmileLeft,  mouthSmileRight
]

_anchor_cache: Optional[Dict[str, Dict[int, np.ndarray]]] = None


def _load_anchors() -> Dict[str, Dict[int, np.ndarray]]:
    global _anchor_cache
    if _anchor_cache is not None:
        return _anchor_cache
    raw = json.loads(_ANCHOR_PATH.read_text(encoding="utf-8"))["anchors"]
    out: Dict[str, Dict[int, np.ndarray]] = {}
    for emo, entries in raw.items():
        out[emo] = {int(e["level"]): np.asarray(e["vad"], dtype=np.float32) for e in entries}
    _anchor_cache = out
    return out


def compile_expressive(
    emotion: Optional[str],
    vad: np.ndarray,
    presets: Dict[str, dict],
    sigma: float = EXPRESSIVE_SIGMA,
    parametric_overlay_channels: Optional[List[int]] = MOUTH_CHEEK_OVERLAY_CHANNELS,
    parametric_overlay_intensity: float = 1.0,
) -> np.ndarray:
    """Return 52-dim blendshape by blending authored L1/L3/L5 presets of emotion,
    with optional channel-masked parametric overlay (Option E).

    Args:
        emotion: emotion family name (e.g. 'crying'). None → neutral.
        vad: (3,) target VAD.
        presets: dict with keys like 'crying_L1', 'crying_L3', 'crying_L5'.
        sigma: RBF bandwidth for level blending.
        parametric_overlay_channels: list of channel indices on which to add
            the parametric layer's output via max() merge. Defaults to
            mouth/cheek "valence-coloring" channels. Pass None or [] to disable.
        parametric_overlay_intensity: scalar α applied to the parametric
            output before max-merge (1.0 = pure max, lower = milder coloring).

    Option E rationale: emotion + within-emotion RBF defines the structural
    pose (eyes, brows, jaw). Valence then continuously colors the mouth/cheek
    so happy-surprise (V+) and sad-surprise (V−) become visually distinct
    without re-authoring presets per V quadrant.
    """
    vad = validate_vad(vad)

    if emotion is None or emotion == "neutral":
        neutral = presets.get("neutral_L3")
        if neutral is not None:
            return np.clip(np.asarray(neutral["bs"], dtype=np.float32), 0, 1)
        return np.zeros(52, dtype=np.float32)

    anchors = _load_anchors().get(emotion)
    if anchors is None:
        return np.zeros(52, dtype=np.float32)

    # Collect authored levels and their anchor VADs
    levels: List[int] = []
    anchor_vads: List[np.ndarray] = []
    bs_vectors: List[np.ndarray] = []
    for L in (1, 3, 5):
        key = f"{emotion}_L{L}"
        if key not in presets or L not in anchors:
            continue
        levels.append(L)
        anchor_vads.append(anchors[L])
        bs_vectors.append(np.asarray(presets[key]["bs"], dtype=np.float32))

    if not bs_vectors:
        return np.zeros(52, dtype=np.float32)

    # RBF weights over authored levels
    dists = np.array([np.linalg.norm(vad - av) for av in anchor_vads], dtype=np.float32)
    weights = np.exp(-(dists ** 2) / (2.0 * sigma ** 2))
    weights /= weights.sum()

    # Weighted blend (within-emotion authored ground truth)
    bs_stack = np.stack(bs_vectors, axis=0)  # (n_levels, 52)
    out = (weights[:, None] * bs_stack).sum(axis=0)

    # Option E: max-merge parametric layer onto valence-coloring channels.
    # Authored values always win when they exceed the parametric output —
    # this preserves authored intent and only fills in mouth/cheek shape
    # when the preset is silent (e.g. surprise has no authored mouth shape).
    if parametric_overlay_channels:
        param = parametric_layer(vad)
        mask = np.asarray(parametric_overlay_channels, dtype=np.int64)
        out[mask] = np.maximum(out[mask], parametric_overlay_intensity * param[mask])

    return np.clip(out, 0.0, 1.0).astype(np.float32)


def compile_expressive_batch(
    emotions: List[Optional[str]],
    vads: np.ndarray,
    presets: Dict[str, dict],
    sigma: float = EXPRESSIVE_SIGMA,
    parametric_overlay_channels: Optional[List[int]] = MOUTH_CHEEK_OVERLAY_CHANNELS,
    parametric_overlay_intensity: float = 1.0,
) -> np.ndarray:
    """Batch version. vads: (N, 3). Returns (N, 52)."""
    N = vads.shape[0]
    out = np.empty((N, 52), dtype=np.float32)
    for i in range(N):
        out[i] = compile_expressive(
            emotions[i], vads[i], presets,
            sigma=sigma,
            parametric_overlay_channels=parametric_overlay_channels,
            parametric_overlay_intensity=parametric_overlay_intensity,
        )
    return out
