#!/usr/bin/env python3
"""
GPT-5.4-mini paraphrase augmentation for seed_train.jsonl.

Flow:
  1. Load seed_train.jsonl (847 scenarios after merge)
  2. **Skip 1-turn solo scenarios** (paraphrasing short self-talk yields near-duplicates)
  3. Sample scenarios: --stratified N (diverse slice), --random N, or --full
  4. For each scenario, call GPT-5.4-mini to generate 2 text paraphrases
  5. Reconstruct full rows: preserve speaker/emotion, apply VAD ±0.05 jitter
  6. Write each paraphrase as its own row to the output jsonl
  7. Run post-hoc validator: flag 비격식 존댓말 rows with 격식체 endings

Target: 634 multi-turn × 2 paraphrases + original turns ≈ 8K training turns, ~$1.20 total.

Usage (requires `pip install openai` and OPENAI_API_KEY env var):

  # Iteration 1 — stratified 10 for prompt review
  python3 augment_openai.py --stratified 10 --output seed_augmented_iter1.jsonl

  # Iteration 2 — different 10 after prompt revision (change --seed)
  python3 augment_openai.py --stratified 10 --seed 43 --output seed_augmented_iter2.jsonl

  # Canary — 50 random scenarios before full run
  python3 augment_openai.py --random 50 --output seed_augmented_canary.jsonl

  # Full batch — all 847 scenarios (~$0.40, ~15 min)
  python3 augment_openai.py --full --output seed_augmented.jsonl

  # Dry run — print the first payload without calling API
  python3 augment_openai.py --stratified 10 --dry-run --output /tmp/dry.jsonl
"""

import argparse
import json
import os
import random
import sys
import time
from collections import defaultdict
from pathlib import Path

DATA_DIR = Path(__file__).parent
INPUT_PATH = DATA_DIR / "seed_train.jsonl"

MODEL = "gpt-5.4-mini"
PARAPHRASES_PER_SCENARIO = 2

# =============================================================================
# ITERATION 1 PROMPT — edit the three strings below between iterations,
# then rerun with a new --output file. The prompt is intentionally in-source
# so diffs are trivial in git.
# =============================================================================

SYSTEM_PROMPT = """너는 한국어 감정 대화 데이터셋을 확장하는 언어 전문가야.

입력은 다중턴 대화 시나리오이다. 각 시나리오마다 **2개의 paraphrase 버전**을 만든다.
모든 버전은 원본과 동일한 turn 개수, speaker 순서, 감정 흐름을 유지해야 한다.

## 엄격한 규칙
1. **emotion label과 VAD는 입력에 표시되어 있고, paraphrase는 반드시 동일한 감정 강도로 작성한다.**
2. **turn 개수와 speaker 순서는 반드시 보존** — 2개 버전 모두 원본과 동일한 구조여야 함.
3. **대화 문맥 일관성 유지** — 각 turn은 이전 turn의 paraphrase 버전에 자연스럽게 이어져야 함.
4. **어휘/어순/문장 구조/수사법 모두 다양화** — 단순 동의어 치환 금지. 어순 변경, 표현 재구성, 수사법 전환 권장.
5. **스타일 마커 유지** (!!/ㅋㅋ/ㅎㅎ/ㅠㅠ/~/입니다/어/요) — 감정 강도와 직결됨. 원문의 마커 개수에 준하게 유지.
6. **원문 복제 금지** — 모든 turn이 의미 있게 변형되어야 함.
7. **style(격식체/반말/비격식 존댓말) 반드시 유지** — 입력과 동일한 어미 체계를 쓸 것.
   - **격식체**: 모든 문장이 -입니다/-습니다/-까/-시- 어미로 끝나야 함. 반말이나 -요 어미 금지.
   - **반말**: -어/-야/-네/-지/-거든 등 어미. 존댓말 금지.
   - **비격식 존댓말**: **모든 문장이 -요/-네요/-나요/-까요/-어요/-아요로 끝나야 함. -입니다/-습니다/-까 절대 금지**. 긴 메시지(3-4 turn)에서도 마지막 turn까지 -요 어미를 유지할 것. "이메일 톤"으로 drift 금지. 또한 **주어가 사물/추상명사일 때 -시- 주체 높임을 붙이지 말 것** (예: "도움이 되셨어요" ❌ → "도움이 됐어요" ✅).

## 감정별 핵심 보존
- **anger (격식체)**: !! 없이 어휘만 격해짐. "납득할 수 없습니다", "용납하기 어렵습니다"
- **anger (반말)**: 감탄사(아/야/와/하) + !! 개수 유지
- **laughter**: ㅋㅋ/ㅎㅎ 최소 하나 유지
- **crying**: ㅠㅠ 또는 눈물/울음 어휘 필수
- **joy (차분한/낮은 arousal)**: !!나 과도한 감탄 금지. 잔잔/편안/따뜻 톤 유지
- **gratitude/apology (격식체)**: 입니다/드립니다 형식 유지
- **shy/fluster**: 말끝 흐림, 주저하는 표현 유지
- **refusal (격식체)**: "어렵습니다/불가합니다" 계열 어휘 유지

## 출력 형식 (반드시 이 형식만)
JSON object 한 개만 출력. 설명/주석/markdown 금지.

{"source_scenario_id": "<입력의 scenario_id>", "paraphrases": [["v1 turn1", "v1 turn2", ...], ["v2 turn1", ...]]}

- paraphrases는 **정확히 길이 2인 리스트**.
- 각 원소(paraphrase 버전)는 **원본 turns 개수와 동일한 길이의 문자열 리스트**.
- speaker/emotion/vad는 출력하지 않음 (원본 메타데이터를 Python에서 재사용).
- source_scenario_id는 입력의 scenario_id와 정확히 일치해야 함."""


