"""Live demo server for the V3 face pipeline.

    user types text in browser
        ↓ POST /api/generate {text}
    KlueTeacher (text → emotion + VAD)
        ↓
    ElevenLabs TTS (text + emotion → mp3)
        ↓
    mel features  +  cond from teacher prediction
        ↓
    V3 face model  +  locked post-processing
        ↓
    blendshapes.json (+ audio.mp3) served back to browser
        ↓
    blendshape-player-live.html renders the avatar with audio

Usage:
    ELEVENLABS_API_KEY=sk-... PYTHONPATH=. python3 -m models.v3_face.serve_live

Then open http://localhost:8091/ in your browser.
"""
from __future__ import annotations

import asyncio
import json
import os
import time
from pathlib import Path

import librosa
import numpy as np
import torch
from flask import Flask, request, jsonify, send_from_directory

from scripts.compiler.constants import ARKIT_52_NAMES
from scripts.compiler.data_pipeline import EMOTION_LABELS, FPS, mel_features
from scripts.compiler.tts import synth_one_elevenlabs

from .infer import (
    crisp_mouth, smooth_brows, inject_blinks, load_model,
)
from .infer_e2e import (
    load_teacher, teacher_predict, build_cond_smoothed,
)

PROJECT_ROOT = Path(__file__).resolve().parents[2]
LIVE_DIR = PROJECT_ROOT / "data" / "viewer_live"
LIVE_DIR.mkdir(exist_ok=True, parents=True)
TOOLS_DIR = PROJECT_ROOT / "tools"
AVATAR_DIR = PROJECT_ROOT / "avatar"
CKPT = PROJECT_ROOT / "models" / "v3_face" / "checkpoints" / "best_expression.pt"
TEACHER_CKPT = (PROJECT_ROOT / "checkpoints" /
                "klue_teacher_clean_ctx2" / "best.pt")
TOKENIZER_DIR = (PROJECT_ROOT / "checkpoints" /
                 "klue_teacher_clean_ctx2" / "tokenizer")

POST_PROC = {
    "crisp_threshold": 0.3,
    "crisp_scale": 1.0,
    "crisp_sigma": 1.3,
    "crisp_mouthclose_sigma": 1.0,
    "brow_min_cutoff": 2.0,
    "brow_beta": 0.01,
    "brow_d_cutoff": 1.0,
    "blink_interval": 6.0,
    "blink_expressive_cap": 0.5,
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"[live] device = {device}")
print(f"[live] loading V3 from {CKPT}")
model, cfg = load_model(CKPT, device)
print(f"[live] loading KlueTeacher from {TEACHER_CKPT}")
teacher, tokenizer = load_teacher(TEACHER_CKPT, TOKENIZER_DIR, device)
print(f"[live] ready ✓")

app = Flask(__name__)


@app.route("/")
def index():
    return send_from_directory(str(TOOLS_DIR), "blendshape-player-live.html")


@app.route("/tools/<path:filename>")
def serve_tool(filename):
    return send_from_directory(str(TOOLS_DIR), filename)


@app.route("/live/<path:filename>")
def serve_live(filename):
    return send_from_directory(str(LIVE_DIR), filename)


@app.route("/avatar/<path:filename>")
def serve_avatar(filename):
    return send_from_directory(str(AVATAR_DIR), filename)


