import os
from typing import Dict, Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from transformers import Wav2Vec2Processor
from einops import rearrange
from functools import partial
from torch import nn
import logging

from mmhug.models.custom_transformers.slipmae.transformer_slipmae_encoder import (
    SlipmaeEncoder,
)
from mmhug.models.custom_transformers.wav2vec2_interp.wav2vec2_interp import (
    Wav2Vec2InterpModel,
)
from mmhug.trainers.base_trainer_model import BaseTrainerModel

from mmhug.losses.utils.diff_aug import DiffAugment

# from mmhug.losses.utils.diff_aug import DiffAugment
from mmhug.losses.contrastive_loss import SyncInfoNCE
from mmhug.registry import HF_MODELS, MODELS
from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_heatmap,
    keepface_mask_prob_from_kpts,
)
from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308
from mmhug.utils.dtype_utils import dtype_from_str
from mmhug.models.custom_transformers.sapiens import (
    SapiensVisionTransformer,
    HeatmapHead,
)
from mmhug.models.custom_transformers.wan22_audio import (
    MAEDecoderV2,
    SapiensMotionExtractorV2,
)
from mmengine import print_log
from mmhug.structures import DataSample
from diffusers.utils import BaseOutput


class SlipmaeOuput(BaseOutput):
    ids_restore: Optional[Tensor] = None
    audio_embedding: Optional[Tensor] = None
    vocal_motion: Optional[Tensor] = None
    nonvocal_motion: Optional[Tensor] = None
    id_embedding: Optional[Tensor] = None
    idx_identity_shuffle: Optional[Tensor] = None
    mask: Optional[Tensor] = None
    keepface_mask: Optional[Tensor] = None
    rand_idx: Optional[Tensor] = None
    heatmap: Optional[Tensor] = None
    keypoint: Optional[Tensor] = None
    pixel_vocal: Optional[Tensor] = None
    pixel_audio: Optional[Tensor] = None
    patch_tokens: Optional[Tensor] = None


