"""Stratified 80/10/10 re-split of emotion dialogue scenarios.

Why this exists:
  The current splits are scenario-disjoint (no leakage) but NOT stratified.
  Train is 97% multi-turn; val is 69% solo. Neutral class is 13.7% in train
  but 6% in val. Macro_f1 measurements become unreliable because test-time
  distribution differs systematically from training distribution.

Stratification strategy:
  Primary key:   scenario type bucket (solo / short-dialogue / long-dialogue)
  Secondary:     multilabel emotion signature (which of the 16 emotions appear)
  Unit of split: SOURCE GROUP (original scenario + all its paraphrases stay together)

Augmentation rule:
  Train:      original + all paraphrased variants
  Val/Test:   ORIGINAL ONLY  (paraphrases inflate val/test and leak augmentation signal)

Run:
  python scripts/stratified_resplit.py
"""
from __future__ import annotations

import json
import random
import shutil
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple

DATA = Path("data/emotion")
INPUT_FILES = ["seed_train_final.jsonl", "seed_val.jsonl", "seed_test.jsonl"]
OUTPUT_FILES = {
    "train": "seed_train_final.jsonl",
    "val": "seed_val.jsonl",
    "test": "seed_test.jsonl",
}
RATIOS = {"train": 0.80, "val": 0.10, "test": 0.10}
SEED = 42
BACKUP_SUFFIX = ".bak3"


def bucket_of(scenario: dict) -> str:
    n = sum(1 for t in scenario["turns"] if t["text"].strip())
    if n <= 1:
        return "solo"
    if n <= 3:
        return "short"
    return "long"


def emotions_of(scenario: dict) -> Set[str]:
    return {t["emotion"] for t in scenario["turns"] if t["text"].strip()}


def emotion_counts_of(scenario: dict) -> Counter:
    """Per-turn emotion counts (e.g. {'joy': 2, 'anger': 1})."""
    c: Counter = Counter()
    for t in scenario["turns"]:
        if t["text"].strip():
            c[t["emotion"]] += 1
    return c


def turn_count(scenario: dict) -> int:
    return sum(1 for t in scenario["turns"] if t["text"].strip())


def is_original(scenario: dict) -> bool:
    """paraphrase_idx == -1 (train convention) or None (val/test convention) => original."""
    idx = scenario.get("paraphrase_idx")
    return idx is None or idx == -1


def load_all() -> List[dict]:
    out = []
    for fn in INPUT_FILES:
        path = DATA / fn
        with path.open(encoding="utf-8") as f:
            for line in f:
                out.append(json.loads(line))
    return out


def group_by_source(scenarios: List[dict]) -> Dict[str, dict]:
    groups: Dict[str, dict] = defaultdict(lambda: {"original": None, "augmentations": []})
    for s in scenarios:
        src = s.get("source_scenario_id") or s["scenario_id"]
        if is_original(s):
            if groups[src]["original"] is not None:
                # Unexpected duplicate: keep first original, treat rest as augmentations
                groups[src]["augmentations"].append(s)
            else:
                groups[src]["original"] = s
        else:
            groups[src]["augmentations"].append(s)
    # Guard: every source must have an original
    missing = [k for k, g in groups.items() if g["original"] is None]
    if missing:
        raise RuntimeError(f"{len(missing)} source groups have no original: {missing[:5]}")
    return dict(groups)


def annotate_groups(groups: Dict[str, dict]) -> None:
    for src, g in groups.items():
        g["bucket"] = bucket_of(g["original"])
        g["emotion_set"] = emotions_of(g["original"])
        g["emotion_counts"] = emotion_counts_of(g["original"])


def iterative_stratify(
    items: List[Tuple[str, dict]],
    ratios: Dict[str, float],
) -> Dict[str, str]:
    """Greedy count-based stratification (Sechidis et al. 2011, weighted variant).

    Uses per-turn emotion COUNTS (not just presence) so a 4-joy scenario is
    weighted 4x a single-joy scenario. This prevents rare emotions like anger
    from vanishing into one split when they appear as minority turns inside
    otherwise-dominant scenarios.

    items:   list of (source_id, group_dict) within a single bucket
    ratios:  target fraction per split

    Returns dict: source_id -> split_name
    """
    n = len(items)
    if n == 0:
        return {}
    target = {s: int(round(n * r)) for s, r in ratios.items()}
    diff = n - sum(target.values())
    if diff != 0:
        target["train"] += diff

    # Aggregate turn-level emotion counts across all items in this bucket
    total_emo_counts: Counter = Counter()
    for _, g in items:
        total_emo_counts.update(g["emotion_counts"])

    # Rarest-first: scenarios containing the rarest emotion are placed first
    def rarity_score(g: dict) -> Tuple[int, int]:
        if not g["emotion_counts"]:
            return (10**9, 0)
        rarest = min(total_emo_counts[e] for e in g["emotion_counts"])
        # Secondary: scenarios with more rare-emotion turns placed earlier
        rare_turns = sum(c for e, c in g["emotion_counts"].items() if total_emo_counts[e] == rarest)
        return (rarest, -rare_turns)

    ordered = sorted(items, key=lambda sg: (rarity_score(sg[1]), sg[0]))

    emo_count = {s: Counter() for s in ratios}
    assigned_count = {s: 0 for s in ratios}
    assignment: Dict[str, str] = {}

    for src, g in ordered:
        best_split = None
        best_score = float("-inf")
        for split, r in ratios.items():
            if assigned_count[split] >= target[split]:
                continue
            # Deficit weighted by turn count per emotion
            deficit = 0.0
            for e, c in g["emotion_counts"].items():
                target_e = total_emo_counts[e] * r
                current_e = emo_count[split][e]
                room = max(0.0, target_e - current_e)
                # Each turn contributes one unit of "want" up to room
                deficit += min(float(c), room)
            # Small tiebreak on remaining capacity to fill splits evenly
            capacity_left = target[split] - assigned_count[split]
            score = deficit + 1e-6 * capacity_left
            if score > best_score:
                best_score = score
                best_split = split

        if best_split is None:
            best_split = max(ratios.keys(), key=lambda s: target[s] - assigned_count[s])

        assignment[src] = best_split
        assigned_count[best_split] += 1
        for e, c in g["emotion_counts"].items():
            emo_count[best_split][e] += c

    return assignment


