"""Restore original VAD values from bak3 backups and apply smart re-clamping.

Steps:
1. Load all bak3 files into a unified lookup dict keyed by (scenario_id, turn_index)
2. For each turn in current data:
   - If the emotion label was CHANGED (differs from bak3): keep current VAD
   - If the emotion label is unchanged: restore VAD from bak3
3. Special handling for gratitude turns with unchanged labels:
   - Restore original V, A from bak3
   - Apply D offset of -0.10 to shift toward new v1.5 anchor range
4. Smart re-clamp with margin=0.30 (only catch true outliers)
5. Validation: print stats and verify 0 OOB at margin=0.30

Usage:
    python scripts/restore_and_smart_clamp_vad.py --dry-run    # Preview only
    python scripts/restore_and_smart_clamp_vad.py              # Apply changes
"""
from __future__ import annotations

import argparse
import json
import statistics
from collections import Counter, defaultdict
from pathlib import Path

DATA = Path("data/emotion")
ANCHORS_PATH = DATA / "emotion_vad_anchors.json"

CURRENT_FILES = ["seed_train_final.jsonl", "seed_val.jsonl", "seed_test.jsonl"]
BAK3_FILES = ["seed_train_final.jsonl.bak3", "seed_val.jsonl.bak3", "seed_test.jsonl.bak3"]

# Gratitude D offset: v1.4 -> v1.5 shift
# v1.5 lowered gratitude D by ~0.10 across all levels to separate from joy
GRATITUDE_D_OFFSET = -0.10


def load_bak3_lookup() -> dict:
    """Load all bak3 files into a dict: (scenario_id, turn_index) -> {emotion, vad}."""
    lookup = {}
    total_turns = 0
    for fn in BAK3_FILES:
        path = DATA / fn
        if not path.exists():
            print(f"WARNING: {fn} does not exist!")
            continue
        with path.open(encoding="utf-8") as f:
            for line in f:
                r = json.loads(line)
                sid = r["scenario_id"]
                for i, t in enumerate(r["turns"]):
                    key = (sid, i)
                    lookup[key] = {
                        "emotion": t["emotion"],
                        "vad": list(t["vad"]),
                    }
                    total_turns += 1
    print(f"[bak3] Loaded {total_turns} turns from {len(BAK3_FILES)} backup files")
    return lookup


def load_current_data() -> dict:
    """Load current data files into a dict: filename -> list of scenario dicts."""
    data = {}
    for fn in CURRENT_FILES:
        path = DATA / fn
        scenarios = []
        with path.open(encoding="utf-8") as f:
            for line in f:
                scenarios.append(json.loads(line))
        data[fn] = scenarios
    return data


def load_anchor_bounds(margin: float) -> dict:
    """Load v1.5 anchor bounds with given margin."""
    anchors = json.loads(ANCHORS_PATH.read_text(encoding="utf-8"))["anchors"]
    bounds = {}
    for emo, entries in anchors.items():
        vs = [e["vad"][0] for e in entries]
        as_ = [e["vad"][1] for e in entries]
        ds = [e["vad"][2] for e in entries]
        bounds[emo] = {
            "v": (min(vs) - margin, max(vs) + margin),
            "a": (min(as_) - margin, max(as_) + margin),
            "d": (min(ds) - margin, max(ds) + margin),
        }
    return bounds


def clamp(val: float, lo: float, hi: float) -> float:
    return max(lo, min(hi, val))


