#!/usr/bin/env python3
"""Compare V2 brow output for ANGER across all ONNX versions."""
import sys
from pathlib import Path
import numpy as np
import librosa
import onnxruntime as ort

sys.path.insert(0, '/dataset/text-to-face-se/LAM_Audio2Expression')
from distillation.student_model import AudioFeatureExtractor

ONNX_DIR = Path('/dataset/mead-expression-training/e2f/distill')
AUDIO = '/dataset/AnimaSync-mic-fix/data/audio_preview/daily_002_t2_anger.mp3'

# V2 native ARKit order — brow indices
BROW_CH = {
    41: 'browDownL', 42: 'browDownR',
    43: 'browInnerUp',
    44: 'browOuterUpL', 45: 'browOuterUpR',
    47: 'cheekSquintL', 48: 'cheekSquintR',
    10: 'eyeSquintL',
}

# Test all non-int8 ONNX models (int8 loses precision, skip)
MODELS = [
    'emotion_face_v2.onnx',
    'emotion_face.onnx',
    'emotion_face_v8_brow09.onnx',
    'emotion_face_streaming.onnx',
]

wav, sr = librosa.load(AUDIO, sr=16000, mono=True)
feats = AudioFeatureExtractor().extract(wav).astype(np.float32)
T = feats.shape[0]
emotion = np.zeros((1, T, 5), dtype=np.float32)
emotion[0, :, 2] = 1.0  # anger

print(f"Audio: {Path(AUDIO).name}  ({len(wav)/sr:.2f}s, T={T})")
print(f"Emotion: anger (idx 2), intensity 1.0\n")

header = f"{'model':35s} │ " + " │ ".join(f"{n:12s}" for n in BROW_CH.values())
print(header)
print("─" * len(header))

for model in MODELS:
    onnx_path = ONNX_DIR / model
    if not onnx_path.exists():
        continue
    try:
        sess = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
        input_names = {i.name for i in sess.get_inputs()}
        # Some streaming models have different input signature
        if input_names != {'features', 'emotion'}:
            print(f"{model:35s} │ SKIP (inputs: {input_names})")
            continue
        out = sess.run(None, {'features': feats[None, ...], 'emotion': emotion})[0][0]  # (T, 52) V2 order
    except Exception as e:
        print(f"{model:35s} │ ERROR: {type(e).__name__}: {e}")
        continue

    row = [f"{model:35s} │"]
    for ch in BROW_CH:
        mean_v = out[:, ch].mean()
        std_v = out[:, ch].std()
        row.append(f" μ{mean_v:.2f} σ{std_v:.3f}")
    print(" │".join(row))
