from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Dict, List, Set, Tuple

import torch
from torch.utils.data import Dataset

from .config import EMOTION_TO_ID
from .tokenizer import MicroTokenizer


# Polarity sanity rules: each emotion's VAD must match its semantic polarity.
# Catches data drift where augmentation/relabel produces emotion↔VAD polarity mismatch
# (e.g. "crying" with V>0 — happy tears — creates adversarial CE/VAD/snap signals).
# Soft-checked at load time: violations warn to stderr, never raise.
_POLARITY_RULES: Dict[str, Dict[str, float]] = {
    "joy": {"V_min": -0.1}, "laughter": {"V_min": 0.0},
    "excitement": {"V_min": -0.1}, "agreement": {"V_min": -0.1},
    "gratitude": {"V_min": 0.0},
    "sadness": {"V_max": 0.0}, "crying": {"V_max": 0.0},
    "sulk": {"V_max": 0.1}, "apology": {"V_max": 0.1},
    "struggle": {"V_max": 0.1},
    "anger": {"V_max": 0.0, "A_min": 0.0}, "refusal": {"V_max": 0.1},
    "surprise": {"A_min": 0.0}, "fluster": {"A_min": 0.0},
    "neutral": {}, "shy": {},
}


def _polarity_violation(emotion: str, vad: Tuple[float, float, float]) -> bool:
    rule = _POLARITY_RULES.get(emotion, {})
    V, A, _D = vad
    if "V_min" in rule and V < rule["V_min"]:
        return True
    if "V_max" in rule and V > rule["V_max"]:
        return True
    if "A_min" in rule and A < rule["A_min"]:
        return True
    return False


class SeedEmotionDataset(Dataset):
    def __init__(
        self,
        jsonl_path: Path,
        tokenizer: MicroTokenizer,
        max_seq_len: int,
        context_window: int = 0,
    ):
        self.tok = tokenizer
        self.max_seq_len = max_seq_len
        self.context_window = context_window
        # sample: (prev_turns: List[(text, speaker)], curr_text, curr_speaker, emo_id, vad)
        self.samples: List[Tuple[List[Tuple[str, str]], str, str, int, Tuple[float, float, float]]] = []
        self.source_ids: Set[str] = set()
        polarity_warnings: List[Tuple[str, str, Tuple[float, float, float]]] = []
        with Path(jsonl_path).open(encoding="utf-8") as f:
            for line in f:
                row = json.loads(line)
                src = row.get("source_scenario_id") or row["scenario_id"]
                self.source_ids.add(src)
                turns = row["turns"]
                for i, t in enumerate(turns):
                    text = t["text"].strip()
                    if not text:
                        continue
                    emo = t["emotion"]
                    if emo not in EMOTION_TO_ID:
                        raise ValueError(
                            f"unknown emotion {emo!r} at scenario {row['scenario_id']}"
                        )
                    vad = t["vad"]
                    assert len(vad) == 3, f"vad must be length-3, got {vad}"
                    vad_tuple = (float(vad[0]), float(vad[1]), float(vad[2]))
                    if _polarity_violation(emo, vad_tuple):
                        polarity_warnings.append((row["scenario_id"], emo, vad_tuple))
                    curr_speaker = t.get("speaker", "?")
                    prev: List[Tuple[str, str]] = []
                    if context_window > 0:
                        start = max(0, i - context_window)
                        for pt in turns[start:i]:
                            ptx = pt["text"].strip()
                            if ptx:
                                prev.append((ptx, pt.get("speaker", "?")))
                    self.samples.append(
                        (prev, text, curr_speaker, EMOTION_TO_ID[emo], vad_tuple)
                    )
        if polarity_warnings:
            print(
                f"[dataset] WARNING: {len(polarity_warnings)} emotion↔VAD polarity "
                f"violations in {jsonl_path} (first 5):",
                file=sys.stderr,
            )
            for sid, emo, vad in polarity_warnings[:5]:
                print(f"  [{sid}] {emo} V={vad[0]:+.2f} A={vad[1]:+.2f} D={vad[2]:+.2f}",
                      file=sys.stderr)

    def __len__(self) -> int:
        return len(self.samples)

    def _compose_text(self, prev: List[Tuple[str, str]], curr_text: str, curr_speaker: str) -> str:
        # Always mark current turn with [SELF] for a consistent "prediction target"
        # signal. Solo utterances and first turns become the no-context case of the
        # same rule rather than a separate distribution.
        if self.context_window <= 0:
            return curr_text
        if not prev:
            return f"[SELF] {curr_text}"
        parts: List[str] = []
        for ptx, pspk in prev:
            marker = "[SELF]" if pspk == curr_speaker else "[OTHER]"
            parts.append(f"{marker} {ptx}")
        parts.append(f"[SELF] {curr_text}")
        return " [SEP] ".join(parts)

    def __getitem__(self, idx: int) -> Dict:
        prev, curr_text, curr_speaker, emo_id, vad = self.samples[idx]
        text = self._compose_text(prev, curr_text, curr_speaker)
        ids = self.tok.encode(text, self.max_seq_len)
        return {"input_ids": ids, "emotion_id": emo_id, "vad": list(vad)}

    def emotion_ids(self) -> List[int]:
        return [s[3] for s in self.samples]


def collate_fn(batch: List[Dict], pad_id: int) -> Dict[str, torch.Tensor]:
    max_len = max(len(b["input_ids"]) for b in batch)
    input_ids: List[List[int]] = []
    attn_mask: List[List[int]] = []
    for b in batch:
        ids = b["input_ids"]
        pad_n = max_len - len(ids)
        input_ids.append(ids + [pad_id] * pad_n)
        attn_mask.append([1] * len(ids) + [0] * pad_n)
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
        "emotion_id": torch.tensor([b["emotion_id"] for b in batch], dtype=torch.long),
        "vad": torch.tensor([b["vad"] for b in batch], dtype=torch.float32),
    }
