import librosa
import numpy as np
from torch import Tensor
from torch import nn
import torch
from transformers import Wav2Vec2Processor
import torch.nn.functional as F
from typing import Literal, Tuple
from mmengine import Config
from mmengine.device import get_device
from mmengine.runner import load_checkpoint
from mmhug.models.custom_transformers.sapiens import (
    SapiensVisionTransformer,
    HeatmapHead,
)
from mmhug.models.custom_transformers.wav2vec2_interp import Wav2Vec2InterpModel
from mmhug.trainers.trainer_wan_audio.trainer_motion_extractor_mae_v3 import (
    TrainerMotionExtractorMAEV3,
)
from mmhug.utils.dtype_utils import dtype_from_str
from mmhug.registry import MODELS

from decord.video_reader import VideoReader

from transformers.models.bit.image_processing_bit import BitImageProcessor

from mmhug.evaluators.evaluator_audio_motion_clip import (
    euclidean_distance_matrix,
    evaluate_avsync,
)

from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_kpts,
)
from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308


def match_accuracy_torch(
    dist: torch.Tensor,  # (T, W) 距离矩阵
    accept_radius: int = 1,
    tie_policy: Literal["any", "strict"] = "any",
) -> Tuple[float, torch.Tensor]:
    assert dist.dim() == 2, "dist must be (T, W)"
    T, W = dist.shape
    center = W // 2

    # 将 NaN 置为 +inf
    d = dist.clone()
    d = torch.where(torch.isnan(d), torch.full_like(d, float("inf")), d)

    mins, _ = d.min(dim=1, keepdim=True)  # (T,1)
    is_min = d.eq(mins)  # (T,W)

    accept = torch.zeros_like(is_min, dtype=torch.bool)
    L = max(0, center - accept_radius)
    R = min(W, center + accept_radius + 1)
    accept[:, L:R] = True

    if tie_policy == "any":
        success = (is_min & accept).any(dim=1)  # (T,)
    elif tie_policy == "strict":
        only_one_min = is_min.sum(dim=1).eq(1)  # (T,)
        min_idx = d.argmin(dim=1)  # (T,)
        in_accept = (min_idx >= L) & (min_idx < R)
        success = only_one_min & in_accept
    else:
        raise ValueError("tie_policy must be 'any' or 'strict'")

    acc = float(success.float().mean().item()) if T > 0 else 0.0
    return acc, success


