"""
Phase 2 Augmentation: Claude Haiku 4.5 paraphrase expansion.

Input:  seed_train.jsonl (1,664 train turns, emotion + VAD labeled)
Output: seed_augmented.jsonl (~10K paraphrases, same labels, varied surface form)

Strategy:
- Batch N turns per API call (context efficiency)
- Prompt caching on system + few-shot (90% discount on cached tokens)
- ThreadPoolExecutor for parallelism (~5 concurrent)
- Preserve emotion label exactly; VAD ±0.05 jitter per paraphrase
- Incremental write with fsync per batch (crash-safe resume)
"""

import argparse
import json
import os
import random
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path

import anthropic

DATA_DIR = Path(__file__).parent
INPUT_PATH = DATA_DIR / "seed_train.jsonl"
OUTPUT_PATH = DATA_DIR / "seed_augmented.jsonl"
STATE_PATH = DATA_DIR / "seed_augmented.state.json"

MODEL = "claude-haiku-4-5"
PARAPHRASES_PER_TURN = 5
TURNS_PER_REQUEST = 10
MAX_CONCURRENT = 5

SYSTEM_PROMPT = """너는 한국어 감정 데이터셋을 확장하는 언어 전문가야.

입력으로 (text, emotion, VAD)가 여러 개 주어진다. 각 입력마다 **5개의 한국어 paraphrase**를 만든다.

## 엄격한 규칙
1. **emotion label은 절대 바꾸지 말 것** — 동일 감정 유지.
2. **VAD 값은 ±0.05 이내로만 jitter 허용** — 의미가 크게 달라지면 안 됨.
3. **어휘뿐 아니라 어순/문장 구조까지 다양화** — 단순 동의어 치환 금지.
4. **스타일 마커 유지** (!!/!/ㅋㅋ/ㅠㅠ/~/어/요/입니다 등) — 감정 강도와 직결됨.
5. **한국어 자연스러움 최우선** — 직역체/어색한 번역투 금지.
6. **원문 text를 그대로 복제하지 말 것** — 모두 변형되어야 함.

## 감정별 핵심 마커 보존 원칙
- **anger**: 감탄사(아/야/하/와) + !! 개수로 강도 표현. 격식체 anger는 !!없이 어휘만 격해짐.
- **laughter**: ㅋㅋ/ㅎㅎ 최소 하나는 유지.
- **joy (L1 calm)**: 잔잔/편안/따뜻 톤. !!나 과도한 감탄 금지.
- **crying**: ㅠㅠ 또는 눈물/울음 관련 어휘 포함.
- **sadness**: 낮은 arousal 유지 (조용한 체념 톤).
- **gratitude/apology**: 격식체일 경우 입니다/드립니다 형식 유지.

## 출력 형식 (반드시 이 형식만)
JSON만 출력한다. 설명/주석/markdown 금지. 한 줄씩 JSON Lines 포맷:

{"idx": 0, "paraphrases": ["...", "...", "...", "...", "..."]}
{"idx": 1, "paraphrases": ["...", "...", "...", "...", "..."]}

idx는 입력 순서와 일치해야 한다."""

FEW_SHOT_USER = """## 입력
[
  {"idx": 0, "text": "와 진짜 미친 거 아니야?! 어떻게 이럴 수가 있어!!", "emotion": "anger", "vad": [-0.75, 0.85, 0.35]},
  {"idx": 1, "text": "아 진짜 고마워 ㅠㅠ 덕분에 살았어", "emotion": "gratitude", "vad": [0.68, 0.45, -0.1]},
  {"idx": 2, "text": "ㅋㅋㅋ 너 진짜 웃긴다 ㅋㅋ", "emotion": "laughter", "vad": [0.75, 0.7, 0.25]}
]"""

