"""
Apply a batch of paraphrases to seed_augmented.jsonl.

Usage:
    python3 apply_paraphrases.py <start_idx> <batch_json_file>

batch_json_file contains a list-of-lists of paraphrase strings, one entry per
consecutive train turn starting at <start_idx>:

    [
      ["para1", "para2", "para3", "para4", "para5"],  # turn start_idx
      ["para1", "para2", "para3", "para4", "para5"],  # turn start_idx + 1
      ...
    ]

The script looks up each source turn in seed_train.jsonl, jitters VAD ±0.05,
and appends enriched rows to seed_augmented.jsonl.
"""

import json
import random
import sys
from pathlib import Path

DATA_DIR = Path(__file__).parent
TRAIN_PATH = DATA_DIR / "seed_train.jsonl"
OUT_PATH = DATA_DIR / "seed_augmented.jsonl"

RNG = random.Random(42)


def load_flat_turns():
    out = []
    with TRAIN_PATH.open() as f:
        for line in f:
            s = json.loads(line)
            for i, t in enumerate(s["turns"]):
                out.append(
                    {
                        "scenario_id": s["scenario_id"],
                        "turn_idx": i,
                        "setting": s.get("setting", ""),
                        "style": s.get("style", ""),
                        "speaker": t["speaker"],
                        "text": t["text"],
                        "emotion": t["emotion"],
                        "vad": t["vad"],
                    }
                )
    return out


def jitter(v):
    return round(max(-1.0, min(1.0, v + RNG.uniform(-0.05, 0.05))), 3)


def main():
    if len(sys.argv) != 3:
        print("usage: apply_paraphrases.py <start_idx> <batch_json>", file=sys.stderr)
        sys.exit(1)

    start = int(sys.argv[1])
    batch_file = Path(sys.argv[2])
    with batch_file.open() as f:
        batch = json.load(f)

    turns = load_flat_turns()
    if start + len(batch) > len(turns):
        print(
            f"ERROR: batch goes to {start + len(batch)} but only {len(turns)} train turns exist",
            file=sys.stderr,
        )
        sys.exit(1)

    n_rows = 0
    with OUT_PATH.open("a") as fout:
        for i, paraphrases in enumerate(batch):
            src = turns[start + i]
            if not isinstance(paraphrases, list):
                print(f"WARN: turn {start+i} paraphrases not a list, skipped", file=sys.stderr)
                continue
            for k, text in enumerate(paraphrases):
                if not isinstance(text, str) or not text.strip():
                    continue
                row = {
                    "source_scenario_id": src["scenario_id"],
                    "source_turn_idx": src["turn_idx"],
                    "paraphrase_idx": k,
                    "setting": src["setting"],
                    "style": src["style"],
                    "speaker": src["speaker"],
                    "text": text,
                    "emotion": src["emotion"],
                    "vad": [jitter(v) for v in src["vad"]],
                }
                fout.write(json.dumps(row, ensure_ascii=False) + "\n")
                n_rows += 1

    total_source_done = start + len(batch)
    print(
        f"OK: +{n_rows} rows | source turns {start}..{start+len(batch)-1} done | "
        f"total progress {total_source_done}/{len(turns)} turns"
    )


if __name__ == "__main__":
    main()
