from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple

REPO_ROOT = Path(__file__).resolve().parents[2]
LABELS_PATH = REPO_ROOT / "data" / "emotion" / "emotion_labels.json"
ANCHORS_PATH = REPO_ROOT / "data" / "emotion" / "emotion_vad_anchors.json"


def _load_labels() -> Tuple[str, ...]:
    data = json.loads(LABELS_PATH.read_text(encoding="utf-8"))["labels"]
    data = sorted(data, key=lambda x: x["id"])
    ids = [x["id"] for x in data]
    assert ids == list(range(16)), f"emotion_labels.json ids must be 0..15, got {ids}"
    return tuple(x["name"] for x in data)


EMOTION_LABELS: Tuple[str, ...] = _load_labels()
EMOTION_TO_ID: Dict[str, int] = {name: i for i, name in enumerate(EMOTION_LABELS)}
ID_TO_EMOTION: Dict[int, str] = {i: name for name, i in EMOTION_TO_ID.items()}
NUM_EMOTIONS = len(EMOTION_LABELS)
assert NUM_EMOTIONS == 16


@dataclass(frozen=True)
class MicroAlbertConfig:
    vocab_size: int = 16000
    pad_token_id: int = 0
    embedding_size: int = 128
    hidden_size: int = 384
    num_layers: int = 6
    num_heads: int = 6
    ffn_size: int = 1536
    max_seq_len: int = 128
    dropout: float = 0.2
    attention_dropout: float = 0.15
    layer_norm_eps: float = 1e-12
    num_emotions: int = 16
    vad_dim: int = 3
    vad_head_hidden: int = 64
    initializer_range: float = 0.02

    label_smoothing: float = 0.15
    vad_loss_weight_max: float = 1.5
    vad_warmup_epochs: int = 4
    vad_dim_weights: Tuple[float, float, float] = (1.0, 1.0, 0.5)
    snap_loss_weight: float = 0.2
    snap_start_epoch: int = 2
    snap_conf_threshold: float = 0.25
    snap_softmax_temp: float = 2.0
    snap_level: int = 3
    range_loss_weight: float = 0.05
    range_target_std: float = 0.35
    smooth_l1_beta: float = 0.1
    class_weight_clip: Tuple[float, float] = (0.5, 3.0)

    backbone_lr: float = 2e-4
    head_lr: float = 1e-3
    weight_decay: float = 0.02
    adam_betas: Tuple[float, float] = (0.9, 0.98)
    adam_eps: float = 1e-6
    grad_clip: float = 1.0
    warmup_ratio: float = 0.10

    seed: int = 42
