from typing import Dict, List, Tuple
from einops import rearrange
from sympy import Q
from torch import Tensor, nn
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
import torch.nn.functional as F
from transformers import T5Tokenizer, UMT5EncoderModel, Wav2Vec2Processor
from mmengine.device import get_device
from diffusers import AutoencoderKLWan
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)
from mmhug.models.custom_transformers.slipmae.transformer_slipmae_encoder import (
    SlipmaeEncoder,
)
from mmhug.models.custom_transformers.slipmae_wanwace.transformer3d_wanvace_audiopack import (
    AudiopackWanVACETransformer3DModel,
)
from mmhug.models.custom_transformers.wav2vec2_interp.wav2vec2_interp import (
    Wav2Vec2InterpModel,
)
from mmhug.registry import HF_MODELS, MODELS
from mmhug.schedulers.timestep_samplers.dist_uniform_timestepsampler import (
    DistUniformTimestepSampler,
)
from mmhug.schedulers.utils.timestep_utils import get_sigmas
from mmhug.trainers.base_trainer_model import BaseTrainerModel
from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_heatmap,
    keepface_mask_prob_from_kpts,
)
from mmhug.trainers.trainer_wan.trainer_wan22_ti2v import Wan22TI2VTrainer
from mmhug.utils.dtype_utils import dtype_from_str
from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308


def fuse_sample(
    a: torch.Tensor, b: torch.Tensor, p: float, generator: torch.Generator | None = None
):
    assert a.shape == b.shape and a.ndim == 3
    B, N, C = a.shape
    mask = torch.rand((B, 1, 1), device=a.device, generator=generator) < p  # [B,1,1]
    out = torch.where(mask, a, b)  # [B,N,C]；mask 会按广播规则扩展
    return out, mask.squeeze(-1).squeeze(-1)


from torch.nn.attention import sdpa_kernel, SDPBackend
import torch.nn.functional as F

# 全局：在你的训练入口（导入模型后、开始训练前）
SDPA_CTX = sdpa_kernel(
    SDPBackend.MATH
)  # 也可以 [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
SDPA_CTX.__enter__()  # 训练结束记得 __exit__()


