import math
import os
import shutil
from typing import List, Optional, Union
from diffusers.pipelines import DiffusionPipeline
from diffusers.pipelines.wan.pipeline_wan_vace import WanVACEPipeline
from einops import rearrange
import librosa
import numpy as np
import torch
from decord import VideoReader
from tqdm import tqdm
import random
from mmengine import Config
from mmengine.device import get_device
from mmengine.runner import load_checkpoint
from diffusers import (
    FlowMatchEulerDiscreteScheduler,
    AutoencoderKLWan,
    UniPCMultistepScheduler,
)
from transformers import (
    BitImageProcessor,
    T5Tokenizer,
    UMT5EncoderModel,
    Wav2Vec2Processor,
)
from mmhug.datasets.transforms.keypoint2mask import sapiens2mask
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead

from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)
from mmhug.models.custom_transformers.slipmae.transformer_slipmae_encoder import (
    SlipmaeEncoder,
)
from mmhug.models.custom_transformers.slipmae_wanwace.transformer3d_wanvace_audiopack import (
    AudiopackWanVACETransformer3DModel,
)
from torchvision.io import write_video
from mmhug.models.custom_transformers.unitalker.audio_adapter import AudioAdapter
from mmhug.models.custom_transformers.wav2vec2_interp.wav2vec2_interp import (
    Wav2Vec2InterpModel,
)
from mmhug.registry import MODELS

from diffusers.utils import BaseOutput
from mmhug.trainers.trainer_wanvace_slipmae.trainer_wanvace_slipmae_v5 import (
    WanVaceSlipmaeTrainerV5,
)

from mmhug.utils.io import merge_video_audio
import torch.nn.functional as F

from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_heatmap,
    keepface_mask_prob_from_kpts,
)

from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308


def video_tensor2numpy(
    video_tensor: torch.Tensor, keep_batch: bool = False  # B, C, T, H, W
):
    video_tensor = video_tensor.permute(0, 2, 3, 4, 1).detach().float().cpu().numpy()
    if not keep_batch:
        video_tensor = video_tensor.squeeze(0)
    return video_tensor


class WanvaceSlipmaeOutput(BaseOutput):
    pred_video: np.ndarray
    masks: Optional[np.ndarray] = None
    ref_video: np.ndarray
    input_video: np.ndarray
    masked_input_video: Optional[np.ndarray] = None


