import glob
import os
import random
from typing import Any
import numpy as np
import torch
from tqdm import tqdm
from transformers import BitImageProcessor, Wav2Vec2Processor
from torchvision.io import read_video, write_video

from mmengine import Config
from mmengine.device import get_device
from mmengine.runner import load_checkpoint
from mmhug.datasets.transforms.keypoint2mask import sapiens2mask

from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)
from mmhug.models.custom_transformers.wan22_audio.transformer_mae_decoder_v2 import (
    MAEDecoderV2,
)
from mmhug.models.custom_transformers.wan22_audio.transformer_motion_extractor_mae_v2 import (
    SapiensMotionExtractorV2,
)
from mmhug.models.custom_transformers.wav2vec2_interp.wav2vec2_interp import (
    Wav2Vec2InterpModel,
)
from mmhug.registry import MODELS
from mmhug.trainers.trainer_slipmae.trainer_slipmae_pretrain import (
    SlipmaeOuput,
    TrainerSlipmaePretrain,
)
from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_heatmap,
    keepface_mask_prob_from_kpts,
)
from mmhug.utils.dtype_utils import dtype_from_str
from mmhug.utils.vis_utils import apply_block_mask_to_rgb
from mmhug.utils.pixel_process_utils import inv_normalize


def to_numpy(tensor: torch.Tensor) -> Any:
    return tensor.permute(0, 2, 3, 1).numpy()


