"""Quick browser demo: type Korean text → see emotion + VAD prediction.

Usage:
    CUDA_VISIBLE_DEVICES=4 python scripts/demo_server.py \
        --ckpt checkpoints/klue_teacher_clean_ctx2/best.pt \
        --context_window 2 --port 8899
"""
from __future__ import annotations

import argparse
import json
import sys
from functools import partial
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

import torch
from flask import Flask, jsonify, request, send_from_directory

from models.microalbert.config import EMOTION_LABELS, MicroAlbertConfig
from models.microalbert.teacher import KlueTeacherForEmotionVAD
from models.microalbert.tokenizer import MicroTokenizer

app = Flask(__name__)
model = None
tok = None
cfg = None
ctx_window = 0

EMOTION_COLORS = {
    "neutral": "#9E9E9E", "joy": "#FFD54F", "laughter": "#FFF176",
    "excitement": "#FF8A65", "agreement": "#81C784", "gratitude": "#A5D6A7",
    "sadness": "#64B5F6", "crying": "#42A5F5", "sulk": "#CE93D8",
    "apology": "#B0BEC5", "struggle": "#FFAB91", "anger": "#EF5350",
    "refusal": "#F48FB1", "surprise": "#FFE082", "fluster": "#FFCC80",
    "shy": "#F8BBD0",
}

HTML = """<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>Emotion Demo</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: -apple-system, sans-serif; background: #1a1a2e; color: #eee;
       display: flex; justify-content: center; padding: 40px; }
.container { max-width: 700px; width: 100%; }
h1 { font-size: 1.4em; margin-bottom: 20px; color: #aaa; }
.input-row { display: flex; gap: 10px; margin-bottom: 10px; }
input[type=text] { flex: 1; padding: 14px; font-size: 16px; border: 1px solid #333;
                   background: #16213e; color: #eee; border-radius: 8px; outline: none; }
input[type=text]:focus { border-color: #5c7cfa; }
button { padding: 14px 28px; background: #5c7cfa; color: white; border: none;
         border-radius: 8px; font-size: 16px; cursor: pointer; }
button:hover { background: #4263eb; }
.context-area { margin-bottom: 20px; }
.context-area textarea { width: 100%; height: 60px; padding: 10px; font-size: 13px;
    background: #16213e; color: #aaa; border: 1px solid #333; border-radius: 8px;
    resize: vertical; outline: none; }
.result { margin-top: 24px; padding: 24px; background: #16213e; border-radius: 12px;
          display: none; }
.emotion-label { font-size: 2em; font-weight: bold; margin-bottom: 8px; }
.confidence { color: #aaa; margin-bottom: 20px; }
.vad-row { display: flex; gap: 16px; margin-bottom: 16px; }
.vad-item { flex: 1; }
.vad-name { font-size: 0.8em; color: #888; margin-bottom: 4px; }
.vad-bar-bg { height: 8px; background: #333; border-radius: 4px; position: relative; }
.vad-bar { height: 8px; border-radius: 4px; position: absolute; }
.vad-val { font-size: 0.9em; margin-top: 4px; }
.top3 { margin-top: 16px; }
.top3-item { display: flex; align-items: center; gap: 8px; margin-bottom: 6px; }
.top3-dot { width: 12px; height: 12px; border-radius: 50%; }
.top3-name { width: 80px; font-size: 0.85em; }
.top3-bar-bg { flex: 1; height: 6px; background: #333; border-radius: 3px; }
.top3-bar { height: 6px; border-radius: 3px; }
.top3-pct { width: 45px; text-align: right; font-size: 0.8em; color: #aaa; }
.history { margin-top: 20px; }
.history-item { padding: 8px 12px; margin-bottom: 4px; border-radius: 6px;
                font-size: 0.85em; display: flex; justify-content: space-between; }
</style></head>
<body>
<div class="container">
    <h1>MicroALBERT Emotion Demo</h1>
    <div class="context-area">
        <textarea id="context" placeholder="(optional) previous turns, one per line&#10;e.g.:&#10;나 오늘 시험 봤어&#10;어떻게 됐어?"></textarea>
    </div>
    <div class="input-row">
        <input type="text" id="text" placeholder="Type Korean text..." autofocus>
        <button onclick="predict()">Predict</button>
    </div>
    <div class="result" id="result">
        <div class="emotion-label" id="emo"></div>
        <div class="confidence" id="conf"></div>
        <div class="vad-row" id="vad"></div>
        <div class="top3" id="top3"></div>
    </div>
    <div class="history" id="history"></div>
</div>
<script>
const COLORS = EMOTION_COLORS_JSON;
async function predict() {
    const text = document.getElementById('text').value.trim();
    if (!text) return;
    const ctx = document.getElementById('context').value.trim().split('\\n').filter(l=>l);
    const res = await fetch('/predict', {
        method: 'POST', headers: {'Content-Type': 'application/json'},
        body: JSON.stringify({text, context: ctx})
    });
    const d = await res.json();
    document.getElementById('result').style.display = 'block';
    document.getElementById('emo').textContent = d.emotion;
    document.getElementById('emo').style.color = COLORS[d.emotion] || '#fff';
    document.getElementById('conf').textContent = `confidence: ${(d.confidence*100).toFixed(1)}%`;

    // VAD bars
    const vadEl = document.getElementById('vad');
    vadEl.innerHTML = '';
    ['Valence','Arousal','Dominance'].forEach((name,i) => {
        const v = d.vad[i];
        const pct = ((v + 1) / 2 * 100);
        const color = i===0 ? (v>0?'#4CAF50':'#f44336') : i===1 ? '#FF9800' : '#2196F3';
        vadEl.innerHTML += `<div class="vad-item"><div class="vad-name">${name}</div>
            <div class="vad-bar-bg"><div class="vad-bar" style="left:50%;width:${Math.abs(v)*50}%;
            ${v>0?'':'transform:translateX(-100%);'}background:${color}"></div></div>
            <div class="vad-val">${v>0?'+':''}${v.toFixed(3)}</div></div>`;
    });

    // Top 3
    const t3 = document.getElementById('top3');
    t3.innerHTML = '<div style="font-size:0.8em;color:#666;margin-bottom:8px">Top predictions</div>';
    d.top5.forEach(t => {
        t3.innerHTML += `<div class="top3-item">
            <div class="top3-dot" style="background:${COLORS[t.emotion]||'#666'}"></div>
            <div class="top3-name">${t.emotion}</div>
            <div class="top3-bar-bg"><div class="top3-bar" style="width:${t.prob*100}%;background:${COLORS[t.emotion]||'#666'}"></div></div>
            <div class="top3-pct">${(t.prob*100).toFixed(1)}%</div></div>`;
    });

    // History
    const hist = document.getElementById('history');
    hist.innerHTML = `<div class="history-item" style="background:${COLORS[d.emotion]||'#333'}22;border-left:3px solid ${COLORS[d.emotion]||'#666'}">
        <span>${text}</span><span style="color:${COLORS[d.emotion]}">${d.emotion} (${(d.confidence*100).toFixed(0)}%)</span>
    </div>` + hist.innerHTML;

    document.getElementById('text').value = '';
    document.getElementById('text').focus();
}
document.getElementById('text').addEventListener('keydown', e => { if(e.key==='Enter') predict(); });
</script>
</body></html>"""


