"""V3 face model — split-branch causal TCN with FiLM conditioning.

Architecture:

    audio (T, 80) ⊕ cond (T, 19)  →  Linear(99 → hidden)
                                      │
                                      ▼
            shared backbone: 6× DilatedCausalConv1d (d=1..32) + FiLM
                                      │
                       ┌──────────────┴──────────────┐
                       ▼                             ▼
       lipsync branch                       expression branch
       2× TCN (d=64,128)                    2× TCN (d=64,128)
                       ▼                             ▼
       Linear(hidden → 31)                  Linear(hidden → 21)
       sigmoid                              sigmoid
                       │                             │
                       └──────────┬──────────────────┘
                                  ▼
                       combined (T, 52) blendshape output
                       (lipsync values at LIPSYNC + SHARED indices,
                        expression values at EXPRESSION_ONLY indices)

Freezing:
    model.freeze_lipsync()  →  shared backbone + lipsync branch + lipsync
                                head are no_grad. Expression branch +
                                head remain trainable. Lipsync output
                                becomes bit-for-bit deterministic from
                                audio input.

~3.7 M params at hidden=192, ~1.2 ms/frame CPU.
Quantization-friendly (Conv1d + Linear + GELU + sigmoid, no attention).
"""
from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional as F

from scripts.compiler.constants import (
    LIPSYNC_ONLY, EXPRESSION_ONLY, SHARED_CHANNELS,
)

from .config import V3FaceConfig


class FiLM(nn.Module):
    """Per-frame Feature-wise Linear Modulation conditioning.

    Predicts (γ, β) from `cond` and applies `x * (1 + γ) + β`.
    """

    def __init__(self, cond_dim: int, hidden_dim: int):
        super().__init__()
        self.proj = nn.Linear(cond_dim, 2 * hidden_dim)
        # Initialize so γ ≈ 0, β ≈ 0 → identity at start
        nn.init.zeros_(self.proj.weight)
        nn.init.zeros_(self.proj.bias)

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        gb = self.proj(cond)
        gamma, beta = gb.chunk(2, dim=-1)
        return x * (1.0 + gamma) + beta


