from __future__ import annotations

import argparse
import json
import re
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

REPO_ROOT = Path(__file__).resolve().parents[2]
ANCHORS_PATH = REPO_ROOT / "data" / "emotion" / "emotion_vad_anchors.json"
LABELS_PATH = REPO_ROOT / "data" / "emotion" / "emotion_labels.json"
KOTE_TRAIN = Path("/dataset/KOTE/train.tsv")

EMOTION_NAMES = [
    "neutral", "joy", "laughter", "excitement", "agreement", "gratitude",
    "sadness", "crying", "sulk", "apology", "struggle",
    "anger", "refusal", "surprise", "fluster", "shy",
]
EMOTION_SET = set(EMOTION_NAMES)

SYSTEM_PROMPT = (
    "You are a Korean emotion classifier for an AI avatar facial animation system. "
    "You classify Korean text by what facial expression the speaker would have if they spoke it out loud.\n\n"
    "Available emotion classes (choose exactly ONE):\n"
    "- neutral (중립): calm, no strong emotion — factual statements, neutral questions\n"
    "- joy (기쁨): soft smile, warm affection — pleasant realizations\n"
    "- laughter (웃음): actually laughing, 'ㅋㅋㅋ' reactions, mouth open smile\n"
    "- excitement (흥분): excited, eager, enthusiastic, high-energy anticipation\n"
    "- agreement (동의): agreeing, '네', '응', '맞아' — calm yielding nod\n"
    "- gratitude (감사): thankful, '고마워', '감사' — direct thanks to another person\n"
    "- sadness (슬픔): downcast face, slight frown, low energy\n"
    "- crying (울음): actually crying, tears, 'ㅠㅠ' when sobbing\n"
    "- sulk (삐침): pouty, offended, petty complaint — like a child pouting\n"
    "- apology (사과): apologetic, regretful, '미안', contrite\n"
    "- struggle (고민): troubled, deep thinking, internal conflict — 'how do I...'\n"
    "- anger (분노): angry, furious, rage — red face, scowl\n"
    "- refusal (거절): firm refusal, 'no', 'can't' — saying no to a request\n"
    "- surprise (놀람): 'wow really?!' — expressive pleasant surprise, raised brows\n"
    "- fluster (당황): caught off-guard, mildly uncomfortable, not in control, 'what?!'\n"
    "- shy (수줍음): shy, bashful, avoiding eye contact, soft smile\n\n"
    "TASK:\n"
    "1. Read the Korean text\n"
    "2. Imagine a person speaking it out loud in conversation\n"
    "3. In ONE short sentence, describe what their face would look like\n"
    "4. Then output your final choice as compact JSON\n\n"
    "IMPORTANT:\n"
    "- Classify by the FACE expression, not by the topic\n"
    "- Political rants, praise, and commentary: focus on the emotional tone\n"
    "- Factual statements, questions, and product descriptions without tone = neutral\n"
    "- Sarcasm is often actually anger/sulk, not positive\n"
    "- Be conservative with confidence: 0.9+ only for very clear cases\n\n"
    "Output format (exactly this structure):\n"
    "Face: <one sentence description>\n"
    '{"emotion": "<one of 16>", "intensity": <1-5>, "confidence": <0.0-1.0>}'
)


def load_anchors(level_fallback: int = 3) -> Dict[Tuple[str, int], List[float]]:
    data = json.loads(ANCHORS_PATH.read_text(encoding="utf-8"))["anchors"]
    out: Dict[Tuple[str, int], List[float]] = {}
    for name, entries in data.items():
        for entry in entries:
            out[(name, int(entry["level"]))] = list(entry["vad"])
    return out


def vad_for(emotion: str, intensity: int, anchors: Dict[Tuple[str, int], List[float]]) -> List[float]:
    intensity = max(1, min(5, intensity))
    key = (emotion, intensity)
    if key in anchors:
        return anchors[key]
    if (emotion, 3) in anchors:
        return anchors[(emotion, 3)]
    return [0.0, 0.0, 0.0]


def load_kote(path: Path, limit: int = 0, min_chars: int = 5, max_chars: int = 300) -> List[Tuple[str, str]]:
    items: List[Tuple[str, str]] = []
    with path.open(encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) < 2:
                continue
            kote_id, text = parts[0], parts[1].strip()
            if not (min_chars <= len(text) <= max_chars):
                continue
            items.append((kote_id, text))
            if limit and len(items) >= limit:
                break
    return items


_JSON_RE = re.compile(r"\{[^{}]*\}")