@MODELS.register_module()
class WanVaceSlipmaeTrainerV5(Wan22TI2VTrainer):
    def train(self, mode: bool = True):
        super().train(mode)
        if mode:
            self.enable_transformer_train()

    def enable_transformer_train(self):
        if not self.vace_only:
            self.transformer.train()
            self.transformer.requires_grad_(True)
        else:
            self.transformer.eval()
            self.transformer.requires_grad_(False)
            self.transformer.audio_pack.train()
            self.transformer.audio_pack.requires_grad_(True)
            self.transformer.audio_proj_in.train()
            self.transformer.audio_proj_in.requires_grad_(True)
            self.transformer.vace_blocks.train()
            self.transformer.vace_blocks.requires_grad_(True)
            self.transformer.vace_patch_embedding.train()
            self.transformer.vace_patch_embedding.requires_grad_(True)

    def __init__(
        self,
        vae,
        transformer,
        # text
        tokenizer,
        text_encoder,
        # audio
        audio_processor,
        audio_encoder,
        audio_adapter,
        # vocal motion encoding
        slipmae_encoder,
        scheduler,
        # keypoint detection
        raw_sapiens: SapiensVisionTransformer = None,
        heatmap_head: HeatmapHead = None,
        # training settings
        prompt_drop_rate: float = 0.1,
        audio_drop_rate: float = 0.0,
        first_frame_drop_rate: float = 0.1,
        use_visual_rate: float = 0.5,
        vace_only: bool = False,
        init_cfg=None,
    ):

        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
        torch.backends.cuda.enable_math_sdp(True)

        device = get_device()
        BaseTrainerModel.__init__(self, init_cfg=init_cfg)

        # vae
        vae["from_pretrained"]["device_map"] = device
        self.vae: AutoencoderKLWan = HF_MODELS.build(vae)
        self.register_buffer(
            "latents_mean",
            torch.tensor(self.vae.config.latents_mean, dtype=torch.float32).view(
                1, self.vae.config.z_dim, 1, 1, 1
            ),
        )
        self.register_buffer(
            "latents_std",
            torch.tensor(self.vae.config.latents_std, dtype=torch.float32).view(
                1, self.vae.config.z_dim, 1, 1, 1
            ),
        )

        # wanvace transformer
        self.transformer: AudiopackWanVACETransformer3DModel = HF_MODELS.build(
            transformer
        )

        # text encoder
        self.tokenizer: T5Tokenizer = HF_MODELS.build(tokenizer)
        text_encoder["from_pretrained"]["device_map"] = device
        self.text_encoder: UMT5EncoderModel = HF_MODELS.build(text_encoder)

        # audio encoder
        self.audio_encoder: Wav2Vec2InterpModel = HF_MODELS.build(audio_encoder)
        self.audio_processor: Wav2Vec2Processor = HF_MODELS.build(audio_processor)

        self.use_audio_adapter = True
        self.audio_adapter = HF_MODELS.build(audio_adapter)

        # vocal visual encoder
        self.realtime_pose_estimate = False
        if raw_sapiens is not None:
            assert heatmap_head is not None
            self.raw_sapiens: SapiensVisionTransformer = HF_MODELS.build(raw_sapiens)
            self.heatmap_head: HeatmapHead = HF_MODELS.build(heatmap_head)
            self.realtime_pose_estimate = True

        self.slipmae_encoder: SlipmaeEncoder = HF_MODELS.build(slipmae_encoder)

        # scheduler
        self.scheduler: FlowMatchEulerDiscreteScheduler = HF_MODELS.build(scheduler)

        self.timestep_sampler = DistUniformTimestepSampler(
            1000, uniform_sampling=True, start_num_idx=0
        )

        # Only transformer trainable
        self.eval()
        self.requires_grad_(False)
        self.vace_only = vace_only

        self.enable_transformer_train()

        self.prompt_drop_rate = prompt_drop_rate
        self.use_visual_rate = use_visual_rate
        self.audio_drop_rate = audio_drop_rate
        self.first_frame_drop_rate = first_frame_drop_rate

        self.to(torch.float32)
        self.check_no_meta_params()

        self._collect_trainable_params()
        torch.cuda.empty_cache()

    def check_no_meta_params(module: nn.Module) -> Tuple[bool, List[str]]:
        meta_params = []
        for name, param in module.named_parameters():
            if param.device.type == "meta":
                meta_params.append(name)
        return (len(meta_params) == 0, meta_params)

    def forward_loss(self, batch: Dict[str, Dict[str, Tensor]]) -> Tensor:
        # ─── 0. Loading from batch ──────────────────────────────────────────────────
        video = batch["video"]  # [B, T, C, H, W]
        first_frame = video[:, :1, :, :, :]
        video = video[:, 1:]

        assert (
            video.shape[1] - 1
        ) % self.vae.config.scale_factor_temporal == 0, f"video length {video.shape[1]} is not divisible by {self.vae.config.scale_factor_temporal}"
        ref_img = batch["ref_img"]
        num_ref_img = ref_img.shape[1]

        B = video.shape[0]
        T = video.shape[1]
        video = rearrange(video, "b t c h w -> b c t h w")
        ref_img = rearrange(ref_img, "b n c h w -> (b n) c 1 h w")
        first_frame = rearrange(first_frame, "b t c h w -> b c t h w")

        # random drop first frame
        if self.training and self.first_frame_drop_rate > 0:
            if torch.rand(1).item() < self.first_frame_drop_rate:
                first_frame = ref_img

        keypoint = batch.get("keypoint", None)
        if keypoint is not None:
            keypoint = keypoint[:, 1:]

        # b c 1 h w
        # b t h w -> b 1 t h w. 0 for not mask, 1 for mask
        mask = batch["mask"].unsqueeze(1)
        mask = mask[:, :, 1:]
        audio = batch["audio"]  # [B, T_audio]
        sr = int(batch["audio_metadata"].get("sr", [16000] * B)[0])
        captions = batch["caption"]

        # ─── 1. Video & Reference img Encoding ───────────────────────────────────────────────
        # b c t h w
        latents = self.encode_video(video)
        ref_latents = self.encode_video(ref_img)
        ref_latents = rearrange(
            ref_latents, "(b n) c 1 h w -> b c n h w", n=num_ref_img
        )
        first_frame_latents = self.encode_video(first_frame)
        ref_latents = torch.cat([ref_latents, first_frame_latents], dim=2)
        latents = torch.concat([ref_latents, latents], dim=2)

        # ─── 2. Vocal feature Encoding ───────────────────────────────────────────────
        # Estimate corresponding video length in frames
        # B, T_down, C_audio. T_down is equal to T_down after vae downsampling and WAN transformer's patchify

        if self.use_visual_rate > 0:
            if torch.rand(1).item() < self.use_visual_rate:
                vocal_hidden_states = self.encode_audio(audio, video_length=T, sr=sr)
            else:
                vocal_hidden_states = self.encode_visual(video, keypoint)
        else:
            vocal_hidden_states = self.encode_audio(audio, video_length=T, sr=sr)

        # random drop vocal hidden states
        if self.audio_drop_rate > 0:
            if torch.rand(1).item() < self.audio_drop_rate:
                vocal_hidden_states = torch.zeros_like(vocal_hidden_states)

        # ─── 3. Text Encoding ────────────────────────────────────────────────
        # Randomly drop the prompt with a probability of prompt_drop_rate

        text_states = self.encode_prompt(captions)

        # ─── 4. Prepare Vace Conditions ─────────────────────────────────────────────────────
        # b 96 t+1 h w
        vace_condition_latents = self.prepare_vace_condition(video, mask, ref_latents)

        # ─── 5. Noise Sampling ───────────────────────────────────────────────
        indices = self.timestep_sampler(
            B, device=self.scheduler.timesteps.device
        ).long()
        timesteps = self.scheduler.timesteps[indices].to(device=latents.device)

        # ─── 6. Add Noise ─────────────────────────────────────────────────────
        noise = torch.randn_like(latents)
        sigmas = get_sigmas(
            self.scheduler, timesteps, n_dim=latents.ndim, dtype=latents.dtype
        )
        noisy_latents = (1 - sigmas) * latents + sigmas * noise
        targets = noise - latents

        # ─── 7. Forward ─────────────────────────────────────────────────────
        model_pred = self.transformer(
            hidden_states=noisy_latents,
            timestep=timesteps,
            encoder_hidden_states=text_states,
            audio_hidden_states=vocal_hidden_states,
            control_hidden_states=vace_condition_latents,
        ).sample

        # ─── 8. Loss ─────────────────────────────────────────────────────
        # reference frame should not be predicted
        loss = F.mse_loss(model_pred.float(), targets.float())
        # loss = F.mse_loss(model_pred, targets)

        return {"loss": loss}

    @torch.no_grad()
    def encode_audio(self, audio: Tensor, video_length: int, sr: int) -> Tensor:
        """
        Encode raw audio waveforms into hidden representations using an audio encoder,
        optionally aligned with the expected video length.

        This method performs two steps:
        1. Use `self.audio_processor` (e.g., Wav2Vec2Processor or WhisperProcessor)
            to convert waveform into model-ready input tensors.
        2. Pass the processed audio into `self.audio_encoder` to obtain hidden states.

        Note:
            The method supports alignment of the audio features with a target video length
            in frames via the `seq_len` argument, useful for cross-modal generation.

        Args:
            audio (Tensor):
                A batch of raw audio waveforms of shape [B, T_audio], where
                B = batch size, T_audio = number of audio samples per example.
            video_length (int):
                The length of the corresponding video in frames (used for alignment).
            sr (int):
                Sampling rate of the audio (e.g., 16000 Hz).

        Returns:
            Tensor:
                A tensor of shape [B, L, D], where L is the number of audio tokens
                (aligned to `video_length` if supported by the encoder),
                and D is the audio feature dimension.
        """
        # Step 1: Tokenize raw waveform using audio processor (e.g., Whisper/Wav2Vec2)
        # Output: audio input tensor of shape [B, N] where N ≈ duration_in_seconds × sr
        B, N = audio.shape

        audio_inputs = self.audio_processor(
            audio.float(), sampling_rate=sr, return_tensors="pt"
        ).input_values.to(
            device=audio.device,
            dtype=dtype_from_str(audio.dtype),
        )
        # assert no other dimensions were added
        audio_inputs = audio_inputs.view(B, N)
        # Step 2: Encode audio to get hidden representations
        # `seq_len` optionally aligns output length with video frame count
        audio_encoder_output = self.audio_encoder(
            audio_inputs, seq_len=int(video_length), output_hidden_states=True
        )

        # Step 3: Select audio hidden states
        # concat all layers output from wav2vec2
        audio_states = torch.cat(
            (audio_encoder_output.last_hidden_state,)
            + audio_encoder_output.hidden_states,
            dim=-1,
        )

        audio_states = self.audio_adapter(audio_states)
        return audio_states

    @torch.no_grad()
    def mask2gray(self, video: Tensor, mask: Tensor) -> Tensor:
        """
        Replace the pixel value of mask region with gray color(127.5), since the video already processed to range(-1, 1), the gray color is 0
        Args:
            video (Tensor):
                Video tensor of shape [B, C, T, H, W]. In range [-1, 1].
            mask (Tensor):
                Binary mask tensor of shape [B, 1, T, H, W]. 0 for unmasked region, 1 for masked region.
        Returns:
            Tensor:
                Video tensor of shape [B, C, T, H, W]. In range [-1, 1].
        """
        video = video * (1 - mask)
        return video

    @torch.no_grad()
    def prepare_vace_condition(
        self, vace_video: Tensor, mask: Tensor, ref_latents: Tensor
    ) -> Tensor:
        num_ref_img = ref_latents.shape[2]
        vace_video = self.mask2gray(vace_video, mask)
        # part 1: condition region latents and mask region latents
        condition_region = vace_video * (
            1 - mask
        )  # Refer to inactive defined in wan-vace
        mask_region = vace_video * mask  # Refer to reactive defined in wan-vace
        condition_region_latents = self.encode_video(condition_region)
        mask_region_latents = self.encode_video(mask_region)
        # b 2c t h w
        vace_video_latents = torch.concat(
            (condition_region_latents, mask_region_latents), dim=1
        )

        # concat ref latents with vace video latents.
        # [b 2c, 1, h, w] [b 2c, t, h, w] -> [b 2c, t+1, h, w]
        ref_latents = torch.concat((ref_latents, torch.zeros_like(ref_latents)), dim=1)
        vace_video_latents = torch.concat((ref_latents, vace_video_latents), dim=2)

        # part 2: mask latents
        # Space downsample to vae latents size
        mask = rearrange(
            mask,
            "b 1 t (h p) (w q) -> b (p q) t h w",
            p=self.vae.config.scale_factor_spatial,
            q=self.vae.config.scale_factor_spatial,
        )
        # Temporal downsample to vae latents size. [b 64, t, h, w].
        vace_mask_latents = torch.nn.functional.interpolate(
            mask.float(),
            size=(
                (mask.shape[2] + self.vae.config.scale_factor_temporal - 1)
                // self.vae.config.scale_factor_temporal,
                mask.shape[3],
                mask.shape[4],
            ),
            mode="nearest-exact",
        )
        # Add reference frame mask. [b, 64, n, h, w] [b 64, t, h, w] -> [b, 64, t+n, h, w]
        vace_mask_latents = torch.concat(
            (
                torch.zeros(
                    [
                        vace_mask_latents.shape[0],
                        vace_mask_latents.shape[1],
                        num_ref_img,
                        vace_mask_latents.shape[3],
                        vace_mask_latents.shape[4],
                    ],
                    device=vace_mask_latents.device,
                    dtype=vace_mask_latents.dtype,
                ),
                vace_mask_latents,
            ),
            dim=2,
        )

        # finally concat video latents and mask latents. [b, 32+64, t+1, h, w]
        vace_condition_latents = torch.concat(
            (vace_video_latents, vace_mask_latents), dim=1
        )
        return vace_condition_latents

    @torch.no_grad()
    def encode_visual(self, video, kpt=None, batch_size=32):
        B = video.shape[0]
        video = rearrange(video, "b c t h w -> (b t) c h w")
        if kpt is not None:
            kpt = rearrange(kpt, "b t k c -> (b t) k c")
        all_vocal_feature = []
        for i in range(0, video.shape[0], batch_size):
            video_batch = video[i : i + batch_size]
            if kpt is not None:
                kpt_batch = kpt[i : i + batch_size]

                keepface_mask_prob = keepface_mask_prob_from_kpts(
                    kpt_batch[:, _fallback_face_indices_ex_ear_308()],
                    img_hw=video_batch.shape[-2:],
                )
            else:
                heatmap_batch = self.get_sapiens_heatmap(video_batch)
                keepface_mask_prob = keepface_mask_prob_from_heatmap(
                    heatmap_batch[:, _fallback_face_indices_ex_ear_308()]
                )

            encoder_output = self.slipmae_encoder(
                video_batch, do_mask=True, mask_prob=keepface_mask_prob
            )
            # (b t) c
            vocal_motion = encoder_output["prompt_tokens"][:, -1]
            all_vocal_feature.append(vocal_motion)
        # (b t) c
        all_vocal_feature = torch.cat(all_vocal_feature, dim=0)
        all_vocal_feature = rearrange(all_vocal_feature, "(b t) c -> b t c", b=B)
        return all_vocal_feature

    @torch.no_grad()
    def get_sapiens_heatmap(self, imgs: Tensor, max_batch_size: int = 32) -> Tensor:
        """
        Get sapiens heatmap from imgs.
        Args:
            imgs (Tensor):
                A batch of imgs: [B, C, H, W]
        Returns:
            heatmap (Tensor):
                A batch of heatmap: [B, N, H, W]
        """
        # [B, C, H, W]
        if imgs.shape[0] > max_batch_size:
            heatmap = []
            for i in range(0, imgs.shape[0], max_batch_size):
                heatmap.append(self.get_sapiens_heatmap(imgs[i : i + max_batch_size]))
            heatmap = torch.cat(heatmap, dim=0)
            return heatmap
        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        heatmap, _ = self.heatmap_head(feats, decode_kpt=False)
        return heatmap