@app.route("/")
def index():
    html = HTML.replace("EMOTION_COLORS_JSON", json.dumps(EMOTION_COLORS))
    return html


@app.route("/predict", methods=["POST"])
def predict():
    data = request.json
    text = data.get("text", "").strip()
    context = data.get("context", [])
    if not text:
        return jsonify({"error": "empty text"}), 400

    # Build input with context
    if ctx_window > 0 and context:
        parts = []
        for i, prev in enumerate(context[-ctx_window:]):
            parts.append(f"[OTHER] {prev}")
        parts.append(f"[SELF] {text}")
        full_text = " [SEP] ".join(parts)
    elif ctx_window > 0:
        full_text = f"[SELF] {text}"
    else:
        full_text = text

    ids = tok.encode(full_text, cfg.max_seq_len)
    input_ids = torch.tensor([ids], dtype=torch.long).to(next(model.parameters()).device)
    attn = torch.ones_like(input_ids)

    with torch.no_grad():
        out = model(input_ids, attn)
        logits = out["emotion_logits"][0]
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
        vad = out["vad"][0].cpu().numpy().tolist()

    pred_id = int(probs.argmax())
    top5_ids = probs.argsort()[::-1][:5]

    return jsonify({
        "emotion": EMOTION_LABELS[pred_id],
        "confidence": float(probs[pred_id]),
        "vad": [round(v, 3) for v in vad],
        "top5": [
            {"emotion": EMOTION_LABELS[int(i)], "prob": round(float(probs[i]), 4)}
            for i in top5_ids
        ],
    })


def main():
    global model, tok, cfg, ctx_window

    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=Path, required=True)
    ap.add_argument("--context_window", type=int, default=0)
    ap.add_argument("--port", type=int, default=8899)
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args()
    ctx_window = args.context_window

    ckpt = torch.load(args.ckpt, map_location=args.device, weights_only=False)
    cfg = MicroAlbertConfig(**ckpt["config"]) if isinstance(ckpt.get("config"), dict) else MicroAlbertConfig()

    tok = MicroTokenizer.build(
        save_dir=args.ckpt.parent / "tokenizer",
        train_jsonl=Path("data/emotion/seed_val.jsonl"),
        max_len=cfg.max_seq_len,
        add_speaker_tokens=(ctx_window > 0),
    )

    model = KlueTeacherForEmotionVAD(
        model_name=ckpt.get("model_name", "klue/roberta-base"),
        num_emotions=cfg.num_emotions,
        vad_dim=cfg.vad_dim,
        dropout=0.0,
        vad_head_hidden=cfg.vad_head_hidden,
        attention_dropout=0.0,
    ).to(args.device)
    if len(tok) != model.backbone.config.vocab_size:
        model.backbone.resize_token_embeddings(len(tok))
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    print(f"\nModel loaded: {args.ckpt}")
    print(f"Context window: {ctx_window}")
    print(f"Device: {args.device}")
    print(f"Open http://localhost:{args.port}\n")

    app.run(host="0.0.0.0", port=args.port, debug=False)


if __name__ == "__main__":
    main()