class PipelineSlipmae:

    def __init__(
        self,
        audio_encoder: Wav2Vec2InterpModel,
        audio_processor: Wav2Vec2Processor,
        mae_encoder: SapiensMotionExtractorV2,
        mae_decoder: MAEDecoderV2,
        raw_sapiens: SapiensVisionTransformer = None,
        heatmap_head: HeatmapHead = None,
        dtype=torch.bfloat16,
    ):
        self.device = get_device()
        self.dtype = dtype
        self.audio_encoder = audio_encoder.to(self.device, self.dtype)
        self.audio_processor = audio_processor
        self.mae_encoder = mae_encoder.to(self.device, self.dtype)
        self.mae_decoder = mae_decoder.to(self.device, self.dtype)
        self.raw_sapiens = raw_sapiens.to(self.device, self.dtype)
        self.heatmap_head = heatmap_head.to(self.device, self.dtype)

        self.input_video_processor = BitImageProcessor(
            do_resize=True,
            size={"shortest_edge": 512},
            do_center_crop=True,
            crop_size={"height": 512, "width": 512},
            do_rescale=True,
            do_normalize=True,
            image_mean=[0.5, 0.5, 0.5],
            image_std=[0.5, 0.5, 0.5],
        )

    @torch.no_grad()
    def get_sapiens_heatmap(
        self,
        imgs: torch.Tensor,
    ) -> torch.Tensor:
        """
        输入: imgs [B, C, H, W] (torch)
        输出: heatmap [B, K, h, w] (torch, on CPU, float32)
        内部会按 max_batch_size 进一步切批，避免显存峰值。
        """

        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        # 256 192
        heatmap, pred = self.heatmap_head(feats, decode_kpt=True)

        keypoints = np.concatenate([p.keypoints for p in pred], axis=0)
        keypoints_scores = np.concatenate([p.keypoint_scores for p in pred], axis=0)

        heatmap = heatmap.float().detach().cpu().numpy()
        keypoints = keypoints
        keypoints_scores = keypoints_scores
        return heatmap, keypoints, keypoints_scores

    def sapiens_face_det(self, video: torch.Tensor, batch_size=8) -> torch.Tensor:
        H, W = video.shape[-2:]
        all_keypoints = []
        all_keypoints_scores = []
        for i in tqdm(range(0, len(video), batch_size)):
            video_batch = video[i : i + batch_size]
            _, keypoints, keypoints_scores = self.get_sapiens_heatmap(video_batch)
            all_keypoints.append(torch.from_numpy(keypoints))
            all_keypoints_scores.append(torch.from_numpy(keypoints_scores))
        all_keypoints = torch.cat(all_keypoints, dim=0)
        all_keypoints_scores = torch.cat(all_keypoints_scores, dim=0)
        keypoints = torch.cat(
            [all_keypoints, all_keypoints_scores.unsqueeze(-1)], dim=-1
        )
        lower_face_mask = sapiens2mask(
            H, W, keypoints, mask_area="lower_face", mask_expand=(0, 0, 0, 20)
        )
        return keypoints, lower_face_mask

    def encode_audio(
        self, audio: torch.Tensor, video_length: int, sr: int
    ) -> torch.Tensor:
        """
        Encode raw audio waveforms into hidden representations using an audio encoder,
        optionally aligned with the expected video length.

        This method performs two steps:
        1. Use `self.audio_processor` (e.g., Wav2Vec2Processor or WhisperProcessor)
            to convert waveform into model-ready input tensors.
        2. Pass the processed audio into `self.audio_encoder` to obtain hidden states.

        Note:
            The method supports alignment of the audio features with a target video length
            in frames via the `seq_len` argument, useful for cross-modal generation.

        Args:
            audio (Tensor):
                A batch of raw audio waveforms of shape [B, T_audio], where
                B = batch size, T_audio = number of audio samples per example.
            video_length (int):
                The length of the corresponding video in frames (used for alignment).
            sr (int):
                Sampling rate of the audio (e.g., 16000 Hz).

        Returns:
            Tensor:
                A tensor of shape [B, L, D], where L is the number of audio tokens
                (aligned to `video_length` if supported by the encoder),
                and D is the audio feature dimension.
        """
        # Step 1: Tokenize raw waveform using audio processor (e.g., Whisper/Wav2Vec2)
        # Output: audio input tensor of shape [B, N] where N ≈ duration_in_seconds × sr
        with torch.no_grad():
            B, N = audio.shape
            # processor can only process float32
            audio_inputs = self.audio_processor(
                audio.float(), sampling_rate=sr, return_tensors="pt"
            ).input_values.to(
                device=audio.device,
                dtype=dtype_from_str(audio.dtype),
            )
            # assert no other dimensions were added
            audio_inputs = audio_inputs.view(B, N)
            # Step 2: Encode audio to get hidden representations
            # `seq_len` optionally aligns output length with video frame count
            audio_encoder_output = self.audio_encoder(
                audio_inputs, seq_len=int(video_length), output_hidden_states=True
            )

        if self.audio_layer == "last":
            audio_states = audio_encoder_output.last_hidden_state
        elif self.audio_layer == "all":
            # concat all layers output from wav2vec2
            audio_states = torch.cat(
                (audio_encoder_output.last_hidden_state,)
                + audio_encoder_output.hidden_states,
                dim=-1,
            )
        else:
            raise ValueError(f"Unknown audio_layer: {self.audio_layer}")

        audio_states = self.audio_adapter(audio_states)
        # Return only the last hidden state: [B, L, D]
        return audio_states

    def encode_video(
        self,
        video: torch.Tensor,
    ) -> SlipmaeOuput:
        video = video.squeeze(0)
        keypoints, _ = self.sapiens_face_det(video)
        keepface_mask_prob = keepface_mask_prob_from_kpts(
            keypoints[:, _fallback_face_indices_ex_ear_308()],
            img_hw=video.shape[-2:],
        ).to(self.device)
        encoder_output = self.mae_encoder(video, do_mask=True, mask_prob=None)
        ids_restore_random = encoder_output["ids_restore"]

        id_embedding = encoder_output["prompt_tokens"][:, 0]

        patch_tokens = encoder_output["patch_tokens"]

        motion_encoder_output = self.mae_encoder(
            video, do_mask=True, mask_prob=keepface_mask_prob
        )
        nonvocal_motion, vocal_motion = motion_encoder_output["prompt_tokens"][
            :, 1:3
        ].unbind(dim=1)

        return SlipmaeOuput(
            vocal_motion=vocal_motion,
            nonvocal_motion=nonvocal_motion,
            patch_tokens=patch_tokens,
            id_embedding=id_embedding,
            mask=encoder_output["mask"],
            keepface_mask=motion_encoder_output["mask"],
            ids_restore=ids_restore_random,
        )

    def mae_decode(
        self,
        identity_token: torch.Tensor,
        ambient_token: torch.Tensor,
        vocal_token: torch.Tensor,
        patch_tokens: torch.Tensor,
        ids_restore: torch.Tensor,
    ) -> SlipmaeOuput:
        return self.mae_decoder(
            torch.cat(
                [
                    torch.stack([identity_token, ambient_token, vocal_token], dim=1),
                    patch_tokens,
                ],
                dim=1,
            ),
            ids_restore=ids_restore,
        )

    def load_video(self, video_path: str) -> torch.Tensor:
        video = read_video(video_path)[0]
        video = torch.tensor(self.input_video_processor(video).pixel_values).to(
            self.device, self.dtype
        )

        return video

    @torch.no_grad()
    def __call__(
        self,
        video: str = None,
        patch_video: str = None,
        identity_video: str = None,
        ambient_video: str = None,
        vocal_video: str = None,
    ) -> Any:
        if video is None:
            assert (
                identity_video is not None
                and ambient_video is not None
                and vocal_video is not None
            )
        else:
            identity_video = ambient_video = vocal_video = patch_video = video

        identity_video: torch.Tensor = self.load_video(identity_video)
        patch_video: torch.Tensor = self.load_video(patch_video)
        ambient_video: torch.Tensor = self.load_video(ambient_video)
        vocal_video: torch.Tensor = self.load_video(vocal_video)

        # # align length of identity, ambient, vocal video
        # min_length = min(
        #     identity_video.shape[0], ambient_video.shape[0], vocal_video.shape[0]
        # )
        # align length of identity, ambient, vocal video
        min_length = 8
        identity_video = identity_video[:min_length]
        ambient_video = ambient_video[:min_length]
        vocal_video = vocal_video[:min_length]
        patch_video = patch_video[:min_length]

        identity_encoder_output = self.encode_video(identity_video)
        id_embedding = identity_encoder_output.id_embedding
        patch_encoder_output = self.encode_video(patch_video)
        ids_restore = patch_encoder_output.ids_restore
        patch_tokens = patch_encoder_output.patch_tokens

        ambient_encoder_output = self.encode_video(ambient_video)
        vocal_encoder_output = self.encode_video(vocal_video)

        decode_video = self.mae_decode(
            identity_token=id_embedding,
            ambient_token=ambient_encoder_output.nonvocal_motion,
            vocal_token=vocal_encoder_output.vocal_motion,
            patch_tokens=patch_tokens,
            ids_restore=ids_restore,
        )

        mask_uniform_rgb = apply_block_mask_to_rgb(
            patch_video, patch_encoder_output["mask"]
        )
        mask_keepface_ambient_rgb = apply_block_mask_to_rgb(
            ambient_video, ambient_encoder_output["keepface_mask"]
        )
        mask_keepface_vocal_rgb = apply_block_mask_to_rgb(
            vocal_video, vocal_encoder_output["keepface_mask"]
        )

        result = {
            "mask_uniform_rgb": mask_uniform_rgb,
            "mask_keepface_ambient_rgb": mask_keepface_ambient_rgb,
            "mask_keepface_vocal_rgb": mask_keepface_vocal_rgb,
            "pred_video": decode_video,
            "identity_video": identity_video,
            "patch_video": patch_video,
            "ambient_video": ambient_video,
            "vocal_video": vocal_video,
        }

        for k, v in result.items():
            if isinstance(v, torch.Tensor):
                result[k] = (
                    inv_normalize(v, scale_back=True)
                    .permute(0, 2, 3, 1)
                    .float()
                    .cpu()
                    .numpy()
                )

        return result


