from typing import Optional, Tuple, Union
import fire
from mmengine.device import get_device
import numpy as np
import torch
from tqdm import tqdm
from transformers import BitImageProcessor
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)
import torch.nn.functional as F
from torchvision.io import read_image, read_video, write_video
from decord import VideoReader
from mmhug.utils.vis_utils import draw_keypoints_sequence
from diffusers.utils import BaseOutput


def is_video_file(video_path: str) -> bool:
    return video_path.endswith((".mp4", ".avi", ".mov"))


def is_image_file(image_path: str) -> bool:
    return image_path.endswith((".png", ".jpg", ".jpeg"))


def resize_keypoints_tk2(
    kpts: torch.Tensor,  # [T, K, 2], in (x, y)
    src_size: tuple[int, int] = (768, 1024),  # (W_src, H_src)  ← 注意顺序！
    dst_size: tuple[int, int] = (512, 512),  # (W_dst, H_dst)
    clamp: bool = True,
    inplace: bool = False,
) -> torch.Tensor:
    """
    将像素坐标从 src_size 映射到 dst_size。
    - kpts: [T, K, 2]，最后一维是 (x, y)。若是 (y, x)，把 scale 的顺序对调或先交换通道。
    - src_size: (W_src, H_src)，例如 Sapiens 的 (768, 1024)
    - dst_size: (W_dst, H_dst)，例如目标 (512, 512)
    """
    assert kpts.ndim == 3 and kpts.shape[-1] == 2, "kpts must be [T, K, 2]"
    Wx, Hy = src_size
    Wd, Hd = dst_size
    sx = float(Wd) / float(Wx)
    sy = float(Hd) / float(Hy)

    out = kpts if inplace else kpts.clone()
    scale = torch.tensor([sx, sy], dtype=out.dtype, device=out.device).view(1, 1, 2)
    out = out * scale

    if clamp:
        out[..., 0].clamp_(0, Wd - 1)
        out[..., 1].clamp_(0, Hd - 1)
    return out


class SapiensOutput(BaseOutput):
    keypoints: np.ndarray
    heatmaps: np.ndarray


