"""Clamp VAD values to each emotion's anchor bounding box.

For each turn, if any V/A/D dimension falls outside the emotion's anchor
range (L1-L5 min/max + margin), clamp it to the boundary. This catches
manually-assigned VAD values that contradict the anchor system.

Usage:
    python scripts/clamp_vad.py [--margin 0.10] [--dry-run]
"""
from __future__ import annotations

import argparse
import json
from collections import Counter
from pathlib import Path

DATA = Path("data/emotion")
ANCHORS_PATH = DATA / "emotion_vad_anchors.json"
FILES = ["seed_train_final.jsonl", "seed_val.jsonl", "seed_test.jsonl"]


def load_anchor_bounds(margin: float) -> dict:
    anchors = json.loads(ANCHORS_PATH.read_text(encoding="utf-8"))["anchors"]
    bounds = {}
    for emo, entries in anchors.items():
        vs = [e["vad"][0] for e in entries]
        as_ = [e["vad"][1] for e in entries]
        ds = [e["vad"][2] for e in entries]
        bounds[emo] = {
            "v": (min(vs) - margin, max(vs) + margin),
            "a": (min(as_) - margin, max(as_) + margin),
            "d": (min(ds) - margin, max(ds) + margin),
        }
    return bounds


def clamp(val: float, lo: float, hi: float) -> float:
    return max(lo, min(hi, val))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--margin", type=float, default=0.10)
    ap.add_argument("--dry-run", action="store_true")
    args = ap.parse_args()

    bounds = load_anchor_bounds(args.margin)
    total_clamped = 0
    total_turns = 0
    clamped_by_emo = Counter()
    examples = []

    for fn in FILES:
        path = DATA / fn
        scenarios = []
        file_clamped = 0
        with path.open(encoding="utf-8") as f:
            for line in f:
                r = json.loads(line)
                for i, t in enumerate(r["turns"]):
                    emo = t["emotion"]
                    if emo not in bounds:
                        continue
                    total_turns += 1
                    v, a, d = t["vad"]
                    b = bounds[emo]
                    nv = round(clamp(v, *b["v"]), 3)
                    na = round(clamp(a, *b["a"]), 3)
                    nd = round(clamp(d, *b["d"]), 3)
                    if nv != round(v, 3) or na != round(a, 3) or nd != round(d, 3):
                        old = [round(v, 3), round(a, 3), round(d, 3)]
                        new = [nv, na, nd]
                        dims = []
                        if nv != round(v, 3):
                            dims.append(f"V:{v:+.2f}→{nv:+.2f}")
                        if na != round(a, 3):
                            dims.append(f"A:{a:+.2f}→{na:+.2f}")
                        if nd != round(d, 3):
                            dims.append(f"D:{d:+.2f}→{nd:+.2f}")
                        if len(examples) < 20:
                            examples.append(
                                f"  {r['scenario_id']} t{i} {emo}: {' '.join(dims)}"
                            )
                        t["vad"] = new
                        file_clamped += 1
                        total_clamped += 1
                        clamped_by_emo[emo] += 1
                scenarios.append(r)

        if not args.dry_run:
            with path.open("w", encoding="utf-8") as f:
                for s in scenarios:
                    f.write(json.dumps(s, ensure_ascii=False) + "\n")
        print(f"{fn}: {file_clamped} turns clamped")

    print(f"\nTotal: {total_clamped}/{total_turns} turns clamped "
          f"(margin={args.margin})")
    print("\nBy emotion:")
    for emo, n in clamped_by_emo.most_common():
        print(f"  {emo:<12s} {n}")
    print("\nFirst 20 examples:")
    for ex in examples:
        print(ex)
    if args.dry_run:
        print("\n[DRY RUN — no files modified]")


if __name__ == "__main__":
    main()