def eval_acc_from_dists(
    dists_list, v_shift=15, Ks=(5, 7, 9, 11, 13, 15), tol=1, bigger_is_better=True
):
    """
    dists_list: 长度 T 的列表，每个元素 shape=[2*v_shift+1]，为“分数向量”(大好)或“距离向量”(小好)
    bigger_is_better: VocaLiST=True；SyncNet 的欧氏距离取 False
    返回: {K: accuracy_in_[0,1]}
    """
    import torch, torch.nn.functional as F

    scores = torch.stack(dists_list, dim=0).float().cpu()  # [T, 2v+1]
    if not bigger_is_better:  # SyncNet 距离 -> 分数
        scores = -scores

    accs = {}
    for K in Ks:
        L = K - 4  # 时间平均核长：K=5->1, K=7->3, K=9->5, ...
        if L > 1:
            # 在时间维做一维平均卷积（边界复制）
            # 形状变换以便用 conv1d 在时间维上卷积
            x = scores.transpose(0, 1).unsqueeze(0)  # [1, 2v+1, T]
            pad = (L // 2, L // 2)
            x = F.pad(x, pad=pad, mode="replicate")
            kernel = torch.ones(1, 1, L, device=x.device) / L
            x = torch.conv1d(x, kernel.expand(x.size(1), -1, -1), groups=x.size(1))
            scores_K = x.squeeze(0).transpose(0, 1)  # [T, 2v+1]
        else:
            scores_K = scores

        idx = scores_K.argmax(dim=1)  # [T]
        offsets = v_shift - idx  # 0=完美对齐
        acc = (offsets.abs() <= tol).float().mean().item()
        accs[K] = acc
    return accs


class SlipMAEPipeline:

    def __init__(
        self,
        mae_encoder,
        video_processor: BitImageProcessor,
        audio_processor: Wav2Vec2Processor,
        audio_encoder: Wav2Vec2InterpModel,
        audio_adapter: nn.Module,
        raw_sapiens: SapiensVisionTransformer = None,
        heatmap_head: HeatmapHead = None,
        dtype=torch.bfloat16,
    ):
        self.device = get_device()
        self.dtype = dtype

        self.video_processor = video_processor
        self.audio_processor = audio_processor

        self.mae_encoder = mae_encoder.to(device=self.device, dtype=self.dtype)
        self.audio_adapter = audio_adapter.to(device=self.device, dtype=self.dtype)
        self.audio_encoder = audio_encoder.to(device=self.device, dtype=self.dtype)

        self.realtime_pose_estimate = False
        if raw_sapiens is not None:
            self.realtime_pose_estimate = True
            self.raw_sapiens = raw_sapiens.to(device=self.device, dtype=self.dtype)
            assert (
                heatmap_head is not None
            ), "raw_sapiens must be provided with heatmap_head"
            self.heatmap_head = heatmap_head.to(device=self.device, dtype=self.dtype)

        self.v_shift = 15

    def encode_audio(self, audio: Tensor, video_length: int, sr: int) -> Tensor:
        """
        Encode raw audio waveforms into hidden representations using an audio encoder,
        optionally aligned with the expected video length.

        This method performs two steps:
        1. Use `self.audio_processor` (e.g., Wav2Vec2Processor or WhisperProcessor)
            to convert waveform into model-ready input tensors.
        2. Pass the processed audio into `self.audio_encoder` to obtain hidden states.

        Note:
            The method supports alignment of the audio features with a target video length
            in frames via the `seq_len` argument, useful for cross-modal generation.

        Args:
            audio (Tensor):
                A batch of raw audio waveforms of shape [B, T_audio], where
                B = batch size, T_audio = number of audio samples per example.
            video_length (int):
                The length of the corresponding video in frames (used for alignment).
            sr (int):
                Sampling rate of the audio (e.g., 16000 Hz).

        Returns:
            Tensor:
                A tensor of shape [B, L, D], where L is the number of audio tokens
                (aligned to `video_length` if supported by the encoder),
                and D is the audio feature dimension.
        """
        # Step 1: Tokenize raw waveform using audio processor (e.g., Whisper/Wav2Vec2)
        # Output: audio input tensor of shape [B, N] where N ≈ duration_in_seconds × sr
        with torch.no_grad():
            B, N = audio.shape
            # processor can only process float32
            audio_inputs = self.audio_processor(
                audio.float(), sampling_rate=sr, return_tensors="pt"
            ).input_values.to(
                device=audio.device,
                dtype=dtype_from_str(audio.dtype),
            )
            # assert no other dimensions were added
            audio_inputs = audio_inputs.view(B, N)
            # Step 2: Encode audio to get hidden representations
            # `seq_len` optionally aligns output length with video frame count
            audio_encoder_output = self.audio_encoder(
                audio_inputs, seq_len=int(video_length), output_hidden_states=True
            )

        # concat all layers output from wav2vec2
        audio_states = torch.cat(
            (audio_encoder_output.last_hidden_state,)
            + audio_encoder_output.hidden_states,
            dim=-1,
        )

        audio_states = self.audio_adapter(audio_states)
        # Return only the last hidden state: [B, L, D]
        return audio_states

    @torch.no_grad()
    def sapiens_face_det(self, video: torch.Tensor, batch_size=8) -> torch.Tensor:
        H, W = video.shape[-2:]
        all_keypoints = []
        all_keypoints_scores = []
        for i in range(0, len(video), batch_size):
            video_batch = video[i : i + batch_size]
            _, keypoints, keypoints_scores = self.get_sapiens_heatmap(video_batch)
            all_keypoints.append(torch.from_numpy(keypoints))
            all_keypoints_scores.append(torch.from_numpy(keypoints_scores))
        all_keypoints = torch.cat(all_keypoints, dim=0)
        all_keypoints_scores = torch.cat(all_keypoints_scores, dim=0)
        keypoints = torch.cat(
            [all_keypoints, all_keypoints_scores.unsqueeze(-1)], dim=-1
        )

        return keypoints

    @torch.no_grad()
    def get_sapiens_heatmap(
        self,
        imgs: torch.Tensor,
    ) -> torch.Tensor:
        """
        输入: imgs [B, C, H, W] (torch)
        输出: heatmap [B, K, h, w] (torch, on CPU, float32)
        内部会按 max_batch_size 进一步切批，避免显存峰值。
        """

        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        # 256 192
        heatmap, pred = self.heatmap_head(feats, decode_kpt=True)

        keypoints = np.concatenate([p.keypoints for p in pred], axis=0)
        keypoints_scores = np.concatenate([p.keypoint_scores for p in pred], axis=0)

        heatmap = heatmap.float().detach().cpu().numpy()
        keypoints = keypoints
        keypoints_scores = keypoints_scores
        return heatmap, keypoints, keypoints_scores

    def encode_vocal(
        self, imgs: Tensor, keypoint: Tensor = None, max_batch_size: int = 8
    ):
        """
        Args:
            img: Tensor, shape in [B, C, H, W]. range in -1, 1

        Returns:
            Tensor, shape in [B, D].
        """
        if keypoint is None:
            keypoint = self.sapiens_face_det(imgs, batch_size=max_batch_size)

        keepface_mask_prob = keepface_mask_prob_from_kpts(
            keypoint[
                :,
                _fallback_face_indices_ex_ear_308(),
            ],
            img_hw=imgs.shape[2:],
        )

        vocal_motion = self.mae_encoder(
            imgs, do_mask=True, mask_prob=keepface_mask_prob.to(imgs.device)
        )["prompt_tokens"][:, -1]
        return vocal_motion

    def calc_pdist(self, feat1, feat2, vshift=15):

        win_size = vshift * 2 + 1

        feat2p = F.pad(feat2, (0, 0, vshift, vshift))

        dists = []

        for i in range(0, len(feat1)):
            # 1 * win_size
            dist = euclidean_distance_matrix(
                F.normalize(feat1[[i]], p=2, dim=-1),
                F.normalize(feat2p[i : i + win_size, :], p=2, dim=-1),
            ).squeeze(0)
            dists.append(dist)

        return dists

    def __call__(self, video, audio, keypoint=None, batch_size=48):

        # if self.realtime_pose_estimate:
        #     assert (
        #         keypoint is None
        #     ), "keypoint must be None when realtime_pose_estimate is True"

        if isinstance(video, str):
            vr = VideoReader(video)
            # to bgr
            video = vr.get_batch(range(len(vr))).asnumpy()

        assert isinstance(video, np.ndarray)
        video = torch.tensor(self.video_processor(video).pixel_values).to(
            self.device, self.dtype
        )

        if keypoint is not None:
            if isinstance(keypoint, str):
                keypoint = np.load(keypoint)
            keypoint = torch.tensor(keypoint).to(self.device, self.dtype)
        else:
            assert self.realtime_pose_estimate

        audio, sr = librosa.load(audio, sr=16000, mono=True)

        audio = torch.from_numpy(audio).to(self.device, self.dtype)

        audio_feats = (
            self.encode_audio(audio.unsqueeze(0), video.shape[0], sr)[0]
            .float()
            .detach()
            .cpu()
        )
        vocal_feats = []
        for i in range(0, video.shape[0], batch_size):
            video_batch = video[i : i + batch_size]
            if keypoint is not None:
                keypoint_batch = keypoint[i : i + batch_size]
            else:
                keypoint_batch = None
            v = self.encode_vocal(video_batch, keypoint_batch)
            vocal_feats.append(v.float().detach().cpu())
        vocal_feats = torch.cat(vocal_feats, dim=0)

        eval_res = evaluate_avsync(
            vocal_feats,
            audio_feats,
            win_size=self.v_shift * 2 + 1,
            k_list=(1, 5, 15),
            batch_size=32,
            topk=3,
        )
        return eval_res


def infer(
    cfg="configs/train/slipmae/slipmae_pretrain.py",
    checkpoint="work_dirs/slipmae_pretrain/iter_3000.pth",
    video_path="data/MEAD/video_resampled/M003_left_30_happy_level_1_013.mp4",
    audio_path=None,
    keypoint_path="data/MEAD/keypoint308/M003_left_30_happy_level_1_013.npy",
    dtype=torch.bfloat16,
):
    audio_path = video_path if audio_path is None else audio_path
    # 1. load cfg & model
    cfg = Config.fromfile(cfg)
    model: TrainerMotionExtractorMAEV3 = MODELS.build(cfg.model)
    load_checkpoint(model, checkpoint, map_location="cpu")
    model.eval()

    # 2. processors
    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    pipeline = SlipMAEPipeline(
        mae_encoder=model.mae_encoder,
        audio_processor=model.audio_processor,
        audio_encoder=model.audio_encoder,
        audio_adapter=model.audio_adapter,
        heatmap_head=model.heatmap_head if model.realtime_pose_estimate else None,
        raw_sapiens=model.raw_sapiens if model.realtime_pose_estimate else None,
        video_processor=input_video_processor,
        dtype=dtype,
    )
    del model
    torch.cuda.empty_cache()
    eval_res = pipeline(video_path, audio_path, keypoint_path)
    print(eval_res)


if __name__ == "__main__":
    import fire

    fire.Fire(infer)
