"""Thin wrapper around LAM Audio2Expression for generating lipsync targets.

Model location: /dataset/text-to-face-se/LAM_Audio2Expression
Checkpoint:     checkpoints_황준희_lr5e-6/best.pt
"""
from __future__ import annotations

import os
import sys
from pathlib import Path

import librosa
import numpy as np
import torch

LAM_DIR = Path("/dataset/text-to-face-se/LAM_Audio2Expression")
LAM_CHECKPOINT = LAM_DIR / "checkpoints_황준희_lr5e-6" / "best.pt"


class LAMWrapper:
    def __init__(self, checkpoint_path: Path = LAM_CHECKPOINT, device: str = None):
        if not LAM_DIR.exists():
            raise FileNotFoundError(f"LAM_DIR not found: {LAM_DIR}")
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"LAM checkpoint not found: {checkpoint_path}")

        sys.path.insert(0, str(LAM_DIR))
        self._orig_cwd = os.getcwd()
        os.chdir(str(LAM_DIR))
        try:
            from models.network import Audio2Expression
        finally:
            os.chdir(self._orig_cwd)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # Reproducing loader from training/precompute_lam_outputs.py
        os.chdir(str(LAM_DIR))
        try:
            self.model = Audio2Expression(
                pretrained_encoder_type="wav2vec",
                pretrained_encoder_path="facebook/wav2vec2-base-960h",
                wav2vec2_config_path="configs/wav2vec2_config.json",
                num_identity_classes=12,
                identity_feat_dim=64,
                hidden_dim=512,
                expression_dim=52,
                norm_type="ln",
                use_transformer=False,
            ).to(self.device)

            ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
            state_dict = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
            new_state_dict = {
                k.replace("module.", "").replace("backbone.", ""): v
                for k, v in state_dict.items()
            }
            self.model.load_state_dict(new_state_dict, strict=False)
            self.model.eval()
        finally:
            os.chdir(self._orig_cwd)

    @torch.no_grad()
    def infer_audio(self, audio_path: Path, fps: int = 30) -> np.ndarray:
        """Run LAM on an audio file.

        Returns:
            (T, 52) float32 blendshapes at `fps` frame rate
        """
        import math
        wav, sr = librosa.load(str(audio_path), sr=16000, mono=True)
        wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).to(self.device)
        num_frames = math.ceil(len(wav) / sr * fps)

        input_dict = {
            "input_audio_array": wav_tensor,
            "id_idx": torch.zeros(1, 12).to(self.device),
            "time_steps": num_frames,
        }
        output = self.model(input_dict)
        bs = output.squeeze(0).cpu().numpy().astype(np.float32)
        return bs

    def infer_wav_array(self, wav: np.ndarray, sr: int = 16000) -> np.ndarray:
        """Run LAM on a raw audio array (already at 16kHz mono)."""
        import tempfile
        import soundfile as sf
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            sf.write(f.name, wav, sr)
            return self.infer_audio(Path(f.name))