FEW_SHOT_USER = """## 입력
{"scenario_id": "example_01", "style": "반말", "turns": [{"speaker": "A", "text": "야 진짜 미친 거 아니야?! 어떻게 이럴 수가 있어!!", "emotion": "anger", "vad": [-0.75, 0.85, 0.35]}, {"speaker": "B", "text": "아 진정해 일단 얘기부터 들어봐", "emotion": "neutral", "vad": [0.05, 0.2, 0.3]}, {"speaker": "A", "text": "아니 이걸 어떻게 그냥 넘어가!!", "emotion": "anger", "vad": [-0.7, 0.8, 0.3]}]}"""


FEW_SHOT_ASSISTANT = """{"source_scenario_id": "example_01", "paraphrases": [["야 이게 말이 돼?! 어떻게 이런 짓을!!", "진정해 일단 상황부터 설명해봐", "아니 이걸 넘어갈 수가 있어?!"], ["아 어이없어 진짜!! 이런 게 가능하냐고!!", "좀 숨 고르고 얘기부터 해봐", "야 이걸 그냥 넘기라고?!"]]}"""


# =============================================================================
# End of prompt section. Code below handles batching, API calls, and output.
# =============================================================================


def load_scenarios(path: Path) -> list[dict]:
    return [json.loads(l) for l in path.open()]


def filter_multiturn(scenarios: list[dict]) -> list[dict]:
    """Skip 1-turn scenarios — paraphrasing short self-talk yields near-duplicates.
    Multi-turn paraphrases get much richer augmentation signal."""
    return [s for s in scenarios if len(s["turns"]) >= 2]


def stratified_sample(scenarios: list[dict], n: int, seed: int) -> list[dict]:
    """Pick n scenarios covering diverse (category, style) combinations."""
    rng = random.Random(seed)
    buckets = defaultdict(list)
    for s in scenarios:
        sid = s["scenario_id"]
        cat = "daily" if "daily" in sid else "long" if "long" in sid else "solo"
        buckets[(cat, s["style"])].append(s)

    keys = sorted(buckets.keys())
    for k in keys:
        rng.shuffle(buckets[k])

    picks = []
    idx = 0
    while len(picks) < n and idx < n * 10:
        key = keys[idx % len(keys)]
        if buckets[key]:
            picks.append(buckets[key].pop())
        idx += 1
    return picks[:n]


def random_sample(scenarios: list[dict], n: int, seed: int) -> list[dict]:
    rng = random.Random(seed)
    return rng.sample(scenarios, min(n, len(scenarios)))


def build_user_message(scenario: dict) -> str:
    payload = {
        "scenario_id": scenario["scenario_id"],
        "style": scenario["style"],
        "turns": [
            {
                "speaker": t["speaker"],
                "text": t["text"],
                "emotion": t["emotion"],
                "vad": t["vad"],
            }
            for t in scenario["turns"]
        ],
    }
    return "## 입력\n" + json.dumps(payload, ensure_ascii=False)


