import torch
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import torch.nn.functional as F
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import math
from typing import Literal, Optional, Tuple
from matplotlib import animation

def save_colorbar(
    cmap_name="magma",
    vmin=0.0,
    vmax=1.0,
    orientation="horizontal",
    ticks=None,
    filename="./bar.png",
    figsize=(6, 1.0),
    dpi=150,
    label=None,
):
    """
    Save a standalone colorbar image mapping [vmin, vmax] -> colormap to `filename`.
    """
    cmap = plt.get_cmap(cmap_name)
    norm = Normalize(vmin=vmin, vmax=vmax)
    sm = ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])  # needed for colorbar

    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_axes([0.05, 0.25, 0.9, 0.5])  # left, bottom, width, height
    cbar = plt.colorbar(sm, cax=ax, orientation=orientation, ticks=ticks)
    if label is not None:
        cbar.set_label(label)
    # format ticks as 0..1 if not provided
    if ticks is None:
        cbar.set_ticks([vmin, (vmin + vmax) / 2.0, vmax])
        cbar.set_ticklabels([f"{vmin:.2f}", f"{(vmin+vmax)/2.0:.2f}", f"{vmax:.2f}"])
    plt.savefig(filename, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)
    return filename


def heatmaps_to_rgb_tensor(
    mix_logits,
    target_size,
    cmap_name="magma",
    device=None,
    dtype=torch.float32,
    interp_mode="bilinear",
    cmap_mode="linear",
    lut_bins=256,
    percentile_clip=(0.0, 100.0),
    vmin=0,
    vmax=1,
    gamma=1.0,
    save_colorbar_flag=True,
    colorbar_path="./bar.png",
    colorbar_figsize=(6, 1.0),
    colorbar_dpi=150,
):
    """
    Convert mix_logits (B, T, H, W) -> RGB tensor (B, T, 3, H_out, W_out) in [0,1],
    and optionally save a standalone colorbar to `colorbar_path` (./bar.png).
    """
    # prepare tensor
    if isinstance(mix_logits, np.ndarray):
        mix = torch.from_numpy(mix_logits)
    else:
        mix = mix_logits
    if not torch.is_tensor(mix):
        raise ValueError("mix_logits must be torch.Tensor or numpy array")
    mix = mix.float()

    B, T, H, W = mix.shape
    H_out, W_out = target_size

    if device is None:
        device = mix.device if mix.device is not None else torch.device("cpu")
    mix = mix.to(device=device, dtype=dtype)

    # resize
    mix_bt = mix.view(B * T, 1, H, W)
    resized = F.interpolate(
        mix_bt,
        size=(H_out, W_out),
        mode=interp_mode,
        align_corners=False if interp_mode == "bilinear" else None,
    )
    resized = resized.squeeze(1)  # (B*T, H_out, W_out)

    # compute vmin/vmax via percentiles if requested
    flat = resized.view(-1)
    if vmin is None or vmax is None:
        if percentile_clip is not None:
            lo_pct, hi_pct = percentile_clip
            try:
                lo = torch.quantile(flat, lo_pct / 100.0).item()
                hi = torch.quantile(flat, hi_pct / 100.0).item()
            except Exception:
                flat_cpu = flat.detach().cpu().numpy()
                lo = float(np.percentile(flat_cpu, lo_pct))
                hi = float(np.percentile(flat_cpu, hi_pct))
        else:
            lo, hi = 0.0, 1.0
        if vmin is None:
            vmin = lo
        if vmax is None:
            vmax = hi

    # save colorbar BEFORE any clipping/scaling if requested
    if save_colorbar_flag:

        save_colorbar(
            cmap_name=cmap_name,
            vmin=vmin,
            vmax=vmax,
            filename=colorbar_path,
            figsize=colorbar_figsize,
            dpi=colorbar_dpi,
        )
        print("colorbar saved to", colorbar_path)

    # scale to [0,1] with vmin/vmax, then apply gamma
    eps = 1e-6
    scaled = (resized - float(vmin)) / (float(vmax) - float(vmin) + eps)
    scaled = scaled.clamp(0.0, 1.0)
    if gamma is not None and gamma != 1.0:
        scaled = scaled.pow(float(gamma))

    # build LUT
    cmap = cm.get_cmap(cmap_name)
    lut = cmap(np.linspace(0.0, 1.0, lut_bins))[:, :3].astype(np.float32)
    lut = torch.from_numpy(lut).to(device=device, dtype=dtype)

    # map via LUT
    scaled_idx = scaled * (lut_bins - 1)
    if cmap_mode == "nearest":
        idx = torch.round(scaled_idx).long().clamp(0, lut_bins - 1)
        flat_idx = idx.view(-1)
        rgb_flat = lut[flat_idx]
        rgb = rgb_flat.view(B * T, H_out, W_out, 3)
    else:
        idx0 = torch.floor(scaled_idx).long().clamp(0, lut_bins - 1)
        idx1 = torch.minimum(idx0 + 1, torch.tensor(lut_bins - 1, device=device))
        w = scaled_idx - idx0.to(dtype)
        f_idx0 = idx0.view(-1)
        f_idx1 = idx1.view(-1)
        rgb0 = lut[f_idx0]
        rgb1 = lut[f_idx1]
        w_flat = w.view(-1).unsqueeze(-1)
        rgb_flat = rgb0 * (1.0 - w_flat) + rgb1 * w_flat
        rgb = rgb_flat.view(B * T, H_out, W_out, 3)

    # reorder to (B,T,3,H_out,W_out)
    rgb = rgb.permute(0, 3, 1, 2)  # (B*T, 3, H_out, W_out)
    rgb = rgb.view(B, T, 3, H_out, W_out)
    return rgb


