import time
import torch
import numpy as np
import torch.nn.functional as F
from dual_stream.utils.peract_utils import _norm_rgb
from torch.cuda.amp import autocast
import torchvision.transforms.functional as TF
from peract_colab.arm.utils import stack_on_channel
from rlbench.backend.observation import Observation
from torchvision.transforms import InterpolationMode
from peract_colab.rlbench.backend.utils import REMOVE_KEYS
from point_renderer.rvt_renderer import RVTBoxRenderer as BoxRenderer
import peract_colab.arm.utils as utils
from typing import List

from dual_stream.utils import rvt_utils
import dual_stream.mvt.utils as mvt_utils
from dual_stream.mvt.augmentation import apply_se3_aug_con
from dual_stream.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from dual_stream.vggt.utils.geometry import unproject_depth_map_to_point_map

LOW_DIM_SIZE = 18  # {left_finger_joint, right_finger_joint, gripper_open, timestep}
IMAGE_SIZE = 128
CAMERAS = ["front", "left_shoulder", "right_shoulder", "wrist"]
SCENE_BOUNDS = [
    -0.3,
    -0.5,
    0.6,
    0.7,
    0.5,
    1.6,
]  # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized
VOXEL_SIZES = [100]  # 100x100x100 voxels
DEMO_AUGMENTATION_EVERY_N = 10  # sample n-th frame in demo
ROTATION_RESOLUTION = 5  # degree increments per axis
VARIATION_DESCRIPTIONS_PKL = "variation_descriptions.pkl"
EPISODE_FOLDER = "episode%d"

def _preprocess_inputs(replay_sample, cameras):
    obs, pcds = [], []
    for n in cameras:
        rgb = stack_on_channel(replay_sample["%s_rgb" % n])
        pcd = stack_on_channel(replay_sample["%s_point_cloud" % n])

        rgb = _norm_rgb(rgb)

        obs.append(
            [rgb, pcd]
        )  # obs contains both rgb and pointcloud (used in ARM for other baselines)
        pcds.append(pcd)  # only pointcloud

    dynamic_extrinsic = stack_on_channel(replay_sample["wrist_camera_extrinsics"]).unsqueeze(1)     # [bs, 1, 4, 4]
    dynamic_intrinsic = stack_on_channel(replay_sample["wrist_camera_intrinsics"]).unsqueeze(1)     # [bs, 1, 3, 3]

    # 检查点云范围 pcd.shape = [bs, 3, 128, 128]
    pc_min = pcd.min(dim=-1)[0].min(dim=-1)[0]  # 形状 [bs, 3]
    pc_max = pcd.max(dim=-1)[0].max(dim=-1)[0]  # 形状 [bs, 3]
    
    dyn_cam_info = []
    for sample in range(dynamic_extrinsic.shape[0]):
        extrinsic = dynamic_extrinsic[sample]
        intrinsic = dynamic_intrinsic[sample]

        # 提取当前样本的点云范围
        current_pc_min = pc_min[sample]  # [3]
        current_pc_max = pc_max[sample]  # [3]

        R = extrinsic[:, :3, :3]        

        T = extrinsic[:, :3, 3]     

        x_range = (current_pc_max[0] - current_pc_min[0]).item()
        y_range = (current_pc_max[1] - current_pc_min[1]).item()

        static_range = 2.0
        aspect_ratio = x_range / y_range
        # 保持与静态相机相同的总视场大小
        if aspect_ratio > 1:
            dyn_img_sizes_w = [static_range, static_range / aspect_ratio]
        else:
            dyn_img_sizes_w = [static_range * aspect_ratio, static_range]
        K = None  

        dyn_cam_info.append((R, T, dyn_img_sizes_w, K))
        
    return obs, pcds, dyn_cam_info


