import os
import cv2
import torch
import argparse
import numpy as np
from torch import nn, Tensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms
from typing import List, Tuple, Union, Optional, Dict
from dual_stream.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from dual_stream.vggt.utils.geometry import unproject_depth_map_to_point_map
import time
from functools import wraps
from vggt.utils.load_fn import load_and_preprocess_images
import torchvision.transforms as TF
from torch.cuda.amp import autocast

DEBUG = False

def _norm_rgb(x):
    # 将 RGB 像素值从 [0, 255] 归一化到 [-1, 1], 适配神经网络的输入范围要求
    return (x.float() / 255.0) * 2.0 - 1.0

def prefix_keys(original_dict, prefix):
    """在字典所有键名前添加指定前缀"""
    return {f"{prefix}_{key}": value for key, value in original_dict.items()}

    def timeit(func):
        """装饰器：测量函数执行时间并打印"""
        @wraps(func)  # 保留原函数的元信息（如__name__）
        def wrapper(*args, **kwargs):
            start_time = time.perf_counter()  # 高精度计时
            result = func(*args, **kwargs)    # 执行原函数
            end_time = time.perf_counter()
            elapsed = end_time - start_time
            if DEBUG:
                print(f"[{func.__name__}] 执行时间: {elapsed:.4f} 秒")
            return result
        return wrapper

