from __future__ import annotations

from typing import Dict, List

import numpy as np
import torch
from sklearn.metrics import f1_score

from .config import EMOTION_LABELS, NUM_EMOTIONS


def safe_pearson(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    if y_true.size < 2 or np.std(y_true) < 1e-8 or np.std(y_pred) < 1e-8:
        return 0.0
    with np.errstate(invalid="ignore"):
        r = np.corrcoef(y_true, y_pred)[0, 1]
    return float(r) if np.isfinite(r) else 0.0


@torch.no_grad()
def evaluate(model, loader, device: str) -> Dict:
    model.eval()
    emo_preds: List[np.ndarray] = []
    emo_trues: List[np.ndarray] = []
    vad_preds: List[np.ndarray] = []
    vad_trues: List[np.ndarray] = []
    for batch in loader:
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attn = batch["attention_mask"].to(device, non_blocking=True)
        out = model(input_ids, attn)
        emo_preds.append(out["emotion_logits"].argmax(dim=-1).cpu().numpy())
        emo_trues.append(batch["emotion_id"].numpy())
        vad_preds.append(out["vad"].cpu().numpy())
        vad_trues.append(batch["vad"].numpy())
    emo_pred = np.concatenate(emo_preds)
    emo_true = np.concatenate(emo_trues)
    vad_pred = np.concatenate(vad_preds, axis=0)
    vad_true = np.concatenate(vad_trues, axis=0)

    labels = list(range(NUM_EMOTIONS))
    macro_f1 = float(
        f1_score(emo_true, emo_pred, labels=labels, average="macro", zero_division=0)
    )
    per_class = f1_score(
        emo_true, emo_pred, labels=labels, average=None, zero_division=0
    )
    per_class_dict = {EMOTION_LABELS[i]: float(per_class[i]) for i in range(NUM_EMOTIONS)}

    mae_dims = np.abs(vad_pred - vad_true).mean(axis=0)
    mae_mean = float(mae_dims.mean())
    r_v = safe_pearson(vad_true[:, 0], vad_pred[:, 0])
    r_a = safe_pearson(vad_true[:, 1], vad_pred[:, 1])
    r_d = safe_pearson(vad_true[:, 2], vad_pred[:, 2])

    pred_hist = np.bincount(emo_pred, minlength=NUM_EMOTIONS)
    pred_class_count = int((pred_hist > 0).sum())
    dominant_ratio = float(pred_hist.max() / max(len(emo_pred), 1))
    vad_stds = vad_pred.std(axis=0)

    return {
        "macro_f1": macro_f1,
        "per_class_f1": per_class_dict,
        "vad_mae": [float(mae_dims[0]), float(mae_dims[1]), float(mae_dims[2])],
        "vad_mae_mean": mae_mean,
        "vad_pearson_r": [r_v, r_a, r_d],
        "pred_class_count": pred_class_count,
        "pred_hist": pred_hist.tolist(),
        "dominant_ratio": dominant_ratio,
        "vad_std": [float(vad_stds[0]), float(vad_stds[1]), float(vad_stds[2])],
        "num_samples": int(len(emo_true)),
    }
