"""Compiler validation tests.

Run:
    cd /dataset/AnimaSync-mic-fix
    python -m scripts.compiler.test_compiler
"""
from __future__ import annotations

import itertools

import numpy as np

from .blend import compile_blendshapes, compile_batch
from .constants import ARKIT_52_NAMES, LIPSYNC_ONLY
from .utils import build_synthetic_presets


def test_anchor_reproducibility(presets, tol=0.15):
    """Compiler called on anchor VAD should return ~= its blendshape."""
    print("\n=== Test 1: Anchor reproducibility ===")
    max_l1 = 0.0
    worst = None
    for name, data in presets.items():
        emotion = name.split("_")[0]
        vad = np.asarray(data["vad"], dtype=np.float32)
        out = compile_blendshapes(emotion, vad, presets, apply_lipsync_mask=False)
        l1 = float(np.abs(out - data["bs"]).mean())
        if l1 > max_l1:
            max_l1 = l1
            worst = name
    status = "PASS" if max_l1 < tol else "FAIL"
    print(f"  {status} — max L1 distance {max_l1:.4f} (tol {tol}), worst: {worst}")
    return max_l1 < tol


def test_smoothness(presets, tol=0.10, n_probes=20):
    """±0.05 VAD perturbation should produce <tol L1 delta."""
    print("\n=== Test 2: Smoothness (±0.05 VAD jitter) ===")
    np.random.seed(42)
    max_delta = 0.0
    for _ in range(n_probes):
        vad = np.random.uniform(-1, 1, size=3).astype(np.float32)
        jitter = np.random.uniform(-0.05, 0.05, size=3).astype(np.float32)
        vad2 = np.clip(vad + jitter, -1, 1)
        out1 = compile_blendshapes(None, vad, presets)
        out2 = compile_blendshapes(None, vad2, presets)
        delta = float(np.abs(out1 - out2).mean())
        if delta > max_delta:
            max_delta = delta
    status = "PASS" if max_delta < tol else "FAIL"
    print(f"  {status} — max L1 delta for ±0.05 jitter: {max_delta:.4f} (tol {tol})")
    return max_delta < tol


def test_valid_range(presets, n_probes=50):
    """All outputs must be in [0, 1]."""
    print("\n=== Test 3: Valid range [0, 1] ===")
    np.random.seed(0)
    violations = 0
    for _ in range(n_probes):
        vad = np.random.uniform(-1, 1, size=3).astype(np.float32)
        out = compile_blendshapes(None, vad, presets)
        if np.any(out < -1e-5) or np.any(out > 1 + 1e-5):
            violations += 1
    status = "PASS" if violations == 0 else "FAIL"
    print(f"  {status} — {violations}/{n_probes} violations")
    return violations == 0


def test_emotion_distinctness(presets, tol=0.02):
    """16 emotions at mid intensity should be pairwise distinguishable."""
    print("\n=== Test 4: Emotion distinctness (L3) ===")
    emotions_to_test = sorted({name.split("_")[0] for name in presets.keys()})

    # Collect L3 (mid) VAD for each emotion
    l3_vads = {}
    for emo in emotions_to_test:
        candidate = f"{emo}_L3"
        if candidate in presets:
            l3_vads[emo] = np.asarray(presets[candidate]["vad"], dtype=np.float32)
        else:
            # fallback: any level
            matches = [n for n in presets if n.startswith(emo)]
            if matches:
                l3_vads[emo] = np.asarray(presets[matches[0]]["vad"], dtype=np.float32)

    outputs = {
        emo: compile_blendshapes(emo, vad, presets)
        for emo, vad in l3_vads.items()
    }

    min_dist = float("inf")
    worst_pair = None
    for a, b in itertools.combinations(l3_vads.keys(), 2):
        d = float(np.abs(outputs[a] - outputs[b]).mean())
        if d < min_dist:
            min_dist = d
            worst_pair = (a, b)
    status = "PASS" if min_dist > tol else "FAIL"
    print(f"  {status} — min pairwise L1: {min_dist:.4f} (tol {tol}), closest pair: {worst_pair}")
    return min_dist > tol


def test_intensity_monotonicity(presets):
    """For same emotion, higher intensity should produce monotonically stronger expression."""
    print("\n=== Test 5: Intensity progression (L1 < L3 < L5) ===")
    emotions_with_all_levels = []
    for name in presets.keys():
        emo = name.split("_")[0]
        if (f"{emo}_L1" in presets and f"{emo}_L3" in presets and f"{emo}_L5" in presets
                and emo not in emotions_with_all_levels):
            emotions_with_all_levels.append(emo)

    failures = []
    for emo in emotions_with_all_levels:
        out_l1 = compile_blendshapes(emo, np.asarray(presets[f"{emo}_L1"]["vad"]), presets)
        out_l3 = compile_blendshapes(emo, np.asarray(presets[f"{emo}_L3"]["vad"]), presets)
        out_l5 = compile_blendshapes(emo, np.asarray(presets[f"{emo}_L5"]["vad"]), presets)

        # Mean activation should increase
        m1, m3, m5 = float(out_l1.mean()), float(out_l3.mean()), float(out_l5.mean())
        if not (m1 <= m3 <= m5):
            failures.append((emo, m1, m3, m5))

    status = "PASS" if not failures else f"PARTIAL ({len(failures)} violations)"
    print(f"  {status} — checked {len(emotions_with_all_levels)} emotions with full L1/L3/L5")
    for f in failures[:3]:
        print(f"    non-monotonic: {f[0]}: L1={f[1]:.3f} L3={f[2]:.3f} L5={f[3]:.3f}")
    return len(failures) == 0


def main():
    print("Building synthetic presets (anchor VAD × parametric layer)...")
    presets = build_synthetic_presets()
    print(f"  → {len(presets)} synthetic presets loaded")

    results = {
        "anchor_reproducibility": test_anchor_reproducibility(presets),
        "smoothness": test_smoothness(presets),
        "valid_range": test_valid_range(presets),
        "emotion_distinctness": test_emotion_distinctness(presets),
        "intensity_monotonicity": test_intensity_monotonicity(presets),
    }

    print("\n" + "=" * 60)
    print(f"{'SUMMARY':^60}")
    print("=" * 60)
    for name, passed in results.items():
        print(f"  {'✓' if passed else '✗'} {name}")
    all_passed = all(results.values())
    print(f"\n  {'ALL PASSED' if all_passed else 'SOME FAILED'}")


if __name__ == "__main__":
    main()