def main(
    cfg: str = "configs/train/slipmae/slipmae_pretrain.py",
    checkpoint: str = "work_dirs/slipmae_pretrain/best_iter_58000.pth",
    video_root: str = "data/hallo3/hallo3_training_data/videos_cropped_new/",
    sample_num: int = 100,
    output_root="outputs/slipmae_pretrain",
    seed: int = 42,
):

    torch.manual_seed(seed)
    # 1. load cfg & model
    cfg = Config.fromfile(cfg)
    model: TrainerSlipmaePretrain = MODELS.build(cfg.model)
    load_checkpoint(model, checkpoint, map_location="cpu")
    model.eval()

    sapiens = SapiensVisionTransformer(
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(1024, 768),
        patch_size=16,
        in_channels=3,
        out_indices=-1,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        final_norm=True,
        with_cls_token=False,
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        patch_cfg=dict(padding=2),
        layer_cfgs=dict(),
        pre_norm=False,
        out_type="featmap",
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    )
    sapiens.init_weights()

    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(512, 512), heatmap_size=(192, 256), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    )
    heatmap_head.init_weights()
    pipeline = PipelineSlipmae(
        audio_encoder=model.audio_encoder,
        audio_processor=model.audio_processor,
        mae_encoder=model.mae_encoder,
        mae_decoder=model.mae_decoder,
        raw_sapiens=sapiens,
        heatmap_head=heatmap_head,
    )

    all_videos = glob.glob(os.path.join(video_root, "*.mp4"))
    for _ in tqdm(range(sample_num)):

        torch.cuda.empty_cache()
        video = random.choice(all_videos)
        # patch_video = random.choice(all_videos)
        # identity_video = random.choice(all_videos)
        # ambient_video = random.choice(all_videos)
        # vocal_video = random.choice(all_videos)

        result = pipeline.__call__(
            video=video
            # patch_video=patch_video,
            # identity_video=identity_video,
            # ambient_video=ambient_video,
            # vocal_video=vocal_video,
        )

        # 3. save result
        output_path = os.path.join(
            output_root, os.path.basename(video).replace(".mp4", "")
        )
        os.makedirs(output_path, exist_ok=True)

        entire_video = []
        for k, v in result.items():
            if isinstance(v, np.ndarray):
                write_video(
                    f"{output_path}/{k}.mp4",
                    torch.from_numpy(v),
                    fps=30,
                )
                entire_video.append(torch.from_numpy(v))

        # concat alone width
        entire_video = torch.cat(entire_video, dim=2)
        write_video(
            f"{output_path}/entire_video.mp4",
            entire_video,
            fps=30,
        )
        print(
            f"Saved results to {output_root}/{os.path.basename(video).replace('.mp4', '')}/"
        )


if __name__ == "__main__":
    import fire

    fire.Fire(main)