def extract_obs(obs: Observation,
                cameras,
                t: int = 0,
                prev_action=None,
                channels_last: bool = False,
                episode_length: int = 10,
                vggt_model=None,
                device=None):
    obs.joint_velocities = None
    grip_mat = obs.gripper_matrix
    grip_pose = obs.gripper_pose
    joint_pos = obs.joint_positions
    obs.gripper_pose = None
    obs.gripper_matrix = None
    obs.wrist_camera_matrix = None
    obs.joint_positions = None
    if obs.gripper_joint_positions is not None:
        obs.gripper_joint_positions = np.clip(
            obs.gripper_joint_positions, 0., 0.04)
    obs_dict = vars(obs)
    obs_dict = {k: v for k, v in obs_dict.items() if v is not None and not k.endswith('_paths')}
    robot_state = np.array([
        obs.gripper_open,
        *obs.gripper_joint_positions,
        *grip_pose,
        *joint_pos])
    # remove low-level proprioception variables that are not needed
    obs_dict = {k: v for k, v in obs_dict.items()
                if k not in REMOVE_KEYS}

    if not channels_last:
        # swap channels from last dim to 1st dim
        obs_dict = {k: np.transpose(
            v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0)
                    for k, v in obs_dict.items() if type(v) == np.ndarray or type(v) == list}
    else:
        # add extra dim to depth data
        obs_dict = {k: v if v.ndim == 3 else np.expand_dims(v, -1)
                    for k, v in obs_dict.items()}
    obs_dict['low_dim_state'] = np.array(robot_state, dtype=np.float32)

    # binary variable indicating if collisions are allowed or not while planning paths to reach poses
    obs_dict['ignore_collisions'] = np.array([obs.ignore_collisions], dtype=np.float32)
    for (k, v) in [(k, v) for k, v in obs_dict.items() if 'point_cloud' in k]:
        obs_dict[k] = v.astype(np.float32)
    
    for camera_name in cameras:
        obs_dict['%s_camera_extrinsics' % camera_name] = obs.misc['%s_camera_extrinsics' % camera_name]
        obs_dict['%s_camera_intrinsics' % camera_name] = obs.misc['%s_camera_intrinsics' % camera_name]
    
    # add timestep to low_dim_state
    time = (1. - (t / float(episode_length - 1))) * 2. - 1.
    obs_dict['low_dim_state'] = np.concatenate(
        [obs_dict['low_dim_state'], [time]]).astype(np.float32)

    obs.gripper_matrix = grip_mat
    obs.joint_positions = joint_pos
    obs.gripper_pose = grip_pose

    if vggt_model is not None:
        action_gripper_pose = torch.from_numpy(obs.gripper_pose).unsqueeze(0)  # (1, 7)
        observation = obs
        obs, pcds = [], []
        for n in CAMERAS:
            rgb = stack_on_channel(torch.from_numpy(getattr(observation, f"{n}_rgb").transpose(2,0,1)).unsqueeze(0).unsqueeze(0))  
            # [1, 3, 128, 128] obs.front_rgb, obs.left_shoulder_rgb, obs.right_shoulder_rgb, obs.wrist_rgb
            pcd = stack_on_channel(torch.from_numpy(getattr(observation, f"{n}_point_cloud").transpose(2,0,1)).unsqueeze(0).unsqueeze(0))
            # [1, 3, 128, 128] obs.front_point_cloud, obs.left_shoulder_point_cloud, obs.right_shoulder_point_cloud, obs.wrist_point_cloud
            rgb = _norm_rgb(rgb)
            obs.append([rgb, pcd])
            pcds.append(pcd)
        pc, img_feat = rvt_utils.get_pc_img_feat(
            obs,
            pcds,
        )

        action_trans_con, action_rot, pc = apply_se3_aug_con(
            pcd=pc,
            action_gripper_pose=action_gripper_pose,
            bounds=torch.tensor(SCENE_BOUNDS),
            trans_aug_range=torch.tensor([0.125, 0.125, 0.125]),
            rot_aug_range=torch.tensor([0.0, 0.0, 45.0]),
        )
        action_trans_con = torch.tensor(action_trans_con)  # [1, 3]
        pc, img_feat = rvt_utils.move_pc_in_bound(pc, img_feat, SCENE_BOUNDS, no_op=False)  # len = 1
        wpt = [x[:3] for x in action_trans_con]
        wpt_local = []
        for _pc, _wpt in zip(pc, wpt):
            a, b = mvt_utils.place_pc_in_cube(_pc, _wpt, with_mean_or_bounds=False,scene_bounds=SCENE_BOUNDS)
            wpt_local.append(a.unsqueeze(0))
        wpt_local = torch.cat(wpt_local, axis=0)    # [1, 3]
        pc = [mvt_utils.place_pc_in_cube(_pc, with_mean_or_bounds=False, scene_bounds=SCENE_BOUNDS)[0] for _pc in pc]
        renderer = BoxRenderer(
            device=device,
            img_size=(224, 224),
            three_views=True,
            with_depth=True,
        )
        img = render(pc=pc, img_feat=img_feat, img_aug=0, mvt1_or_mvt2=True, renderer=renderer) # [1, 3, 10, 224, 224]
        wpt_local_stage_one_noisy = mvt_utils.add_uni_noi(wpt_local.clone().detach(), 2 * 0.05)
        pc_st2, rev_trans = mvt_utils.trans_pc(pc, loc=wpt_local_stage_one_noisy, sca=4)
        img_st2 = render(pc=pc_st2, img_feat=img_feat, img_aug=0, mvt1_or_mvt2=False, renderer=renderer)
        
        # rgb_vggt_1 = load_and_preprocess_images(img[:,:,3:6]).unsqueeze(0).to(device)# obs.img_paths
        # rgb_vggt_2 = load_and_preprocess_images(img_st2[:,:,3:6]).unsqueeze(0).to(device)
        rgb_vggt_1 = preprocess_tensor_images(img[:,:,3:6]).to(device)          # [1, 3, 3, 518, 518]
        rgb_vggt_2 = preprocess_tensor_images(img_st2[:,:,3:6]).to(device)
        
        
        vggt_features, image_shape, images, aggregated_tokens_list, ps_idx = extract_vggt_features(rgb_vggt_1, vggt_model,
                                                                                                   device, return_attn=False)
        (kp_1, kp_2, kp_3, *_) = sample_keypoints(vggt_features, image_shape, images, aggregated_tokens_list, ps_idx, 
                                                  vggt_model, device, num_keypoints=300, min_distance=5)
        vggt_features_st2, image_shape_st2, images_st2, aggregated_tokens_list_st2, ps_idx_st2 = extract_vggt_features(rgb_vggt_2, 
                                                                                                                       vggt_model,
                                                                                                                       device, 
                                                                                                                       return_attn=False)
        (kp_1_st2, kp_2_st2, kp_3_st2, *_) = sample_keypoints(vggt_features_st2, image_shape_st2, images_st2, aggregated_tokens_list_st2, 
                                                              ps_idx_st2, vggt_model, device, num_keypoints=300, min_distance=5)
        MAX_LEN = 300
        def to_numpy(x):
            if isinstance(x, (list, tuple)):
                return [to_numpy(i) for i in x]
            return x.cpu().numpy() if hasattr(x, 'cpu') else x    
        kp_1 = to_numpy(kp_1)
        kp_2 = to_numpy(kp_2)
        kp_3 = to_numpy(kp_3)
        # kp_4 = to_numpy(kp_4)
        kp_1_st2 = to_numpy(kp_1_st2)
        kp_2_st2 = to_numpy(kp_2_st2)
        kp_3_st2 = to_numpy(kp_3_st2)
        padded_1 = np.zeros((1, MAX_LEN, 2))
        padded_2 = np.zeros((1, MAX_LEN, 2))
        padded_3 = np.zeros((1, MAX_LEN, 2))
        # padded_4 = np.zeros((1, MAX_LEN, 2))
        padded_1_st2 = np.zeros((1, MAX_LEN, 2))
        padded_2_st2 = np.zeros((1, MAX_LEN, 2))
        padded_3_st2 = np.zeros((1, MAX_LEN, 2))
        padded_1[:, :kp_1[0].shape[0]] = kp_1[0]
        padded_2[:, :kp_2[0].shape[0]] = kp_2[0]
        padded_3[:, :kp_3[0].shape[0]] = kp_3[0]
        # padded_4[:, :kp_4[0].shape[0]] = kp_4[0]
        padded_1_st2[:, :kp_1_st2[0].shape[0]] = kp_1_st2[0]
        padded_2_st2[:, :kp_2_st2[0].shape[0]] = kp_2_st2[0]
        padded_3_st2[:, :kp_3_st2[0].shape[0]] = kp_3_st2[0]
        
        # mh, mw = vggt_features['image_shape']
        # resized_images = [
        #     F.interpolate(torch.from_numpy(img).float().unsqueeze(0).to(device), 
        #                  size=(mh, mw))
        #     for img in rgb_images
        # ]
        
        obs_dict['kp_1'] = to_numpy(padded_1)
        obs_dict['kp_2'] = to_numpy(padded_2)
        obs_dict['kp_3'] = to_numpy(padded_3)
        # obs_dict['kp_4'] = to_numpy(padded_4)
        obs_dict['kp_1_st2'] = to_numpy(padded_1_st2)
        obs_dict['kp_2_st2'] = to_numpy(padded_2_st2)
        obs_dict['kp_3_st2'] = to_numpy(padded_3_st2)
        # obs_dict['mask_1'] = mask_1.cpu()
        # obs_dict['mask_2'] = mask_2.cpu()
        # obs_dict['mask_3'] = mask_3.cpu()
        # obs_dict['mask_4'] = mask_4.cpu()
        # obs_dict['resized_images'] = to_numpy(resized_images)
        # obs_dict['vggt_features'] = {k: to_numpy(v) for k, v in vggt_features.items()}
        # obs_dict['valid_kp'] = valid_kp.cpu()
        
        # 分辨率调整
        processed = {}
        processed_st2 = {}
        original_size = 518
        target_size=128
        for key in ['point_map_view_1', 'point_map_view_2', 'point_map_view_3']:#, 'point_map_view_4'
            if key in vggt_features:
                # (B,518,518,3) -> (B,3,518,518) -> 下采样 -> (B,3,128,128) -> (B,128,128,3)
                tensor = vggt_features[key].permute(0, 3, 1, 2)  # NHWC -> NCHW
                resized = F.interpolate(tensor.float(), size=(target_size, target_size), mode='bilinear', align_corners=False)
                processed[key] = resized.permute(0, 2, 3, 1)  # NCHW -> NHWC
            if key in vggt_features_st2:
                tensor_st2 = vggt_features_st2[key].permute(0, 3, 1, 2)  # NHWC -> NCHW
                resized_st2 = F.interpolate(tensor_st2.float(), size=(target_size, target_size), mode='bilinear', align_corners=False)
                processed_st2[key] = resized_st2.permute(0, 2, 3, 1)
        
        for key in ['point_conf_view_1', 'depth_pred_1', 'depth_pred_2', 'depth_pred_3']:#, 'depth_pred_4'
            if key in vggt_features:
                # (B,518,518) -> (B,1,518,518) -> 下采样 -> (B,1,128,128) -> (B,128,128)
                tensor = vggt_features[key].unsqueeze(1) # NHW -> NCHW
                resized = F.interpolate(tensor, size=(target_size, target_size), mode='nearest')
                processed[key] = resized.squeeze(1)
            if key in vggt_features_st2:
                tensor_st2 = vggt_features_st2[key].unsqueeze(1)
                resized_st2 = F.interpolate(tensor_st2, size=(target_size, target_size), mode='nearest')
                processed_st2[key] = resized_st2.squeeze(1)
        
        for k, v in vggt_features.items():
            if k not in processed:
                processed[k] = v
        for k, v in vggt_features_st2.items():
            if k not in processed_st2:
                processed_st2[k] = v
        # processed['rendered_img'] = img[:,:,3:6]
        # processed_st2['rendered_img'] = img_st2[:,:,3:6]
        obs_dict['vggt_features'] = {k: to_numpy(v) for k, v in processed.items()} 
        obs_dict['vggt_features_st2'] = {k: to_numpy(v) for k, v in processed_st2.items()} 

    return obs_dict