@app.route("/api/generate", methods=["POST"])
def generate():
    data = request.get_json(silent=True) or {}
    text = (data.get("text") or "").strip()
    voice_override = data.get("voice_id") or None
    if not text:
        return jsonify({"error": "empty text"}), 400

    t_total = time.perf_counter()
    sid = f"live_{int(time.time() * 1000)}"

    # 1. KlueTeacher: text → emotion + VAD
    t0 = time.perf_counter()
    emos, probs, vads = teacher_predict(teacher, tokenizer, [text], device)
    emo, vad = emos[0], vads[0]
    t_teacher = time.perf_counter() - t0

    # 2. TTS: synthesize voice matching emotion
    t0 = time.perf_counter()
    audio_path = LIVE_DIR / f"{sid}.mp3"
    try:
        asyncio.run(synth_one_elevenlabs(
            text=text, out_path=audio_path,
            voice_id=voice_override,
            emotion=emo, vad=list(vad),
            voice_seed=sid,
        ))
    except Exception as e:
        return jsonify({"error": f"TTS failed: {e}"}), 500
    if not audio_path.exists():
        return jsonify({"error": "TTS returned no audio"}), 500
    t_tts = time.perf_counter() - t0

    # 3. Audio → mel features
    t0 = time.perf_counter()
    wav, sr = librosa.load(str(audio_path), sr=16000, mono=True)
    if len(wav) < 16000 * 0.1:
        return jsonify({"error": "audio too short"}), 500
    mel = mel_features(wav, sr=sr, fps=FPS)
    T = mel.shape[0]
    t_mel = time.perf_counter() - t0

    # 4. Build cond (single turn — no cross-turn smoothing needed)
    cond = build_cond_smoothed([emo], np.expand_dims(vad, axis=0), [T],
                                vad_smooth_sigma=0.0)

    # 5. V3 forward
    t0 = time.perf_counter()
    audio_t = torch.from_numpy(mel.astype(np.float32)).unsqueeze(0).to(device)
    cond_t = torch.from_numpy(cond.astype(np.float32)).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(audio_t, cond_t).squeeze(0).cpu().numpy().astype(np.float32)
    t_v3 = time.perf_counter() - t0

    # 6. Locked post-processing (same as offline e2e)
    t0 = time.perf_counter()
    pred = crisp_mouth(pred,
                       threshold=POST_PROC["crisp_threshold"],
                       scale=POST_PROC["crisp_scale"],
                       pre_smooth_sigma=POST_PROC["crisp_sigma"],
                       mouth_close_sigma=POST_PROC["crisp_mouthclose_sigma"])
    pred = smooth_brows(pred,
                        min_cutoff=POST_PROC["brow_min_cutoff"],
                        beta=POST_PROC["brow_beta"],
                        d_cutoff=POST_PROC["brow_d_cutoff"])
    pred = inject_blinks(pred, scenario_id=sid,
                         mean_interval_s=POST_PROC["blink_interval"],
                         expressive_cap=POST_PROC["blink_expressive_cap"])
    t_post = time.perf_counter() - t0

    # 7. Write JSON
    top3 = [
        {"label": EMOTION_LABELS[i], "prob": float(probs[0][i])}
        for i in np.argsort(probs[0])[-3:][::-1]
    ]
    blendshapes_json = {
        "scenario_id": sid,
        "fps": 30,
        "num_frames": int(pred.shape[0]),
        "names": ARKIT_52_NAMES,
        "turns": [{
            "turn_idx": 0,
            "emotion": emo,
            "vad": vad.tolist(),
            "text": text,
            "top3_emotions": top3,
        }],
        "blendshapes": np.round(pred, 4).tolist(),
    }
    json_path = LIVE_DIR / f"{sid}.json"
    json_path.write_text(json.dumps(blendshapes_json, ensure_ascii=False))

    t_total = time.perf_counter() - t_total
    print(f"[live] {sid} text={text[:30]!r} emo={emo} "
          f"vad={vad.round(2).tolist()} "
          f"T={T} total={t_total:.2f}s "
          f"(teacher={t_teacher:.2f}s tts={t_tts:.2f}s "
          f"v3={t_v3:.3f}s post={t_post:.3f}s)")

    return jsonify({
        "scenario_id": sid,
        "blendshapes_url": f"/live/{sid}.json",
        "audio_url": f"/live/{sid}.mp3",
        "emotion": emo,
        "vad": vad.tolist(),
        "top3": top3,
        "num_frames": int(pred.shape[0]),
        "timing": {
            "teacher_ms": int(t_teacher * 1000),
            "tts_ms": int(t_tts * 1000),
            "mel_ms": int(t_mel * 1000),
            "v3_ms": int(t_v3 * 1000),
            "post_ms": int(t_post * 1000),
            "total_ms": int(t_total * 1000),
        },
    })


if __name__ == "__main__":
    port = int(os.getenv("PORT", "8091"))
    app.run(host="0.0.0.0", port=port, debug=False, threaded=False)