class PipelineSapiens:
    def __init__(
        self,
        image_processor: BitImageProcessor,
        backbone: SapiensVisionTransformer,
        heatmap_head: HeatmapHead,
        dtype=torch.bfloat16,
    ):
        self.device = get_device()
        self.dtype = dtype

        self.image_processor = image_processor
        self.backbone = backbone.to(self.device, self.dtype)
        self.heatmap_head = heatmap_head.to(self.device, self.dtype)

        self.backbone.eval()
        self.heatmap_head.eval()

    # ---------- 小工具：把任意数组/张量规范成 (T,H,W,3) uint8 ----------
    def _as_thwc_uint8(self, x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        x = np.asarray(x)

        if x.ndim == 3:
            # 可能是 (C,H,W) 或 (H,W,C)
            if x.shape[0] in (1, 3):  # CHW -> HWC
                x = np.transpose(x, (1, 2, 0))
            elif x.shape[-1] in (1, 3):  # HWC -> HWC
                pass
            else:
                raise ValueError(f"Unsupported 3D shape: {x.shape}")
            x = x[None]  # -> (1,H,W,C)
        elif x.ndim == 4:
            # 可能是 (T,C,H,W) 或 (T,H,W,C)
            if x.shape[1] in (1, 3) and x.shape[-1] not in (1, 3):  # TCHW -> THWC
                x = np.transpose(x, (0, 2, 3, 1))
            elif x.shape[-1] in (1, 3):  # THWC
                pass
            else:
                raise ValueError(f"Unsupported 4D shape: {x.shape}")
        else:
            raise ValueError(f"Expect 3D/4D array, got {x.ndim}D")

        if x.dtype != np.uint8:
            x = np.clip(x, 0, 255).astype(np.uint8)
        return x  # (T,H,W,C), uint8

    @torch.no_grad()
    def get_sapiens_heatmap(
        self,
        imgs: torch.Tensor,
    ) -> torch.Tensor:
        """
        输入: imgs [B, C, H, W] (torch)
        输出: heatmap [B, K, h, w] (torch, on CPU, float32)
        内部会按 max_batch_size 进一步切批，避免显存峰值。
        """

        feats = self.backbone(
            imgs.to(self.device, self.dtype), out_type="featmap"
        ).last_hidden_state
        # 256 192
        heatmap, pred = self.heatmap_head(feats, decode_kpt=True)

        keypoints = np.concatenate([p.keypoints for p in pred], axis=0)
        keypoints_scores = np.concatenate([p.keypoint_scores for p in pred], axis=0)

        heatmap = heatmap.float().detach().cpu().numpy()
        keypoints = keypoints
        keypoints_scores = keypoints_scores
        return heatmap, keypoints, keypoints_scores

    @torch.no_grad()
    def __call__(
        self,
        video: Union[str, np.ndarray, torch.Tensor],
        batch_size: int = 64,  # 每次从磁盘取多少帧参与前向
    ) -> np.ndarray:
        """
        返回完整 heatmap：np.ndarray，形状 [T, K, h, w]，数值 float32。
        - 对于视频路径：使用 Decord 流式解码，分块前向，避免把整段视频载入内存。
        - 对于数组/张量：同样按 batch 切分计算，减少峰值占用。
        """

        # ---------- 情况 A：输入为视频路径（流式解码） ----------
        if isinstance(video, str) and is_video_file(video):
            vr = VideoReader(
                video
            )  # 懒加载，不会把整段放入内存  :contentReference[oaicite:1]{index=1}
            T_total = len(vr)

            all_heatmaps = []
            all_keypoints = []
            all_keypoints_scores = []
            for s in tqdm(range(0, T_total, batch_size)):
                e = min(s + batch_size, T_total)
                # 1) 小批抓帧: (B,H,W,3) uint8
                frames = vr.get_batch(
                    list(range(s, e))
                ).asnumpy()  # 亦可切片 vr[s:e]  :contentReference[oaicite:2]{index=2}
                # 2) 预处理到 [B,C,H,W] (float32, CPU)
                pixel_values = self.image_processor.preprocess(
                    frames, return_tensors="pt"
                ).pixel_values
                # 3) return keypoints in numpy format, in float 16
                heatmaps, keypoints, keypoints_scores = self.get_sapiens_heatmap(
                    pixel_values
                )
                all_heatmaps.append(heatmaps)
                all_keypoints.append(keypoints)
                all_keypoints_scores.append(keypoints_scores)

            all_keypoints = np.concatenate(all_keypoints, axis=0)
            all_keypoints_scores = np.concatenate(all_keypoints_scores, axis=0)
            all_keypoints = np.concatenate(
                [all_keypoints, all_keypoints_scores[..., None]], axis=-1
            )
            all_heatmaps = np.concatenate(all_heatmaps, axis=0)
            return SapiensOutput(keypoints=all_keypoints, heatmaps=all_heatmaps)

        # ---------- 情况 B：输入为单张图 / 已加载的帧数组 ----------
        # 统一成 (T,H,W,3) uint8，再按 A 的流程切批
        if isinstance(video, str) and is_image_file(video):
            img = read_image(video)  # [C,H,W] uint8 (torch)
            frames = img.permute(1, 2, 0).unsqueeze(0).numpy()  # -> (1,H,W,3)
        elif isinstance(video, (np.ndarray, torch.Tensor)):
            frames = self._as_thwc_uint8(video)  # (T,H,W,3)
        else:
            raise ValueError(f"Unsupported input type: {type(video)}")

        T_total = frames.shape[0]
        all_heatmaps = []
        all_keypoints = []
        all_keypoints_scores = []

        for s in range(0, T_total, batch_size):
            e = min(s + batch_size, T_total)
            sub = frames[s:e]  # (B,H,W,3) uint8
            pixel_values = self.image_processor(sub, return_tensors="pt").pixel_values
            heatmaps, keypoints, keypoints_scores = self.get_sapiens_heatmap(
                pixel_values
            )

            all_heatmaps.append(heatmaps)
            all_keypoints.append(keypoints)
            all_keypoints_scores.append(keypoints_scores)

        all_keypoints = np.concatenate(all_keypoints, axis=0)
        all_keypoints_scores = np.concatenate(all_keypoints_scores, axis=0)
        all_keypoints = np.concatenate(
            [all_keypoints, all_keypoints_scores[..., None]], axis=-1
        )
        all_heatmaps = np.concatenate(all_heatmaps, axis=0)

        return SapiensOutput(keypoints=all_keypoints, heatmaps=all_heatmaps)


def infer(
    video_path: str = "data/hallo3/hallo3_training_data/videos_cropped_new/fe0eb399d3372546b6437401d707551b.mp4",
):
    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    backbone = SapiensVisionTransformer(
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(1024, 768),
        patch_size=16,
        in_channels=3,
        out_indices=-1,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        final_norm=True,
        with_cls_token=False,
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        patch_cfg=dict(padding=2),
        layer_cfgs=dict(),
        pre_norm=False,
        out_type="featmap",
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    )
    backbone.init_weights()

    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(512, 512), heatmap_size=(192, 256), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    )
    heatmap_head.init_weights()

    pipeline = PipelineSapiens(
        image_processor=input_video_processor,
        backbone=backbone,
        heatmap_head=heatmap_head,
    )

    sapiens_output: SapiensOutput = pipeline(video_path)
    keypoints = sapiens_output.keypoints
    # visualize heatmap and video

    visualize_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=False,
        do_normalize=False,
    )
    video = read_video(video_path)[0].numpy()
    video = visualize_video_processor.preprocess(
        video, return_tensors="pt"
    ).pixel_values
    video = draw_keypoints_sequence(video.permute(0, 2, 3, 1).numpy(), keypoints)

    write_video("keypoints.mp4", video, fps=25)


if __name__ == "__main__":
    fire.Fire(infer)