def parse_response(text: str) -> Optional[Dict]:
    if not text:
        return None
    # Take the last JSON-looking chunk (model may echo prompt)
    candidates = _JSON_RE.findall(text)
    if not candidates:
        return None
    for raw in reversed(candidates):
        try:
            obj = json.loads(raw)
        except json.JSONDecodeError:
            continue
        if "emotion" not in obj:
            continue
        emo = str(obj.get("emotion", "")).strip().lower()
        if emo not in EMOTION_SET:
            continue
        try:
            intensity = int(obj.get("intensity", 3))
        except (TypeError, ValueError):
            intensity = 3
        try:
            confidence = float(obj.get("confidence", 0.5))
        except (TypeError, ValueError):
            confidence = 0.5
        return {
            "emotion": emo,
            "intensity": max(1, min(5, intensity)),
            "confidence": max(0.0, min(1.0, confidence)),
        }
    return None


def build_prompts(tokenizer, texts: List[str]) -> List[str]:
    prompts = []
    for t in texts:
        msgs = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"Text: {t}"},
        ]
        prompts.append(
            tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        )
    return prompts


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct")
    ap.add_argument("--input", type=Path, default=KOTE_TRAIN)
    ap.add_argument("--output", type=Path, default=REPO_ROOT / "data" / "emotion" / "kote_relabeled.jsonl")
    ap.add_argument("--raw_output", type=Path, default=REPO_ROOT / "data" / "emotion" / "kote_relabeled_raw.jsonl")
    ap.add_argument("--batch_size", type=int, default=16)
    ap.add_argument("--max_new_tokens", type=int, default=150)
    ap.add_argument("--limit", type=int, default=0, help="Cap KOTE samples (0=all)")
    ap.add_argument("--min_confidence", type=float, default=0.7)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    return ap.parse_args()


def main() -> None:
    args = parse_args()

    print(f"[load] tokenizer from {args.model}", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"[load] model from {args.model}", flush=True)
    model = AutoModelForCausalLM.from_pretrained(args.model, dtype=torch.bfloat16)
    model = model.to(args.device)
    model.eval()

    anchors = load_anchors()
    print(f"[data] loading KOTE from {args.input}", flush=True)
    items = load_kote(args.input, limit=args.limit)
    print(f"[data] total samples: {len(items):,}", flush=True)

    args.output.parent.mkdir(parents=True, exist_ok=True)
    raw_f = args.raw_output.open("w", encoding="utf-8")
    out_f = args.output.open("w", encoding="utf-8")
    kept = 0
    parse_fail = 0
    low_conf = 0

    t0 = time.time()
    with torch.inference_mode():
        for start in range(0, len(items), args.batch_size):
            batch = items[start : start + args.batch_size]
            ids, texts = zip(*batch)
            prompts = build_prompts(tokenizer, list(texts))
            enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024)
            enc = {k: v.to(args.device) for k, v in enc.items()}
            gen = model.generate(
                **enc,
                max_new_tokens=args.max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
            out_tokens = gen[:, enc["input_ids"].shape[1]:]
            decoded = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)

            for kote_id, text, raw_resp in zip(ids, texts, decoded):
                parsed = parse_response(raw_resp)
                raw_row = {
                    "kote_id": kote_id,
                    "text": text,
                    "response": raw_resp,
                    "parsed": parsed,
                }
                raw_f.write(json.dumps(raw_row, ensure_ascii=False) + "\n")
                if parsed is None:
                    parse_fail += 1
                    continue
                if parsed["confidence"] < args.min_confidence:
                    low_conf += 1
                    continue
                vad = vad_for(parsed["emotion"], parsed["intensity"], anchors)
                scenario = {
                    "scenario_id": f"kote_{kote_id}",
                    "source_scenario_id": f"kote_{kote_id}",
                    "paraphrase_idx": -1,
                    "setting": "KOTE online comment (relabeled)",
                    "style": "비격식",
                    "turns": [
                        {
                            "speaker": "A",
                            "text": text,
                            "emotion": parsed["emotion"],
                            "vad": vad,
                            "intensity": parsed["intensity"],
                            "confidence": parsed["confidence"],
                        }
                    ],
                }
                out_f.write(json.dumps(scenario, ensure_ascii=False) + "\n")
                kept += 1
            raw_f.flush()
            out_f.flush()

            done = start + len(batch)
            if done % (args.batch_size * 10) == 0 or done == len(items):
                elapsed = time.time() - t0
                rate = done / max(elapsed, 1e-6)
                eta = (len(items) - done) / max(rate, 1e-6)
                print(
                    f"[progress] {done}/{len(items)} | kept={kept} parse_fail={parse_fail} "
                    f"low_conf={low_conf} | {elapsed:.0f}s ({rate:.1f}/s) eta {eta:.0f}s",
                    flush=True,
                )

    raw_f.close()
    out_f.close()

    print(
        f"[done] total={len(items):,} kept={kept:,} parse_fail={parse_fail:,} "
        f"low_conf={low_conf:,}",
        flush=True,
    )
    print(f"[done] filtered output → {args.output}", flush=True)
    print(f"[done] raw responses → {args.raw_output}", flush=True)


if __name__ == "__main__":
    main()
