from __future__ import annotations

import torch
from torch import nn
from transformers import AutoModel


class KlueTeacherForEmotionVAD(nn.Module):
    def __init__(
        self,
        model_name: str = "klue/roberta-base",
        num_emotions: int = 16,
        vad_dim: int = 3,
        dropout: float = 0.1,
        vad_head_hidden: int = 64,
        attention_dropout: float | None = None,
    ):
        super().__init__()
        # Override backbone dropouts to regularize fine-tuning on small datasets.
        # Pretrained weights are untouched; only dropout probabilities change.
        attn_drop = attention_dropout if attention_dropout is not None else dropout
        self.backbone = AutoModel.from_pretrained(
            model_name,
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=attn_drop,
        )
        h = self.backbone.config.hidden_size
        self.pooler = nn.Sequential(nn.Linear(h, h), nn.Tanh())
        self.emotion_head = nn.Linear(h, num_emotions)
        self.vad_head = nn.Sequential(
            nn.Linear(h, vad_head_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(vad_head_hidden, vad_dim),
        )
        self._init_heads()

    def _init_heads(self) -> None:
        for module in [self.pooler, self.emotion_head, self.vad_head]:
            for m in module.modules():
                if isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, mean=0.0, std=0.02)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> dict:
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        pooled = self.pooler(cls)
        return {
            "emotion_logits": self.emotion_head(pooled),
            "vad": self.vad_head(pooled),
            "pooled": pooled,
            "last_hidden_state": out.last_hidden_state,
        }

    def num_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