def report_distribution(scenarios_per_split: Dict[str, List[dict]]) -> None:
    print("\n" + "=" * 72)
    print("POST-SPLIT DISTRIBUTION")
    print("=" * 72)

    all_emos: Set[str] = set()
    for scens in scenarios_per_split.values():
        for s in scens:
            all_emos |= emotions_of(s)
    all_emos_sorted = sorted(all_emos)

    # Scenario-level bucket breakdown
    print(f"\n{'split':<8}{'scenarios':>12}{'turns':>10}{'solo':>8}{'short':>8}{'long':>8}")
    for split in ("train", "val", "test"):
        scens = scenarios_per_split[split]
        bucket_dist = Counter(bucket_of(s) for s in scens)
        total_turns = sum(turn_count(s) for s in scens)
        print(
            f"{split:<8}{len(scens):>12}{total_turns:>10}"
            f"{bucket_dist['solo']:>8}{bucket_dist['short']:>8}{bucket_dist['long']:>8}"
        )

    # Turn-level emotion distribution (what the model actually sees)
    print(f"\n{'emotion':<14}" + "".join(f"{s+'%':>10}" for s in ("train", "val", "test")))
    totals = {}
    dists = {}
    for split, scens in scenarios_per_split.items():
        emo_count = Counter()
        total = 0
        for s in scens:
            for t in s["turns"]:
                if t["text"].strip():
                    emo_count[t["emotion"]] += 1
                    total += 1
        totals[split] = total
        dists[split] = emo_count
    for emo in all_emos_sorted:
        row = f"{emo:<14}"
        for split in ("train", "val", "test"):
            pct = (dists[split][emo] / totals[split] * 100) if totals[split] else 0.0
            row += f"{pct:>9.1f}%"
        print(row)


def main() -> None:
    scenarios = load_all()
    print(f"Loaded {len(scenarios)} scenarios from {len(INPUT_FILES)} input files.")

    groups = group_by_source(scenarios)
    annotate_groups(groups)
    print(f"Unique source groups: {len(groups)}")
    print(f"Bucket distribution: {Counter(g['bucket'] for g in groups.values())}")

    # Stratify within each bucket independently so bucket ratios are preserved
    random.seed(SEED)
    assignment: Dict[str, str] = {}
    for bucket_name in ("solo", "short", "long"):
        bucket_items = [(src, g) for src, g in groups.items() if g["bucket"] == bucket_name]
        random.shuffle(bucket_items)
        sub_assignment = iterative_stratify(bucket_items, RATIOS)
        print(
            f"\n[{bucket_name}] n={len(bucket_items)} "
            f"-> {Counter(sub_assignment.values())}"
        )
        assignment.update(sub_assignment)

    # Build per-split scenario lists
    output: Dict[str, List[dict]] = {"train": [], "val": [], "test": []}
    dropped_augs = 0
    for src, g in groups.items():
        split = assignment[src]
        output[split].append(g["original"])
        if split == "train":
            output[split].extend(g["augmentations"])
        else:
            dropped_augs += len(g["augmentations"])

    print(
        f"\nDropped {dropped_augs} augmentations (sources reassigned to val/test "
        f"contribute only their original)."
    )

    # Backup existing files
    for fn in INPUT_FILES:
        src_path = DATA / fn
        if src_path.exists():
            backup = src_path.with_suffix(src_path.suffix + BACKUP_SUFFIX)
            shutil.copy2(src_path, backup)
            print(f"Backed up: {src_path} -> {backup}")

    # Write new files
    for split, out_name in OUTPUT_FILES.items():
        out_path = DATA / out_name
        with out_path.open("w", encoding="utf-8") as f:
            for s in output[split]:
                f.write(json.dumps(s, ensure_ascii=False) + "\n")
        print(f"Wrote {out_path}: {len(output[split])} scenarios")

    report_distribution(output)


if __name__ == "__main__":
    main()