FEW_SHOT_ASSISTANT = """{"idx": 0, "paraphrases": ["야 이게 말이 돼?! 어떻게 이런 짓을!!", "아 진짜 어이없네!! 이런 일이 어떻게 벌어지냐고!!", "하 미쳤나 진짜?! 이걸 어떻게 받아들이라는 거야!!", "와 이건 좀 심한 거 아니야?! 진짜 너무하잖아!!", "아니 이게 가능한 일이야?! 말도 안 되는 소리지!!"]}
{"idx": 1, "paraphrases": ["정말 너무 고마워 ㅠㅠ 덕분에 진짜 한시름 놓았어", "아 진짜 너 아니었으면 어쩔 뻔했어 ㅠㅠ 고맙다", "고마워 ㅠㅠ 진심으로, 너 덕분에 살았다니까", "야 진짜 고마워서 눈물 날 뻔 ㅠㅠ 덕분이야", "너무 고마워 ㅠㅠ 진짜 네 덕분에 다 풀렸어"]}
{"idx": 2, "paraphrases": ["야 ㅋㅋㅋ 진짜 너 때문에 웃겨 죽겠어 ㅋㅋ", "ㅋㅋㅋㅋ 아 진짜 웃겨 너란 애는 ㅋㅋ", "ㅎㅎㅎ 뭔 소리야 ㅋㅋ 너무 웃기잖아", "아 ㅋㅋㅋ 그만해 진짜 배꼽 빠지겠네 ㅋㅋ", "ㅋㅋㅋ 너 개그맨 해도 되겠다 진짜 ㅋㅋ"]}"""


@dataclass
class TurnRef:
    scenario_id: str
    turn_idx: int
    setting: str
    style: str
    speaker: str
    text: str
    emotion: str
    vad: list


def load_train_turns(path: Path) -> list[TurnRef]:
    refs = []
    with path.open() as f:
        for line in f:
            s = json.loads(line)
            for i, t in enumerate(s["turns"]):
                refs.append(
                    TurnRef(
                        scenario_id=s["scenario_id"],
                        turn_idx=i,
                        setting=s.get("setting", ""),
                        style=s.get("style", ""),
                        speaker=t["speaker"],
                        text=t["text"],
                        emotion=t["emotion"],
                        vad=t["vad"],
                    )
                )
    return refs


def load_state() -> set[tuple[str, int]]:
    """Return set of (scenario_id, turn_idx) already processed."""
    if not STATE_PATH.exists():
        return set()
    with STATE_PATH.open() as f:
        data = json.load(f)
    return {tuple(x) for x in data.get("done", [])}


def save_state(done: set[tuple[str, int]]) -> None:
    tmp = STATE_PATH.with_suffix(".tmp")
    with tmp.open("w") as f:
        json.dump({"done": sorted([list(x) for x in done])}, f)
    tmp.replace(STATE_PATH)


def build_user_msg(batch: list[TurnRef]) -> str:
    items = [
        {"idx": i, "text": t.text, "emotion": t.emotion, "vad": t.vad}
        for i, t in enumerate(batch)
    ]
    return "## 입력\n" + json.dumps(items, ensure_ascii=False, indent=2)


def parse_response(text: str, batch_size: int) -> dict[int, list[str]]:
    """Parse JSON-Lines response into {idx: [paraphrases]}."""
    out: dict[int, list[str]] = {}
    for line in text.strip().split("\n"):
        line = line.strip()
        if not line or not line.startswith("{"):
            continue
        try:
            obj = json.loads(line)
        except json.JSONDecodeError:
            continue
        idx = obj.get("idx")
        paras = obj.get("paraphrases", [])
        if isinstance(idx, int) and 0 <= idx < batch_size and isinstance(paras, list):
            out[idx] = [p for p in paras if isinstance(p, str) and p.strip()]
    return out


def jitter_vad(vad: list[float], rng: random.Random) -> list[float]:
    return [round(max(-1.0, min(1.0, v + rng.uniform(-0.05, 0.05))), 3) for v in vad]