def render(pc, img_feat, img_aug, mvt1_or_mvt2, renderer):
    
    assert isinstance(mvt1_or_mvt2, bool)
    with torch.no_grad():
        with autocast(enabled=False):
            img = []
            for _pc, _img_feat in zip(pc, img_feat):
                _pc = _pc.to(renderer.renderer.device)
                _img_feat = _img_feat.to(renderer.renderer.device)
                # fix when the pc is empty
                max_pc = 1.0 if len(_pc) == 0 else torch.max(torch.abs(_pc))
                img.append(
                    renderer(
                        _pc,
                        torch.cat((_pc / max_pc, _img_feat), dim=-1),
                        fix_cam=True,
                        dyn_cam_info=None,
                    ).unsqueeze(0)
                )

    img = torch.cat(img, 0)
    img = img.permute(0, 1, 4, 2, 3)

    # image augmentation
    if img_aug != 0:
        stdv = img_aug * torch.rand(1, device=renderer.renderer.device)
        # values in [-stdv, stdv]
        noise = stdv * ((2 * torch.rand(*img.shape, device=renderer.renderer.device)) - 1)
        img = torch.clamp(img + noise, -1, 1)

    bs = img.shape[0]
    pixel_loc = torch.zeros((3, 3, 224, 224),device=renderer.renderer.device)
    pixel_loc[:, 0, :, :] = (torch.linspace(-1, 1, 3).unsqueeze(-1).unsqueeze(-1))
    pixel_loc[:, 1, :, :] = (torch.linspace(-1, 1, 224).unsqueeze(0).unsqueeze(-1))
    pixel_loc[:, 2, :, :] = (torch.linspace(-1, 1, 224).unsqueeze(0).unsqueeze(0))
    pixel_loc = pixel_loc
    img = torch.cat(
        (img, pixel_loc.unsqueeze(0).repeat(bs, 1, 1, 1, 1)), dim=2
    )

    return img