class WanvaceSlipmaePipelineV5(WanVACEPipeline):

    def __init__(
        self,
        vae: AutoencoderKLWan,
        transformer: AudiopackWanVACETransformer3DModel,
        # text
        tokenizer: T5Tokenizer,
        text_encoder: UMT5EncoderModel,
        # audio
        audio_processor: Wav2Vec2Processor,
        audio_encoder: Wav2Vec2InterpModel,
        audio_adapter: AudioAdapter,
        slipmae_encoder: SlipmaeEncoder,
        # video processor
        video_processor: BitImageProcessor,
        scheduler: FlowMatchEulerDiscreteScheduler,
        # face keypoint detection
        raw_sapiens: SapiensVisionTransformer,
        heatmap_head: HeatmapHead,
        dtype=torch.bfloat16,
    ):
        DiffusionPipeline.__init__(self)
        self.torch_device = get_device()
        self.torch_dtype = dtype
        self.register_modules(
            vae=vae.to(device=self.torch_device, dtype=self.torch_dtype),
            transformer=transformer.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            tokenizer=tokenizer,
            text_encoder=text_encoder.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            audio_processor=audio_processor,
            audio_encoder=audio_encoder.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            audio_adapter=audio_adapter.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            raw_sapiens=raw_sapiens.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            heatmap_head=heatmap_head.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            slipmae_encoder=slipmae_encoder.to(
                device=self.torch_device, dtype=self.torch_dtype
            ),
            scheduler=scheduler,
            video_processor=video_processor,
        )

        self.vae_scale_factor_temporal = (
            2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
        )
        self.vae_scale_factor_spatial = (
            2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
        )

    def load_audio(self, audio: Union[str, np.ndarray, torch.Tensor]):
        if isinstance(audio, str):
            audio = librosa.load(audio, sr=16000, mono=True)[0]

        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio)
        assert isinstance(audio, torch.Tensor)
        return audio

    def load_video(self, video: Union[str, np.ndarray]):
        if isinstance(video, str):
            vr = VideoReader(video)
            video = vr.get_batch(range(len(vr))).asnumpy()
        assert isinstance(video, np.ndarray)
        video = self.video_processor(video).pixel_values
        video = torch.tensor(video)
        return video

    def loop_video(self, video, keypoints, masks, tgt_length):
        num_loops = math.ceil(tgt_length / len(video))
        loop_video_frames = []
        loop_keypoints = []
        loop_masks = []

        for i in range(num_loops):
            if i % 2 == 0:
                loop_video_frames.append(video)
                loop_keypoints.append(keypoints)
                loop_masks.append(masks)
            else:
                loop_video_frames.append(torch.flip(video, dims=(0,)))
                loop_keypoints.append(torch.flip(keypoints, dims=(0,)))
                loop_masks.append(torch.flip(masks, dims=(0,)))
        video_frames = torch.cat(loop_video_frames, dim=0)[:tgt_length]
        keypoints = torch.cat(loop_keypoints, dim=0)[:tgt_length]
        masks = torch.cat(loop_masks, dim=0)[:tgt_length]

        return video_frames, keypoints, masks

    @torch.no_grad()
    def __call__(
        self,
        prompt: str,
        video: Union[str, np.ndarray],
        audio: Union[str, np.ndarray, torch.Tensor] = None,
        driven_video: Union[str, np.ndarray] = None,
        num_frames_per_chunk: int = 17,
        num_inference_steps: int = 50,
        guidance_scale: float = 1.0,
        negative_prompt: str = "",
        num_videos_per_prompt: int = 1,
        max_sequence_length: int = 512,
        sr=16000,
        fps=25,
        num_ref_img=1,
    ):
        self._guidance_scale = guidance_scale

        video = self.load_video(video).to(
            device=self.torch_device, dtype=self.torch_dtype
        )

        assert (
            audio is None or driven_video is None
        ), "Only one of audio and driven_video can be provided"

        if audio is not None:
            # audio: [1, N_sample]
            audio, padded_length, tgt_length = self.prepare_audio(
                audio, sr=sr, video_fps=fps, chunk_length=num_frames_per_chunk
            )
            vocal_feats = self.encode_audio(
                audio, padded_length, 16000, chunk_length=num_frames_per_chunk
            )
        elif driven_video is not None:
            # t c h w
            driven_video, padded_length, tgt_length = self.prepare_driven_video(
                driven_video, chunk_length=num_frames_per_chunk
            )
            driven_video = driven_video.to(self.torch_device, self.torch_dtype)
            driven_keypoint, _ = self.sapiens_face_det(driven_video)
            driven_keypoint = driven_keypoint.unsqueeze(0).to(
                self.torch_device, self.torch_dtype
            )
            driven_video = rearrange(driven_video, "t c h w -> 1 c t h w")
            # t c h w
            vocal_feats = self.encode_visual(driven_video, driven_keypoint)
        else:
            raise ValueError("Either audio or driven_video must be provided")

        # make sure tgt_length is 4n + 1
        if len(video) > padded_length:
            video = video[:padded_length]

        keypoints, masks = self.sapiens_face_det(video)

        # loop video to padded_length
        video, keypoints, masks = self.loop_video(
            video, keypoints, masks, padded_length
        )

        print(f"Padded video length: {padded_length}")
        print(f"Final generated video length: {tgt_length}")

        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt=prompt,
            negative_prompt=negative_prompt,
            do_classifier_free_guidance=self.do_classifier_free_guidance,
            num_videos_per_prompt=num_videos_per_prompt,
            max_sequence_length=max_sequence_length,
            device=self.torch_device,
            dtype=self.torch_dtype,
        )

        ref_idx = torch.randint(0, tgt_length, (num_ref_img,))
        ref_img = video[ref_idx, :, :, :]
        ref_img = rearrange(ref_img, "t c h w -> t c 1 h w").to(
            self.torch_device, self.torch_dtype
        )

        # 1 c 1 h w. The first clip has no last frame, so we use a zero tensor.
        # last_frame = torch.zeros_like(ref_img[:1])
        last_frame = ref_img

        num_inferences = math.ceil(padded_length / num_frames_per_chunk)
        pred_video = []
        for i in tqdm(range(num_inferences), desc="Doing inference..."):
            ref_latents = self.encode_video(ref_img)
            ref_latents = rearrange(ref_latents, "t c 1 h w -> 1 c t h w")
            ref_idx = torch.randint(0, tgt_length, (num_ref_img,))
            ref_img = video[ref_idx]
            ref_img = rearrange(ref_img, "t c h w -> t c 1 h w").to(
                self.torch_device, self.torch_dtype
            )
            # 1 c 1 h w
            last_frame_latents = self.encode_video(last_frame)
            # 1 c n+1 h w
            cond_latents = torch.cat([ref_latents, last_frame_latents], dim=2)

            vocal_feat = vocal_feats[
                :, i * num_frames_per_chunk : (i + 1) * num_frames_per_chunk
            ]
            # b t h w
            mask = (
                masks[i * num_frames_per_chunk : (i + 1) * num_frames_per_chunk]
                .unsqueeze(0)
                .to(self.torch_device)
            )

            v = video[i * num_frames_per_chunk : (i + 1) * num_frames_per_chunk]

            v = rearrange(v, "t c h w -> 1 c t h w").to(
                self.torch_device, self.torch_dtype
            )
            _, C, T, H, W = v.shape

            vace_condition_latents = self.prepare_vace_condition(
                v, mask, cond_latents
            ).to(self.torch_device, self.torch_dtype)

            # (T - 1) // 4 + 1 + 2
            noisy_latents = self.prepare_latents(  # T+4 will cause latent time dimension +1, refers to the ref img
                1,
                num_channels_latents=16,
                height=H,
                width=W,
                num_frames=T + self.vae_scale_factor_temporal * (num_ref_img + 1),
            ).to(
                self.torch_device, self.torch_dtype
            )
            latents = noisy_latents

            # --------------------------- Prepare timesteps -----------------------------------
            self.scheduler.set_timesteps(num_inference_steps, device=self.torch_device)
            timesteps = self.scheduler.timesteps

            for ts in tqdm(timesteps, desc="Denoising ... "):
                timestep = ts.expand(latents.shape[0])

                latent_model_input = latents

                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    audio_hidden_states=vocal_feat,
                    control_hidden_states=vace_condition_latents,
                ).sample

                if self.do_classifier_free_guidance:
                    noise_uncond = self.transformer(
                        hidden_states=latent_model_input,
                        timestep=timestep,
                        encoder_hidden_states=negative_prompt_embeds,
                        audio_hidden_states=vocal_feat,
                        control_hidden_states=vace_condition_latents,
                    ).sample
                    noise_pred = noise_uncond + guidance_scale * (
                        noise_pred - noise_uncond
                    )
                latents = self.scheduler.step(
                    noise_pred, ts, latents, return_dict=False
                )[0]
            # 1 c t h w -> t c h w
            pred_video_chunk = self.decode_video(latents[:, :, num_ref_img + 1 :])
            last_frame = pred_video_chunk[:, :, -1:]
            pred_video_chunk = video_tensor2numpy(pred_video_chunk)
            # t h w c
            pred_video.append(pred_video_chunk)
        # t c h w in range [-1, 1]
        pred_video = np.concatenate(pred_video, axis=0)[:tgt_length]
        pred_video = (pred_video * 0.5 + 0.5) * 255
        pred_video = pred_video.astype(np.uint8)

        masks = rearrange(masks[:tgt_length], "t h w -> t 1 h w").float().cpu().numpy()

        input_video = video[:tgt_length].float().cpu().numpy()
        masked_input_video = (1 - masks) * input_video
        input_video = (rearrange(input_video, "t c h w -> t h w c") * 0.5 + 0.5) * 255
        masked_input_video = (
            rearrange(masked_input_video, "t c h w -> t h w c") * 0.5 + 0.5
        ) * 255

        ref_video = rearrange(ref_img, "n c 1 h w -> 1 c n h w")
        # n h w c
        ref_video = video_tensor2numpy(ref_video)

        return WanvaceSlipmaeOutput(
            pred_video=pred_video,
            masks=masks,
            ref_video=ref_video,
            input_video=input_video,
            masked_input_video=masked_input_video,
        )

    @torch.no_grad()
    def encode_video(self, video: torch.Tensor) -> torch.Tensor:
        z = self.vae.encode(video).latent_dist  # shape: (B, z_dim, H, W[, ...])
        z = z.sample()
        # 3. 构造与 decode 中相同的均值和“倒数标准差”（latents_std）的张量
        latents_mean = (
            torch.tensor(self.vae.config.latents_mean)
            .view(1, self.vae.config.z_dim, *([1] * (z.ndim - 2)))
            .to(z.device, z.dtype)
        )
        inv_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
            1, self.vae.config.z_dim, *([1] * (z.ndim - 2))
        ).to(z.device, z.dtype)

        latents = (z - latents_mean) * inv_std

        return latents

    def prepare_audio(
        self,
        audio: torch.Tensor,
        sr: int = 16000,
        video_fps: int = 25,
        chunk_length: int = 17,
    ) -> torch.Tensor:
        # ------------------------------- Determine the video length and padding audio ------------------------------------#
        # One frame for the given image, each chunk generate new (chunk length - chunk_overlap) frames
        # The audio should be padded to the nearest chunk length
        audio = self.load_audio(audio).unsqueeze(0)
        audio_num_frames = int(audio.shape[1] / sr * video_fps)

        if audio_num_frames < chunk_length:
            padded_length = chunk_length
        elif audio_num_frames % chunk_length != 0:

            padded_length = (
                audio_num_frames
                + chunk_length
                - (audio_num_frames - chunk_length) % chunk_length
            )
        else:
            padded_length = audio_num_frames

        audio = F.pad(
            audio,
            (0, math.ceil(padded_length / video_fps * sr) - audio.shape[1]),
            mode="constant",
            value=0,
        )
        assert (
            math.ceil(audio.shape[1] / sr * video_fps) == padded_length
        ), f"audio shape: {audio.shape}, padded length: {padded_length}, padded audio num frames: {audio.shape[1] / sr * video_fps}"

        # padded audio, padded length, and the finally generated video length should be cropped to audio_num_frames
        return audio, padded_length, audio_num_frames

    @torch.no_grad()
    def encode_audio(
        self,
        audio: torch.Tensor,
        video_length: int,
        sr: int = 16000,
        chunk_length=17,
        video_fps=25,
    ) -> torch.Tensor:
        """
        Encode raw audio waveforms into hidden representations using an audio encoder,
        optionally aligned with the expected video length.

        This method performs two steps:
        1. Use `self.audio_processor` (e.g., Wav2Vec2Processor or WhisperProcessor)
            to convert waveform into model-ready input tensors.
        2. Pass the processed audio into `self.audio_encoder` to obtain hidden states.

        Note:
            The method supports alignment of the audio features with a target video length
            in frames via the `seq_len` argument, useful for cross-modal generation.

        Args:
            audio (Tensor):
                A batch of raw audio waveforms of shape [B, T_audio], where
                B = batch size, T_audio = number of audio samples per example.
            video_length (int):
                The length of the corresponding video in frames (used for alignment).
            sr (int):
                Sampling rate of the audio (e.g., 16000 Hz).

        Returns:
            Tensor:
                A tensor of shape [B, L, D], where L is the number of audio tokens
                (aligned to `video_length` if supported by the encoder),
                and D is the audio feature dimension.
        """
        # Step 1: Tokenize raw waveform using audio processor (e.g., Whisper/Wav2Vec2)
        # Output: audio input tensor of shape [B, N] where N ≈ duration_in_seconds × sr
        B, N = audio.shape

        audio_inputs = self.audio_processor(
            audio.float(), sampling_rate=sr, return_tensors="pt"
        ).input_values.to(
            device=self.torch_device,
            dtype=self.torch_dtype,
        )
        # assert no other dimensions were added
        audio_inputs = audio_inputs.view(B, N)
        # Step 2: Encode audio to get hidden representations
        # `seq_len` optionally aligns output length with video frame count
        audio_encoder_output = []
        for chunk in tqdm(range(0, video_length, chunk_length), desc="Encoding audio"):
            inp = audio_inputs[
                :,
                int(chunk / video_fps * sr) : int(
                    (chunk + chunk_length) / video_fps * sr
                ),
            ]
            if inp.shape[1] == 0:
                break
            output = self.audio_encoder(
                inp,
                seq_len=int(inp.shape[1] / sr * video_fps),
                output_hidden_states=True,
            )
            states = torch.cat(
                (output.last_hidden_state,) + output.hidden_states,
                dim=-1,
            )
            audio_encoder_output.append(states)

        audio_encoder_output = torch.cat(audio_encoder_output, dim=1)

        audio_states = self.audio_adapter(audio_encoder_output)
        return audio_states

    def prepare_driven_video(
        self, video: torch.Tensor, chunk_length: int = 25
    ) -> torch.Tensor:
        video = self.load_video(video)
        driven_video_num_frames = video.shape[0]

        if driven_video_num_frames < chunk_length:
            padded_length = chunk_length
        elif driven_video_num_frames % chunk_length != 0:
            padded_length = (
                driven_video_num_frames
                + chunk_length
                - (driven_video_num_frames - chunk_length) % chunk_length
            )
        else:
            padded_length = driven_video_num_frames

        if padded_length > driven_video_num_frames:
            video = video[
                (video.shape[0] - 1)
                - torch.abs(
                    (
                        torch.arange(padded_length, device=video.device)
                        % (2 * (video.shape[0] - 1))
                    )
                    - (video.shape[0] - 1)
                )
            ]
        return video, padded_length, driven_video_num_frames

    @torch.no_grad()
    def decode_video(self, latents: torch.Tensor) -> torch.Tensor:
        latents = latents.to(self.vae.dtype)
        latents_mean = (
            torch.tensor(self.vae.config.latents_mean)
            .view(1, self.vae.config.z_dim, 1, 1, 1)
            .to(latents.device, latents.dtype)
        )
        latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
            1, self.vae.config.z_dim, 1, 1, 1
        ).to(latents.device, latents.dtype)
        latents = latents / latents_std + latents_mean
        video = self.vae.decode(latents).sample
        return video

    @torch.no_grad()
    def sapiens_face_det(self, video: torch.Tensor, batch_size=8) -> torch.Tensor:
        H, W = video.shape[-2:]
        all_keypoints = []
        all_keypoints_scores = []
        for i in range(0, len(video), batch_size):
            video_batch = video[i : i + batch_size]
            heatmap, keypoints, keypoints_scores = self.get_sapiens_heatmap(video_batch)
            all_keypoints.append(torch.from_numpy(keypoints))
            all_keypoints_scores.append(torch.from_numpy(keypoints_scores))
        all_keypoints = torch.cat(all_keypoints, dim=0)
        all_keypoints_scores = torch.cat(all_keypoints_scores, dim=0)
        keypoints = torch.cat(
            [all_keypoints, all_keypoints_scores.unsqueeze(-1)], dim=-1
        )
        lower_face_mask = sapiens2mask(
            H, W, keypoints, mask_area="lower_face", mask_expand=(0, 0, 0, 20)
        )
        return keypoints, lower_face_mask

    @torch.no_grad()
    def get_sapiens_heatmap(
        self,
        imgs: torch.Tensor,
    ) -> torch.Tensor:
        """
        输入: imgs [B, C, H, W] (torch)
        输出: heatmap [B, K, h, w] (torch, on CPU, float32)
        内部会按 max_batch_size 进一步切批，避免显存峰值。
        """

        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        # 256 192
        heatmap, pred = self.heatmap_head(feats, decode_kpt=True)

        keypoints = np.concatenate([p.keypoints for p in pred], axis=0)
        keypoints_scores = np.concatenate([p.keypoint_scores for p in pred], axis=0)

        heatmap = heatmap.float().detach().cpu().numpy()
        keypoints = keypoints
        keypoints_scores = keypoints_scores
        return heatmap, keypoints, keypoints_scores

    def prepare_vace_condition(
        self, vace_video: torch.Tensor, mask: torch.Tensor, ref_latents: torch.Tensor
    ) -> torch.Tensor:
        """
        vace_video: [b, c, t, h, w]
        mask: [b, t, h, w]
        ref_latents: [b, c, 2, h, w]
        """
        num_ref_img = ref_latents.shape[2]
        mask = mask.unsqueeze(1)
        vace_video = self.mask2gray(vace_video, mask)

        # part 1: condition region latents and mask region latents
        condition_region = vace_video * (
            1 - mask
        )  # Refer to inactive defined in wan-vace
        # save video
        import torchvision

        # torchvision.io.write_video(
        #     "condition_region.mp4",
        #     (rearrange(condition_region, "1 c t h w -> t h w c") + 1) * 127.5,
        #     fps=30,
        # )
        mask_region = vace_video * mask  # Refer to reactive defined in wan-vace
        # torchvision.io.write_video(
        #     "mask_region.mp4",
        #     (rearrange(mask_region, "1 c t h w -> t h w c") + 1) * 127.5,
        #     fps=30,
        # )
        condition_region_latents = self.encode_video(condition_region)
        mask_region_latents = self.encode_video(mask_region)
        # b 2c t h w
        vace_video_latents = torch.concat(
            (condition_region_latents, mask_region_latents), dim=1
        )

        # concat ref latents with vace video latents.
        # [b 2c, 1, h, w] [b 2c, t, h, w] -> [b 2c, t+1, h, w]
        ref_latents = torch.concat((ref_latents, torch.zeros_like(ref_latents)), dim=1)
        vace_video_latents = torch.concat((ref_latents, vace_video_latents), dim=2)

        # part 2: mask latents
        # Space downsample to vae latents size
        mask = rearrange(mask, "b 1 t (h p) (w q) -> b (p q) t h w", p=8, q=8)
        # Temporal downsample to vae latents size. [b 64, t, h, w].
        vace_mask_latents = torch.nn.functional.interpolate(
            mask.float(),
            size=(
                (mask.shape[2] + 3) // 4,
                mask.shape[3],
                mask.shape[4],
            ),
            mode="nearest-exact",
        )

        # Add reference frame mask. [b, 64, n, h, w] [b 64, t, h, w] -> [b, 64, t+n, h, w]
        vace_mask_latents = torch.concat(
            (
                torch.zeros(
                    [
                        vace_mask_latents.shape[0],
                        vace_mask_latents.shape[1],
                        num_ref_img,
                        vace_mask_latents.shape[3],
                        vace_mask_latents.shape[4],
                    ],
                    device=vace_mask_latents.device,
                    dtype=vace_mask_latents.dtype,
                ),
                vace_mask_latents,
            ),
            dim=2,
        )

        # finally concat video latents and mask latents. [b, 32+64, t+1, h, w]
        vace_condition_latents = torch.concat(
            (vace_video_latents, vace_mask_latents), dim=1
        )
        return vace_condition_latents

    def mask2gray(self, video: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Replace the pixel value of mask region with gray color(127.5), since the video already processed to range(-1, 1), the gray color is 0
        Args:
            video (Tensor):
                Video tensor of shape [B, C, T, H, W]. In range [-1, 1].
            mask (Tensor):
                Binary mask tensor of shape [B, 1, T, H, W]. 0 for unmasked region, 1 for masked region.
        Returns:
            Tensor:
                Video tensor of shape [B, C, T, H, W]. In range [-1, 1].
        """
        video = video * (1 - mask.to(video.device))
        return video

    @torch.no_grad()
    def encode_visual(self, video, kpt=None, batch_size=32):
        B = video.shape[0]
        video = rearrange(video, "b c t h w -> (b t) c h w")
        if kpt is not None:
            kpt = rearrange(kpt, "b t k c -> (b t) k c")
        all_vocal_feature = []
        for i in range(0, video.shape[0], batch_size):
            video_batch = video[i : i + batch_size]
            if kpt is not None:
                kpt_batch = kpt[i : i + batch_size]

                keepface_mask_prob = keepface_mask_prob_from_kpts(
                    kpt_batch[:, _fallback_face_indices_ex_ear_308()],
                    img_hw=video_batch.shape[-2:],
                )
            else:
                heatmap_batch = self.get_sapiens_heatmap(video_batch)
                keepface_mask_prob = keepface_mask_prob_from_heatmap(
                    heatmap_batch[:, _fallback_face_indices_ex_ear_308()]
                )

            encoder_output = self.slipmae_encoder(
                video_batch, do_mask=True, mask_prob=keepface_mask_prob
            )
            # (b t) c
            vocal_motion = encoder_output["prompt_tokens"][:, -1]
            all_vocal_feature.append(vocal_motion)
        # (b t) c
        all_vocal_feature = torch.cat(all_vocal_feature, dim=0)
        all_vocal_feature = rearrange(all_vocal_feature, "(b t) c -> b t c", b=B)
        return all_vocal_feature


def infer(
    video_path="data/vfhq/video_resampled/Clip+_BH0a6hi0nI+P0+C1+F8964-9100.mp4",
    # audio_path="data/RAVDNESS/video_resampled/01-01-01-01-01-02-03.mp4",
    audio_path=None,
    driven_video_path="data/vfhq/video_resampled/Clip+_BH0a6hi0nI+P0+C0+F10612-10735.mp4",
    # driven_video_path="data/vfhq/video_resampled/Clip+zXbxb4m_k9U+P0+C0+F3249-3415.mp4",
    # audio_path="data/vfhq/video_resampled/Clip+zXbxb4m_k9U+P0+C0+F3249-3415.mp4",
    prompt="Adjust the speaker’s mouth shapes based on the input audio.",
    cfg="configs/train/wanvace_slipmae/wanvacev5_slipmae_1.3b_lowerface.py",
    checkpoint="work_dirs/wanvacev5_slipmae_1.3b_lowerface/iter_186500.pth",
    output_path: str = "output/wanvacev5/",
    guidance_scale: float = 1.0,
    num_ref_img=1,
    num_inference_steps=50,
):

    # ---------- Build output Path -----------
    output_path = os.path.join(
        output_path,
        os.path.basename(cfg).rstrip(".py"),
        os.path.basename(checkpoint).rstrip(".pth"),
    )
    if audio_path is not None:
        assert audio_path.endswith(".wav") or audio_path.endswith(".mp4")
        output_path = os.path.join(
            output_path,
            os.path.basename(video_path).rstrip(".mp4"),
            (
                os.path.basename(audio_path).rstrip(".wav")
                if audio_path.endswith(".wav")
                else os.path.basename(audio_path).rstrip(".mp4")
            ),
        )
    else:
        assert driven_video_path is not None
        output_path = os.path.join(
            output_path,
            os.path.basename(video_path).rstrip(".mp4"),
            (
                os.path.basename(driven_video_path).rstrip(".mp4")
                if driven_video_path.endswith(".mp4")
                else os.path.basename(driven_video_path).rstrip(".avi")
            ),
        )
    os.makedirs(output_path, exist_ok=True)

    # build model, load checkpoint
    cfg = Config.fromfile(cfg)
    model: WanVaceSlipmaeTrainerV5 = MODELS.build(cfg.model)
    load_checkpoint(model, checkpoint, map_location="cpu")

    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    # build Sapiens for keypoint detection
    # sapiens
    sapiens = SapiensVisionTransformer(
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(1024, 768),
        patch_size=16,
        in_channels=3,
        out_indices=-1,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        final_norm=True,
        with_cls_token=False,
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        patch_cfg=dict(padding=2),
        layer_cfgs=dict(),
        pre_norm=False,
        out_type="featmap",
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    )
    sapiens.init_weights()

    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(512, 512), heatmap_size=(192, 256), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    )
    heatmap_head.init_weights()

    scheduler = UniPCMultistepScheduler.from_pretrained(
        "checkpoints/Wan2.1-VACE-1.3B-diffusers", subfolder="scheduler"
    )

    pipeline = WanvaceSlipmaePipelineV5(
        vae=model.vae,
        transformer=model.transformer,
        tokenizer=model.tokenizer,
        text_encoder=model.text_encoder,
        raw_sapiens=sapiens,
        heatmap_head=heatmap_head,
        video_processor=input_video_processor,
        audio_processor=model.audio_processor,
        audio_encoder=model.audio_encoder,
        audio_adapter=model.audio_adapter,
        slipmae_encoder=model.slipmae_encoder,
        scheduler=scheduler,
    )

    output = pipeline(
        prompt,
        video_path,
        audio_path,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        driven_video=driven_video_path,
        num_ref_img=num_ref_img,
    )
    pred_video = output.pred_video
    # ref_video = output.ref_video
    input_video = output.input_video
    masked_input_video = output.masked_input_video

    video = np.concatenate((input_video, masked_input_video, pred_video), axis=-2)
    write_video(os.path.join(output_path, "output.mp4"), video, fps=25)
    if audio_path is not None:
        shutil.copy(audio_path, os.path.join(output_path, "audio.wav"))
    else:
        shutil.copy(driven_video_path, os.path.join(output_path, "audio.wav"))
    merge_video_audio(
        os.path.join(output_path, "output.mp4"),
        os.path.join(output_path, "audio.wav"),
        os.path.join(output_path, "output_audio.mp4"),
        fps=25,
        sr=16000,
    )


if __name__ == "__main__":
    import fire

    fire.Fire(infer)