def _ensure_float01_tensor(t: torch.Tensor):
    """Convert to float32 tensor in [0,1]. Accept uint8 or float inputs."""
    if not torch.is_tensor(t):
        t = torch.tensor(t)
    if t.dtype == torch.uint8:
        t = t.to(dtype=torch.float32) / 255.0
    else:
        t = t.to(dtype=torch.float32)
    return t.clamp(0.0, 1.0)


def _repeat_expand_mask(mask: torch.Tensor, scale: int):
    """
    Expand mask by repeating each cell scale x scale times.
    mask: (B, T, Hm, Wm) -> returns (B, T, Hm*scale, Wm*scale)
    """
    return mask.repeat_interleave(scale, dim=2).repeat_interleave(scale, dim=3)


def _infer_mask_hw_from_L(L: int, H: int, W: int) -> Tuple[int, int]:
    """
    尝试从 L 推断出合适的 (Hm, Wm)，满足 H % Hm == 0 and W % Wm == 0，
    且 Hm * Wm == L。若找到多组解，返回第一个匹配项（遍历较小的因子优先）。
    若找不到则抛出 ValueError。
    """
    # 遍历所有因子
    for hm in range(1, int(math.sqrt(L)) + 1):
        if L % hm == 0:
            wm = L // hm
            # 两种排列都试一遍：(hm, wm) 和 (wm, hm)
            if H % hm == 0 and W % wm == 0:
                return hm, wm
            if H % wm == 0 and W % hm == 0:
                return wm, hm
    # 若上半部分未找到，完整遍历（L 可能较小，这样也可行）
    for hm in range(int(math.sqrt(L)) + 1, L + 1):
        if L % hm == 0:
            wm = L // hm
            if H % hm == 0 and W % wm == 0:
                return hm, wm
            if H % wm == 0 and W % hm == 0:
                return wm, hm
    raise ValueError(
        f"Cannot infer (Hm, Wm) from L={L} that divides image H={H}, W={W}. "
        "Make sure L == Hm*Wm and H%Hm==0 and W%Wm==0 for some factorization."
    )


