from einops import rearrange
import fire
from typing import Any, Optional
import numpy as np
import torch
from transformers import BitImageProcessor
from torch import Tensor, nn
import torch.nn.functional as F
from decord import VideoReader
from mmhug.datasets.transforms.load_video_with_label import EMOTION_MAPPING
from mmhug.models.custom_transformers.marlin.marlin import Marlin
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)

from mmhug.registry import MODELS
from mmengine import Config
from mmengine.runner import load_checkpoint
from mmengine.device import get_device


class SlipmaeClassifyPipeline:

    def __init__(
        self,
        video_processor,
        backbone,
        classifier: nn.Module,
        use_prompt_tokens: bool = True,
        face_only: bool = False,  # True for emotion, False for action and appearance
        raw_sapiens: Optional[SapiensVisionTransformer] = None,
        heatmap_head: Optional[HeatmapHead] = None,
        task: str = "multiclass",  # multiclass for appearance, multilabel for action and emotion
        dtype=torch.bfloat16,
    ):
        super().__init__()
        self.device = get_device()
        self.dtype = dtype
        self.backbone = backbone.to(self.device, self.dtype)

        self.face_only = face_only
        if face_only:
            assert raw_sapiens is not None
            assert heatmap_head is not None
            self.raw_sapiens = raw_sapiens.to(self.device, self.dtype)
            self.heatmap_head = heatmap_head.to(self.device, self.dtype)

        self.video_processor = video_processor
        self.classifier = classifier.to(self.device, self.dtype)

        self.task = task
        self.T = 1
        self.use_prompt_tokens = use_prompt_tokens

    @torch.no_grad()
    def get_sapiens_heatmap(self, imgs: Tensor, max_batch_size: int = 8) -> Tensor:
        """
        Get sapiens heatmap from imgs.
        Args:
            imgs (Tensor):
                A batch of imgs: [B, C, H, W]
        Returns:
            heatmap (Tensor):
                A batch of heatmap: [B, N, H, W]
        """
        # [B, C, H, W]
        if imgs.shape[0] > max_batch_size:
            heatmap = []
            for i in range(0, imgs.shape[0], max_batch_size):
                heatmap.append(self.get_sapiens_heatmap(imgs[i : i + max_batch_size]))
            heatmap = torch.cat(heatmap, dim=0)
            return heatmap
        feats = self.raw_sapiens(
            imgs.to(self.device, self.dtype), out_type="featmap"
        ).last_hidden_state
        heatmap, _ = self.heatmap_head(feats, decode_kpt=False)
        return heatmap

    def __call__(self, video, num_frames=16, batch_size=8) -> Any:
        # 1) 读入帧并做预处理
        if isinstance(video, str):
            vr = VideoReader(video)
            video = vr.get_batch(range(len(vr))).asnumpy()  # (T, H, W, 3)
        assert isinstance(video, np.ndarray)

        proc = self.video_processor(video, return_tensors="pt")
        pixel_values = (
            proc.pixel_values
        )  # 常见为 [F, C, H, W]；个别处理器也可能给 [B, C, T, H, W]

        # 统一为 frames-first: [F, C, H, W]
        if pixel_values.dim() == 5:
            # [B, C, T, H, W] -> 取 batch=1，并转为 [T, C, H, W]
            pixel_values = pixel_values[0].permute(1, 0, 2, 3).contiguous()
        elif pixel_values.dim() != 4:
            raise ValueError(
                f"Unexpected pixel_values shape: {tuple(pixel_values.shape)}"
            )
        # t c h w
        frames = pixel_values
        F_total, C, H, W = frames.shape
        T = self.T

        # 2) 生成测试用的多 clip（均匀采样 K 个长度为 T 的窗口）
        #    参考 VideoMAE/TSN 的多视角评估：Kinetics 常用 5 clips；SSv2 常用 2 clips。
        if F_total <= T:
            # 不足 T 帧：用线性插值索引重复/跳采，拼成一个 T 帧 clip
            idx = torch.linspace(0, max(0, F_total - 1), steps=T).round().long()
            clip_starts = [0]  # 仅作占位
            clips = [frames[idx]]  # [T, C, H, W]
        else:
            # 均匀选 K 个起点（左闭右开，避免越界）
            K = min(num_frames, max(1, int(round(F_total / T))))  # 自适应上限为 5
            span = F_total - T
            # 在 [0, span] 上均匀取 K 个起点
            clip_starts = torch.linspace(0, span, steps=K).round().long().tolist()
            clips = [frames[s : s + T] for s in clip_starts]  # 每个 [T, C, H, W]
        clips = torch.cat(clips, dim=0).to(self.device, self.dtype)

        if self.face_only:
            heatmap = self.get_sapiens_heatmap(clips, max_batch_size=batch_size)
            keepface_mask_prob = (
                1
                - torch.max(
                    F.interpolate(
                        heatmap.float(), (32, 32), mode="bilinear", align_corners=False
                    ).to(heatmap.dtype),
                    dim=1,
                ).values
            ).view(heatmap.shape[0], -1)
            feat = []
            for i in range(0, clips.shape[0], batch_size):
                output = self.backbone(
                    clips[i : i + batch_size],
                    do_mask=True,
                    mask_prob=keepface_mask_prob[i : i + batch_size],
                )
                if self.use_prompt_tokens:
                    feat.append(output.prompt_tokens.detach().cpu())
                else:
                    feat.append(output.patch_tokens.detach().cpu())

        else:
            feat = []
            for i in range(0, clips.shape[0], batch_size):
                output = self.backbone(
                    clips[i : i + batch_size],
                    do_mask=False,
                )
                if self.use_prompt_tokens:
                    feat.append(output.prompt_tokens.detach().cpu())
                else:
                    feat.append(output.patch_tokens.detach().cpu())

        feat = torch.cat(feat, dim=0)

        feat = feat.mean(dim=1).to(self.device, self.dtype)

        logits = self.classifier(feat)  # [1, K]

        # get_label name

        if self.task == "multiclass":
            # 按常规做法对 clip 概率做平均（多视角共识）
            probs = torch.softmax(logits.float().mean(dim=0), dim=-1)  # [num_classes]
            pred = int(probs.argmax().item())
            return {
                "probs": probs.float().cpu(),
                "pred": pred,
                "clip_starts": clip_starts,
            }
        else:
            # multilabel：对每个 clip 的 sigmoid 概率做平均
            probs = torch.sigmoid(logits.float()).mean(dim=0)  # [num_classes]
            pred = (probs > 0.5).int().tolist()
            return {
                "probs": probs.float().cpu(),
                "pred": pred,
                "clip_starts": clip_starts,
            }


def infer(
    cfg="configs/train/slipmae/slipmae_emotion_faceonly.py",
    checkpoint="work_dirs/slipmae_emotion_faceonly/iter_7000.pth",
    video_path: str = "data/celebv-hq/videos_resampled/resampled_znZDbxVmFbM_0.mp4",
    face_only: bool = True,
):
    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    cfg = Config.fromfile(cfg)
    model = MODELS.build(cfg.model)
    load_checkpoint(model, checkpoint, map_location="cpu")

    pipeline = SlipmaeClassifyPipeline(
        video_processor=input_video_processor,
        backbone=model.backbone,
        classifier=model.classifier,
        use_prompt_tokens=model.use_prompt_tokens,
        face_only=face_only,
        raw_sapiens=model.raw_sapiens if face_only else None,
        heatmap_head=model.heatmap_head if face_only else None,
        task="multiclass",
    )

    res = pipeline(video_path)
    print(res)


if __name__ == "__main__":
    fire.Fire(infer)