def call_api(
    client: anthropic.Anthropic,
    batch: list[TurnRef],
    attempt: int = 0,
) -> dict[int, list[str]]:
    try:
        resp = client.messages.create(
            model=MODEL,
            max_tokens=4096,
            system=[
                {
                    "type": "text",
                    "text": SYSTEM_PROMPT,
                    "cache_control": {"type": "ephemeral"},
                }
            ],
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": FEW_SHOT_USER,
                            "cache_control": {"type": "ephemeral"},
                        }
                    ],
                },
                {"role": "assistant", "content": FEW_SHOT_ASSISTANT},
                {"role": "user", "content": build_user_msg(batch)},
            ],
        )
        text = "".join(b.text for b in resp.content if b.type == "text")
        return parse_response(text, len(batch))
    except (anthropic.RateLimitError, anthropic.APIConnectionError, anthropic.InternalServerError) as e:
        if attempt >= 4:
            raise
        delay = 2 ** attempt + random.random()
        print(f"  retry {attempt + 1} after {delay:.1f}s ({type(e).__name__})", file=sys.stderr)
        time.sleep(delay)
        return call_api(client, batch, attempt + 1)


def process_batch(
    client: anthropic.Anthropic,
    batch: list[TurnRef],
    rng: random.Random,
) -> list[dict]:
    parsed = call_api(client, batch)
    out_rows = []
    for i, ref in enumerate(batch):
        paras = parsed.get(i, [])
        for k, text in enumerate(paras[:PARAPHRASES_PER_TURN]):
            out_rows.append(
                {
                    "source_scenario_id": ref.scenario_id,
                    "source_turn_idx": ref.turn_idx,
                    "paraphrase_idx": k,
                    "setting": ref.setting,
                    "style": ref.style,
                    "speaker": ref.speaker,
                    "text": text,
                    "emotion": ref.emotion,
                    "vad": jitter_vad(ref.vad, rng),
                }
            )
    return out_rows


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dry-run", action="store_true", help="Process 2 batches then stop")
    ap.add_argument("--limit", type=int, default=None, help="Max turns to process")
    ap.add_argument("--resume", action="store_true", help="Skip turns already in state file")
    args = ap.parse_args()

    if not os.environ.get("ANTHROPIC_API_KEY"):
        print("ERROR: ANTHROPIC_API_KEY not set", file=sys.stderr)
        sys.exit(1)

    client = anthropic.Anthropic()
    rng = random.Random(42)

    turns = load_train_turns(INPUT_PATH)
    done = load_state() if args.resume else set()
    pending = [t for t in turns if (t.scenario_id, t.turn_idx) not in done]

    if args.limit:
        pending = pending[: args.limit]
    if args.dry_run:
        pending = pending[: TURNS_PER_REQUEST * 2]

    print(f"Loaded {len(turns)} train turns ({len(done)} already done, {len(pending)} pending)")
    print(f"Model: {MODEL} | batch={TURNS_PER_REQUEST} | parallel={MAX_CONCURRENT} | {PARAPHRASES_PER_TURN}× per turn")

    batches = [pending[i : i + TURNS_PER_REQUEST] for i in range(0, len(pending), TURNS_PER_REQUEST)]
    print(f"Total batches: {len(batches)}")

    mode = "a" if args.resume and OUTPUT_PATH.exists() else "w"
    t0 = time.time()
    written = 0
    tokens_in = tokens_out = tokens_cached = 0

    with OUTPUT_PATH.open(mode) as fout, ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as pool:
        futures = {pool.submit(process_batch, client, b, rng): b for b in batches}
        for n, fut in enumerate(as_completed(futures), 1):
            batch = futures[fut]
            try:
                rows = fut.result()
            except Exception as e:
                print(f"  batch {n} FAILED: {type(e).__name__}: {e}", file=sys.stderr)
                continue
            for row in rows:
                fout.write(json.dumps(row, ensure_ascii=False) + "\n")
            fout.flush()
            os.fsync(fout.fileno())
            for ref in batch:
                done.add((ref.scenario_id, ref.turn_idx))
            save_state(done)
            written += len(rows)
            elapsed = time.time() - t0
            rate = n / elapsed if elapsed > 0 else 0
            eta = (len(batches) - n) / rate if rate > 0 else 0
            print(
                f"[{n:4d}/{len(batches)}] +{len(rows):3d} rows | total {written:5d} | "
                f"{elapsed:5.0f}s elapsed | eta {eta:5.0f}s"
            )

    print(f"\nDone. Wrote {written} augmented rows to {OUTPUT_PATH}")
    print(f"Elapsed: {time.time() - t0:.0f}s")


if __name__ == "__main__":
    main()