class CausalConv1d(nn.Module):
    """Left-padded 1D conv that never looks at future frames."""

    def __init__(self, channels: int, kernel_size: int, dilation: int):
        super().__init__()
        self.left_pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(channels, channels, kernel_size,
                              dilation=dilation, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.pad(x, (self.left_pad, 0))
        return self.conv(x)


class TCNBlock(nn.Module):
    """Residual block: 2× CausalConv1d + FiLM + GELU + dropout."""

    def __init__(self, hidden_dim: int, kernel_size: int, dilation: int,
                 cond_dim: int, dropout: float):
        super().__init__()
        self.conv1 = CausalConv1d(hidden_dim, kernel_size, dilation)
        self.conv2 = CausalConv1d(hidden_dim, kernel_size, dilation)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.film = FiLM(cond_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        residual = x

        h = x.transpose(1, 2)
        h = self.conv1(h)
        h = h.transpose(1, 2)
        h = self.norm1(h)
        h = F.gelu(h)
        h = self.dropout(h)

        h = h.transpose(1, 2)
        h = self.conv2(h)
        h = h.transpose(1, 2)
        h = self.norm2(h)
        h = self.film(h, cond)
        h = F.gelu(h)
        h = self.dropout(h)

        return residual + h


# Channel split (constant across V3 lifetime — order matches ARKIT_52_NAMES).
# Lipsync branch handles audio-driven mouth/jaw motion. Expression branch
# handles emotion/prosody-driven brow/cheek/eye motion.
LIPSYNC_BRANCH_CHANNELS = sorted(set(LIPSYNC_ONLY) | set(SHARED_CHANNELS))  # 31
EXPRESSION_BRANCH_CHANNELS = sorted(set(EXPRESSION_ONLY))                    # 21


class V3FaceModel(nn.Module):
    """Split-branch causal TCN: (audio, cond) → (T, 52) blendshapes."""

    def __init__(self, cfg: V3FaceConfig):
        super().__init__()
        self.cfg = cfg

        # Channel index buffers (registered so they move to model device)
        self.register_buffer("lipsync_idx",
                             torch.tensor(LIPSYNC_BRANCH_CHANNELS, dtype=torch.long))
        self.register_buffer("expression_idx",
                             torch.tensor(EXPRESSION_BRANCH_CHANNELS, dtype=torch.long))

        n_lipsync = len(LIPSYNC_BRANCH_CHANNELS)
        n_expression = len(EXPRESSION_BRANCH_CHANNELS)
        assert n_lipsync + n_expression == cfg.output_dim, \
            f"channel split mismatch: {n_lipsync} + {n_expression} != {cfg.output_dim}"

        # ─── Shared encoder ───────────────────────────────────────────
        self.input_proj = nn.Linear(cfg.audio_dim + cfg.cond_dim, cfg.hidden_dim)
        self.shared_blocks = nn.ModuleList([
            TCNBlock(cfg.hidden_dim, cfg.kernel_size, d, cfg.cond_dim, cfg.dropout)
            for d in cfg.shared_dilations
        ])

        # ─── Lipsync branch ───────────────────────────────────────────
        self.lipsync_blocks = nn.ModuleList([
            TCNBlock(cfg.hidden_dim, cfg.kernel_size, d, cfg.cond_dim, cfg.dropout)
            for d in cfg.branch_dilations
        ])
        self.lipsync_head = nn.Linear(cfg.hidden_dim, n_lipsync)

        # ─── Expression branch ────────────────────────────────────────
        self.expression_blocks = nn.ModuleList([
            TCNBlock(cfg.hidden_dim, cfg.kernel_size, d, cfg.cond_dim, cfg.dropout)
            for d in cfg.branch_dilations
        ])
        self.expression_head = nn.Linear(cfg.hidden_dim, n_expression)

    def forward(self, audio: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        Args:
            audio: (B, T, audio_dim)
            cond:  (B, T, cond_dim)
        Returns:
            blendshapes: (B, T, 52) in [0, 1]
        """
        x = torch.cat([audio, cond], dim=-1)
        x = self.input_proj(x)

        for block in self.shared_blocks:
            x = block(x, cond)

        # Lipsync branch
        lx = x
        for block in self.lipsync_blocks:
            lx = block(lx, cond)
        lipsync_out = torch.sigmoid(self.lipsync_head(lx))   # (B, T, 31)

        # Expression branch
        ex = x
        for block in self.expression_blocks:
            ex = block(ex, cond)
        expression_out = torch.sigmoid(self.expression_head(ex))  # (B, T, 21)

        # Combine: scatter each branch's output into the right ARKit channels
        B, T = audio.shape[0], audio.shape[1]
        out = torch.zeros(B, T, self.cfg.output_dim,
                          device=audio.device, dtype=lipsync_out.dtype)
        out[..., self.lipsync_idx] = lipsync_out
        out[..., self.expression_idx] = expression_out
        return out

    def freeze_lipsync(self) -> int:
        """Freeze shared backbone + lipsync branch + lipsync head.

        Sets `requires_grad=False` on every parameter that contributes to
        the lipsync output path. After this, lipsync output is bit-for-bit
        deterministic from audio + cond — the expression branch can be
        retrained without ANY drift in lipsync.

        Returns the number of frozen parameters.
        """
        frozen = 0
        for p in self.input_proj.parameters():
            p.requires_grad = False
            frozen += p.numel()
        for block in self.shared_blocks:
            for p in block.parameters():
                p.requires_grad = False
                frozen += p.numel()
        for block in self.lipsync_blocks:
            for p in block.parameters():
                p.requires_grad = False
                frozen += p.numel()
        for p in self.lipsync_head.parameters():
            p.requires_grad = False
            frozen += p.numel()
        return frozen

    @property
    def n_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    @property
    def n_trainable(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    @property
    def size_mb(self) -> float:
        """Float32 disk size in MB."""
        return self.n_params * 4 / (1024 * 1024)