def apply_block_mask_to_rgb(
    rgb: torch.Tensor,
    mask: torch.Tensor,
    alpha: float = 0.5,
    color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
) -> torch.Tensor:
    """
    在 RGB 图像上根据下采样的 patch-mask 做透明度覆盖。

    输入：
        rgb: [B, C, H, W], float tensor 值域假定在 [0,1]
        mask: [B, Hm, Wm] 或 [B, L] (L = Hm * Wm). 若为扁平 [B, L]，假定 Hm == Wm == sqrt(L)（且为整数）。
        alpha: 透明度标量 (0=完全透明, 1=完全覆盖)
        color: 覆盖颜色三元组 (r,g,b)，与 rgb 同 dtype 范围（通常 0..1）

    返回：
        out: [B, C, H, W]，与输入 rgb 同 dtype 和 device
    """
    if rgb.ndim != 4:
        raise ValueError(f"rgb must be 4D [B, C, H, W], got shape {rgb.shape}")

    B, C, H, W = rgb.shape
    if C < 3:
        raise ValueError(f"rgb should have at least 3 channels (RGB). got C={C}")

    # 处理 mask 的形状
    if mask.ndim == 3:
        # [B, Hm, Wm]
        if mask.shape[0] != B:
            raise ValueError(f"mask batch size {mask.shape[0]} != rgb batch size {B}")
        Bm, Hm, Wm = mask.shape
    elif mask.ndim == 2:
        # [B, L] -> 假设 Hm == Wm
        if mask.shape[0] != B:
            raise ValueError(f"mask batch size {mask.shape[0]} != rgb batch size {B}")
        L = mask.shape[1]
        s = int(math.isqrt(L))
        if s * s != L:
            raise ValueError(
                f"Flattened mask length L={L} is not a perfect square; cannot infer Hm=Wm."
            )
        Hm = Wm = s
        mask = mask.view(B, Hm, Wm)
    else:
        raise ValueError(
            f"mask must be shape [B, Hm, Wm] or [B, L], got ndim={mask.ndim}, shape={mask.shape}"
        )

    # 检查可整除性
    if H % Hm != 0 or W % Wm != 0:
        raise ValueError(
            f"Image HxW = ({H},{W}) must be divisible by mask Hm x Wm = ({Hm},{Wm})."
        )

    ph, pw = H // Hm, W // Wm  # 每个 patch 在像素空间的尺寸

    # 规范 mask dtype 到 rgb dtype，允许 bool/int/float mask
    mask_up = mask.to(dtype=rgb.dtype, device=rgb.device)  # [B, Hm, Wm]

    # expand patch mask -> pixel mask via repeat_interleave (保持离散性)
    # 先加 channel dim: [B, 1, Hm, Wm]
    mask_up = mask_up.unsqueeze(1)
    # 在 H 和 W 方向重复
    mask_up = mask_up.repeat_interleave(ph, dim=2).repeat_interleave(
        pw, dim=3
    )  # [B,1,H,W]

    # 构造 overlay color [B, C, H, W]
    col = torch.tensor(color, dtype=rgb.dtype, device=rgb.device).view(1, 3, 1, 1)
    if C == 3:
        overlay = col.expand(B, 3, H, W)
    else:
        # 如果 rgb 有更多通道（例如带 alpha 或其他），我们把前三通道用 color，其它通道保持原样
        overlay = rgb.clone()
        overlay[:, :3, :, :] = col.expand(B, 3, H, W)

    # alpha 为标量，允许 0..1
    if not (0.0 <= alpha <= 1.0):
        raise ValueError(f"alpha must be in [0,1], got {alpha}")

    # 混合： out = rgb * (1 - mask_up * alpha) + overlay * (mask_up * alpha)
    # mask_up shape [B,1,H,W] -> broadcast to [B,C,H,W]
    mask_broadcast = mask_up.expand(-1, C, -1, -1)
    out = rgb * (1.0 - mask_broadcast * alpha) + overlay * (mask_broadcast * alpha)

    return out


