"""Archetype layer: RBF interpolation over authored preset blendshapes.

Per docs/research/vad-to-arkit-blendshape-mapping.md §6.1.
"""
from __future__ import annotations

from typing import Optional

import numpy as np

from .constants import (
    RBF_SIGMA,
    RBF_EMOTION_FAMILY_BOOST,
    RBF_NEUTRAL_DEEMPHASIS,
)


def _extract_family(preset_name: str) -> str:
    """Extract emotion family from preset name (e.g. 'joy_L3' → 'joy')."""
    if "_" in preset_name:
        return preset_name.split("_")[0]
    return preset_name


def archetype_blend(
    vad: np.ndarray,
    presets: dict,
    emotion_hint: Optional[str] = None,
    sigma: float = RBF_SIGMA,
) -> np.ndarray:
    """Weight-blend preset blendshapes by RBF distance in VAD space.

    Args:
        vad: (3,) float, query VAD in [-1, 1]
        presets: dict {name: {"vad": [V,A,D], "bs": np.ndarray(52,)}}
        emotion_hint: optional family name (e.g. 'joy'); boosts same-family presets
        sigma: RBF Gaussian bandwidth

    Returns:
        (52,) float32 weighted blend of preset blendshapes
    """
    if not presets:
        return np.zeros(52, dtype=np.float32)

    vad = np.asarray(vad, dtype=np.float32)
    weights = []
    bs_list = []

    for name, data in presets.items():
        anchor_vad = np.asarray(data["vad"], dtype=np.float32)
        dist = float(np.linalg.norm(vad - anchor_vad))
        w = float(np.exp(-(dist ** 2) / (2.0 * sigma * sigma)))

        # Emotion-family boosting (option C)
        if emotion_hint is not None:
            family = _extract_family(name)
            if family == emotion_hint:
                w *= RBF_EMOTION_FAMILY_BOOST
            elif family == "neutral" and emotion_hint != "neutral":
                w *= RBF_NEUTRAL_DEEMPHASIS

        weights.append(w)
        bs_list.append(np.asarray(data["bs"], dtype=np.float32))

    weights = np.asarray(weights, dtype=np.float32)

    # Edge case: all weights tiny (input far from every anchor)
    if weights.sum() < 1e-6:
        distances = [
            float(np.linalg.norm(vad - np.asarray(d["vad"], dtype=np.float32)))
            for d in presets.values()
        ]
        idx = int(np.argmin(distances))
        return bs_list[idx].copy()

    weights /= weights.sum()
    bs_stack = np.stack(bs_list, axis=0)  # (N, 52)
    return np.einsum("i,ij->j", weights, bs_stack).astype(np.float32)