def shift_gratitude_d(old_d: float) -> float:
    """Apply D offset to shift gratitude D toward new v1.5 anchor range.

    Simple offset preserves natural variation while shifting the distribution
    toward lower D values (humility/indebtedness).
    """
    return round(old_d + GRATITUDE_D_OFFSET, 3)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dry-run", action="store_true", help="Preview without modifying files")
    ap.add_argument("--clamp-margin", type=float, default=0.30, help="Smart clamp margin (default 0.30)")
    args = ap.parse_args()

    print("=" * 70)
    print("VAD Restoration & Smart Re-clamping")
    print("=" * 70)

    # Step 0: Load everything
    bak3 = load_bak3_lookup()
    current = load_current_data()
    bounds = load_anchor_bounds(args.clamp_margin)

    # Load label corrections to understand which turns changed
    corrections_path = DATA / "label_corrections.jsonl"
    label_changes = {}  # (scenario_id, turn_index) -> {old, new}
    if corrections_path.exists():
        with corrections_path.open(encoding="utf-8") as f:
            for line in f:
                c = json.loads(line)
                key = (c["scenario_id"], c["turn_index"])
                label_changes[key] = {"old": c["old"], "new": c["new"]}
    print(f"[corrections] Loaded {len(label_changes)} label corrections")

    # Stats tracking
    stats = {
        "total_turns": 0,
        "restored_from_bak3": 0,
        "kept_current_label_changed": 0,
        "kept_current_no_bak3": 0,
        "gratitude_d_remapped": 0,
        "smart_clamped": 0,
        "unchanged": 0,
    }
    restored_by_emo = Counter()
    clamp_by_emo = Counter()
    clamp_examples = []
    restore_examples = []
    gratitude_examples = []
    no_bak3_turns = []

    # Step 1-3: Process each turn
    for fn, scenarios in current.items():
        for scenario in scenarios:
            sid = scenario["scenario_id"]
            for ti, turn in enumerate(scenario["turns"]):
                stats["total_turns"] += 1
                key = (sid, ti)
                current_emo = turn["emotion"]
                current_vad = list(turn["vad"])

                # Check if this turn exists in bak3
                if key not in bak3:
                    stats["kept_current_no_bak3"] += 1
                    if len(no_bak3_turns) < 10:
                        no_bak3_turns.append(f"  {sid} t{ti} ({current_emo})")
                    continue

                bak3_entry = bak3[key]
                bak3_emo = bak3_entry["emotion"]
                bak3_vad = bak3_entry["vad"]

                # Was the label changed?
                label_was_changed = (current_emo != bak3_emo)

                if label_was_changed:
                    # Label changed -> keep current VAD (it was regenerated for new label)
                    stats["kept_current_label_changed"] += 1
                    continue

                # Label unchanged -> restore VAD from bak3
                old_vad = list(current_vad)

                if current_emo == "gratitude":
                    # Special handling: restore V, A from bak3; shift D by offset
                    new_v = bak3_vad[0]
                    new_a = bak3_vad[1]
                    new_d = shift_gratitude_d(bak3_vad[2])
                    turn["vad"] = [round(new_v, 3), round(new_a, 3), new_d]
                    stats["gratitude_d_remapped"] += 1
                    stats["restored_from_bak3"] += 1
                    restored_by_emo["gratitude"] += 1
                    if len(gratitude_examples) < 10:
                        gratitude_examples.append(
                            f"  {sid} t{ti}: D {bak3_vad[2]:+.3f} (bak3) -> {new_d:+.3f} (remapped), "
                            f"V {old_vad[0]:+.3f}->{new_v:+.3f}, A {old_vad[1]:+.3f}->{new_a:+.3f}"
                        )
                else:
                    # Non-gratitude: straight restore
                    turn["vad"] = [round(bak3_vad[0], 3), round(bak3_vad[1], 3), round(bak3_vad[2], 3)]
                    stats["restored_from_bak3"] += 1
                    restored_by_emo[current_emo] += 1

                    # Track if VAD actually changed
                    if old_vad != turn["vad"]:
                        if len(restore_examples) < 15:
                            restore_examples.append(
                                f"  {sid} t{ti} {current_emo}: "
                                f"[{old_vad[0]:+.3f},{old_vad[1]:+.3f},{old_vad[2]:+.3f}] -> "
                                f"[{turn['vad'][0]:+.3f},{turn['vad'][1]:+.3f},{turn['vad'][2]:+.3f}]"
                            )
                    else:
                        stats["unchanged"] += 1

    # Step 4: Smart re-clamping (margin=0.30)
    print(f"\n{'=' * 70}")
    print(f"Step 4: Smart Re-clamping (margin={args.clamp_margin})")
    print(f"{'=' * 70}")

    for fn, scenarios in current.items():
        file_clamped = 0
        for scenario in scenarios:
            for ti, turn in enumerate(scenario["turns"]):
                emo = turn["emotion"]
                if emo not in bounds:
                    continue
                v, a, d = turn["vad"]
                b = bounds[emo]
                nv = round(clamp(v, *b["v"]), 3)
                na = round(clamp(a, *b["a"]), 3)
                nd = round(clamp(d, *b["d"]), 3)
                if nv != round(v, 3) or na != round(a, 3) or nd != round(d, 3):
                    dims = []
                    if nv != round(v, 3):
                        dims.append(f"V:{v:+.3f}->{nv:+.3f}")
                    if na != round(a, 3):
                        dims.append(f"A:{a:+.3f}->{na:+.3f}")
                    if nd != round(d, 3):
                        dims.append(f"D:{d:+.3f}->{nd:+.3f}")
                    if len(clamp_examples) < 30:
                        clamp_examples.append(
                            f"  {scenario['scenario_id']} t{ti} {emo}: {' '.join(dims)}"
                        )
                    turn["vad"] = [nv, na, nd]
                    file_clamped += 1
                    stats["smart_clamped"] += 1
                    clamp_by_emo[emo] += 1
        print(f"  {fn}: {file_clamped} turns clamped")

    # Step 5: Write results (if not dry-run)
    if not args.dry_run:
        # Create bak4 backup before writing
        for fn in CURRENT_FILES:
            path = DATA / fn
            bak_path = DATA / (fn + ".bak4")
            bak_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8")
            print(f"  Backup: {fn} -> {fn}.bak4")

        for fn, scenarios in current.items():
            path = DATA / fn
            with path.open("w", encoding="utf-8") as f:
                for s in scenarios:
                    f.write(json.dumps(s, ensure_ascii=False) + "\n")
            print(f"  Written: {fn}")

    # Report
    print(f"\n{'=' * 70}")
    print("REPORT")
    print(f"{'=' * 70}")
    print(f"\nTotal turns processed: {stats['total_turns']}")
    print(f"VAD restored from bak3: {stats['restored_from_bak3']}")
    print(f"  - gratitude D remapped: {stats['gratitude_d_remapped']}")
    print(f"  - unchanged (already same): {stats['unchanged']}")
    print(f"Kept current VAD (label changed): {stats['kept_current_label_changed']}")
    print(f"Kept current VAD (no bak3 match): {stats['kept_current_no_bak3']}")
    print(f"Smart clamped (margin={args.clamp_margin}): {stats['smart_clamped']}")

    print(f"\nRestored by emotion:")
    for emo, n in restored_by_emo.most_common():
        print(f"  {emo:<12s} {n}")

    if restore_examples:
        print(f"\nSample VAD restorations (first 15 where value changed):")
        for ex in restore_examples:
            print(ex)

    if gratitude_examples:
        print(f"\nSample gratitude D remappings:")
        for ex in gratitude_examples:
            print(ex)

    if no_bak3_turns:
        print(f"\nTurns not found in bak3 (first 10):")
        for ex in no_bak3_turns:
            print(ex)

    if stats["smart_clamped"] > 0:
        print(f"\nSmart clamp by emotion (margin={args.clamp_margin}):")
        for emo, n in clamp_by_emo.most_common():
            print(f"  {emo:<12s} {n}")
        print(f"\nAll clamped examples:")
        for ex in clamp_examples:
            print(ex)
    else:
        print(f"\nNo turns needed clamping at margin={args.clamp_margin}")

    # Step 5 validation: verify 0 OOB
    print(f"\n{'=' * 70}")
    print(f"Validation: OOB check at margin={args.clamp_margin}")
    print(f"{'=' * 70}")
    oob_count = 0
    for fn, scenarios in current.items():
        for scenario in scenarios:
            for ti, turn in enumerate(scenario["turns"]):
                emo = turn["emotion"]
                if emo not in bounds:
                    continue
                v, a, d = turn["vad"]
                b = bounds[emo]
                if v < b["v"][0] - 0.001 or v > b["v"][1] + 0.001:
                    oob_count += 1
                if a < b["a"][0] - 0.001 or a > b["a"][1] + 0.001:
                    oob_count += 1
                if d < b["d"][0] - 0.001 or d > b["d"][1] + 0.001:
                    oob_count += 1
    print(f"OOB dimensions: {oob_count}")

    # VAD distribution stats per emotion
    print(f"\n{'=' * 70}")
    print("VAD Distribution Stats (per emotion, after restoration)")
    print(f"{'=' * 70}")
    emo_vads = defaultdict(lambda: {"v": [], "a": [], "d": []})
    for fn, scenarios in current.items():
        for scenario in scenarios:
            for turn in scenario["turns"]:
                emo = turn["emotion"]
                v, a, d = turn["vad"]
                emo_vads[emo]["v"].append(v)
                emo_vads[emo]["a"].append(a)
                emo_vads[emo]["d"].append(d)

    print(f"{'Emotion':<12s} {'N':>5s}  {'V_mean':>7s} {'V_std':>6s}  {'A_mean':>7s} {'A_std':>6s}  {'D_mean':>7s} {'D_std':>6s}")
    print("-" * 75)
    for emo in sorted(emo_vads.keys()):
        vals = emo_vads[emo]
        n = len(vals["v"])
        def mean_std(lst):
            m = statistics.mean(lst)
            s = statistics.pstdev(lst) if len(lst) > 1 else 0.0
            return m, s
        vm, vs_ = mean_std(vals["v"])
        am, as_ = mean_std(vals["a"])
        dm, ds = mean_std(vals["d"])
        print(f"{emo:<12s} {n:>5d}  {vm:>+7.3f} {vs_:>6.3f}  {am:>+7.3f} {as_:>6.3f}  {dm:>+7.3f} {ds:>6.3f}")

    if args.dry_run:
        print(f"\n[DRY RUN -- no files modified]")
    else:
        print(f"\nFiles updated successfully. Backups saved as .bak4")


if __name__ == "__main__":
    main()
