#!/usr/bin/env python3
"""Probe streaming v2i ONNX on the anger audio — compare with v8_brow09."""
import sys
from pathlib import Path
import numpy as np
import librosa
import onnxruntime as ort

sys.path.insert(0, '/dataset/text-to-face-se/LAM_Audio2Expression')
from distillation.student_model import AudioFeatureExtractor

AUDIO = '/dataset/AnimaSync-mic-fix/data/audio_preview/daily_002_t2_anger.mp3'
STREAMING = '/dataset/mead-expression-training/e2f/distill/emotion_face_streaming.onnx'
CHUNK = 5
LSTM_LAYERS = 2
LSTM_H = 320
CONV_CTX_FRAMES = 16

BROW_CH = {
    41: 'browDownL', 42: 'browDownR',
    43: 'browInnerUp',
    44: 'browOuterUpL', 45: 'browOuterUpR',
    47: 'cheekSquintL', 48: 'cheekSquintR',
    10: 'eyeSquintL',
}


def run_streaming(sess, feats, emotion_idx, intensity=1.0):
    """Run streaming inference with state threading.

    feats: (T, 141)
    Returns (T, 52) in V2 native order.
    """
    T = feats.shape[0]
    # Pad to multiple of CHUNK
    pad = (CHUNK - T % CHUNK) % CHUNK
    if pad:
        feats = np.concatenate([feats, np.zeros((pad, 141), dtype=np.float32)], axis=0)
    T_padded = feats.shape[0]
    n_chunks = T_padded // CHUNK

    # State init — zeros
    lstm_h = np.zeros((LSTM_LAYERS, 1, LSTM_H), dtype=np.float32)
    lstm_c = np.zeros((LSTM_LAYERS, 1, LSTM_H), dtype=np.float32)
    conv_ctx = np.zeros((1, CONV_CTX_FRAMES, LSTM_H), dtype=np.float32)

    # Emotion one-hot for chunk — broadcast over 5 frames
    emo_chunk = np.zeros((1, CHUNK, 5), dtype=np.float32)
    if emotion_idx >= 0:
        emo_chunk[0, :, emotion_idx] = intensity

    out_all = np.zeros((T_padded, 52), dtype=np.float32)
    for c in range(n_chunks):
        feat_chunk = feats[c*CHUNK:(c+1)*CHUNK][None, ...]  # (1, 5, 141)
        outs = sess.run(None, {
            'features': feat_chunk,
            'emotion': emo_chunk,
            'lstm_h': lstm_h,
            'lstm_c': lstm_c,
            'conv_ctx': conv_ctx,
        })
        bs, lstm_h, lstm_c, conv_ctx = outs
        out_all[c*CHUNK:(c+1)*CHUNK] = bs[0]

    return out_all[:T]


def main():
    wav, sr = librosa.load(AUDIO, sr=16000, mono=True)
    feats = AudioFeatureExtractor().extract(wav).astype(np.float32)
    T = feats.shape[0]
    print(f"Audio: {Path(AUDIO).name}  ({len(wav)/sr:.2f}s, T={T})")

    sess = ort.InferenceSession(STREAMING, providers=['CPUExecutionProvider'])
    print(f"Model: {Path(STREAMING).name}\n")

    for emo_name, emo_idx in [('neutral', -1), ('anger', 2), ('sadness', 3)]:
        out = run_streaming(sess, feats, emo_idx, intensity=1.0)
        print(f"═══ emotion={emo_name} (idx {emo_idx}) ═══")
        print(f"  {'channel':20s} │ {'baseline μ':>10s} │ {'motion σ':>10s} │ range")
        for ch, name in BROW_CH.items():
            m = out[:, ch].mean()
            s = out[:, ch].std()
            lo, hi = out[:, ch].min(), out[:, ch].max()
            mark = ' ✓' if s > 0.01 else (' ·' if m > 0.05 else ' ✗')
            print(f"  {name:20s} │ {m:>10.3f} │ {s:>10.4f} │ [{lo:.2f}, {hi:.2f}]{mark}")
        print()


if __name__ == '__main__':
    main()
