from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from .config import ANCHORS_PATH, EMOTION_LABELS, EMOTION_TO_ID, MicroAlbertConfig


def build_anchor_table(level: int) -> torch.Tensor:
    data = json.loads(ANCHORS_PATH.read_text(encoding="utf-8"))["anchors"]
    rows = []
    for name in EMOTION_LABELS:
        entries = data[name]
        match = [e for e in entries if e["level"] == level]
        assert len(match) == 1, f"expected exactly one level-{level} entry for {name}"
        rows.append(match[0]["vad"])
    return torch.tensor(rows, dtype=torch.float32)


def build_data_centroid_anchors(
    train_jsonl: Path, num_emotions: int, fallback_level: int = 3
) -> torch.Tensor:
    """Compute per-emotion VAD centroids from training data.

    Replaces hand-curated anchors (which were 0.25-0.40 from data means)
    with empirical class centroids — snap loss now pulls toward what the
    data actually shows for each class. Falls back to JSON level-3 anchors
    for any class with zero training samples.
    """
    sums = np.zeros((num_emotions, 3), dtype=np.float64)
    counts = np.zeros(num_emotions, dtype=np.int64)
    with Path(train_jsonl).open(encoding="utf-8") as f:
        for line in f:
            row = json.loads(line)
            for t in row["turns"]:
                emo_id = EMOTION_TO_ID.get(t["emotion"])
                if emo_id is None:
                    continue
                vad = t["vad"]
                if len(vad) != 3:
                    continue
                sums[emo_id] += vad
                counts[emo_id] += 1
    centroids = sums / counts[:, None].clip(min=1)
    # For any class with zero samples, fall back to JSON anchor
    if (counts == 0).any():
        json_anchors = build_anchor_table(fallback_level).numpy()
        for i in range(num_emotions):
            if counts[i] == 0:
                centroids[i] = json_anchors[i]
    return torch.tensor(centroids, dtype=torch.float32)


def compute_class_weights(
    emotion_ids: np.ndarray,
    num_classes: int,
    clip_range: Tuple[float, float],
) -> torch.Tensor:
    from sklearn.utils.class_weight import compute_class_weight

    w = compute_class_weight(
        class_weight="balanced",
        classes=np.arange(num_classes),
        y=emotion_ids,
    )
    w = np.clip(w, clip_range[0], clip_range[1])
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32)


class MultitaskLoss(nn.Module):
    def __init__(
        self,
        cfg: MicroAlbertConfig,
        class_weights: torch.Tensor,
        anchors: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.cfg = cfg
        self.register_buffer("class_weights", class_weights.float())
        # Default to JSON anchors for backward compat; pass data centroids for snap fix.
        if anchors is None:
            anchors = build_anchor_table(cfg.snap_level)
        self.register_buffer("anchors", anchors)
        self.register_buffer(
            "vad_dim_weights", torch.tensor(cfg.vad_dim_weights, dtype=torch.float32)
        )
        self._dim_w_sum = float(sum(cfg.vad_dim_weights))

    def current_weights(self, epoch: int) -> Tuple[float, float]:
        cfg = self.cfg
        ramp = min(1.0, (epoch + 1) / max(1, cfg.vad_warmup_epochs))
        w_vad = cfg.vad_loss_weight_max * ramp
        w_snap = cfg.snap_loss_weight if epoch >= cfg.snap_start_epoch else 0.0
        return w_vad, w_snap

    def forward(
        self,
        emotion_logits: torch.Tensor,
        vad_pred: torch.Tensor,
        emotion_target: torch.Tensor,
        vad_target: torch.Tensor,
        epoch: int,
    ) -> Tuple[torch.Tensor, Dict[str, float], Dict[str, float]]:
        cfg = self.cfg

        l_ce = F.cross_entropy(
            emotion_logits,
            emotion_target,
            weight=self.class_weights,
            label_smoothing=cfg.label_smoothing,
        )

        diff = F.smooth_l1_loss(
            vad_pred, vad_target, reduction="none", beta=cfg.smooth_l1_beta
        )
        l_vad = (diff * self.vad_dim_weights).sum(dim=-1).mean() / self._dim_w_sum

        # Snap: classifier gradient blocked via detach; backbone+pooler+vad_head still update
        # via pred_vad → L_snap path. Top-1 argmax anchor, confidence gated.
        soft_probs = F.softmax(emotion_logits.detach() / cfg.snap_softmax_temp, dim=-1)
        top_prob, top_idx = soft_probs.max(dim=-1)
        gate = (top_prob >= cfg.snap_conf_threshold).float()
        anchor_target = self.anchors[top_idx]
        snap_per = F.smooth_l1_loss(
            vad_pred, anchor_target, reduction="none", beta=cfg.smooth_l1_beta
        ).mean(dim=-1)
        l_snap = (snap_per * gate).sum() / gate.sum().clamp(min=1.0)
        gate_rate = float(gate.mean().item())

        # L_range: diversity guard, not a clamp. tanh removed from vad_head, so spread is learned.
        if vad_pred.size(0) >= 2:
            batch_std = vad_pred.std(dim=0)
            l_range = F.relu(cfg.range_target_std - batch_std).pow(2).mean()
        else:
            l_range = vad_pred.new_zeros(())

        w_vad, w_snap = self.current_weights(epoch)

        total = (
            l_ce
            + w_vad * l_vad
            + w_snap * l_snap
            + cfg.range_loss_weight * l_range
        )

        components = {
            "ce": float(l_ce.detach().item()),
            "vad": float(l_vad.detach().item()),
            "snap": float(l_snap.detach().item()),
            "range": float(l_range.detach().item()),
        }
        aux = {"w_vad": float(w_vad), "w_snap": float(w_snap), "gate_rate": gate_rate}
        return total, components, aux
