from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional as F

from .config import MicroAlbertConfig


class MicroAlbertEmbeddings(nn.Module):
    def __init__(self, cfg: MicroAlbertConfig):
        super().__init__()
        self.token_emb = nn.Embedding(
            cfg.vocab_size, cfg.embedding_size, padding_idx=cfg.pad_token_id
        )
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.embedding_size)
        self.emb_proj = nn.Linear(cfg.embedding_size, cfg.hidden_size, bias=False)
        self.layer_norm = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.dropout)
        self.register_buffer(
            "position_ids",
            torch.arange(cfg.max_seq_len).unsqueeze(0),
            persistent=False,
        )

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        L = input_ids.size(1)
        pos = self.position_ids[:, :L]
        x = self.token_emb(input_ids) + self.pos_emb(pos)
        x = self.emb_proj(x)
        x = self.layer_norm(x)
        return self.dropout(x)


class MicroAlbertAttention(nn.Module):
    def __init__(self, cfg: MicroAlbertConfig):
        super().__init__()
        assert cfg.hidden_size % cfg.num_heads == 0
        self.num_heads = cfg.num_heads
        self.head_dim = cfg.hidden_size // cfg.num_heads
        self.hidden_size = cfg.hidden_size
        self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size)
        self.out_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
        self.attention_dropout = cfg.attention_dropout

    def forward(self, h: torch.Tensor, attention_mask: torch.Tensor, return_qkv: bool = False):
        B, L, _ = h.shape
        qkv = self.qkv(h)
        q_flat, k_flat, v_flat = qkv.split(self.hidden_size, dim=-1)
        q = q_flat.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = k_flat.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = v_flat.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        # SDPA bool mask: True = keep. Input mask uses 1 = real token.
        assert attention_mask.dtype in (torch.bool, torch.long, torch.int32, torch.int64)
        mask4 = attention_mask.bool()[:, None, None, :]
        drop = self.attention_dropout if self.training else 0.0
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask4, dropout_p=drop)
        out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
        out = self.out_proj(out)
        if return_qkv:
            return out, q_flat, k_flat, v_flat
        return out


class SharedTransformerBlock(nn.Module):
    def __init__(self, cfg: MicroAlbertConfig):
        super().__init__()
        self.attn = MicroAlbertAttention(cfg)
        self.attn_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
        self.ffn_up = nn.Linear(cfg.hidden_size, cfg.ffn_size)
        self.ffn_down = nn.Linear(cfg.ffn_size, cfg.hidden_size)
        self.ffn_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, h: torch.Tensor, attention_mask: torch.Tensor, return_qkv: bool = False):
        if return_qkv:
            attn_out, q, k, v = self.attn(h, attention_mask, return_qkv=True)
        else:
            attn_out = self.attn(h, attention_mask)
        h = self.attn_ln(h + self.dropout(attn_out))
        ffn = self.ffn_down(F.gelu(self.ffn_up(h)))
        h = self.ffn_ln(h + self.dropout(ffn))
        if return_qkv:
            return h, q, k, v
        return h


class MicroAlbertBackbone(nn.Module):
    def __init__(self, cfg: MicroAlbertConfig):
        super().__init__()
        self.cfg = cfg
        self.embeddings = MicroAlbertEmbeddings(cfg)
        self.shared_block = SharedTransformerBlock(cfg)
        self.num_layers = cfg.num_layers
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        std = self.cfg.initializer_range
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.padding_idx is not None:
                with torch.no_grad():
                    module.weight[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, return_last_qkv: bool = False):
        h = self.embeddings(input_ids)
        last_q = last_k = last_v = None
        for i in range(self.num_layers):
            if return_last_qkv and i == self.num_layers - 1:
                h, last_q, last_k, last_v = self.shared_block(h, attention_mask, return_qkv=True)
            else:
                h = self.shared_block(h, attention_mask)
        if return_last_qkv:
            return h, last_q, last_k, last_v
        return h


class MicroAlbertForEmotionVAD(nn.Module):
    def __init__(self, cfg: MicroAlbertConfig):
        super().__init__()
        self.cfg = cfg
        self.backbone = MicroAlbertBackbone(cfg)
        self.pooler = nn.Sequential(nn.Linear(cfg.hidden_size, cfg.hidden_size), nn.Tanh())
        self.emotion_head = nn.Linear(cfg.hidden_size, cfg.num_emotions)
        # VAD head: no Tanh at output. Saturation near ±0.95 anchors suppresses gradients 5–10×.
        self.vad_head = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.vad_head_hidden),
            nn.GELU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(cfg.vad_head_hidden, cfg.vad_dim),
        )
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        std = self.cfg.initializer_range
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.padding_idx is not None:
                with torch.no_grad():
                    module.weight[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> dict:
        h = self.backbone(input_ids, attention_mask)
        cls = h[:, 0, :]
        pooled = self.pooler(cls)
        return {
            "emotion_logits": self.emotion_head(pooled),
            "vad": self.vad_head(pooled),
            "pooled": pooled,
        }

    def num_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