def parse_response(text: str, expected_id: str, expected_turns: int) -> tuple[list[list[str]] | None, str]:
    """Returns (paraphrases, error_message). On success, error_message is ''."""
    try:
        obj = json.loads(text)
    except json.JSONDecodeError as e:
        return None, f"invalid JSON: {e}"
    if obj.get("source_scenario_id") != expected_id:
        return None, f"source_scenario_id mismatch: got {obj.get('source_scenario_id')!r}, expected {expected_id!r}"
    paras = obj.get("paraphrases")
    if not isinstance(paras, list):
        return None, f"paraphrases is not a list (got {type(paras).__name__})"
    if len(paras) != PARAPHRASES_PER_SCENARIO:
        return None, f"wrong paraphrase count: got {len(paras)}, expected {PARAPHRASES_PER_SCENARIO}"
    for i, p in enumerate(paras):
        if not isinstance(p, list):
            return None, f"paraphrase[{i}] is not a list"
        if len(p) != expected_turns:
            return None, f"paraphrase[{i}] wrong turn count: got {len(p)}, expected {expected_turns}"
        if not all(isinstance(t, str) and t.strip() for t in p):
            return None, f"paraphrase[{i}] has empty or non-string turn"
    return paras, ""


def call_api(client, scenario: dict, verbose: bool = False, attempt: int = 0) -> list[list[str]] | None:
    from openai import APIConnectionError, InternalServerError, RateLimitError

    try:
        resp = client.chat.completions.create(
            model=MODEL,
            max_completion_tokens=2048,
            response_format={"type": "json_object"},
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": FEW_SHOT_USER},
                {"role": "assistant", "content": FEW_SHOT_ASSISTANT},
                {"role": "user", "content": build_user_message(scenario)},
            ],
        )
        text = resp.choices[0].message.content or ""
        paras, err = parse_response(text, scenario["scenario_id"], len(scenario["turns"]))
        if paras is None and verbose:
            print(f"  PARSE ERROR: {err}", file=sys.stderr)
            print(f"  RAW RESPONSE:\n{text}\n", file=sys.stderr)
        return paras
    except (RateLimitError, APIConnectionError, InternalServerError) as e:
        if attempt >= 3:
            raise
        delay = 2 ** attempt + random.random()
        print(f"  retry {attempt + 1} after {delay:.1f}s ({type(e).__name__})", file=sys.stderr)
        time.sleep(delay)
        return call_api(client, scenario, verbose, attempt + 1)


def jitter_vad(vad: list[float], rng: random.Random) -> list[float]:
    return [round(max(-1.0, min(1.0, v + rng.uniform(-0.05, 0.05))), 3) for v in vad]


# 격식체 endings that should NEVER appear in 비격식 존댓말 rows
FORMAL_ENDING_PATTERNS = [
    "습니다", "습니까", "입니다", "입니까", "겠습니다", "드립니다", "시오",
]


def validate_output(path: Path) -> None:
    """Post-run sanity check. Flags 비격식 존댓말 rows that drifted to 격식체."""
    issues = []
    with path.open() as f:
        for line in f:
            row = json.loads(line)
            if row.get("style") != "비격식 존댓말":
                continue
            for t in row["turns"]:
                text = t["text"].rstrip("?!.\u2026 ")
                for pat in FORMAL_ENDING_PATTERNS:
                    if text.endswith(pat):
                        issues.append((row["scenario_id"], row["paraphrase_idx"], t["text"]))
                        break

    if issues:
        print(f"\n⚠  {len(issues)} 비격식 존댓말 rows drifted to 격식체 endings:", file=sys.stderr)
        for sid, idx, text in issues[:10]:
            print(f"    {sid} (p{idx}): {text}", file=sys.stderr)
        if len(issues) > 10:
            print(f"    ... and {len(issues) - 10} more", file=sys.stderr)
        print(f"\n  Consider regenerating flagged rows or manually fixing endings.", file=sys.stderr)
    else:
        print(f"\n✓ Register validation passed: 0 register drift in 비격식 존댓말 rows")