def visualize_image_and_heatmap(
    image,
    heatmaps,
    mode="max",            # 'max' | 'mean' | 'sum'
    cmap="viridis",
    figsize=(10, 5),
    save_path=None,        # 支持 .png/.jpg/.mp4/.gif
    dpi=150,
    fps=25,                # 仅视频模式下生效
    per_frame_norm=True,   # 视频模式：逐帧归一化
    alpha_image=0.5,       # 右图底图透明度
    alpha_hm=0.6           # 右图热图透明度
):
    """
    单图或视频 + heatmap 可视化。
    - image:  (H,W) or (H,W,3) or (T,H,W,3)；RGB、0~1 浮点或 0~255
    - heatmaps: (K,H,W) 或 (T,K,H,W)
    """
    import warnings

    # ---------- 工具函数：按 K 聚合 ----------
    def _combine(hm, how):
        if how == "max":
            return np.max(hm, axis=0)
        elif how == "mean":
            return np.mean(hm, axis=0)
        elif how == "sum":
            return np.sum(hm, axis=0)
        else:
            raise ValueError("mode must be one of ['max','mean','sum']")

    # ---------- 工具函数：2D resize 到 (H,W) ----------
    def _resize_2d(arr, H, W):
        """arr: (h,w) -> (H,W)"""
        h, w = arr.shape[-2], arr.shape[-1]
        if (h, w) == (H, W):
            return arr
        # 优先用 OpenCV；否则用 torch；两者都没有则报错
        try:
            import cv2
            # OpenCV 的插值建议：缩小用 INTER_AREA，放大用 INTER_LINEAR。:contentReference[oaicite:2]{index=2}
            if H < h or W < w:
                inter = cv2.INTER_AREA
            else:
                inter = cv2.INTER_LINEAR
            return cv2.resize(arr.astype(np.float32), (W, H), interpolation=inter)
        except Exception:
            try:
                import torch
                import torch.nn.functional as F
                ten = torch.from_numpy(arr.astype(np.float32))[None, None]  # (1,1,h,w)
                out = F.interpolate(ten, size=(H, W), mode="bilinear", align_corners=False)  # :contentReference[oaicite:3]{index=3}
                return out[0, 0].numpy()
            except Exception as e:
                raise RuntimeError(
                    "Need OpenCV (cv2) or PyTorch to resize heatmaps."
                ) from e

    # ---------- 统一 image 到 [T,H,W,3] RGB uint8 显示域 ----------
    img = np.asarray(image)
    if img.ndim == 2:
        img = np.stack([img] * 3, axis=-1)  # 灰度 -> 伪 RGB
    if img.ndim == 3:
        img = img[None, ...]  # -> (T=1,H,W,3)
    assert img.ndim == 4 and img.shape[-1] == 3, f"image must be HWC or THWC RGB, got {img.shape}"
    # 归一化到 0~1，再转绘制用浮点
    if img.dtype != np.float32 and img.dtype != np.float64:
        img = img.astype(np.float32)
    if img.max() > 1.0:   # 兼容 0-255
        img = img / 255.0
    img = np.clip(img, 0.0, 1.0)

    T, H, W, _ = img.shape

    # ---------- 统一 heatmaps 到 (T,K,Hh,Wh) ----------
    hm = np.asarray(heatmaps)
    if hm.ndim == 3:  # (K,H,W) -> (1,K,H,W)
        hm = hm[None, ...]
    assert hm.ndim == 4, f"heatmaps must be (K,H,W) or (T,K,H,W), got {hm.shape}"
    assert hm.shape[0] == T, f"frame count mismatch: image T={T}, heatmap T={hm.shape[0]}"
    _, K, Hh, Wh = hm.shape

    # ---------- 每帧：聚合 -> 归一化 -> 尺寸对齐 ----------
    combined = np.empty((T, H, W), dtype=np.float32)
    for t in range(T):
        cm = _combine(hm[t], mode)  # (Hh,Wh)
        cm = _resize_2d(cm, H, W)   # (H,W)
        if per_frame_norm:
            vmin, vmax = float(cm.min()), float(cm.max())
            if vmax - vmin < 1e-12:
                cm = np.zeros_like(cm)
            else:
                cm = (cm - vmin) / (vmax - vmin + 1e-8)
        combined[t] = cm.astype(np.float32)

    # ---------- 单图模式：保持原有左右对照 + colorbar ----------
    if T == 1:
        fig, axes = plt.subplots(1, 2, figsize=figsize)
        axes[0].imshow(img[0], interpolation="nearest")  # imshow 支持 0~1 RGB 浮点和 alpha 叠加。:contentReference[oaicite:4]{index=4}
        axes[0].set_title("Original Image")
        axes[0].axis("off")

        axes[1].imshow(img[0], interpolation="nearest", alpha=alpha_image)
        hm_im = axes[1].imshow(combined[0], interpolation="nearest", cmap=cmap, alpha=alpha_hm, vmin=0, vmax=1)  # 透明叠加。:contentReference[oaicite:5]{index=5}
        axes[1].set_title(f"Image + Combined Heatmap ({mode})")
        axes[1].axis("off")

        fig.colorbar(hm_im, ax=axes, fraction=0.046, pad=0.04)
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
            print(f"已保存到: {save_path}")
        plt.show()
        return

    # ---------- 视频模式：做动画（左原帧，右叠加），可保存为 mp4/gif ----------
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    ax0, ax1 = axes
    ax0.set_title("Original Frame")
    ax1.set_title(f"Frame + Combined Heatmap ({mode})")
    for ax in axes:
        ax.axis("off")

    im0 = ax0.imshow(img[0], interpolation="nearest")
    im1_bg = ax1.imshow(img[0], interpolation="nearest", alpha=alpha_image)
    hm_im = ax1.imshow(combined[0], interpolation="nearest", cmap=cmap, alpha=alpha_hm, vmin=0, vmax=1)
    cbar = fig.colorbar(hm_im, ax=axes, fraction=0.046, pad=0.04)

    def _update(i):
        im0.set_data(img[i])
        im1_bg.set_data(img[i])
        hm_im.set_data(combined[i])
        return im0, im1_bg, hm_im

    ani = animation.FuncAnimation(fig, _update, frames=T, interval=1000.0 / fps, blit=False)  # Matplotlib 动画接口。:contentReference[oaicite:6]{index=6}

    if save_path:
        if save_path.lower().endswith(".mp4"):
            try:
                writer = animation.FFMpegWriter(fps=fps)
                ani.save(save_path, writer=writer, dpi=dpi)
                print(f"已保存视频到: {save_path}")
            except Exception as e:
                print(f"保存 MP4 失败：{e}\n请确认本机已安装 FFmpeg。")
        elif save_path.lower().endswith(".gif"):
            try:
                writer = animation.PillowWriter(fps=fps)
                ani.save(save_path, writer=writer, dpi=dpi)
                print(f"已保存 GIF 到: {save_path}")
            except Exception as e:
                print(f"保存 GIF 失败：{e}")
        else:
            # 其它后缀：保存首帧静态图
            plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
            print(f"已保存首帧静态图到: {save_path}")

    plt.tight_layout()
    plt.show()


