#!/usr/bin/env python3
"""V2 spike — verify ONNX model runs on AnimaSync audio, visualize prosody signal.

Usage:
    python3 scripts/compiler/v2_spike.py
"""
import sys
from pathlib import Path
import numpy as np
import librosa
import onnxruntime as ort
import matplotlib.pyplot as plt

sys.path.insert(0, '/dataset/text-to-face-se/LAM_Audio2Expression')
from distillation.student_model import AudioFeatureExtractor

ONNX_PATH = '/dataset/mead-expression-training/e2f/distill/emotion_face_v8_brow09.onnx'
AUDIO_PATH = '/dataset/AnimaSync-mic-fix/data/audio_preview/daily_003_t0_struggle.mp3'
OUT_PNG = '/dataset/AnimaSync-mic-fix/data/v2_spike_daily_003_t0.png'

SR = 16000
FPS = 30

EMOTION_TO_IDX = {'neutral': 0, 'joy': 1, 'anger': 2, 'sadness': 3, 'surprise': 4}
SUB_TO_BASE = {
    'struggle': 'sadness', 'apology': 'sadness', 'sadness': 'sadness',
    'anger': 'anger', 'frustration': 'anger', 'rage': 'anger',
    'joy': 'joy', 'excitement': 'joy', 'relief': 'joy',
    'surprise': 'surprise', 'fear': 'surprise',
    'neutral': 'neutral',
}

CHANNELS_TO_PLOT = {
    'browInnerUp (43)':       43,
    'browDownL (41)':         41,
    'browDownR (42)':         42,
    'browOuterUpL (44)':      44,
    'browOuterUpR (45)':      45,
    'cheekSquintL (47)':      47,
    'cheekSquintR (48)':      48,
    'noseSneerL (49)':        49,
    'eyeSquintL (10)':        10,
    'jawOpen (17)':           17,      # for comparison — lipsync channel
}

EXPRESSION_CHANNELS = [10, 11, 12, 13, 41, 42, 43, 44, 45, 47, 48, 49, 50]


def main():
    # 1. Load audio
    print(f'Loading: {AUDIO_PATH}')
    audio, _ = librosa.load(AUDIO_PATH, sr=SR, mono=True)
    duration = len(audio) / SR
    print(f'  duration: {duration:.2f}s  samples: {len(audio)}')

    # 2. Extract 141-dim features
    print('Extracting 141-dim features...')
    extractor = AudioFeatureExtractor()
    feats = extractor.extract(audio)  # (T, 141)
    T = feats.shape[0]
    print(f'  features: {feats.shape}  → {T / FPS:.2f}s at {FPS}fps')

    # 3. Build emotion vector (1, T, 5). "struggle" → sadness.
    base = SUB_TO_BASE.get('struggle', 'neutral')
    idx = EMOTION_TO_IDX[base]
    emotion = np.zeros((1, T, 5), dtype=np.float32)
    if base != 'neutral':
        emotion[0, :, idx] = 1.0
    print(f'  emotion: struggle → {base} (idx {idx}), intensity 1.0')

    # 4. Load ONNX + infer
    print(f'Loading ONNX: {ONNX_PATH}')
    sess = ort.InferenceSession(ONNX_PATH, providers=['CPUExecutionProvider'])
    feats_in = feats.astype(np.float32)[None, ...]  # (1, T, 141)
    out = sess.run(None, {'features': feats_in, 'emotion': emotion})[0]  # (1, T, 52)
    bs = out[0]  # (T, 52)
    print(f'  output: {bs.shape}  range [{bs.min():.3f}, {bs.max():.3f}]')

    # 5. Stats on expression channels
    expr = bs[:, EXPRESSION_CHANNELS]  # (T, 13)
    mean_per_ch = expr.mean(axis=0)
    std_per_ch = expr.std(axis=0)
    motion = expr - mean_per_ch  # mean-centered
    print('\nExpression-channel stats (mean ± std across time):')
    for i, ch in enumerate(EXPRESSION_CHANNELS):
        print(f'  ch {ch:2d}: mean={mean_per_ch[i]:+.3f}  std={std_per_ch[i]:.3f}  '
              f'motion_range=[{motion[:,i].min():+.3f}, {motion[:,i].max():+.3f}]')
    print(f'\nMean motion amplitude across expression channels: {np.abs(motion).mean():.4f}')
    print(f'Max motion amplitude: {np.abs(motion).max():.4f}')

    # 6. Plot
    print(f'\nSaving plot: {OUT_PNG}')
    t_axis = np.arange(T) / FPS
    fig, axes = plt.subplots(len(CHANNELS_TO_PLOT), 1, figsize=(12, 14), sharex=True)
    for ax, (name, ch) in zip(axes, CHANNELS_TO_PLOT.items()):
        ax.plot(t_axis, bs[:, ch], linewidth=0.9)
        mean_v = bs[:, ch].mean()
        ax.axhline(mean_v, color='r', linestyle='--', alpha=0.4, linewidth=0.5)
        ax.set_ylabel(name, fontsize=8)
        ax.set_ylim(-0.1, max(1.05, bs[:, ch].max() * 1.1))
        ax.grid(True, alpha=0.3)
    axes[-1].set_xlabel('time (s)')
    fig.suptitle(f'V2 output  |  daily_003_t0_struggle  |  emotion=sadness  |  '
                 f'duration={duration:.1f}s', fontsize=10)
    plt.tight_layout()
    plt.savefig(OUT_PNG, dpi=100)
    print('Done.')


if __name__ == '__main__':
    main()