def main():
    ap = argparse.ArgumentParser()
    grp = ap.add_mutually_exclusive_group(required=True)
    grp.add_argument("--stratified", type=int, metavar="N", help="Stratified sample of N scenarios")
    grp.add_argument("--random", type=int, metavar="N", help="Random sample of N scenarios")
    grp.add_argument("--full", action="store_true", help="Process all 847 scenarios")
    grp.add_argument("--scenario-id", type=str, metavar="ID", help="Process a single specific scenario by ID")
    ap.add_argument("--seed", type=int, default=42, help="RNG seed for sampling and VAD jitter")
    ap.add_argument("--output", type=Path, required=True, help="Output jsonl path")
    ap.add_argument("--dry-run", action="store_true", help="Print first request payload without calling API")
    ap.add_argument("--verbose", action="store_true", help="Print raw API response on parse failure")
    args = ap.parse_args()

    scenarios = load_scenarios(INPUT_PATH)
    n_total = len(scenarios)
    # Skip 1-turn scenarios (--scenario-id bypasses this for targeted debugging)
    if not args.scenario_id:
        scenarios = filter_multiturn(scenarios)
    n_multiturn = len(scenarios)

    if args.stratified:
        pending = stratified_sample(scenarios, args.stratified, args.seed)
    elif args.random:
        pending = random_sample(scenarios, args.random, args.seed)
    elif args.scenario_id:
        pending = [s for s in scenarios if s["scenario_id"] == args.scenario_id]
        if not pending:
            print(f"ERROR: scenario_id {args.scenario_id!r} not found", file=sys.stderr)
            sys.exit(1)
    else:
        pending = scenarios

    print(f"Loaded {n_total} scenarios ({n_total - n_multiturn} 1-turn skipped), will process {len(pending)}")
    print(f"Model: {MODEL} | {PARAPHRASES_PER_SCENARIO} paraphrases per scenario")
    print(f"Output: {args.output}")

    if args.dry_run:
        print("\n--- DRY RUN: first request ---")
        print("SYSTEM:")
        print(SYSTEM_PROMPT)
        print("\nUSER (few-shot):")
        print(FEW_SHOT_USER)
        print("\nASSISTANT (few-shot):")
        print(FEW_SHOT_ASSISTANT)
        print("\nUSER (actual):")
        print(build_user_message(pending[0]))
        print(f"\n(Would send {len(pending)} such requests)")
        return

    if not os.environ.get("OPENAI_API_KEY"):
        print("ERROR: OPENAI_API_KEY not set", file=sys.stderr)
        sys.exit(1)

    try:
        from openai import OpenAI
    except ImportError:
        print("ERROR: openai package not installed. Run: pip install openai", file=sys.stderr)
        sys.exit(1)

    client = OpenAI()
    rng = random.Random(args.seed)

    written = 0
    failed = 0
    t0 = time.time()

    with args.output.open("w") as fout:
        for i, src in enumerate(pending, 1):
            try:
                paras = call_api(client, src, verbose=args.verbose)
            except Exception as e:
                print(f"[{i}/{len(pending)}] {src['scenario_id']} FAILED: {type(e).__name__}: {e}", file=sys.stderr)
                failed += 1
                continue

            if not paras:
                print(f"[{i}/{len(pending)}] {src['scenario_id']} PARSE FAILED", file=sys.stderr)
                failed += 1
                continue

            for k, turn_texts in enumerate(paras):
                row = {
                    "scenario_id": f"{src['scenario_id']}_p{k}",
                    "source_scenario_id": src["scenario_id"],
                    "paraphrase_idx": k,
                    "setting": src.get("setting", ""),
                    "style": src["style"],
                    "turns": [
                        {
                            "speaker": t["speaker"],
                            "text": paraphrased,
                            "emotion": t["emotion"],
                            "vad": jitter_vad(t["vad"], rng),
                        }
                        for t, paraphrased in zip(src["turns"], turn_texts)
                    ],
                }
                fout.write(json.dumps(row, ensure_ascii=False) + "\n")
                written += 1
            fout.flush()
            print(f"[{i}/{len(pending)}] {src['scenario_id']} ok (+{PARAPHRASES_PER_SCENARIO} paraphrases)")

    elapsed = time.time() - t0
    print(f"\nDone in {elapsed:.0f}s: {written} paraphrase rows written, {failed} scenarios failed")
    print(f"Output: {args.output}")

    validate_output(args.output)


if __name__ == "__main__":
    main()