@MODELS.register_module(force=True)
class TrainerSlipmaePretrain(BaseTrainerModel):
    audio_encoder: Wav2Vec2InterpModel
    audio_processor: Wav2Vec2Processor
    mae_encoder: SlipmaeEncoder
    raw_sapiens: SapiensVisionTransformer
    heatmap_head: HeatmapHead
    mae_decoder: MAEDecoderV2
    contrastive_loss: SyncInfoNCE

    def train(self, mode: bool = True):
        super().train(mode)
        if mode:
            self.audio_encoder.eval()
            self.audio_encoder.requires_grad_(False)

            if self.realtime_pose_estimate:
                self.raw_sapiens.eval()
                self.raw_sapiens.requires_grad_(False)

                self.heatmap_head.eval()
                self.heatmap_head.requires_grad_(False)

            if self.audio_filter_model is not None:
                self.audio_filter_model.eval()
                self.audio_filter_model.requires_grad_(False)

    def __init__(
        self,
        audio_encoder: Wav2Vec2InterpModel,
        audio_processor: Wav2Vec2Processor,
        mae_encoder: SapiensMotionExtractorV2,
        mae_decoder: MAEDecoderV2,
        raw_sapiens: SapiensVisionTransformer = None,
        heatmap_head: HeatmapHead = None,
        audio_filter_model: Optional[Wav2Vec2InterpModel] = None,
        audio_adapter=dict(in_feature=768, out_feature=1024),
        audio_layer: str = "last",
        loss_cfg=dict(
            temporal_neighbors=1,  # treat temporal neighbor frames as positive samples for contrastive loss
            pixel_loss_weight=1,
            cl_loss_weight=1,
            ortho_loss_weight=0.1,
            ortho_loss_mode="cov",
        ),
        shuffle_id: bool = True,
        motion_augment: Optional[str] = "color",
        train_minibatch: Optional[
            int
        ] = None,  # when training, randomly choose train_minibatch frames from each video. For audio, we firstly encode the entire audio and choose corresponding frames from audio features.
        init_cfg=None,
    ):
        super().__init__(init_cfg=init_cfg)
        self.audio_encoder = HF_MODELS.build(audio_encoder)
        self.audio_encoder.eval()
        self.audio_encoder.requires_grad_(False)

        self.audio_adapter = nn.Sequential(
            nn.Linear(audio_adapter["in_feature"], audio_adapter["out_feature"]),
            nn.LayerNorm(audio_adapter["out_feature"]),
        )

        self.audio_layer = audio_layer
        self.audio_processor = HF_MODELS.build(audio_processor)

        self.mae_encoder = HF_MODELS.build(mae_encoder)
        self.mae_encoder.init_weights()

        self.mae_decoder = HF_MODELS.build(mae_decoder)
        self.mae_decoder.init_weights()

        self.realtime_pose_estimate = False
        if raw_sapiens is not None:
            self.realtime_pose_estimate = True
            assert heatmap_head is not None
            self.raw_sapiens = HF_MODELS.build(raw_sapiens)
            self.raw_sapiens.init_weights()
            self.raw_sapiens.eval()
            self.raw_sapiens.requires_grad_(False)

            self.heatmap_head = HF_MODELS.build(heatmap_head)
            self.heatmap_head.eval()
            self.heatmap_head.requires_grad_(False)
            self.heatmap_head.init_weights()

        if audio_filter_model is not None:
            self.audio_filter_model = HF_MODELS.build(audio_filter_model)
            self.audio_filter_model.eval()
            self.audio_filter_model.requires_grad_(False)
        else:
            self.audio_filter_model = None

        self.contrastive_loss = SyncInfoNCE(
            k=loss_cfg.get("temporal_neighbors", 0),
        )

        self.loss_cfg = loss_cfg
        self.shuffle_id = shuffle_id

        self.mask_ratio = self.mae_encoder.mask_ratio

        self.motion_augment = None
        if motion_augment is not None:
            self.motion_augment = partial(DiffAugment, policy=motion_augment)

        self.train_minibatch = train_minibatch
        self._collect_trainable_params()

    def encode_audio(self, audio: Tensor, video_length: int, sr: int) -> Tensor:
        """
        Encode raw audio waveforms into hidden representations using an audio encoder,
        optionally aligned with the expected video length.

        This method performs two steps:
        1. Use `self.audio_processor` (e.g., Wav2Vec2Processor or WhisperProcessor)
            to convert waveform into model-ready input tensors.
        2. Pass the processed audio into `self.audio_encoder` to obtain hidden states.

        Note:
            The method supports alignment of the audio features with a target video length
            in frames via the `seq_len` argument, useful for cross-modal generation.

        Args:
            audio (Tensor):
                A batch of raw audio waveforms of shape [B, T_audio], where
                B = batch size, T_audio = number of audio samples per example.
            video_length (int):
                The length of the corresponding video in frames (used for alignment).
            sr (int):
                Sampling rate of the audio (e.g., 16000 Hz).

        Returns:
            Tensor:
                A tensor of shape [B, L, D], where L is the number of audio tokens
                (aligned to `video_length` if supported by the encoder),
                and D is the audio feature dimension.
        """
        # Step 1: Tokenize raw waveform using audio processor (e.g., Whisper/Wav2Vec2)
        # Output: audio input tensor of shape [B, N] where N ≈ duration_in_seconds × sr
        with torch.no_grad():
            B, N = audio.shape
            # processor can only process float32
            audio_inputs = self.audio_processor(
                audio.float(), sampling_rate=sr, return_tensors="pt"
            ).input_values.to(
                device=audio.device,
                dtype=dtype_from_str(audio.dtype),
            )
            # assert no other dimensions were added
            audio_inputs = audio_inputs.view(B, N)
            # Step 2: Encode audio to get hidden representations
            # `seq_len` optionally aligns output length with video frame count
            audio_encoder_output = self.audio_encoder(
                audio_inputs, seq_len=int(video_length), output_hidden_states=True
            )

        if self.audio_layer == "last":
            audio_states = audio_encoder_output.last_hidden_state
        elif self.audio_layer == "all":
            # concat all layers output from wav2vec2
            audio_states = torch.cat(
                (audio_encoder_output.last_hidden_state,)
                + audio_encoder_output.hidden_states,
                dim=-1,
            )
        else:
            raise ValueError(f"Unknown audio_layer: {self.audio_layer}")

        audio_states = self.audio_adapter(audio_states)
        # Return only the last hidden state: [B, L, D]
        return audio_states

    @torch.no_grad()
    def encode_audio_filter(self, audio: Tensor, video_length: int, sr: int) -> Tensor:
        assert hasattr(self, "audio_filter_model"), "audio_filter_model not found"
        B, N = audio.shape
        audio_inputs = self.audio_processor(
            audio.float(), sampling_rate=sr, return_tensors="pt"
        ).input_values.to(
            device=audio.device,
            dtype=dtype_from_str(audio.dtype),
        )
        # assert no other dimensions were added
        audio_inputs = audio_inputs.view(B, N)
        # Step 2: Encode audio to get hidden representations
        # `seq_len` optionally aligns output length with video frame count
        audio_encoder_output = self.audio_filter_model_model(
            audio_inputs, seq_len=int(video_length), output_hidden_states=True
        )

        # Return only the last hidden state: [B, L, D]
        return audio_encoder_output.last_hidden_state

    def forward_mae(
        self,
        imgs: Tensor,
        audio: Tensor,
        keypoints: Optional[Tensor] = None,
        sr: int = 16000,
        audio_frames: int = None,
        encode_only: bool = False,
    ):
        """
        Args:
            imgs (Tensor):
                A batch of video frames of shape [B, T, C, H, W], where
                B = batch size, T = number of frames, C = number of channels,
                H = height, W = width.
            audio (Tensor):
                A batch of raw audio waveforms of shape [B, N_sample], where
                B = batch size, N_sample = number of audio samples per example.
            audio_frames (int, optional):
                The number of audio frames to align with the video frames.
                If None, the num_frames of imgs will be used.
            sr (int, optional):
                Sampling rate of the audio (e.g., 16000 Hz).
        """
        B = imgs.shape[0]

        assert (
            imgs.shape[0] == 1
        ), "only support batch size == 1 in current implementation"

        ori_video_length = imgs.shape[1]
        if self.train_minibatch is not None and self.training:
            # During training, randomly choose train_minibatch frames from each video.
            #  For audio, we firstly encode the entire audio and choose corresponding frames from audio features.
            rand_idx = torch.randperm(imgs.shape[1])[: self.train_minibatch]
            imgs = imgs[:, rand_idx]
            if keypoints is not None:
                keypoints = keypoints[:, rand_idx]
                keypoints = keypoints.squeeze(0)

        else:
            rand_idx = torch.arange(imgs.shape[1])

        audio_frames = audio_frames or ori_video_length
        imgs = imgs.squeeze(0)
        # b K h w
        if self.realtime_pose_estimate:
            keypoints = None
            heatmap = self.get_sapiens_heatmap(imgs)
            keepface_mask_prob = keepface_mask_prob_from_heatmap(
                heatmap[:, _fallback_face_indices_ex_ear_308()]
            )
        else:
            assert keypoints is not None
            heatmap = None
            keepface_mask_prob = keepface_mask_prob_from_kpts(
                keypoints[
                    :,
                    _fallback_face_indices_ex_ear_308(),
                ],
                img_hw=imgs.shape[2:],
            )

        encoder_output = self.mae_encoder(imgs, do_mask=True, mask_prob=None)
        ids_restore_random = encoder_output["ids_restore"]

        id_embedding = encoder_output["prompt_tokens"][:, 0]
        # [B, L_after_drop, C]
        patch_tokens = encoder_output["patch_tokens"]

        motion_imgs = imgs
        if self.motion_augment is not None:
            motion_imgs = self.motion_augment(motion_imgs)
        motion_encoder_output = self.mae_encoder(
            motion_imgs, do_mask=True, mask_prob=keepface_mask_prob
        )
        nonvocal_motion, vocal_motion = motion_encoder_output["prompt_tokens"][
            :, 1:3
        ].unbind(dim=1)

        # Reconstruction

        # [B, T, C]
        audio_embedding = self.encode_audio(audio, audio_frames, sr)
        audio_embedding = audio_embedding[:, rand_idx]
        # (b t) c
        audio_embedding = rearrange(audio_embedding, "b t c -> (b t) c")

        # randomly shuffle identity embedding, since in the same video, id is the same
        if self.shuffle_id:
            idx_identity_shuffle = torch.randperm(id_embedding.shape[0])
            id_embedding = id_embedding[idx_identity_shuffle]
        else:
            idx_identity_shuffle = list(range(id_embedding.shape[0]))

        if encode_only:
            return SlipmaeOuput(
                ids_restore=ids_restore_random,
                audio_embedding=audio_embedding,
                vocal_motion=vocal_motion,
                nonvocal_motion=nonvocal_motion,
                id_embedding=id_embedding,
                idx_identity_shuffle=idx_identity_shuffle,
                heatmap=heatmap,
                keypoint=keypoints,
                mask=encoder_output["mask"],
                keepface_mask=motion_encoder_output["mask"],
                rand_idx=rand_idx,
            )

        pixel_vocal = self.forward_mae_decode(
            id_embedding,
            nonvocal_motion,
            vocal_motion,
            patch_tokens,
            ids_restore_random,
        )

        if audio_frames == ori_video_length:
            pixel_audio = self.forward_mae_decode(
                id_embedding,
                nonvocal_motion,
                audio_embedding,
                patch_tokens,
                ids_restore_random,
            )
        else:
            print_log(
                f"Skip reconstruction from audio because of audio and video duration mismatch, audio_frames: {audio_frames}, T: {ori_video_length}",
                logger="current",
                level=logging.INFO,
            )
            pixel_audio = None

        return SlipmaeOuput(
            pixel_vocal=pixel_vocal,
            pixel_audio=pixel_audio,
            ids_restore=ids_restore_random,
            audio_embedding=audio_embedding,
            vocal_motion=vocal_motion,
            nonvocal_motion=nonvocal_motion,
            id_embedding=id_embedding,
            idx_identity_shuffle=idx_identity_shuffle,
            heatmap=heatmap,
            keypoint=keypoints,
            mask=encoder_output["mask"],
            keepface_mask=motion_encoder_output["mask"],
            rand_idx=rand_idx,
        )

    def forward_mae_decode(
        self, id_embedding, nonvocal_motion, vocal_motion, patch_tokens, ids_restore
    ):
        return self.mae_decoder(
            torch.cat(
                [
                    torch.stack([id_embedding, nonvocal_motion, vocal_motion], dim=1),
                    patch_tokens,
                ],
                dim=1,
            ),
            ids_restore=ids_restore,
        )

    def forward_loss(self, batch):
        # B, N

        audio = batch["audio"]
        sr = int(batch["audio_metadata"].get("sr", [16000] * audio.shape[0])[0])
        imgs = batch["video"]
        keypoints = batch.get("keypoint", None)

        B, T = imgs.shape[:2]
        assert (
            B == 1
        ), f"only support batch size == 1 in current implementation, but got {B}"

        mae_output = self.forward_mae(imgs, audio, keypoints=keypoints, sr=sr)
        audio_embedding = mae_output["audio_embedding"]
        pixel_vocal = mae_output["pixel_vocal"]
        pixel_audio = mae_output["pixel_audio"]
        vocal_motion = mae_output["vocal_motion"]
        nonvocal_motion = mae_output["nonvocal_motion"]
        id_embedding = mae_output["id_embedding"]
        rand_idx = mae_output["rand_idx"]
        gt = imgs.squeeze(0)[rand_idx]

        pixel_loss_vocal = F.mse_loss(pixel_vocal, gt)
        pixel_loss_audio = F.mse_loss(pixel_audio, gt)

        # ---------------------------- Loss Computation ----------------------------
        audio_self_distance = (
            self.get_audio_self_distance(audio, T, sr)
            if self.audio_filter_model is not None
            else None
        )
        if audio_self_distance is not None:
            audio_self_distance = audio_self_distance.index_select(
                0, rand_idx.to(audio_self_distance.device)
            ).index_select(1, rand_idx.to(audio_self_distance.device))
        cl_loss = self.contrastive_loss(
            audio_embedding,
            vocal_motion,
            audio_self_distance,
            None,
        )

        loss_dict = {
            "pixel_loss_vocal": pixel_loss_vocal
            * self.loss_cfg.get("pixel_loss_weight", 1),
            "pixel_loss_audio": pixel_loss_audio
            * self.loss_cfg.get("pixel_loss_weight", 1),
            "cl_loss": cl_loss * self.loss_cfg.get("cl_loss_weight", 1),
        }

        if self.loss_cfg.get("ortho_loss_weight", 0) > 0:
            ortho_loss = self.orthogonality_loss(
                id_embedding,
                nonvocal_motion,
                vocal_motion,
                mode=self.loss_cfg.get("ortho_loss_mode"),
            )[0]
            loss_dict["ortho_loss"] = ortho_loss

        return loss_dict

    @torch.no_grad()
    def get_sapiens_heatmap(self, imgs: Tensor, max_batch_size: int = 32) -> Tensor:
        """
        Get sapiens heatmap from imgs.
        Args:
            imgs (Tensor):
                A batch of imgs: [B, C, H, W]
        Returns:
            heatmap (Tensor):
                A batch of heatmap: [B, N, H, W]
        """
        # [B, C, H, W]
        if imgs.shape[0] > max_batch_size:
            heatmap = []
            for i in range(0, imgs.shape[0], max_batch_size):
                heatmap.append(self.get_sapiens_heatmap(imgs[i : i + max_batch_size]))
            heatmap = torch.cat(heatmap, dim=0)
            return heatmap
        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        heatmap, _ = self.heatmap_head(feats, decode_kpt=False)
        return heatmap

    def get_audio_self_distance(
        self, audio: Tensor, video_length: int, sr: int
    ) -> Tensor:
        assert hasattr(self, "audio_filter_model"), "audio_filter_model not found"
        B, N = audio.shape
        audio_inputs = self.audio_processor(
            audio.float(), sampling_rate=sr, return_tensors="pt"
        ).input_values.to(
            device=audio.device,
            dtype=dtype_from_str(audio.dtype),
        )
        # assert no other dimensions were added
        audio_inputs = audio_inputs.view(B, N)
        # Step 2: Encode audio to get hidden representations
        # `seq_len` optionally aligns output length with video frame count
        audio_encoder_output = self.audio_filter_model(
            audio_inputs, seq_len=int(video_length), output_hidden_states=True
        )

        # Return only the last hidden state: [B, L, D]
        audio_hidden_states = audio_encoder_output.last_hidden_state

        audio_hidden_states = rearrange(audio_hidden_states, "1 t d -> t d")
        # T, T
        normalized = F.normalize(audio_hidden_states, p=2, dim=-1)
        audio_self_distance = normalized.matmul(normalized.T)
        return audio_self_distance

    @staticmethod
    def orthogonality_loss(
        z_id, z_non, z_voc, mode="cosine", eps=1e-6, normalize=True, reduction="mean"
    ):
        """
        Compute orthogonality / decorrelation loss between three token embeddings.
        Inputs:
            z_id, z_non, z_voc: tensors, shape [B, C]
            mode: 'cosine' or 'cov'
                - 'cosine': mean squared cosine similarity per-sample between token pairs
                - 'cov': covariance-based off-diagonal Frobenius norm between token feature axes
            normalize: whether to L2-normalize vectors before cosine calc (only used in 'cosine' mode)
            reduction: 'mean' or 'sum' for final scalar reduction
        Returns:
            loss: scalar tensor
            info: dict with diagnostics
                - 'cosine_mean' (if mode=='cosine'): mean of cosines (not squared)
                - 'cosine_sq_mean' (if mode=='cosine'): mean of cos^2
                - 'cov_offdiag_norm' (if mode=='cov'): Frobenius norm of off-diag covariance blocks
        """
        assert z_id.ndim == 2 and z_non.ndim == 2 and z_voc.ndim == 2
        B, C = z_id.shape
        device = z_id.device

        if mode == "cosine":
            # optionally normalize (recommended)
            if normalize:
                z_id_n = F.normalize(z_id, p=2, dim=-1, eps=eps)
                z_non_n = F.normalize(z_non, p=2, dim=-1, eps=eps)
                z_voc_n = F.normalize(z_voc, p=2, dim=-1, eps=eps)
            else:
                # safe normalization for numerical stability in cosine computation
                z_id_n = z_id / (z_id.norm(dim=-1, keepdim=True) + eps)
                z_non_n = z_non / (z_non.norm(dim=-1, keepdim=True) + eps)
                z_voc_n = z_voc / (z_voc.norm(dim=-1, keepdim=True) + eps)

            # pairwise cosine per sample
            sim_id_non = (z_id_n * z_non_n).sum(dim=-1)  # [B]
            sim_id_voc = (z_id_n * z_voc_n).sum(dim=-1)
            sim_non_voc = (z_non_n * z_voc_n).sum(dim=-1)

            # squared cosine (penalize magnitude regardless of sign)
            sim_sq = torch.cat(
                [sim_id_non**2, sim_id_voc**2, sim_non_voc**2], dim=0
            )  # [3*B]
            if reduction == "mean":
                loss = sim_sq.mean()
            elif reduction == "sum":
                loss = sim_sq.sum()
            else:
                raise ValueError("reduction must be 'mean' or 'sum'")

            info = {
                "cosine_mean": torch.cat([sim_id_non, sim_id_voc, sim_non_voc], dim=0)
                .mean()
                .detach(),
                "cosine_sq_mean": loss.detach(),
            }
            return loss, info

        elif mode == "cov":
            # center along batch dim
            def centered(x):
                return x - x.mean(dim=0, keepdim=True)

            z1 = centered(z_id)
            z2 = centered(z_non)
            z3 = centered(z_voc)

            # compute covariance blocks: (C x C) matrices
            # cov_{12} = z1^T z2 / (B-1)
            denom = max(B - 1, 1)
            cov12 = (z1.t() @ z2) / denom  # [C, C]
            cov13 = (z1.t() @ z3) / denom
            cov23 = (z2.t() @ z3) / denom

            # off-diagonal norm (we want to minimize cross-covariance magnitude)
            frob12 = cov12.pow(2).sum()  # scalar
            frob13 = cov13.pow(2).sum()
            frob23 = cov23.pow(2).sum()
            total = frob12 + frob13 + frob23

            # optional normalization by C^2 to make scale invariant to feature dim
            loss = total / (C * C)
            if reduction == "sum":
                loss = loss * (C * C)  # revert to sum if requested (rare)

            info = {
                "cov_offdiag_norm": total.detach(),
                "cov12_frob": frob12.detach(),
                "cov13_frob": frob13.detach(),
                "cov23_frob": frob23.detach(),
            }
            return loss, info

        else:
            raise ValueError("mode must be 'cosine' or 'cov'")

    @torch.no_grad()
    def forward_predict(self, batch: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:
        imgs = batch["video"]
        audio = batch["audio"]
        sr = int(batch["audio_metadata"].get("sr", [16000] * audio.shape[0])[0])

        mae_output = self.forward_mae(
            imgs,
            audio,
            sr,
        )

        mae_output = {
            k: v.unsqueeze(0) if isinstance(v, Tensor) else v
            for k, v in mae_output.items()
        }
        batch.update(mae_output)
        sample_list = DataSample(**batch)
        # split sample_list into a sample list
        sample_list = sample_list.split()
        return sample_list


@torch.no_grad()
def infer(
    cfg="configs/train/motion_extractor/motion_extractor_mae_v3_crossattn_shuffleframe.py",
    checkpoint="work_dirs/motion_extractor_mae_v3_crossattn_shuffleframe/iter_56000.pth",
    video_path="data/processed/video_crop_v1/20240307_7343374805606861607_0465897_0468515_fps22.000_shot_001_seg_005.mp4",
    audio_path=None,
    max_duration: float = 20,
    fps: int = 25,
    output_path: str = "./out_motion_extractor_mae_v3",
    do_av_mapping: bool = True,
    dtype: torch.dtype = torch.bfloat16,
):
    """
    Args:
        do_reconstruction: Whether perform MAE reconstruction.
        do_av_mapping: Whether reorder the video w.r.t input audio.
    """
    import torchaudio
    from torchvision.io import read_video, write_video
    from mmengine.device import get_device
    from mmengine import Config
    from mmengine.runner import load_checkpoint
    from transformers.models.bit.image_processing_bit import BitImageProcessor
    from mmhug.evaluators.evaluator_audio_motion_clip import euclidean_distance_matrix
    from mmhug.utils.pixel_process_utils import inv_normalize
    from mmhug.utils.io import merge_video_audio

    from mmhug.utils.vis_utils import (
        blend_image_with_heatmaps,
        apply_block_mask_to_rgb,
    )

    vis_video = []

    basename = os.path.basename(video_path).split(".")[0]
    output_path = os.path.join(
        output_path,
        os.path.basename(cfg).rstrip(".py"),
        os.path.basename(checkpoint).rstrip(".pth"),
        basename,
    )
    os.makedirs(output_path, exist_ok=True)

    device = get_device()

    # 1. load cfg & model
    cfg = Config.fromfile(cfg)
    model: TrainerSlipmaePretrain = MODELS.build(cfg.model)
    load_checkpoint(model, checkpoint, map_location="cpu")
    model = model.to(device, dtype)  # 不强制全模型到非原生 dtype，推理可保持 float32
    model.eval()

    # 2. processors
    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    # 3. load audio
    if audio_path is None:
        audio_path = video_path
    audio, sr = torchaudio.load(audio_path)  # audio: (channels, samples)
    # convert to mono and trim to max_duration
    audio = torch.mean(audio, dim=0, keepdim=True)  # (1, samples)
    max_samples = int(max_duration * sr)
    audio = audio[:, :max_samples]  # (1, samples)
    # keep audio on CPU initially; encode_audio will move to device as needed
    audio = audio.to(device, dtype)

    # 4. load video frames (use torchvision.read_video)
    # read_video returns (frames, audio, info) — frames shape: (T, H, W, C) uint8
    frames = read_video(video_path)[0][: int(max_duration * fps)]
    frames = (
        torch.tensor(input_video_processor(frames).pixel_values)
        .unsqueeze(0)
        .to(device, dtype)
    )

    mae_output = model.forward_mae(frames, audio, sr)

    frames = frames.squeeze(0)
    vis_video.append(frames)

    T, C, H, W = frames.shape
    pixel_vocal = mae_output["pixel_vocal"]
    vis_video.append(pixel_vocal)
    pixel_audio = mae_output["pixel_audio"]
    if pixel_audio is not None:
        vis_video.append(pixel_audio)
    vocal_motion = mae_output["vocal_motion"]
    audio_embedding = mae_output["audio_embedding"]
    idx_identity_shuffle = mae_output["idx_identity_shuffle"]
    heatmap = mae_output["heatmap"]

    # reorder frames w.r.t audio
    if do_av_mapping:
        audio_codes = F.normalize(audio_embedding, p=2, dim=-1)  # (T, D)
        # vocal_tokens might be shape (T, 1, D') depending on your encoder; ensure shape (T, D')
        vocal_codes = vocal_motion
        vocal_codes = F.normalize(vocal_codes, p=2, dim=-1)

        dist_mat = euclidean_distance_matrix(audio_codes, vocal_codes)  # (T, T)
        top_1_idxs = torch.argmin(
            dist_mat, dim=1
        )  # For each audio index, best matching motion idx
        # reorder original (B=1) video frames along time using top_1_idxs
        # ori_video shape (1, T, C, H, W) -> select frames
        reordered_video_audio = frames[top_1_idxs]  # shape (T, C, H, W)
        vis_video.append(reordered_video_audio)
        accuracy = (top_1_idxs == torch.arange(T, device=device)).float().mean()
        print(f"AV mapping accuracy: {accuracy.item():.4f}")
    # reorder frames w.r.t identity
    reordered_video_audio = reordered_video_audio[idx_identity_shuffle]
    vis_video.append(reordered_video_audio)

    # apply mask to rgb

    mask_uniform_rgb = apply_block_mask_to_rgb(frames, mae_output["mask"])
    mask_keepface_rgb = apply_block_mask_to_rgb(frames, mae_output["keepface_mask"])
    vis_video.append(mask_uniform_rgb)
    vis_video.append(mask_keepface_rgb)

    # visualize heatmap
    heatmap = F.interpolate(heatmap, size=(H, W), mode="bilinear")
    heatmap_rgb = blend_image_with_heatmaps(frames, heatmap)
    vis_video.append(heatmap_rgb)
    # visualization

    vis_video = torch.cat(vis_video, dim=-1).float().detach().cpu()

    # 14. de-normalize and convert to uint8 for saving
    out = inv_normalize(vis_video, scale_back=True)
    # Permute to (T, H, W, C) for write_video
    out_np = out.permute(0, 2, 3, 1).numpy()

    # 15. save audio + video, then mux
    write_video(os.path.join(output_path, "output.mp4"), out_np, fps=fps)
    # audio is currently on device; move to cpu and original dtype
    print(audio.shape)
    torchaudio.save(f"{output_path}/audio.wav", audio.float().cpu(), sr)
    merge_video_audio(
        f"{output_path}/output.mp4",
        f"{output_path}/audio.wav",
        f"{output_path}/output_audio.mp4",
        fps=fps,
        sr=sr,
    )
    print("Saved output to", output_path)


if __name__ == "__main__":
    import fire

    fire.Fire(infer)
