from cycler import V
from einops import rearrange
import numpy as np
import torch
import torch.nn.functional as F
from transformers import BitImageProcessor
from decord import VideoReader


def keepface_mask_prob_from_kpts(
    keypoints: torch.Tensor,  # [B, K, 3] -> (x, y, conf)
    grid_size: int = 32,
    img_hw: (
        tuple[int, int] | None
    ) = None,  # (H_img, W_img)；若坐标已归一化则可为 None 并 normalized=True
    normalized: bool = False,
    top_m: int = 100,  # 只取置信度最高的 M 个关键点（集中且稳）
    base_sigma_grid: float = 1.25,  # 以 32x32 网格为单位的基准 σ（1~2 通常足够集中）
    conf_beta: float = 0.5,  # σ ∝ 1/(conf^β)；高置信更窄
    min_sigma: float = 0.8,  # σ 的下界，避免过尖导致数值不稳
    gamma_sharpen: float = 1.4,  # 对 max 响应幂次锐化（>1 更集中）
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    输出: keepface_mask_prob ∈ [0,1]，形状 [B, grid_size*grid_size]
    语义与你的原式一致：靠近关键点 -> 概率小；远离关键点 -> 概率大。
    """
    assert keypoints.ndim == 3 and keypoints.size(-1) == 3, "keypoints 应为 [B,K,3]"
    device = keypoints.device
    B, K, _ = keypoints.shape

    xy = keypoints[..., :2]  # [B,K,2]
    conf = keypoints[..., 2].clamp_(0.0, 1.0)  # [B,K]

    # 仅保留 top-m 高置信关键点（保证“由点主导”，避免弱点抹平峰值）
    if top_m < K:
        top_idx = conf.topk(top_m, dim=1).indices  # [B, M]
        batch_idx = torch.arange(B, device=device).unsqueeze(-1)
        xy = xy[batch_idx, top_idx, :]  # [B,M,2]
        conf = conf[batch_idx, top_idx]  # [B,M]
        K = top_m

    # 坐标映射到 32x32 网格
    if normalized:
        xy_grid = xy * (grid_size - 1)
    else:
        assert img_hw is not None, "未归一化坐标需给出 img_hw=(H,W)"
        Himg, Wimg = img_hw
        xg = xy[..., 0] * (grid_size - 1) / max(Wimg - 1, 1)
        yg = xy[..., 1] * (grid_size - 1) / max(Himg - 1, 1)
        xy_grid = torch.stack([xg, yg], dim=-1)  # [B,K,2]

    # 网格中心 P=G*G
    ys = torch.linspace(0, grid_size - 1, grid_size, device=device)
    xs = torch.linspace(0, grid_size - 1, grid_size, device=device)
    Y, X = torch.meshgrid(ys, xs, indexing="ij")
    centers = torch.stack([X, Y], dim=-1).view(1, grid_size * grid_size, 2)  # [1,P,2]

    # 每点自适应 σ：σ_k = max(min_sigma, base_sigma_grid / (conf^β))
    sigma_k = base_sigma_grid / (conf.clamp(min=eps) ** conf_beta)  # [B,K]
    sigma_k = torch.maximum(sigma_k, torch.full_like(sigma_k, min_sigma))
    sigma2_k = sigma_k.unsqueeze(1) ** 2  # [B,1,K]

    # 距离与响应：等价于为每个关键点在 32x32 上放置高斯；再对 K 取最大值
    # dist2: [B,P,K]
    dist2 = ((centers.unsqueeze(2) - xy_grid.unsqueeze(1)) ** 2).sum(dim=-1)
    resp = torch.exp(-dist2 / (2.0 * sigma2_k))  # [B,P,K]
    max_resp = resp.max(dim=-1).values  # [B,P]

    # 锐化（可选）：把中高响应进一步拉高，峰更集中
    max_resp = max_resp.clamp(0, 1) ** gamma_sharpen  # [B,P]

    # 最终概率：1 - max
    keepface_mask_prob = (1.0 - max_resp).clamp(0.0, 1.0)  # [B,P]
    return keepface_mask_prob


def keepface_mask_prob_from_heatmap(
    heatmap: torch.Tensor,  # [B, K, H, W]
):
    keepface_mask_prob = (
        1
        - torch.max(
            F.interpolate(
                heatmap.float(), (32, 32), mode="bilinear", align_corners=False
            ).to(heatmap.dtype),
            dim=1,
        ).values
    ).view(heatmap.shape[0], -1)
    return keepface_mask_prob


if __name__ == "__main__":
    from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308
    from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead

    from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
        SapiensVisionTransformer,
    )
    from mmhug.models.custom_transformers.slipmae.transformer_slipmae_encoder import (
        SlipmaeEncoder,
    )
    from mmhug.models.custom_transformers.syncnet.utils import write_video
    from mmhug.pipelines.pipeline_sapiens import PipelineSapiens, SapiensOutput
    from mmhug.registry import HF_MODELS

    import torch
    import mmengine

    device = mmengine.device.get_device()
    dtype = torch.bfloat16
    video_path: str = "data/celebv-hq/videos_resampled/resampled_zwrE99cctuw_24.mp4"

    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    backbone = SapiensVisionTransformer(
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(1024, 768),
        patch_size=16,
        in_channels=3,
        out_indices=-1,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        final_norm=True,
        with_cls_token=False,
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        patch_cfg=dict(padding=2),
        layer_cfgs=dict(),
        pre_norm=False,
        out_type="featmap",
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    )
    backbone.init_weights()

    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(512, 512), heatmap_size=(192, 256), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    )
    heatmap_head.init_weights()

    mae_encoder: SlipmaeEncoder = HF_MODELS.build(
        dict(
            type="SlipmaeEncoder",
            arch="sapiens_0.3b",
            norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
            img_size=(512, 512),
            patch_size=16,
            in_channels=3,
            drop_rate=0.0,
            drop_path_rate=0.0,
            qkv_bias=True,
            norm_cfg=dict(type="LN", eps=1e-6),
            frozen_stages=-1,
            interpolate_mode="bicubic",
            layer_scale_init_value=0.0,
            pre_norm=False,
            num_extra_tokens=3,
            mask_ratio=0.75,
            init_cfg=dict(
                type="Pretrained",
                checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
            ),
        )
    ).to(device, dtype)

    pipeline = PipelineSapiens(
        image_processor=input_video_processor,
        backbone=backbone,
        heatmap_head=heatmap_head,
    )

    sapiens_output: SapiensOutput = pipeline(video_path)
    keypoints = sapiens_output.keypoints
    heatmaps = sapiens_output.heatmaps

    mask_ratio_keypoints = keepface_mask_prob_from_kpts(
        torch.from_numpy(keypoints[:, _fallback_face_indices_ex_ear_308()]).to(
            device, dtype
        ),
        img_hw=(512, 512),
    )
    print(_fallback_face_indices_ex_ear_308())
    mask_ratio_heatmaps = keepface_mask_prob_from_heatmap(
        torch.from_numpy(heatmaps[:, _fallback_face_indices_ex_ear_308()]).to(
            device, dtype
        )
    )

    vr = VideoReader(video_path)
    video = vr.get_batch(range(len(vr))).asnumpy()
    video = torch.tensor(
        input_video_processor.preprocess(video).pixel_values, dtype=dtype, device=device
    )

    mask_keypoints = (
        mae_encoder.forward(video, do_mask=True, mask_prob=mask_ratio_keypoints)["mask"]
        .detach()
        .cpu()
    )

    mask_heatmaps = (
        mae_encoder.forward(video, do_mask=True, mask_prob=mask_ratio_heatmaps)["mask"]
        .detach()
        .cpu()
    )

    mask_keypoints = rearrange(mask_keypoints, "b (h w) -> b h w", h=32, w=32)
    mask_keypoints = F.interpolate(
        mask_keypoints.unsqueeze(1),
        size=(512, 512),
        mode="nearest-exact",
    )

    mask_heatmaps = rearrange(mask_heatmaps, "b (h w) -> b h w", h=32, w=32)
    mask_heatmaps = F.interpolate(
        mask_heatmaps.unsqueeze(1),
        size=(512, 512),
        mode="nearest-exact",
    )

    video_mask_kpt = video.cpu() * mask_keypoints
    video_mask_heatmap = video.cpu() * mask_heatmaps

    # t c h w
    output_video = torch.cat([video_mask_kpt, video_mask_heatmap], dim=-1)
    # t h w c
    output_video = output_video.permute(0, 2, 3, 1)
    output_video = (output_video * 0.5 + 0.5) * 255

    # save_video
    write_video("output.mp4", output_video.float().numpy().astype(np.uint8), fps=25)