def preprocess_tensor_images(tensor, mode="crop"):
    """
    直接处理张量输入，保持与 load_and_preprocess_images 相同的逻辑
    输入: [B, C, H, W] 或 [C, H, W]
    返回: [B, 3, 518, 518] (GPU张量)
    """
    raw_tensor = tensor         # [1,3,3,224,224]
    num_img = tensor.shape[1]
    target_size = 518
    _, _, _, height, width = tensor.shape

    resized_images = []
    for img_index in range(num_img):
        tensor = raw_tensor[:, img_index]
        if tensor.dim() == 3:
            tensor = tensor.unsqueeze(0)  # [C,H,W] -> [1,C,H,W]
        if mode == "crop":
            # 保持宽高比，调整宽度到518，高度按比例裁剪
            new_width = target_size
            new_height = round(height * (new_width / width) / 14) * 14      # int(height * (new_width / width))
            tensor = TF.resize(tensor, [new_height, new_width],interpolation=InterpolationMode.BICUBIC,antialias=True)
            # 假设 tensor 是 [C, H, W] 且值域 [0, 1]
            
            # 中心裁剪高度到518
            if new_height > target_size:
                start_y = (new_height - target_size) // 2
                tensor = tensor[:, :, start_y : start_y + target_size, :]
        
        elif mode == "pad":
            # 调整最大边到518，另一侧填充
            scale = target_size / max(height, width)
            new_height, new_width = int(height * scale), int(width * scale)
            tensor = TF.resize(tensor, [new_height, new_width])
            
            # 填充到518x518
            pad_h = target_size - new_height
            pad_w = target_size - new_width
            tensor = TF.pad(tensor, [pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=1.0)

        resized_images.append(tensor)
    resized_tensor = torch.stack(resized_images, dim=1)  # concatenate images
    return resized_tensor


def extract_vggt_features(rgb_vggt, model, device, return_attn=False):
    DEBUG = False
    num_view = rgb_vggt.shape[1]
    # assert num_view == 3, f"shape of rgb_vggt is {rgb_vggt.shape}"
    vggt_dtype = (
        torch.bfloat16
        if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
        else torch.float16
    )
    with torch.no_grad():
        start = time.time()
        with torch.cuda.amp.autocast(dtype=vggt_dtype):
            images = rgb_vggt  # add batch dimension
            
            # process
            # images = load_and_preprocess_images(image_names).to(device)
            if return_attn:
                aggregated_tokens_list, ps_idx, attn = model.aggregator(
                    images
                )  # attn (B*S, num_heads, P, P) 全局注意力权重矩阵
            else:
                aggregated_tokens_list, ps_idx = model.aggregator(
                    images
                )
        if DEBUG:
            print(f"aggregator 耗时: {time.time() - start:.4f} 秒")
        # Predict Cameras
        start = time.time()
        pose_enc = model.camera_head(aggregated_tokens_list)[-1]
        # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
        extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
        if DEBUG:
            print(f"extrinsic 耗时: {time.time() - start:.4f} 秒")
        
        # Predict Depth Maps
        start = time.time()
        depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
        if DEBUG:
            print(f"depth_map 耗时: {time.time() - start:.4f} 秒")
        
        # Predict Point Maps
        start = time.time()
        point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
        if DEBUG:
            print(f"point_map 耗时: {time.time() - start:.4f} 秒")
        
        # Construct 3D Points from Depth Maps and Cameras
        # which usually leads to more accurate 3D points than point map branch
        point_maps_by_unprojection = []
        start = time.time()
        for i in range(depth_map.size(0)):
            point_map_by_unprojection = unproject_depth_map_to_point_map(
                depth_map[i].cpu().numpy(),  # (V, 518, 518, 1)
                extrinsic[i].cpu().numpy(),  # (V, 3, 4)
                intrinsic[i].cpu().numpy(),
            )
            point_maps_by_unprojection.append(
                torch.from_numpy(point_map_by_unprojection).float()
            )
        point_map_by_unprojection = torch.stack(
            point_maps_by_unprojection
        )  # (B, V, 518, 518, 3)
        if DEBUG:
            print(f"point_map unprojection 耗时: {time.time() - start:.4f} 秒")
        

        # 动态创建视图数据
        results = {}
        
        # 创建每个视图的点云映射、置信度等
        for i in range(num_view):
            results[f"point_map_view_{i+1}"] = (
                point_map_by_unprojection[:, i, ...].detach().clone().to(device)
            )
            results[f"point_conf_view_{i+1}"] = point_conf[:, i]
            results[f"extrinsic_{i+1}"] = extrinsic[:, i]
            results[f"intrinsic_{i+1}"] = intrinsic[:, i]
            results[f"depth_pred_{i+1}"] = depth_map[:, i].squeeze(-1)
        
        results["image_shape"] = tuple(rgb_vggt.shape[-2:])
        results["images"] = images
        results["ps_idx"] = ps_idx
        
        if return_attn:
            # 处理注意力信息
            cost_views = []
            for i in range(num_view):
                view_attn = attn.chunk(num_view, dim=0)[i]
                cost_view = view_attn.mean(dim=1)
                results[f"cost_{i+1}"] = cost_view
                cost_views.append(cost_view)
            
            return results, aggregated_tokens_list
        
        return (
            results,
            tuple(rgb_vggt.shape[-2:]),#image_shape
            images,
            aggregated_tokens_list,
            ps_idx,
        )


def sample_keypoints(
    vggt_features,
    image_shape,
    images,
    aggregated_tokens_list,
    ps_idx,
    model,
    device,
    num_keypoints=300,
    min_distance=5,
):

    point_conf_view_1 = vggt_features["point_conf_view_1"]

    mask_1, mask_2, mask_3 = get_coview_masks(vggt_features, image_shape)  # (B, H, W)

    # 在mask为True的有效区域内，通过非极大值抑制（NMS）筛选出置信度图conf的局部最大值，最终返回最多300个关键点的坐标 (B, 300, 2)
    sampled_kp_1 = sample_keypoints_nms(
        mask_1,
        point_conf_view_1,
        N=num_keypoints,
        min_distance=min_distance,
        device=device,
    )

    if sampled_kp_1 is None:
        print("No keypoints found in the first view.")
        return None, None, None, None, None
    sampled_kp_1 = sampled_kp_1[:, :, [1, 0]].int()  # (row, col) -> (x, y)
    # list of length 4 (B, V, 2, 2)
    sampled_kp_o, vis_score, conf_score = model.track_head(
        aggregated_tokens_list, images, ps_idx, query_points=sampled_kp_1
    )

    sampled_kp_2 = sampled_kp_o[-1][:, 1].int()  # (x, y)
    sampled_kp_3 = sampled_kp_o[-1][:, 2].int()
    # sampled_kp_4 = sampled_kp_o[-1][:, 3].int()

    mh, mw = image_shape
    valid_kp_1 = (
        (sampled_kp_1[:, :, 0] >= 3)
        & (sampled_kp_1[:, :, 0] < int(mw) - 3)
        & (sampled_kp_1[:, :, 1] >= 3)
        & (sampled_kp_1[:, :, 1] < int(mh) - 3)
    )
    valid_kp_2 = (
        (sampled_kp_2[:, :, 0] >= 3)
        & (sampled_kp_2[:, :, 0] < int(mw) - 3)
        & (sampled_kp_2[:, :, 1] >= 3)
        & (sampled_kp_2[:, :, 1] < int(mh) - 3)
    )
    valid_kp_3 = (
        (sampled_kp_3[:, :, 0] >= 3)
        & (sampled_kp_3[:, :, 0] < int(mw) - 3)
        & (sampled_kp_3[:, :, 1] >= 3)
        & (sampled_kp_3[:, :, 1] < int(mh) - 3)
    )

    valid_kp = valid_kp_1 & valid_kp_2 & valid_kp_3  # (B, 300)

    # kp_1, kp_2, kp_3 = [], [], []  # list of length B
    
    bs = valid_kp.shape[0]
    kp_1 = torch.full((bs, num_keypoints, 2), float('nan'), device=device)
    kp_2 = torch.full((bs, num_keypoints, 2), float('nan'), device=device)
    kp_3 = torch.full((bs, num_keypoints, 2), float('nan'), device=device)
    
    for b in range(bs):
        mask_b = valid_kp[b]  # (300,)

        kp_b_1 = sampled_kp_1[b][mask_b]  # (N_b, 2)
        kp_b_2 = sampled_kp_2[b][mask_b]
        kp_b_3 = sampled_kp_3[b][mask_b]

        # kp_1.append(kp_b_1)
        # kp_2.append(kp_b_2)
        # kp_3.append(kp_b_3)
        
        num_valid = mask_b.sum().item()        
        if num_valid > 0:
            kp_1[b, :num_valid] = kp_b_1
            kp_2[b, :num_valid] = kp_b_2
            kp_3[b, :num_valid] = kp_b_3
            
    return kp_1, kp_2, kp_3, valid_kp, mask_1, mask_2, mask_3


def get_coview_masks(vggt_features, image_shape, num_view=3):
    """
    计算多视图之间的共视区域掩码
    Args:
        vggt_features: dict with keys 'point_map_view_1', ..., 'intrinsic_1', ..., 'extrinsic_1', ...
        image_shape: (H, W)
    Returns:
        masks: tuple of (B, H, W) masks for each view
    """

    B = vggt_features["point_map_view_1"].shape

    point_maps = [
        vggt_features[f"point_map_view_{i}"] for i in range(1, num_view+1)
    ]  # list of (B, H, W, 3)
    extrinsics = [
        vggt_features[f"extrinsic_{i}"] for i in range(1, num_view+1)
    ]  # list of (B, 3, 4)
    intrinsics = [
        vggt_features[f"intrinsic_{i}"] for i in range(1, num_view+1)
    ]  # list of (B, 3, 3)

    world_point_maps = []
    for i in range(num_view):
        world_points = convert_camera_to_world(
            point_maps[i], extrinsics[0]
        )  # (B, H, W, 3)
        world_point_maps.append(world_points)

    Ps = [torch.bmm(intrinsics[i], extrinsics[i]) for i in range(num_view)]  # (B, 3, 3)

    pairings = [(0, 1), (1, 2), (2, 0)]  # view1 ↔ view4, view2 ↔ view3 等

    masks = []
    for src, dst in pairings:
        P = Ps[dst]
        world_points = world_point_maps[src]
        mask = get_coview_mask(world_points, P, image_shape)  # (B, H, W)
        masks.append(mask)

    return tuple(masks)


def get_coview_mask(point_map, P, image_shape):
    """
    Args:
        point_map: (B, H, W, 3)
        P: (B, 3, 3) - projection matrix (intrinsic @ extrinsic[:3])
        image_shape: (H_img, W_img)
    Returns:
        mask: (B, H, W) - valid projection mask
    """
    H_img, W_img = image_shape
    B, H, W, _ = point_map.shape

    proj_points = compute_projection(P, point_map)  # (B, H, W, 2)

    u = proj_points[..., 0]
    v = proj_points[..., 1]

    mask = (u >= 0) & (u < W_img) & (v >= 0) & (v < H_img)
    return mask


def sample_keypoints_nms(mask, conf, N, min_distance, device=None):
    """
    使用非极大值抑制(NMS)采样关键点

    参数:
        mask (torch.Tensor): 布尔掩码张量，形状为 (B, H, W)，表示有效区域
        conf (torch.Tensor): 置信度图，形状为 (B, H, W)
        N (int): 需要采样的关键点数量
        min_distance (int): 关键点之间的最小距离（像素）
        device (torch.device, optional): 计算设备，默认为None（使用mask的设备）

    返回:
        torch.Tensor: 采样后的关键点坐标，形状为 (B, N, 2)

    功能说明:
        1. 创建分数图，将有效区域的置信度赋值给对应位置
        2. 使用最大池化实现非极大值抑制(NMS)
        3. 筛选出局部最大值点作为候选关键点
        4. 对每个批次样本采样指定数量的关键点
    """
    if device is None:
        device = mask.device
    B, H, W = mask.shape

    score_map = torch.zeros_like(mask, dtype=torch.float32, device=device)
    score_map[mask] = conf[mask]

    kernel_size = int(min_distance) * 2 + 1
    pad = kernel_size // 2

    pooled = F.max_pool2d(
        score_map.unsqueeze(1), kernel_size=kernel_size, stride=1, padding=pad
    ).squeeze(1)

    eps = 1e-6
    nms_mask = (score_map - pooled).abs() < eps
    nms_mask = nms_mask & mask
    keypoints_list = []
    for b in range(B):
        keypoints = torch.nonzero(nms_mask[b], as_tuple=False)  # (M, 2)
        M = keypoints.shape[0]
        if M == 0:
            # print("No keypoints found by nms.")
            keypoints_list.append(torch.zeros((N, 2), device=device, dtype=torch.int64))
        elif M > N:
            perm = torch.randperm(M, device=device)[:N]
            sampled_keypoints = keypoints[perm]
            keypoints_list.append(sampled_keypoints)
        else:
            # 如果关键点不足 N 个，重复采样
            repeat_times = (N + M - 1) // M
            sampled_keypoints = torch.repeat_interleave(keypoints, repeat_times, dim=0)[
                :N
            ]
            keypoints_list.append(sampled_keypoints)
    return torch.stack(keypoints_list)  # (B, N, 2)


def convert_camera_to_world(point_map, extrinsic):
    """
    Args:
        point_map: (B, H, W, 3)
        extrinsic: (B, 3, 4) - [R | t]
    Returns:
        world_points: (B, H, W, 3)
    """
    R = extrinsic[:, :, :3]  # (B, 3, 3)
    t = extrinsic[:, :, 3].unsqueeze(1)  # (B, 1, 3)
    R_inv = torch.inverse(R)  # (B, 3, 3)

    # Reshape point_map for batched matmul: (B, H*W, 3)
    B, H, W, _ = point_map.shape
    points_flat = point_map.view(B, -1, 3)  # (B, H*W, 3)

    # Transform: (B, H*W, 3) → (B, 3, H*W)
    transformed = torch.bmm(R_inv, (points_flat - t).transpose(1, 2)).transpose(1, 2)

    return transformed.view(B, H, W, 3)


def compute_projection(P, points_3d):
    """
    Args:
        P: (B, 3, 4) torch tensor, projection matrix.
        points_3d: (B, ..., 3) tensor of 3D world points.

    Returns:
        proj_points: (B, ..., 2) tensor of 2D pixel coordinates.
    """
    B = P.shape[0]
    orig_shape = points_3d.shape[:-1]
    points_flat = points_3d.view(B, -1, 3)  # (B, N, 3)
    ones = torch.ones((B, points_flat.shape[1], 1), device=points_flat.device)
    points_h = torch.cat([points_flat, ones], dim=-1)  # (B, N, 4)

    # Batch matrix multiplication: (B, 3, 4) @ (B, 4, N) -> (B, 3, N)
    proj_h = torch.bmm(P, points_h.transpose(1, 2))

    # Normalize: (B, 3, N) -> (B, N, 3)
    proj_h = proj_h.transpose(1, 2)
    proj_points = proj_h[..., :2] / (proj_h[..., 2:3] + 1e-8)

    # Reshape back to original
    return proj_points.view(*orig_shape, 2)


# extract CLIP language features for goal string
def _clip_encode_text(clip_model, text):
    x = clip_model.token_embedding(text).type(
        clip_model.dtype
    )  # [batch_size, n_ctx, d_model]

    x = x + clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = clip_model.ln_final(x).type(clip_model.dtype)

    emb = x.clone()
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ clip_model.text_projection

    return x, emb


# discretize translation, rotation, gripper open, and ignore collision actions
def _get_action(
    obs_tp1: Observation,
    obs_tm1: Observation,
    rlbench_scene_bounds: List[float],  # metric 3D bounds of the scene
    voxel_sizes: List[int],
    rotation_resolution: int,
    crop_augmentation: bool,
): 
    quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
    if quat[-1] < 0:
        quat = -quat
    disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
    attention_coordinate = obs_tp1.gripper_pose[:3]
    trans_indicies, attention_coordinates = [], []
    bounds = np.array(rlbench_scene_bounds)
    ignore_collisions = int(obs_tm1.ignore_collisions)
    for depth, vox_size in enumerate(
        voxel_sizes
    ):  # only single voxelization-level is used in PerAct
        index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
        trans_indicies.extend(index.tolist())
        res = (bounds[3:] - bounds[:3]) / vox_size
        attention_coordinate = bounds[:3] + res * index
        attention_coordinates.append(attention_coordinate)

    rot_and_grip_indicies = disc_rot.tolist()
    grip = float(obs_tp1.gripper_open)
    rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
    return (
        trans_indicies,
        rot_and_grip_indicies,
        ignore_collisions,
        np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
        attention_coordinates,
    )