"""Main compiler API: blend parametric + archetype layers."""
from __future__ import annotations

from typing import List, Optional

import numpy as np

from .archetype import archetype_blend
from .constants import (
    BLEND_NEUTRAL_THRESHOLD,
    BLEND_NEUTRAL_WEIGHT,
    BLEND_EXTREME_THRESHOLD,
    BLEND_EXTREME_WEIGHT,
    BLEND_DEFAULT_WEIGHT,
    RBF_SIGMA,
)
from .parametric import (
    parametric_layer,
    apply_conflict_resolution,
    apply_channel_mask,
)
from .utils import validate_vad


def _dynamic_blend_weight(vad: np.ndarray) -> float:
    """Per-query blend weight: neutral favors parametric, extreme favors archetype."""
    magnitude = float(np.linalg.norm(vad))
    if magnitude < BLEND_NEUTRAL_THRESHOLD:
        return BLEND_NEUTRAL_WEIGHT
    if magnitude > BLEND_EXTREME_THRESHOLD:
        return BLEND_EXTREME_WEIGHT
    return BLEND_DEFAULT_WEIGHT


def compile_blendshapes(
    emotion: Optional[str],
    vad: np.ndarray,
    presets: dict,
    sigma: float = RBF_SIGMA,
    apply_lipsync_mask: bool = True,
    override_blend_weight: Optional[float] = None,
) -> np.ndarray:
    """Compute 52 blendshape values for given (emotion, VAD).

    Args:
        emotion: emotion family name (e.g. 'joy') for archetype boosting; None → pure VAD
        vad: (3,) in [-1, 1]
        presets: dict of archetype presets
        sigma: RBF bandwidth (archetype layer)
        apply_lipsync_mask: zero out LIPSYNC_ONLY channels
        override_blend_weight: if set, use this instead of dynamic

    Returns:
        (52,) float32 in [0, 1]
    """
    vad = validate_vad(vad)

    # Layer 1: parametric
    param = parametric_layer(vad)

    # Layer 2: archetype (RBF over presets)
    arch = archetype_blend(vad, presets, emotion_hint=emotion, sigma=sigma)

    # Combine
    w = override_blend_weight if override_blend_weight is not None else _dynamic_blend_weight(vad)
    out = param * w + arch * (1.0 - w)

    # Conflict resolution
    out = apply_conflict_resolution(out, vad)

    # Mask LIPSYNC_ONLY channels (compiler doesn't own them)
    if apply_lipsync_mask:
        out = apply_channel_mask(out)

    return np.clip(out, 0.0, 1.0).astype(np.float32)


def compile_batch(
    emotions: List[Optional[str]],
    vads: np.ndarray,
    presets: dict,
    **kwargs,
) -> np.ndarray:
    """Batch version. vads: (N, 3). Returns (N, 52)."""
    N = vads.shape[0]
    if len(emotions) != N:
        raise ValueError(f"emotions length {len(emotions)} != vads {N}")
    out = np.empty((N, 52), dtype=np.float32)
    for i in range(N):
        out[i] = compile_blendshapes(emotions[i], vads[i], presets, **kwargs)
    return out