def blend_image_with_heatmaps(
    image: torch.Tensor,
    heatmaps: torch.Tensor,
    mode: str = "max",
    cmap: str = "viridis",
    alpha: float = 0.6,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    将 image 和 heatmaps 混合（批量处理），返回与 image 相同 shape 的 tensor。

    Args:
        image: torch.Tensor, [B, C, H, W], values in [0,1]
        heatmaps: torch.Tensor, [B, K, H, W], values in [0,1]
        mode: 'max'|'mean'|'sum' 聚合 heatmaps 维度 K
        cmap: matplotlib colormap name (用于 C==3 的彩色 overlay)
        alpha: float in [0,1], overlay 强度缩放因子
        eps: small eps for numeric stability when normalizing

    Returns:
        out: torch.Tensor, [B, C, H, W], values in [0,1], same dtype & device as `image`
    """
    if image.ndim != 4:
        raise ValueError(f"image must be [B,C,H,W], got {image.shape}")
    if heatmaps.ndim != 4:
        raise ValueError(f"heatmaps must be [B,K,H,W], got {heatmaps.shape}")

    B, C, H, W = image.shape
    B_h, K, H_h, W_h = heatmaps.shape
    if B != B_h:
        raise ValueError(
            f"batch size mismatch: image batch {B} vs heatmaps batch {B_h}"
        )
    if H != H_h or W != W_h:
        raise ValueError(
            f"spatial size mismatch: image {(H,W)} vs heatmaps {(H_h,W_h)}"
        )

    # aggregate across K
    if mode == "max":
        combined = torch.max(heatmaps, dim=1)[0]  # [B, H, W]
    elif mode == "mean":
        combined = torch.mean(heatmaps, dim=1)  # [B, H, W]
    elif mode == "sum":
        combined = torch.sum(heatmaps, dim=1)  # [B, H, W]
    else:
        raise ValueError("mode must be one of ['max','mean','sum']")

    # normalize each sample to [0,1] (per-sample)
    # avoid division by zero if constant map
    combined_min = combined.view(B, -1).min(dim=1)[0].view(B, 1, 1)
    combined_max = combined.view(B, -1).max(dim=1)[0].view(B, 1, 1)
    denom = (combined_max - combined_min).clamp(min=eps)
    combined_norm = (combined - combined_min) / denom  # [B, H, W], float tensor

    device = image.device
    dtype = image.dtype

    # prepare mask alpha [B, 1, H, W]
    mask_alpha = (combined_norm * alpha).unsqueeze(1)  # [B,1,H,W]

    # Case C == 3: use colormap to create RGB overlay
    if C == 3:
        # map each combined_norm[b] to RGB via matplotlib colormap
        cmap_func = cm.get_cmap(cmap)
        overlays = []
        # convert to cpu numpy for colormap; do per-sample loop (usually small batch)
        cn_cpu = combined_norm.float().detach().cpu().numpy()  # shape [B, H, W]
        for b in range(B):
            arr = cn_cpu[b]  # HxW float in [0,1]
            rgba = cmap_func(arr)  # H x W x 4, floats in [0,1]
            rgb = rgba[..., :3]  # H x W x 3
            # transpose to C x H x W and convert back to torch
            rgb_t = torch.from_numpy(np.ascontiguousarray(rgb.transpose(2, 0, 1))).to(
                dtype=dtype, device=device
            )
            overlays.append(rgb_t.unsqueeze(0))  # [1,3,H,W]
        overlay = torch.cat(overlays, dim=0)  # [B,3,H,W]
        # blend
        mask_alpha_broadcast = mask_alpha.expand(-1, 3, -1, -1)  # [B,3,H,W]
        out_rgb3 = image * (1.0 - mask_alpha_broadcast) + overlay * (
            mask_alpha_broadcast
        )
        out = out_rgb3
    else:
        # For single-channel images or other channel counts:
        # - If C == 1: do grayscale blending with combined_norm
        # - If C > 3: blend first 3 channels with color and keep others unchanged OR
        #   here we choose to do grayscale blend for all channels (simple & safe).
        # We'll do grayscale blend: out = image * (1 - mask_alpha) + combined_norm*alpha broadcast
        combined_gray = combined_norm.unsqueeze(1)  # [B,1,H,W]
        mask_alpha_broadcast = mask_alpha  # [B,1,H,W]
        gray_overlay = combined_gray  # values in [0,1]
        out = image * (1.0 - mask_alpha_broadcast) + gray_overlay * mask_alpha_broadcast
        # if C > 1, broadcast gray overlay across channels
        if C > 1:
            out = out.expand(-1, C, -1, -1)

    # ensure same dtype/device & clamp to [0,1]
    out = out.to(device=device, dtype=dtype).clamp(0.0, 1.0)
    return out


def visualize_keypoint_heatmaps(
    heatmap: np.ndarray, colors: np.ndarray | None = None, to_uint8: bool = True
):
    """
    将关键点热力图可视化为 RGB 图（或序列），像素值由“各通道最大值”组成。
    输入:
      - heatmap: np.ndarray, 形状为 (K, H, W) 或 (T, K, H, W)，数值应在 [0, 1]
      - colors:  None 或形状 (K, 3) 的数组，RGB ∈ [0, 1]；若为 None，自动用 tab20 生成
      - to_uint8: 是否输出 uint8（0..255）。False 则输出 float32（0..1）
    输出:
      - out: 若输入为 (K, H, W) → (H, W, 3)
             若输入为 (T, K, H, W) → (T, H, W, 3)
    """
    hm = np.asarray(heatmap)
    if hm.ndim == 3:
        # (K, H, W)
        K, H, W = hm.shape
        hm_exp = hm[..., None]  # (K, H, W, 1)
        axis_k = 0
    elif hm.ndim == 4:
        # (T, K, H, W)
        T, K, H, W = hm.shape
        hm_exp = hm[..., None]  # (T, K, H, W, 1)
        axis_k = 1
    else:
        raise ValueError("heatmap shape must be (K,H,W) or (T,K,H,W)")

    # 保障取值范围
    hm_exp = np.clip(hm_exp, 0.0, 1.0)

    # 颜色表：若未提供，则用 matplotlib 的 tab20 生成 K 个颜色
    if colors is None:
        try:
            import matplotlib.pyplot as plt

            cmap = plt.get_cmap(
                "tab20"
            )  # 常用定性调色板 :contentReference[oaicite:1]{index=1}
            cols = np.array([cmap(i % 20)[:3] for i in range(K)], dtype=np.float32)
        except Exception:
            # 兜底：HSV 均匀取样
            h = np.linspace(0, 1, K, endpoint=False)
            cols = np.stack(
                [
                    np.abs(np.cos(2 * np.pi * h)),  # 简单产生 0..1 的 RGB
                    np.abs(np.cos(2 * np.pi * (h + 1 / 3))),
                    np.abs(np.cos(2 * np.pi * (h + 2 / 3))),
                ],
                axis=1,
            ).astype(np.float32)
            cols /= cols.max(axis=0, keepdims=True) + 1e-8
    else:
        cols = np.asarray(colors, dtype=np.float32)
        if cols.shape != (K, 3):
            raise ValueError(f"colors must have shape (K,3), got {cols.shape}")

    # 广播到 (K, 1, 1, 3) 或 (1, K, 1, 1, 3) 后逐像素取 RGB 最大值（沿 K 轴）
    if hm.ndim == 3:
        cols_b = cols[:, None, None, :]  # (K,1,1,3)
        weighted = (
            hm_exp * cols_b
        )  # (K,H,W,3) 逐元素广播 :contentReference[oaicite:2]{index=2}
        out = weighted.max(
            axis=axis_k
        )  # (H,W,3)   沿 K 轴取最大 :contentReference[oaicite:3]{index=3}
    else:
        cols_b = cols[None, :, None, None, :]  # (1,K,1,1,3)
        weighted = hm_exp * cols_b  # (T,K,H,W,3)
        out = weighted.max(axis=axis_k)  # (T,H,W,3)

    out = np.clip(out, 0.0, 1.0).astype(np.float32)
    if to_uint8:
        out = (out * 255.0 + 0.5).astype(np.uint8)
    return out


import math
from typing import Iterable, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np


Color = Union[str, Tuple[int, int, int]]  # BGR in cv2
Point = Tuple[int, int]
Link = Tuple[int, int]  # (i, j) over keypoint indices


def _as_thwc_rgb(img: np.ndarray) -> np.ndarray:
    """Accept HWC RGB or THWC RGB, return (T,H,W,3) uint8 RGB."""
    x = np.asarray(img)
    if x.ndim == 3:
        H, W, C = x.shape
        assert C == 3, f"Expect HWC RGB, got shape={x.shape}"
        x = x[None]
    elif x.ndim == 4:
        assert x.shape[-1] == 3, f"Expect THWC RGB, got shape={x.shape}"
    else:
        raise ValueError(f"Unsupported image ndim={x.ndim}")
    if x.dtype != np.uint8:
        x = np.clip(x, 0, 255).astype(np.uint8)
    return x


def _colors_like_k(
    k: int, color: Optional[Union[Color, Sequence[Color]]]
) -> List[Optional[Color]]:
    # None or single -> broadcast；列表长度必须等于 K
    if color is None or isinstance(color, (str, tuple)):
        return [color] * k
    color = list(color)
    if len(color) != k:
        raise ValueError(f"len(color)={len(color)} must equal K={k}")
    return color


def _draw_line_mmpose(
    img: np.ndarray, p1: Point, p2: Point, color: Color, thickness: int
):
    cv2.line(
        img,
        p1,
        p2,
        color=color if isinstance(color, tuple) else (0, 255, 0),
        thickness=thickness,
        lineType=cv2.LINE_AA,
    )


def _draw_link_openpose(
    img: np.ndarray, p1: Point, p2: Point, color: Color, line_width: int
):
    # 参考 mmpose 可视化做法：用椭圆多边形形成“棒状”肢体。:contentReference[oaicite:3]{index=3}
    x = np.array([p1[0], p2[0]], dtype=np.int32)
    y = np.array([p1[1], p2[1]], dtype=np.int32)
    mX = float(x.mean())
    mY = float(y.mean())
    length = float(np.hypot(y[0] - y[1], x[0] - x[1]))
    angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
    poly = cv2.ellipse2Poly(
        (int(mX), int(mY)), (int(length / 2), int(line_width)), int(angle), 0, 360, 1
    )
    col = color if isinstance(color, tuple) else (0, 255, 0)
    cv2.fillConvexPoly(img, poly, color=col, lineType=cv2.LINE_AA)


def draw_keypoints_sequence(
    image: np.ndarray,
    keypoints: np.ndarray,  # (T, K, 3)  -> x,y,score
    skeleton: Optional[Sequence[Link]] = None,
    kpt_color: Optional[Union[Color, Sequence[Color]]] = (0, 255, 0),
    link_color: Optional[Union[Color, Sequence[Color]]] = (255, 0, 0),
    kpt_thr: float = 0.3,  # 与 MMPose 默认一致。:contentReference[oaicite:4]{index=4}
    radius: int = 3,
    line_width: int = 2,
    show_kpt_idx: bool = False,
    skeleton_style: str = "mmpose",  # 'mmpose' | 'openpose'（仅画法不同）:contentReference[oaicite:5]{index=5}
) -> np.ndarray:
    """
    在图像(或帧序列)上以 MMPose 的风格绘制关键点与骨架。

    参数:
      - image: HWC RGB 或 THWC RGB（uint8）
      - keypoints: (T, K, 3) 数组，通道为 (x, y, score)，像素坐标须与 image 尺寸一致
      - skeleton: 可选，骨架连线列表 [(i, j), ...]
      - kpt_color/link_color: None/单色/长度为 K 或 len(skeleton) 的列表；OpenCV 用 BGR, 这里统一接受 BGR tuple；传字符串时用默认色
      - kpt_thr: 分数阈值（<kpt_thr 的点/边不画）
      - skeleton_style: 'mmpose' 画直线；'openpose' 画“椭圆棒状”边

    返回:
      - 绘制后的图像；形状与输入 image 一致（HWC 或 THWC，RGB uint8）
    """
    # 规范化图像、关键点形状
    img_seq = _as_thwc_rgb(image)  # (T, H, W, 3) RGB
    kpts = np.asarray(keypoints, dtype=np.float32)
    assert (
        kpts.ndim == 3 and kpts.shape[-1] == 3
    ), f"keypoints must be (T,K,3), got {kpts.shape}"
    T_img, H, W, _ = img_seq.shape
    T_kp, K, _ = kpts.shape
    assert T_kp == T_img, f"frame count mismatch: image T={T_img}, keypoints T={T_kp}"

    # 颜色准备
    kpt_colors = _colors_like_k(K, kpt_color)
    link_colors = None
    if skeleton is not None and link_color is not None:
        link_colors = _colors_like_k(len(skeleton), link_color)

    out = img_seq.copy()
    for t in range(T_img):
        canvas = np.zeros_like(out[t]) if skeleton_style == "openpose" else out[t]
        kp = kpts[t]  # (K,3)
        scores = kp[:, 2]
        # 画连线
        if skeleton is not None and link_colors is not None:
            for sk_id, (i, j) in enumerate(skeleton):
                if i < 0 or j < 0 or i >= K or j >= K:
                    continue
                if scores[i] < kpt_thr or scores[j] < kpt_thr:
                    continue
                xi, yi = int(round(kp[i, 0])), int(round(kp[i, 1]))
                xj, yj = int(round(kp[j, 0])), int(round(kp[j, 1]))
                # 边界检查（MMPose 也会跳过越界/不可见）:contentReference[oaicite:6]{index=6}
                if not (0 <= xi < W and 0 <= yi < H and 0 <= xj < W and 0 <= yj < H):
                    continue
                color = link_colors[sk_id]
                if color is None:
                    continue
                if isinstance(color, str):
                    color = (255, 0, 0)  # default BGR
                if skeleton_style == "openpose":
                    _draw_link_openpose(canvas, (xi, yi), (xj, yj), color, line_width)
                else:
                    _draw_line_mmpose(
                        canvas, (xi, yi), (xj, yj), color, thickness=line_width
                    )

        # 画关键点
        for kid in range(K):
            if scores[kid] < kpt_thr:
                continue
            x, y = int(round(kp[kid, 0])), int(round(kp[kid, 1]))
            if not (0 <= x < W and 0 <= y < H):
                continue
            color = kpt_colors[kid]
            if color is None:
                continue
            if isinstance(color, str):
                color = (0, 255, 0)  # default BGR
            cv2.circle(
                canvas, (x, y), radius, color, thickness=-1, lineType=cv2.LINE_AA
            )
            if show_kpt_idx:
                # 在点附近标 index
                cv2.putText(
                    canvas,
                    str(kid),
                    (x + radius, y - radius),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    color,
                    thickness=1,
                    lineType=cv2.LINE_AA,
                )

        out[t] = canvas

    # 单帧输入 -> 单帧输出
    return out[0] if image.ndim == 3 else out