def preprocess_tensor_images(tensor, mode="crop"):
    """
    直接处理张量输入，保持与 load_and_preprocess_images 相同的逻辑
    输入: [B, C, H, W] 或 [C, H, W]
    返回: [B, 3, 518, 518] (GPU张量)
    """
    raw_tensor = tensor         # [1,3,3,224,224]
    num_img = tensor.shape[1]
    target_size = 518
    _, _, _, height, width = tensor.shape

    resized_images = []
    for img_index in range(num_img):
        tensor = raw_tensor[:, img_index]
        if tensor.dim() == 3:
            tensor = tensor.unsqueeze(0)  # [C,H,W] -> [1,C,H,W]
        if mode == "crop":
            # 保持宽高比，调整宽度到518，高度按比例裁剪
            new_width = target_size
            new_height = round(height * (new_width / width) / 14) * 14      # int(height * (new_width / width))
            # tensor = TF.resize(tensor, [new_height, new_width],interpolation=InterpolationMode.BICUBIC,antialias=True)
            tensor = F.interpolate(
                tensor, 
                size=(new_height, new_width), 
                mode='bicubic', 
                align_corners=False
            )
            # 假设 tensor 是 [C, H, W] 且值域 [0, 1]
            
            # 中心裁剪高度到518
            if new_height > target_size:
                start_y = (new_height - target_size) // 2
                tensor = tensor[:, :, start_y : start_y + target_size, :]
        
        elif mode == "pad":
            # 调整最大边到518，另一侧填充
            scale = target_size / max(height, width)
            new_height, new_width = int(height * scale), int(width * scale)
            # tensor = TF.resize(tensor, [new_height, new_width])
            tensor = F.interpolate(
                tensor, 
                size=(new_height, new_width), 
                mode='bicubic', 
                align_corners=False
            )
            
            # 填充到518x518
            pad_h = target_size - new_height
            pad_w = target_size - new_width
            # tensor = TF.pad(tensor, [pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=1.0)
            F.pad(
                tensor,
                (pad_w//2, pad_w - pad_w//2,  # 左右填充 (左, 右)
                pad_h//2, pad_h - pad_h//2),   # 上下填充 (上, 下)
                value=1.0
            )
        resized_images.append(tensor)
    resized_tensor = torch.stack(resized_images, dim=1)  # concatenate images
    return resized_tensor
 
def render(pc, img_feat, img_aug, mvt1_or_mvt2, mvt1 , mvt2 , dyn_cam_info=None , renderer = None ):
    assert isinstance(mvt1_or_mvt2, bool)
    if mvt1_or_mvt2:
        mvt = mvt1
    else:
        mvt = mvt2

    with torch.no_grad():
        with autocast(enabled=False):
            if dyn_cam_info is None:
                dyn_cam_info_itr = (None,) * len(pc)
            else:
                dyn_cam_info_itr = dyn_cam_info

            if mvt.add_corr:
                if mvt.norm_corr:
                    img = []
                    for _pc, _img_feat, _dyn_cam_info in zip(
                        pc, img_feat, dyn_cam_info_itr
                    ):
                        # fix when the pc is empty
                        max_pc = 1.0 if len(_pc) == 0 else torch.max(torch.abs(_pc))
                        _img_feat = _img_feat.to(max_pc.device)
                        img.append(
                            renderer(
                                _pc,
                                torch.cat((_pc / max_pc, _img_feat), dim=-1),
                                fix_cam=True,
                                dyn_cam_info=(
                                    (_dyn_cam_info,)
                                    if not (_dyn_cam_info is None)
                                    else None
                                ),
                            ).unsqueeze(0)
                        )
                else:
                    img = [
                        renderer(
                            _pc,
                            torch.cat((_pc, _img_feat), dim=-1),
                            fix_cam=True,
                            dyn_cam_info=(
                                (_dyn_cam_info,)
                                if not (_dyn_cam_info is None)
                                else None
                            ),
                        ).unsqueeze(0)
                        for (_pc, _img_feat, _dyn_cam_info) in zip(
                            pc, img_feat, dyn_cam_info_itr
                        )
                    ]
            else:
                img = [
                    renderer(
                        _pc,
                        _img_feat,
                        fix_cam=True,
                        dyn_cam_info=(
                            (_dyn_cam_info,)
                            if not (_dyn_cam_info is None)
                            else None
                        ),
                    ).unsqueeze(0)
                    for (_pc, _img_feat, _dyn_cam_info) in zip(
                        pc, img_feat, dyn_cam_info_itr
                    )
                ]

    img = torch.cat(img, 0)
    img = img.permute(0, 1, 4, 2, 3)

    # for visualization purposes
    if mvt.add_corr:
        mvt.img = img[:, :, 3:].clone().detach()
    else:
        mvt.img = img.clone().detach()

    # image augmentation
    if img_aug != 0:
        stdv = img_aug * torch.rand(1, device=img.device)
        # values in [-stdv, stdv]
        noise = stdv * ((2 * torch.rand(*img.shape, device=img.device)) - 1)
        img = torch.clamp(img + noise, -1, 1)

    if mvt.add_pixel_loc:
        bs = img.shape[0]
        pixel_loc = mvt.pixel_loc.to(img.device)
        img = torch.cat((img, pixel_loc.unsqueeze(0).repeat(bs, 1, 1, 1, 1)), dim=2)

    return img


def check_keypoint(kp):
    """检查关键点坐标范围并打印统计信息
    
    Args:
        kp (torch.Tensor): 关键点坐标张量 [N,2]
    """
    if isinstance(kp, torch.Tensor):
        kp = kp.cpu().numpy()  # 确保数据在CPU上
    x_coords = kp[..., 0].ravel()  # 展平x坐标
    y_coords = kp[..., 1].ravel()  # 展平y坐标

    # 过滤掉nan值
    valid_x = x_coords[~np.isnan(x_coords)]
    valid_y = y_coords[~np.isnan(y_coords)]

    # 打印统计信息
    print(f"\nKeypoint Coordinate Ranges :")
    if len(valid_x) > 0:
        print(f"X (width)  min: {np.min(valid_x):.1f}, max: {np.max(valid_x):.1f}")
        print(f"Y (height) min: {np.min(valid_y):.1f}, max: {np.max(valid_y):.1f}")
        print(f"Valid points: {len(valid_x)}/{len(x_coords)}")
    else:
        print("WARNING: All coordinates are NaN!")

    # 返回统计结果
    # return {
    #     "x_range": (np.min(valid_x), np.max(valid_x)) if len(valid_x) > 0 else (None, None),
    #     "y_range": (np.min(valid_y), np.max(valid_y)) if len(valid_y) > 0 else (None, None),
    #     "valid_ratio": len(valid_x)/len(x_coords) if len(x_coords)>0 else 0
    # }

def sigmoid(tensor, temp=1.0):
    """temperature controlled sigmoid

    takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
    """
    exponent = -tensor / temp
    # clamp the input tensor for stability
    exponent = torch.clamp(exponent, min=-50, max=50)
    y = 1.0 / (1.0 + torch.exp(exponent))
    return y


def compute_projection(P, points_3d):
    """
    Args:
        P: (B, 3, 4) torch tensor, projection matrix.
        points_3d: (B, ..., 3) tensor of 3D world points.

    Returns:
        proj_points: (B, ..., 2) tensor of 2D pixel coordinates.
    """
    B = P.shape[0]
    orig_shape = points_3d.shape[:-1]
    points_flat = points_3d.view(B, -1, 3)  # (B, N, 3)
    ones = torch.ones((B, points_flat.shape[1], 1), device=points_flat.device)
    points_h = torch.cat([points_flat, ones], dim=-1)  # (B, N, 4)

    # Batch matrix multiplication: (B, 3, 4) @ (B, 4, N) -> (B, 3, N)
    proj_h = torch.bmm(P, points_h.transpose(1, 2))

    # Normalize: (B, 3, N) -> (B, N, 3)
    proj_h = proj_h.transpose(1, 2)
    proj_points = proj_h[..., :2] / (proj_h[..., 2:3] + 1e-8)

    # Reshape back to original
    return proj_points.view(*orig_shape, 2)


# dino patch size is even, so the pixel corner is not really aligned, potential improvements here, borrowed from DINO-Tracker
def interpolate_features(
    descriptors, pts, h, w, normalize=True, patch_size=14, stride=14
):
    """
    从特征图中通过双线性插值提取指定坐标点的特征

    参数:
        descriptors: 特征图张量 (B, C, H, W)
        pts: 要提取特征的2D坐标点 (B, N, 2)
        h: 特征图高度
        w: 特征图宽度
        normalize: 是否对提取的特征进行L2归一化
        patch_size: 特征块大小
        stride: 特征提取步长

    返回:
        interpolated_features: 插值后的特征 (B, C, N)
    """
    last_coord_h = ((h - patch_size) // stride) * stride + (patch_size / 2)
    last_coord_w = ((w - patch_size) // stride) * stride + (patch_size / 2)
    ah = 2 / (last_coord_h - (patch_size / 2))
    aw = 2 / (last_coord_w - (patch_size / 2))
    bh = 1 - last_coord_h * 2 / (last_coord_h - (patch_size / 2))
    bw = 1 - last_coord_w * 2 / (last_coord_w - (patch_size / 2))

    a = torch.tensor([[aw, ah]]).to(pts).float()
    b = torch.tensor([[bw, bh]]).to(pts).float()
    keypoints = a * pts + b

    # Expand dimensions for grid sampling
    keypoints = keypoints.unsqueeze(
        -3
    )  # Shape becomes [batch_size, 1, num_keypoints, 2]

    # Interpolate using bilinear sampling
    interpolated_features = F.grid_sample(
        descriptors, keypoints, align_corners=True, padding_mode="border"
    )

    # interpolated_features will have shape [batch_size, channels, 1, num_keypoints]
    interpolated_features = interpolated_features.squeeze(-2)

    return (
        F.normalize(interpolated_features, dim=1)
        if normalize
        else interpolated_features
    )


def convert_camera_to_world(point_map, extrinsic):
    """
    Args:
        point_map: (B, H, W, 3)
        extrinsic: (B, 3, 4) - [R | t]
    Returns:
        world_points: (B, H, W, 3)
    """
    R = extrinsic[:, :, :3]  # (B, 3, 3)
    t = extrinsic[:, :, 3].unsqueeze(1)  # (B, 1, 3)
    R_inv = torch.inverse(R)  # (B, 3, 3)

    # Reshape point_map for batched matmul: (B, H*W, 3)
    B, H, W, _ = point_map.shape
    points_flat = point_map.view(B, -1, 3)  # (B, H*W, 3)

    # Transform: (B, H*W, 3) → (B, 3, H*W)
    transformed = torch.bmm(R_inv, (points_flat - t).transpose(1, 2)).transpose(1, 2)

    return transformed.view(B, H, W, 3)


def sample_keypoints_nms(mask, conf, N, min_distance, device=None):
    """
    使用非极大值抑制(NMS)采样关键点

    参数:
        mask (torch.Tensor): 布尔掩码张量，形状为 (B, H, W)，表示有效区域
        conf (torch.Tensor): 置信度图，形状为 (B, H, W)
        N (int): 需要采样的关键点数量
        min_distance (int): 关键点之间的最小距离（像素）
        device (torch.device, optional): 计算设备，默认为None（使用mask的设备）

    返回:
        torch.Tensor: 采样后的关键点坐标，形状为 (B, N, 2)

    功能说明:
        1. 创建分数图，将有效区域的置信度赋值给对应位置
        2. 使用最大池化实现非极大值抑制(NMS)
        3. 筛选出局部最大值点作为候选关键点
        4. 对每个批次样本采样指定数量的关键点
    """
    if device is None:
        device = mask.device
    B, H, W = mask.shape

    score_map = torch.zeros_like(mask, dtype=torch.float32, device=device)
    score_map[mask] = conf[mask]

    kernel_size = int(min_distance) * 2 + 1
    pad = kernel_size // 2

    pooled = F.max_pool2d(
        score_map.unsqueeze(1), kernel_size=kernel_size, stride=1, padding=pad
    ).squeeze(1)

    eps = 1e-6
    nms_mask = (score_map - pooled).abs() < eps
    nms_mask = nms_mask & mask
    keypoints_list = []
    for b in range(B):
        keypoints = torch.nonzero(nms_mask[b], as_tuple=False)  # (M, 2)
        M = keypoints.shape[0]
        if M == 0:
            # print("No keypoints found by nms.")
            keypoints_list.append(torch.zeros((N, 2), device=device, dtype=torch.int64))
        elif M > N:
            perm = torch.randperm(M, device=device)[:N]
            sampled_keypoints = keypoints[perm]
            keypoints_list.append(sampled_keypoints)
        else:
            # 如果关键点不足 N 个，重复采样
            repeat_times = (N + M - 1) // M
            sampled_keypoints = torch.repeat_interleave(keypoints, repeat_times, dim=0)[
                :N
            ]
            keypoints_list.append(sampled_keypoints)
    return torch.stack(keypoints_list)  # (B, N, 2)


def get_coview_mask(point_map, P, image_shape):
    """
    Args:
        point_map: (B, H, W, 3)
        P: (B, 3, 3) - projection matrix (intrinsic @ extrinsic[:3])
        image_shape: (H_img, W_img)
    Returns:
        mask: (B, H, W) - valid projection mask
    """
    H_img, W_img = image_shape
    B, H, W, _ = point_map.shape

    proj_points = compute_projection(P, point_map)  # (B, H, W, 2)

    u = proj_points[..., 0]
    v = proj_points[..., 1]

    mask = (u >= 0) & (u < W_img) & (v >= 0) & (v < H_img)
    return mask


def get_coview_masks(vggt_features, image_shape, num_view=3):
    """
    计算多视图之间的共视区域掩码
    Args:
        vggt_features: dict with keys 'point_map_view_1', ..., 'intrinsic_1', ..., 'extrinsic_1', ...
        image_shape: (H, W)
    Returns:
        masks: tuple of (B, H, W) masks for each view
    """

    B = vggt_features["point_map_view_1"].shape

    point_maps = [
        vggt_features[f"point_map_view_{i}"] for i in range(1, num_view+1)
    ]  # list of (B, H, W, 3)
    extrinsics = [
        vggt_features[f"extrinsic_{i}"] for i in range(1, num_view+1)
    ]  # list of (B, 3, 4)
    intrinsics = [
        vggt_features[f"intrinsic_{i}"] for i in range(1, num_view+1)
    ]  # list of (B, 3, 3)

    world_point_maps = []
    for i in range(num_view):
        world_points = convert_camera_to_world(
            point_maps[i], extrinsics[0]
        )  # (B, H, W, 3)
        world_point_maps.append(world_points)

    Ps = [torch.bmm(intrinsics[i], extrinsics[i]) for i in range(num_view)]  # (B, 3, 3)

    pairings = [(0, 1), (1, 2), (2, 0)]  # view1 ↔ view4, view2 ↔ view3 等

    masks = []
    for src, dst in pairings:
        P = Ps[dst]
        world_points = world_point_maps[src]
        mask = get_coview_mask(world_points, P, image_shape)  # (B, H, W)
        masks.append(mask)

    return tuple(masks)



# @timeit
def sample_keypoints(
    vggt_features,
    image_shape,
    images,
    aggregated_tokens_list,
    ps_idx,
    model,
    device,
    num_keypoints=300,
    min_distance=5,
):

    point_conf_view_1 = vggt_features["point_conf_view_1"]

    mask_1, mask_2, mask_3 = get_coview_masks(vggt_features, image_shape)  # (B, H, W)

    # 在mask为True的有效区域内，通过非极大值抑制（NMS）筛选出置信度图conf的局部最大值，最终返回最多300个关键点的坐标 (B, 300, 2)
    sampled_kp_1 = sample_keypoints_nms(
        mask_1,
        point_conf_view_1,
        N=num_keypoints,
        min_distance=min_distance,
        device=device,
    )

    if sampled_kp_1 is None:
        print("No keypoints found in the first view.")
        return None, None, None, None, None
    sampled_kp_1 = sampled_kp_1[:, :, [1, 0]].int()  # (row, col) -> (x, y)
    # list of length 4 (B, V, 2, 2)
    sampled_kp_o, vis_score, conf_score = model.track_head(
        aggregated_tokens_list, images, ps_idx, query_points=sampled_kp_1
    )

    sampled_kp_2 = sampled_kp_o[-1][:, 1].int()  # (x, y)
    sampled_kp_3 = sampled_kp_o[-1][:, 2].int()
    # sampled_kp_4 = sampled_kp_o[-1][:, 3].int()

    mh, mw = image_shape
    valid_kp_1 = (
        (sampled_kp_1[:, :, 0] >= 3)
        & (sampled_kp_1[:, :, 0] < int(mw) - 3)
        & (sampled_kp_1[:, :, 1] >= 3)
        & (sampled_kp_1[:, :, 1] < int(mh) - 3)
    )
    valid_kp_2 = (
        (sampled_kp_2[:, :, 0] >= 3)
        & (sampled_kp_2[:, :, 0] < int(mw) - 3)
        & (sampled_kp_2[:, :, 1] >= 3)
        & (sampled_kp_2[:, :, 1] < int(mh) - 3)
    )
    valid_kp_3 = (
        (sampled_kp_3[:, :, 0] >= 3)
        & (sampled_kp_3[:, :, 0] < int(mw) - 3)
        & (sampled_kp_3[:, :, 1] >= 3)
        & (sampled_kp_3[:, :, 1] < int(mh) - 3)
    )

    valid_kp = valid_kp_1 & valid_kp_2 & valid_kp_3  # (B, 300)

    # kp_1, kp_2, kp_3 = [], [], []  # list of length B
    
    bs = valid_kp.shape[0]
    kp_1 = torch.full((bs, num_keypoints, 2), float('nan'), device=device)
    kp_2 = torch.full((bs, num_keypoints, 2), float('nan'), device=device)
    kp_3 = torch.full((bs, num_keypoints, 2), float('nan'), device=device)
    
    for b in range(bs):
        mask_b = valid_kp[b]  # (300,)

        kp_b_1 = sampled_kp_1[b][mask_b]  # (N_b, 2)
        kp_b_2 = sampled_kp_2[b][mask_b]
        kp_b_3 = sampled_kp_3[b][mask_b]

        # kp_1.append(kp_b_1)
        # kp_2.append(kp_b_2)
        # kp_3.append(kp_b_3)
        
        num_valid = mask_b.sum().item()        
        if num_valid > 0:
            kp_1[b, :num_valid] = kp_b_1
            kp_2[b, :num_valid] = kp_b_2
            kp_3[b, :num_valid] = kp_b_3
            
    return kp_1, kp_2, kp_3, valid_kp, mask_1, mask_2, mask_3

# @timeit
def extract_vggt_features(rgb_vggt, model, device, return_attn=False):
    num_view = rgb_vggt.shape[1]
    # assert num_view == 3, f"shape of rgb_vggt is {rgb_vggt.shape}"
    vggt_dtype = (
        torch.bfloat16
        if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
        else torch.float16
    )
    with torch.no_grad():
        start = time.time()
        with torch.cuda.amp.autocast(dtype=vggt_dtype):
            images = rgb_vggt  # add batch dimension
            
            # process
            # images = load_and_preprocess_images(image_names).to(device)
            if return_attn:
                aggregated_tokens_list, ps_idx, attn = model.aggregator(
                    images
                )  # attn (B*S, num_heads, P, P) 全局注意力权重矩阵
            else:
                aggregated_tokens_list, ps_idx = model.aggregator(
                    images
                )
        if DEBUG:
            print(f"aggregator 耗时: {time.time() - start:.4f} 秒")
        # Predict Cameras
        start = time.time()
        pose_enc = model.camera_head(aggregated_tokens_list)[-1]
        # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
        extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
        if DEBUG:
            print(f"extrinsic 耗时: {time.time() - start:.4f} 秒")
        
        # Predict Depth Maps
        start = time.time()
        depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
        if DEBUG:
            print(f"depth_map 耗时: {time.time() - start:.4f} 秒")
        
        # Predict Point Maps
        start = time.time()
        point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
        if DEBUG:
            print(f"point_map 耗时: {time.time() - start:.4f} 秒")
        
        # Construct 3D Points from Depth Maps and Cameras
        # which usually leads to more accurate 3D points than point map branch
        point_maps_by_unprojection = []
        start = time.time()
        for i in range(depth_map.size(0)):
            point_map_by_unprojection = unproject_depth_map_to_point_map(
                depth_map[i].cpu().numpy(),  # (V, 518, 518, 1)
                extrinsic[i].cpu().numpy(),  # (V, 3, 4)
                intrinsic[i].cpu().numpy(),
            )
            point_maps_by_unprojection.append(
                torch.from_numpy(point_map_by_unprojection).float()
            )
        point_map_by_unprojection = torch.stack(
            point_maps_by_unprojection
        )  # (B, V, 518, 518, 3)
        if DEBUG:
            print(f"point_map unprojection 耗时: {time.time() - start:.4f} 秒")
        

        # 动态创建视图数据
        results = {}
        
        # 创建每个视图的点云映射、置信度等
        for i in range(num_view):
            results[f"point_map_view_{i+1}"] = (
                point_map_by_unprojection[:, i, ...].detach().clone().to(device)
            )
            results[f"point_conf_view_{i+1}"] = point_conf[:, i]
            results[f"extrinsic_{i+1}"] = extrinsic[:, i]
            results[f"intrinsic_{i+1}"] = intrinsic[:, i]
            results[f"depth_pred_{i+1}"] = depth_map[:, i].squeeze(-1)
        
        results["image_shape"] = tuple(rgb_vggt.shape[-2:])
        results["images"] = images
        results["ps_idx"] = ps_idx
        
        if return_attn:
            # 处理注意力信息
            cost_views = []
            for i in range(num_view):
                view_attn = attn.chunk(num_view, dim=0)[i]
                cost_view = view_attn.mean(dim=1)
                results[f"cost_{i+1}"] = cost_view
                cost_views.append(cost_view)
            
            return results, aggregated_tokens_list
        
        return (
            results,
            tuple(rgb_vggt.shape[-2:]),#image_shape
            images,
            aggregated_tokens_list,
            ps_idx,
        )


def process_view_feature(
    feature,
    kp,
    patch_size,
    stride,
    normalize,
    resize_factor,
    img_size=518,
):
    # unused
    batch_size = len(kp)
    max_kp = max(k.shape[0] for k in kp)
    kp_padded = torch.zeros(batch_size, max_kp, 2, device=feature.device)
    kp_mask = torch.zeros(batch_size, max_kp, dtype=torch.bool, device=feature.device)

    for i, k in enumerate(kp):
        kp_padded[i, : k.shape[0]] = k * resize_factor
        kp_mask[i, : k.shape[0]] = True

    interpolated = interpolate_features(
        feature,
        kp_padded,
        h=img_size,
        w=img_size,
        patch_size=patch_size,
        stride=stride,
        normalize=False,
    ).permute(
        0, 2, 1
    )  # (B, max_N, C)
    interpolated = interpolated * kp_mask.unsqueeze(-1)

    if normalize:
        valid_interpolated = interpolated[kp_mask]
        if valid_interpolated.numel() > 0:
            interpolated[kp_mask] = F.normalize(valid_interpolated, p=2, dim=-1)
    return interpolated


# ———————————————————————————————————上面是间接调用的工具函数—————————————————————————————————————————————————
# ———————————————————————————————————下面是直接调用的工具函数—————————————————————————————————————————————————
def get_3d_keypoints_from_gt(data, view_names=["1", "2", "3"]):
    """
    从测试数据中提取3D关键点坐标（基于点云映射）

    参数:
        data_test (dict): 包含测试数据的字典，包括：
            - kp_1, kp_2, ...: 2D关键点坐标 (B, N, 2)
            - point_map_view_1, ...: 点云数据 (B, C, W, H) 或 (B, 1, C, W, H)
        view_names (list): 视图名称列表

    返回: 
    
        kp_3d_gt (dict): 包含每个视图3D关键点坐标的字典，键为 'kp_1', 'kp_2' 等
        (B, N, 3)
    """
    kp_3d_gt = {}
    scale = 128 / 518  # kp输入是 518x518 图像输出为 128x128
    batch_size = next(
        v for v in data.values() if isinstance(v, (torch.Tensor, np.ndarray))
    ).shape[0]
    for idx, view in enumerate(view_names, 1):
        kp_data = data[f"kp_{idx}"].squeeze()
        if isinstance(kp_data, torch.Tensor):
            kp_2d = kp_data.cpu().numpy().astype(np.float32)  # (B, N, 2)
        else:
            kp_2d = kp_data.astype(np.float32)
    
        nan_mask = np.isnan(kp_2d).any(axis=-1)
        kp_2d = np.where(nan_mask[..., None], -1, kp_2d)
        kp_pixel = (kp_2d * scale).astype(int)
        kp_pixel = np.clip(kp_pixel, [0, 0], [127, 127])  # 128x128 图像

        # 获取点云数据 (B, W, H, 3)
        point_cloud = data[f"point_map_view_{idx}"].squeeze(
            1
        ).permute(0, 3, 1, 2)  
        # shape: (3, W, H) or (B, 1, 3, W, H)
        if isinstance(point_cloud, torch.Tensor):
            point_cloud = point_cloud.cpu().numpy()

        # print(f"point_cloud shape is {point_cloud.shape}")#(24, 224, 224, 3)
        # 初始化输出数组 (B, N, 3)
        batch_kp_3d = np.zeros((batch_size, kp_pixel.shape[1], 3), dtype=np.float32)
        for b in range(batch_size):
            # 处理当前batch样本
            pc = point_cloud[b]  # (3, W, H)
            W, H = pc.shape[1], pc.shape[2]

            for n in range(kp_pixel.shape[1]):
                x, y = kp_pixel[b, n]
                if 0 <= x < W and 0 <= y < H:
                    xyz_world = pc[:, x, y]
                    if not np.all(xyz_world == 0):  # 过滤零值点
                        batch_kp_3d[b, n] = xyz_world

        kp_3d_gt[f"kp_{idx}"] = batch_kp_3d

    for idx in range(1, len(view_names) + 1):
        kp_3d_gt[f"kp_{idx}"] = torch.from_numpy(kp_3d_gt[f"kp_{idx}"])
    return kp_3d_gt

# @timeit
def get_vggt_feature_map(rgb_vggt, vggt_model, device , visualize = False):

    #[extract_vggt_features] 执行时间: 1.2543 秒
    vggt_features, image_shape, images, aggregated_tokens_list, ps_idx = (
        extract_vggt_features(
            rgb_vggt.to(device), vggt_model, device=device
        )
    )
    
    #[sample_keypoints] 执行时间: 0.6068 秒
    (kp_1, kp_2, kp_3, valid_kp, mask_1, mask_2, mask_3) = sample_keypoints(
        vggt_features,
        image_shape,
        images,
        aggregated_tokens_list,
        ps_idx,
        vggt_model,
        device=device,
        num_keypoints=300,
        min_distance=5,
    )
    
    # check_keypoint(kp_1)

    kp_and_pc = {
        "kp_1": kp_1,
        "kp_2": kp_2,
        "kp_3": kp_3,
        "point_map_view_1":vggt_features["point_map_view_1"],
        "point_map_view_2":vggt_features["point_map_view_2"],
        "point_map_view_3":vggt_features["point_map_view_3"],
    }
    
    # print(f"point_map_view_1 shape is {vggt_features['point_map_view_1'].shape}")#[24, 224, 224, 3]
    
    
    if visualize:
        for idx, view_name in enumerate(["1", "2", "3"]):  # view_idx从1开始
            kp_and_pc.update({
                f"rgb_{view_name}": vggt_features['images'][idx],
                f"depth_{view_name}": vggt_features[f'depth_pred_{view_name}'],
                # f"camera_intrinsics_{view_name}": vggt_features[f'intrinsic_{view_name}'],
                # f"camera_extrinsics_{view_name}": vggt_features[f'extrinsic_{view_name}'],
            })

    return vggt_features, kp_and_pc


def align_features_with_kp(
    desc_1, desc_2, desc_3, match_input_dict, img_size, img_patch_size
):
    """
    将特征描述符与关键点对齐，通过点云投影和特征插值

    参数:
        desc_1, desc_2, desc_3: 三个视图的特征描述符 (B, C, H, W)
        match_input_dict: 包含匹配信息的字典，需包含:
            'match_kpts_img': 2D关键点列表 [batch_size] -> (N, 2)
            'normalize': 是否归一化特征的标志
        img_size: 图像尺寸（高度/宽度）
        img_patch_size: 图像块尺寸
    返回:
        aligned_desc_1, aligned_desc_2, aligned_desc_3: 对齐后的特征描述符 (B, N, C)
    """
    if match_input_dict is None:
        return desc_1, desc_2, desc_3

   
    match_kpts_img = match_input_dict["match_kpts_img"]
    
    # 处理每个视图的特征对齐
    def process_view(desc, view_idx):
        """
        处理单个视图的特征对齐

        参数:
            desc: 视图特征图 (B, C, H, W)
            view_idx: 视图索引 (0,1,2)

        返回:
            aligned_feat: 对齐后的特征 (B, 300, C)
        """
        # 提取当前视图的关键点投影坐标
        view_kpts = match_kpts_img[view_idx]  # (B, 300, 2)

        # 创建有效点掩码
        valid_mask = ~torch.isnan(view_kpts[..., 0])  # (B, 300)

        # 清理坐标：无效点设为0
        pts_clean = view_kpts.clone().float()
        pts_clean[~valid_mask] = 0

        # 特征插值
        feat = interpolate_features(
            desc,
            pts_clean,
            h=img_size,
            w=img_size,
            patch_size=img_patch_size,
            stride=img_patch_size,
            normalize=False,
        ).squeeze(
            2
        )  # (B, C, 300)

        # 应用有效掩码并转置维度
        feat = feat * valid_mask.unsqueeze(1).float()  # (B, C, 300)
        feat = feat.permute(0, 2, 1)  # (B, 300, C)

        # 可选归一化
        if match_input_dict["normalize"]:
            feat = F.normalize(feat, p=2, dim=-1)

        return feat

    # 处理三个视图
    return (process_view(desc_1, 0), process_view(desc_2, 1), process_view(desc_3, 2))
