import os 
import cv2
import csv
import torch
import pickle
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
# from rvt.utils.vggt_utils import check_keypoint,timeit

IMG_IDX = 0

# def plot_3d_comparison(ax, data_test, kp_3d_pred, view='overview'):
#     """3D点云可视化"""
#     # 获取ground truth点云（各视角先reshape为(N,3)再合并）
#     gt_points = []
#     for view_name in ['front', 'left_shoulder', 'right_shoulder', 'wrist']:
#         pc = data_test[f'{view_name}_point_cloud']
#         if isinstance(pc, torch.Tensor):
#             pc = pc.cpu().numpy()
#         if pc.ndim == 5:    # (B,1,3,H,W)
#             pc = pc[0,0]    # 取第一个样本，移除冗余维度
#         elif pc.ndim == 4:  # (B,3,H,W)
#             pc = pc[0]      # 取第一个样本
#         elif pc.ndim == 3:  # (3,W,H)
#             pass            
#         else:
#             raise ValueError(f"Unexpected point cloud shape: {pc.shape}")
            
#         pc = pc.transpose(1,2,0)  
#         # pc = data_test[f'{view_name}_point_cloud'].transpose(1,2,0)  # shape: (W,H,3)
#         gt_points.append(pc.reshape(-1,3))          # (16384, 3)
#     gt_points = np.concatenate(gt_points, axis=0)   # (65536, 3)
#     # 自动计算合理的坐标范围（保留10%边界）
#     x_min, x_max = np.percentile(gt_points[:,0], [5, 95])
#     y_min, y_max = np.percentile(gt_points[:,1], [5, 95]) 
#     z_min, z_max = np.percentile(gt_points[:,2], [5, 95])
#     bounds = [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
#     # 过滤无效点（全零点和NaN）
#     valid_mask = ~np.all(gt_points == 0, axis=1) & ~np.isnan(gt_points).any(axis=1)
#     gt_points = gt_points[valid_mask]
#     # 绘制ground truth点云（使用更高效的绘制方式）
#     ax.scatter(gt_points[:,0], gt_points[:,1], gt_points[:,2], 
#               c='gray', alpha=0.1, s=5, label='Scene PointCloud')

#     # 绘制关键点（过滤无效点）
#     colors = ['red', 'green', 'blue', 'yellow']
#     for i in range(1, 4):
#         kp = kp_3d_pred[f'kp_{i}']  # (300, 3)
#         if isinstance(kp, torch.Tensor):
#             kp = kp.cpu().numpy()
#         # 处理可能的batch维度
#         if kp.ndim == 3:  # (B,N,3)
#             kp = kp[0]    # 取第一个样本
#         valid_kp = kp[~np.all(kp == 0, axis=1)]  # 过滤全零点 
#         if len(valid_kp) > 0:
#             ax.scatter(valid_kp[0], valid_kp[1], valid_kp[2], #valid_kp[:,0], valid_kp[:,1], valid_kp[:,2],
#                        c=colors[i-1], s=30, marker='o', label=f'KP_{i} ({len(valid_kp)} pts)')
#     # 设置坐标轴
#     ax.set_xlabel('X')
#     ax.set_ylabel('Y') 
#     ax.set_zlabel('Z')
#     ax.set_xticks([])
#     ax.set_yticks([])
#     ax.set_zticks([])
#     ax.legend(loc='upper right', fontsize=8)
#     # 设置视角和范围
#     if view == 'front':
#         ax.view_init(elev=0, azim=0)  # 正前方视角
#     elif view == 'top_down':
#         ax.view_init(elev=90, azim=0)  # 正上方视角
#     else:
#         ax.view_init(elev=30, azim=60)
#     if bounds:
#         ax.set_xlim(bounds[0], bounds[3])
#         ax.set_ylim(bounds[1], bounds[4])
#         ax.set_zlim(bounds[2], bounds[5])

# def plot_error_heatmap(ax, data_test, kp_3d_pred, view_name):
#     """绘制投影误差热力图"""
#     def safe_squeeze(arr):
#         """安全压缩维度，保留至少2D"""
#         if isinstance(arr, torch.Tensor):
#             arr = arr.cpu().numpy()
#         while arr.ndim > 2:
#             arr = arr.squeeze(0)
#         return arr
#     # 获取原始关键点和预测投影
#     kp_idx = ['front', 'left_shoulder', 'right_shoulder', 'wrist'].index(view_name) + 1
#     original_kps = safe_squeeze(data_test[f'kp_{kp_idx}'])
#     scale = 128 / 518
#     # 收集误差数据
#     points = []
#     errors = []
#     for kp_id in range(1, 5):
#         if kp_id-1 == ['front', 'left_shoulder', 'right_shoulder', 'wrist'].index(view_name):
#             kp_world = safe_squeeze(kp_3d_pred[f'kp_{kp_id}'])
#             extrinsics = safe_squeeze(data_test[f'{view_name}_camera_extrinsics'])
#             intrinsics = safe_squeeze(data_test[f'{view_name}_camera_intrinsics'])
#             R = extrinsics[:3,:3]
#             t = extrinsics[:3,3]
#             # kp_cam = (kp_world - t) @ R.T
#             # kp_homo = np.column_stack([kp_cam, np.ones(len(kp_cam))])
#             # kp_pixel = kp_homo @ intrinsics.T
#             kp_cam = (kp_world - t) @ R
#             kp_pixel = kp_cam @ intrinsics.T
#             kp_pixel = kp_pixel[:,:2] / kp_pixel[:,2:]
#             if len(kp_pixel) == len(original_kps):
#                 error = np.linalg.norm(kp_pixel - original_kps * scale, axis=1)
#                 points.append(kp_pixel)
#                 errors.append(error)
#             else:
#                 print(f"Warning: Mismatched keypoints count ({len(kp_pixel)} vs {len(original_kps)})")

#     # 绘制误差图
#     if points and errors:
#         points = np.concatenate(points, axis=0)  # (N,2)
#         errors = np.concatenate(errors, axis=0)  # (N,)
        
#         # 检查维度一致性
#         assert len(points) == len(errors), \
#                f"Points ({len(points)}) and errors ({len(errors)}) must have same length"
#         # 方案1：带误差值的散点图
#         sc = ax.scatter(points[:,0], points[:,1], c=errors,
#                        cmap='Reds', s=30, alpha=0.7,
#                        vmin=0, vmax=20)  # 设置合理的误差范围
#         ax.set_xlabel('X pixel')
#         ax.set_ylabel('Y pixel')
#         ax.set_title('Reprojection Errors')
#         cb = plt.colorbar(sc, ax=ax)
#         cb.set_label('Error (pixels)')
#         # 方案2：误差直方图（备用）
#         # ax.hist(errors, bins=20, color='skyblue')
#         # ax.set_xlabel('Error (pixels)')
#         # ax.set_title('Error Distribution')
#     else:
#         ax.text(0.5, 0.5, 'No valid reprojections', 
#                ha='center', va='center', transform=ax.transAxes)


# # def plot_2d_projection(ax, data_test, img_match, view_name, overlay_type='rgb', kps=None):
# #     """
# #     在2D图像上可视化关键点投影，支持三种不同的背景类型：RGB图像、点云图(PT)和深度图
    
# #     参数:
# #         ax (matplotlib.axes.Axes): 绘图的目标坐标轴
# #         data_test (dict): 包含测试数据的字典，应包含点云图和深度图
# #         img_match (numpy.ndarray): 匹配的RGB图像，形状为(3, H, W)
# #         view_name (int): 视图名称（如1,2,3），用于从data_test中提取对应视图的数据
# #         overlay_type (str, optional): 背景类型，可选值：
# #             'rgb' - 使用RGB图像作为背景
# #             'PT' - 使用点云图作为背景
# #             其他值 - 使用深度图作为背景
# #             默认为 'rgb'
# #         kps (numpy.ndarray, optional): 关键点坐标数组，形状为(N, 2)。默认为None
    
# #     返回:
# #         None: 结果直接绘制在传入的坐标轴ax上
# #     """
    
    
# #     # 获取数据
# #     pt = data_test[f'point_map_view_{view_name}'][IMG_IDX].squeeze()     # (518, 518, 3)
# #     rgb = img_match[view_name-1]                                             # (3, 224, 224)
# #     depth = data_test[f'depth_{view_name}'][IMG_IDX].squeeze()      # (518, 518)

# #     # print(f"pt shape is {pt.shape}")
# #     # print(f"rgb shape is {rgb.shape}")
# #     # print(f"depth shape is {depth.shape}")

# #     # 分辨率调整
# #     original_size = pt.shape[IMG_IDX] # 518
# #     target_size= rgb.shape[1] # 224
# #     if original_size != target_size:
# #         resized_pt = F.interpolate(pt.permute(2, 0, 1).float(), size=(target_size, target_size), mode='bilinear', align_corners=False)
# #         pt = resized_pt.permute(1, 2, 0)  # CHW -> HWC
# #         resized_depth = F.interpolate(depth.unsqueeze(0), size=(target_size, target_size), mode='nearest')
# #         depth = resized_depth.squeeze(0)
    
# #     if isinstance(depth, torch.Tensor):
# #         depth = depth.cpu().numpy()
# #     if isinstance(rgb, torch.Tensor):
# #         rgb = rgb.cpu().numpy()
# #     if isinstance(pt, torch.Tensor):
# #         pt = pt.cpu().numpy()
# #     if isinstance(kps, torch.Tensor):
# #         kps = kps.cpu().numpy()
                
# #     # 如果需要将 RGB 从 224x224 缩放至 128x128
# #     # if rgb.shape[-2:] == (target_size, target_size):
# #     #     rgb_128 = cv2.resize(rgb.transpose(1, 2, 0), (128, 128), interpolation=cv2.INTER_AREA)
# #     #     rgb_128 = rgb_128.transpose(2, 0, 1)  # (3, 128, 128)
# #     # else:
# #     #     rgb_128 = rgb

# #     H, W = depth.shape

# #     # 点云归一化
# #     def normalize_point_cloud(pc):
# #         pc = pc.astype(np.float32)
# #         # 各通道独立归一化（保留颜色相对关系）
# #         for c in range(3):
# #             min_val = pc[..., c].min()
# #             max_val = pc[..., c].max()
# #             if max_val > min_val:
# #                 pc[..., c] = (pc[..., c] - min_val) / (max_val - min_val)
# #         return pc

# #     # check_keypoint(kps)
# #     # print(f"rgb shape is {rgb.shape}")

# #     # 显示逻辑
# #     if overlay_type == 'rgb':
# #         img = (rgb * 255).astype(np.uint8) if rgb.max() <= 1 else rgb.astype(np.uint8)
# #         kps = np.clip(kps, [0, 0], [W-1, H-1])
# #         ax.imshow(img.transpose(1, 2, 0))
# #         ax.scatter(kps[:, 0], kps[:, 1], c='yellow', s=40, marker='x', linewidth=1.5, label='Keypoints')
    
# #     elif overlay_type == 'PT':
# #         pt_normalized = normalize_point_cloud(pt)
# #         ax.imshow(pt_normalized)
# #         kps = np.clip(kps, [0, 0], [W-1, H-1])
# #         ax.scatter(kps[:, 0], kps[:, 1], c='yellow', s=40, marker='x', linewidth=1.5, label='Keypoints')
    
# #     else:  # 深度图显示
# #         depth_normalized = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
# #         ax.imshow(depth_normalized, cmap='viridis', vmin=0, vmax=1)
# #         kps = np.clip(kps, [0, 0], [W-1, H-1])
# #         ax.scatter(kps[:, 0], kps[:, 1], c='red', s=40, marker='x', linewidth=1.5, label='Keypoints')

# #     ax.set_xlim(0, W)
# #     ax.set_ylim(H, 0)
# #     ax.legend(loc='upper right', fontsize=8, framealpha=0.5)
# #     ax.set_title(f'View {view_name} - {overlay_type.upper()}', fontsize=10)
# #     ax.grid(False)
# #     ax.set_xticks([])
# #     ax.set_yticks([])
# #     """
# #     绘制2D投影对比图，包括原始关键点和预测关键点。
    
# #     :param ax: matplotlib的axes对象，用于绘图
# #     :param data_test: 测试数据字典，包含RGB、深度图、相机内外参等信息
# #     :param kp_3d_pred: 预测的3D关键点字典
# #     :param view_name: 视角名称（例如'front', 'left_shoulder'等）
# #     :param overlay_type: 可选'rgb'或'depth'，决定背景是RGB图像还是深度图
# #     :param original_kps: 原始2D关键点坐标列表，如果提供，则会画在图像上
# #     """
# #     def safe_transpose(arr):
# #         """安全转置函数（处理3D/4D输入）"""
# #         if isinstance(arr, torch.Tensor):
# #             arr = arr.cpu().numpy()
            
# #         if arr.ndim == 4:  # (B,C,H,W)
# #             arr = arr[0]    # 取第一个样本
# #         elif arr.ndim == 3 and arr.shape[0] == 3:  # (C,H,W)
# #             pass
# #         else:
# #             arr = arr.squeeze()
# #         return arr.transpose(1,2,0)  # (H,W,C)
# #     rgb = safe_transpose(data_test[f'{view_name}_rgb'])  # (H,W,3)
# #     depth = data_test[f'{view_name}_depth']
# #     if isinstance(depth, torch.Tensor):
# #         depth = depth.cpu().numpy()
# #     depth = depth.squeeze()  # (H,W)
# #     intrinsics = data_test[f'{view_name}_camera_intrinsics'].squeeze()
# #     extrinsics = data_test[f'{view_name}_camera_extrinsics'].squeeze()
# #     if isinstance(intrinsics, torch.Tensor):
# #         intrinsics = intrinsics.cpu().numpy()
# #     if isinstance(extrinsics, torch.Tensor):
# #         extrinsics = extrinsics.cpu().numpy()
# #     H, W = depth.shape

# #     # 显示背景（添加分辨率适配）
# #     if overlay_type == 'rgb':
# #         img = (rgb * 255).astype(np.uint8) if rgb.max() <= 1 else rgb.astype(np.uint8)
# #         if original_kps is not None:
# #             if isinstance(original_kps, torch.Tensor):
# #                 original_kps = original_kps.cpu().numpy()
# #             original_kps = np.clip(original_kps, [0, 0], [W-1, H-1])
# #             ax.scatter(original_kps[:,0], original_kps[:,1], 
# #                       c='blue', s=20, marker='x', label='Original KPs')
# #         ax.imshow(img)
# #     else:
# #         ax.imshow(depth, cmap='gray', vmin=0, vmax=2)
    
# #     colors = plt.get_cmap('tab10').colors
# #     markers = ['*', 's', 'D', 'o']
# #     # 投影关键点
# #     for kp_id in range(1, 5):
# #         kp_world = kp_3d_pred[f'kp_{kp_id}']
# #         if isinstance(kp_world, torch.Tensor):
# #             kp_world = kp_world.cpu().numpy()
# #         # 处理可能的batch维度
# #         if kp_world.ndim == 3:  # (B,N,3)
# #             kp_world = kp_world[0]  # (N,3)
# #         # 世界→相机坐标系 (使用外参的逆变换)
# #         R = extrinsics[:3,:3]
# #         t = extrinsics[:3,3]
# #         # kp_cam = (kp_world - t) @ R.T
# #         # kp_homo = np.column_stack([kp_cam, np.ones(len(kp_cam))])
# #         # kp_pixel = kp_homo @ intrinsics.T
# #         kp_cam = (kp_world - t) @ R  # 等价于 R.T @ (kp_world - t)
# #         # 相机→像素坐标
# #         kp_pixel = kp_cam @ intrinsics.T
# #         kp_pixel = kp_pixel[:,:2] / kp_pixel[:,2:]
# #         # 过滤有效点
# #         valid_mask = (
# #             (kp_pixel[:,0] >= 0) & (kp_pixel[:,0] < W) &
# #             (kp_pixel[:,1] >= 0) & (kp_pixel[:,1] < H) &
# #             (kp_cam[:,2] > 0)  # z值必须为正（在相机前方）
# #         )
# #         kp_pixel = kp_pixel[valid_mask]
# #         # 绘制     当前视角的关键点，使用特殊标记
# #         current_view = (kp_id-1 == ['front', 'left_shoulder', 'right_shoulder', 'wrist'].index(view_name))
# #         label = f'KP_{kp_id}' + (' (current)' if current_view else '')
        
# #         if len(kp_pixel) > 0:
# #             ax.scatter(kp_pixel[:,1], kp_pixel[:,0], c=[colors[kp_id-1]], 
# #                        s=50 if current_view else 30, marker=markers[kp_id-1], alpha=0.8,
# #                       edgecolors='white', linewidth=0.5, label=label)

# #     ax.set_xlim(0, W)
# #     ax.set_ylim(H, 0)
# #     ax.legend(loc='upper right', fontsize=8)
# #     ax.set_title(f'{view_name} Projection')
# def plot_2d_projection(ax, img_rgb, kps, title='View'):
#     """
#     在2D图像上直接绘制关键点（不需要相机参数投影）
    
#     参数:
#         ax: matplotlib坐标轴对象
#         img_rgb: RGB图像 (形状为 (3, H, W) 或 (H, W, 3))
#         kps: 关键点坐标数组, 形状为 (N, 2)
#         title: 图像标题
#     """
#     # 处理图像格式
#     if isinstance(img_rgb, torch.Tensor):
#         img_rgb = img_rgb.cpu().numpy()
    
#     # 确保图像为 (H, W, 3) 格式
#     if img_rgb.shape[0] == 3:
#         img_rgb = img_rgb.transpose(1, 2, 0)
    
#     # 归一化图像到 [0, 255]
#     if img_rgb.max() <= 1.0:
#         img_rgb = (img_rgb * 255).astype(np.uint8)
#     else:
#         img_rgb = img_rgb.astype(np.uint8)
    
#     # 处理关键点坐标
#     if isinstance(kps, torch.Tensor):
#         kps = kps.cpu().numpy()
    
#     H, W = img_rgb.shape[:2]
    
#     # 显示RGB图像
#     ax.imshow(img_rgb)
    
#     # 确保关键点在图像范围内
#     kps = np.clip(kps, [0, 0], [W-1, H-1])
    
#     # 绘制关键点（300个点）
#     ax.scatter(kps[:, 0], kps[:, 1], c='yellow', s=20, marker='o', 
#                edgecolors='black', linewidth=0.5, alpha=0.7, label='Keypoints')
    
#     # 设置坐标轴属性
#     ax.set_xlim(0, W)
#     ax.set_ylim(H, 0)  # 翻转y轴以匹配图像坐标系
#     ax.set_title(title)
#     ax.legend(loc='upper right', fontsize=8)
#     ax.set_xticks([])
#     ax.set_yticks([])


# # # @timeit
# # def visualize_comparison(data, rgb , save_dir="debug_runs/comparison_results"):
# #     """
# #     多模态比较预测关键点与原始点云
# #     包含：3D点云对比、2D投影对比、深度图叠加、误差分析
# #     """
# #     os.makedirs(save_dir, exist_ok=True)
# #     print("================================")
# #     print(f"save to dir {save_dir}!")
# #     print("================================")
    
    
    
# #     # ================= 1. 3D点云对比可视化 =================
# #     fig = plt.figure(figsize=(18, 10))
# #     kp_3d_pred = {
# #             "kp_1": batch["kp_1"],
# #             "kp_2": batch["kp_2"],
# #             "kp_3": batch["kp_3"]
# #         }
# #     # 3D视图1
# #     ax1 = fig.add_subplot(121, projection='3d')
# #     plot_3d_comparison(ax1, data, kp_3d_pred, view='overview')
# #     # 3D视图2（不同视角）
# #     ax2 = fig.add_subplot(122, projection='3d')
# #     plot_3d_comparison(ax2, data, kp_3d_pred, view='top_down')
# #     plt.tight_layout()
# #     plt.savefig(f"{save_dir}/3d_comparison.png", dpi=300)
# #     plt.close()
# #     # ================= 2. 各相机视角2D投影对比 =================
# #     view_mapping = {
# #         '1': 1, 
# #         '2': 2,
# #         '3': 3,

# #     }
# #     original_kp = {
# #         '1': data['kp_1'][IMG_IDX].squeeze(0) ,
# #         '2': data['kp_2'][IMG_IDX].squeeze(0) ,
# #         '3': data['kp_3'][IMG_IDX].squeeze(0) ,
   
# #     }    
# #     for view_name, view_idx in view_mapping.items():
# #         fig = plt.figure(figsize=(15, 6))
# #         # 点图
# #         ax1 = fig.add_subplot(131)
# #         plot_2d_projection(ax1, data, rgb , view_idx, overlay_type='PT', kps=original_kp[view_name])
# #         ax1.set_title(f'{view_name}  Pointmap')
# #         # 深度图
# #         ax2 = fig.add_subplot(132)
# #         plot_2d_projection(ax2, data, rgb ,view_idx, overlay_type='depth', kps=original_kp[view_name])
# #         ax2.set_title(f'{view_name}  Depth')
# #         # RGB图像
# #         ax3 = fig.add_subplot(133)
# #         plot_2d_projection(ax3, data, rgb , view_idx , overlay_type='rgb', kps=original_kp[view_name])
# #         ax3.set_title(f'{view_name}  RGB')
# #         plt.tight_layout()
# #         plt.savefig(f"{save_dir}/{view_name}.png", dpi=300)
# #         plt.close()
# def visualize_comparison(data, rgb_vggt_1, save_dir="debug_runs/comparison_results"):
#     """
#     在RGB图像上直接可视化关键点
#     """
#     os.makedirs(save_dir, exist_ok=True)
#     print(f"Save to directory: {save_dir}")
    
#     # 定义视图映射（根据你的实际视图配置调整）
#     view_mapping = {
#         '1': 0,  # 第一个视图
#         '2': 1,  # 第二个视图
#         '3': 2   # 第三个视图
#     }
    
#     # 获取每个视图的关键点（假设关键点存储在data中）
#     print(f"[DEBUG] data['kp_1'].shape = {data['kp_1'].shape}, data['kp_1'][0].shape = {data['kp_1'][0].shape}, \
#         data['kp_1'][0].squeeze(0).shape = {data['kp_1'][0].squeeze(0).shape}")
#     original_kp = {
#         '1': data['kp_1'][0].squeeze(0),  # 取第一个样本，并去除多余的维度
#         '2': data['kp_2'][0].squeeze(0),
#         '3': data['kp_3'][0].squeeze(0)
#     }
    
#     # 每个视图单独绘制
#     for view_name, view_idx in view_mapping.items():
#         # 获取当前视图的图像（从rgb_vggt_1中提取）
#         print(f"[DEBUG] rgb_vggt_1[view_idx].shape = {rgb_vggt_1[view_idx].shape}")
#         view_img = rgb_vggt_1[view_idx]  # 取出当前视图的图像
        
#         # 创建图像
#         fig, ax = plt.subplots(figsize=(8, 6))
        
#         # 调用修改后的函数
#         plot_2d_projection(
#             ax=ax,
#             img_rgb=view_img,
#             kps=original_kp[view_name],
#             title=f'View {view_name} Keypoints'
#         )
        
#         # 保存图像
#         plt.tight_layout()
#         plt.savefig(f"{save_dir}/view_{view_name}.png", dpi=200)
#         plt.close()

# def visualize_comparison_st2(data, rgb_vggt_2, save_dir="debug_runs/comparison_results"):
#     """
#     在RGB图像上直接可视化关键点
#     """
#     os.makedirs(save_dir, exist_ok=True)
#     print(f"Save to directory: {save_dir}")
    
#     # 定义视图映射（根据你的实际视图配置调整）
#     view_mapping = {
#         '1': 0,  # 第一个视图
#         '2': 1,  # 第二个视图
#         '3': 2   # 第三个视图
#     }
    
#     # 获取每个视图的关键点（假设关键点存储在data中）
#     original_kp = {
#         '1': data['kp_1_st2'][0].squeeze(0),  # 取第一个样本，并去除多余的维度
#         '2': data['kp_2_st2'][0].squeeze(0),
#         '3': data['kp_3_st2'][0].squeeze(0)
#     }
    
#     # 每个视图单独绘制
#     for view_name, view_idx in view_mapping.items():
#         # 获取当前视图的图像（从rgb_vggt_2中提取）
#         view_img = rgb_vggt_2[view_idx]  # 取出当前视图的图像
        
#         # 创建图像
#         fig, ax = plt.subplots(figsize=(8, 6))
        
#         # 调用修改后的函数
#         plot_2d_projection(
#             ax=ax,
#             img_rgb=view_img,
#             kps=original_kp[view_name],
#             title=f'View {view_name} Keypoints'
#         )
        
#         # 保存图像
#         plt.tight_layout()
#         plt.savefig(f"{save_dir}/view_{view_name}_st2.png", dpi=200)
#         plt.close()

def plot_3d_comparison(ax, data_test, kp_3d_pred, view='overview'):
    """3D点云可视化"""
    # 获取ground truth点云（各视角先reshape为(N,3)再合并）
    gt_points = []
    for view_name in ['front', 'left_shoulder', 'right_shoulder', 'wrist']:
        pc = data_test[f'{view_name}_point_cloud']
        if isinstance(pc, torch.Tensor):
            pc = pc.cpu().numpy()
        if pc.ndim == 5:    # (B,1,3,H,W)
            pc = pc[0,0]    # 取第一个样本，移除冗余维度
        elif pc.ndim == 4:  # (B,3,H,W)
            pc = pc[0]      # 取第一个样本
        elif pc.ndim == 3:  # (3,W,H)
            pass            
        else:
            raise ValueError(f"Unexpected point cloud shape: {pc.shape}")
            
        pc = pc.transpose(1,2,0)  
        # pc = data_test[f'{view_name}_point_cloud'].transpose(1,2,0)  # shape: (W,H,3)
        gt_points.append(pc.reshape(-1,3))          # (16384, 3)
    gt_points = np.concatenate(gt_points, axis=0)   # (65536, 3)
    # 自动计算合理的坐标范围（保留10%边界）
    x_min, x_max = np.percentile(gt_points[:,0], [5, 95])
    y_min, y_max = np.percentile(gt_points[:,1], [5, 95]) 
    z_min, z_max = np.percentile(gt_points[:,2], [5, 95])
    bounds = [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
    # 过滤无效点（全零点和NaN）
    valid_mask = ~np.all(gt_points == 0, axis=1) & ~np.isnan(gt_points).any(axis=1)
    gt_points = gt_points[valid_mask]
    # 绘制ground truth点云（使用更高效的绘制方式）
    ax.scatter(gt_points[:,0], gt_points[:,1], gt_points[:,2], 
              c='gray', alpha=0.1, s=5, label='Scene PointCloud')

    # 绘制关键点（过滤无效点）
    colors = ['red', 'green', 'blue', 'yellow']
    for i in range(1, 4):
        kp = kp_3d_pred[f'kp_{i}']  # (300, 3)
        if isinstance(kp, torch.Tensor):
            kp = kp.cpu().numpy()
        # 处理可能的batch维度
        if kp.ndim == 3:  # (B,N,3)
            kp = kp[0]    # 取第一个样本
        valid_kp = kp[~np.all(kp == 0, axis=1)]  # 过滤全零点 
        if len(valid_kp) > 0:
            ax.scatter(valid_kp[0], valid_kp[1], valid_kp[2], #valid_kp[:,0], valid_kp[:,1], valid_kp[:,2],
                       c=colors[i-1], s=30, marker='o', label=f'KP_{i} ({len(valid_kp)} pts)')
    # 设置坐标轴
    ax.set_xlabel('X')
    ax.set_ylabel('Y') 
    ax.set_zlabel('Z')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.legend(loc='upper right', fontsize=8)
    # 设置视角和范围
    if view == 'front':
        ax.view_init(elev=0, azim=0)  # 正前方视角
    elif view == 'top_down':
        ax.view_init(elev=90, azim=0)  # 正上方视角
    else:
        ax.view_init(elev=30, azim=60)
    if bounds:
        ax.set_xlim(bounds[0], bounds[3])
        ax.set_ylim(bounds[1], bounds[4])
        ax.set_zlim(bounds[2], bounds[5])



def plot_2d_projection(ax, data_test, kp_3d_pred, rgb_vggt, overlay_type='rgb', original_kps=None):
    """
    绘制2D投影对比图，包括原始关键点和预测关键点。
    
    :param ax: matplotlib的axes对象，用于绘图
    :param data_test: 测试数据字典，包含RGB、深度图、相机内外参等信息
    :param kp_3d_pred: 预测的3D关键点字典
    :param view_name: 视角名称（例如'front', 'left_shoulder'等）
    :param overlay_type: 可选'rgb'或'depth'，决定背景是RGB图像还是深度图
    :param original_kps: 原始2D关键点坐标列表，如果提供，则会画在图像上
    """
    def safe_transpose(arr):
        """安全转置函数（处理3D/4D输入）"""
        if isinstance(arr, torch.Tensor):
            arr = arr.cpu().numpy()
            
        if arr.ndim == 4:  # (B,C,H,W)
            arr = arr[0]    # 取第一个样本
        elif arr.ndim == 3 and arr.shape[0] == 3:  # (C,H,W)
            pass
        else:
            arr = arr.squeeze()
        return arr.transpose(1,2,0)  # (H,W,C)
    rgb = safe_transpose(rgb_vggt)  # (H,W,3)
    depth = data_test[f'{view_name}_depth']
    if isinstance(depth, torch.Tensor):
        depth = depth.cpu().numpy()
    depth = depth.squeeze()  # (H,W)
    intrinsics = data_test[f'{view_name}_camera_intrinsics'].squeeze()
    extrinsics = data_test[f'{view_name}_camera_extrinsics'].squeeze()
    if isinstance(intrinsics, torch.Tensor):
        intrinsics = intrinsics.cpu().numpy()
    if isinstance(extrinsics, torch.Tensor):
        extrinsics = extrinsics.cpu().numpy()
    H, W = depth.shape

    # 显示背景（添加分辨率适配）
    if overlay_type == 'rgb':
        img = (rgb * 255).astype(np.uint8) if rgb.max() <= 1 else rgb.astype(np.uint8)
        if original_kps is not None:
            if isinstance(original_kps, torch.Tensor):
                original_kps = original_kps.cpu().numpy()
            original_kps = np.clip(original_kps, [0, 0], [W-1, H-1])
            ax.scatter(original_kps[:,0], original_kps[:,1], 
                      c='blue', s=20, marker='x', label='Original KPs')
        ax.imshow(img)
    else:
        ax.imshow(depth, cmap='gray', vmin=0, vmax=2)
    
    colors = plt.get_cmap('tab10').colors
    markers = ['*', 's', 'D', 'o']
    # 投影关键点
    for kp_id in range(1, 5):
        kp_world = kp_3d_pred[f'kp_{kp_id}']
        if isinstance(kp_world, torch.Tensor):
            kp_world = kp_world.cpu().numpy()
        # 处理可能的batch维度
        if kp_world.ndim == 3:  # (B,N,3)
            kp_world = kp_world[0]  # (N,3)
        # 世界→相机坐标系 (使用外参的逆变换)
        R = extrinsics[:3,:3]
        t = extrinsics[:3,3]
        # kp_cam = (kp_world - t) @ R.T
        # kp_homo = np.column_stack([kp_cam, np.ones(len(kp_cam))])
        # kp_pixel = kp_homo @ intrinsics.T
        kp_cam = (kp_world - t) @ R  # 等价于 R.T @ (kp_world - t)
        # 相机→像素坐标
        kp_pixel = kp_cam @ intrinsics.T
        kp_pixel = kp_pixel[:,:2] / kp_pixel[:,2:]
        # 过滤有效点
        valid_mask = (
            (kp_pixel[:,0] >= 0) & (kp_pixel[:,0] < W) &
            (kp_pixel[:,1] >= 0) & (kp_pixel[:,1] < H) &
            (kp_cam[:,2] > 0)  # z值必须为正（在相机前方）
        )
        kp_pixel = kp_pixel[valid_mask]
        # 绘制     当前视角的关键点，使用特殊标记
        current_view = (kp_id-1 == ['front', 'left_shoulder', 'right_shoulder', 'wrist'].index(view_name))
        label = f'KP_{kp_id}' + (' (current)' if current_view else '')
        
        if len(kp_pixel) > 0:
            ax.scatter(kp_pixel[:,1], kp_pixel[:,0], c=[colors[kp_id-1]], 
                       s=50 if current_view else 30, marker=markers[kp_id-1], alpha=0.8,
                      edgecolors='white', linewidth=0.5, label=label)

    ax.set_xlim(0, W)
    ax.set_ylim(H, 0)
    ax.legend(loc='upper right', fontsize=8)
    ax.set_title(f'{view_name} Projection')



def plot_error_heatmap(ax, data_test, kp_3d_pred, view_name):
    """绘制投影误差热力图"""
    def safe_squeeze(arr):
        """安全压缩维度，保留至少2D"""
        if isinstance(arr, torch.Tensor):
            arr = arr.cpu().numpy()
        while arr.ndim > 2:
            arr = arr.squeeze(0)
        return arr
    # 获取原始关键点和预测投影
    kp_idx = ['front', 'left_shoulder', 'right_shoulder', 'wrist'].index(view_name) + 1
    original_kps = safe_squeeze(data_test[f'kp_{kp_idx}'])
    scale = 128 / 518
    # 收集误差数据
    points = []
    errors = []
    for kp_id in range(1, 5):
        if kp_id-1 == ['front', 'left_shoulder', 'right_shoulder', 'wrist'].index(view_name):
            kp_world = safe_squeeze(kp_3d_pred[f'kp_{kp_id}'])
            extrinsics = safe_squeeze(data_test[f'{view_name}_camera_extrinsics'])
            intrinsics = safe_squeeze(data_test[f'{view_name}_camera_intrinsics'])
            R = extrinsics[:3,:3]
            t = extrinsics[:3,3]
            # kp_cam = (kp_world - t) @ R.T
            # kp_homo = np.column_stack([kp_cam, np.ones(len(kp_cam))])
            # kp_pixel = kp_homo @ intrinsics.T
            kp_cam = (kp_world - t) @ R
            kp_pixel = kp_cam @ intrinsics.T
            kp_pixel = kp_pixel[:,:2] / kp_pixel[:,2:]
            if len(kp_pixel) == len(original_kps):
                error = np.linalg.norm(kp_pixel - original_kps * scale, axis=1)
                points.append(kp_pixel)
                errors.append(error)
            else:
                print(f"Warning: Mismatched keypoints count ({len(kp_pixel)} vs {len(original_kps)})")

    # 绘制误差图
    if points and errors:
        points = np.concatenate(points, axis=0)  # (N,2)
        errors = np.concatenate(errors, axis=0)  # (N,)
        
        # 检查维度一致性
        assert len(points) == len(errors), \
               f"Points ({len(points)}) and errors ({len(errors)}) must have same length"
        # 方案1：带误差值的散点图
        sc = ax.scatter(points[:,0], points[:,1], c=errors,
                       cmap='Reds', s=30, alpha=0.7,
                       vmin=0, vmax=20)  # 设置合理的误差范围
        ax.set_xlabel('X pixel')
        ax.set_ylabel('Y pixel')
        ax.set_title('Reprojection Errors')
        cb = plt.colorbar(sc, ax=ax)
        cb.set_label('Error (pixels)')
        # 方案2：误差直方图（备用）
        # ax.hist(errors, bins=20, color='skyblue')
        # ax.set_xlabel('Error (pixels)')
        # ax.set_title('Error Distribution')
    else:
        ax.text(0.5, 0.5, 'No valid reprojections', 
               ha='center', va='center', transform=ax.transAxes)

def plot_2d_projection_simple(ax, img_rgb, kps, title='View'):
    """
    在2D图像上直接绘制关键点（简单版本）
    
    参数:
        ax: matplotlib坐标轴对象
        img_rgb: RGB图像 (形状为 (H, W, C) 或 (C, H, W))
        kps: 关键点坐标数组, 形状为 (N, 2)
        title: 图像标题
    """
    # 转换图像为numpy数组
    if isinstance(img_rgb, torch.Tensor):
        img_rgb = img_rgb.cpu().numpy()
    
    # 处理图像形状：确保为 (H, W, C)
    if img_rgb.ndim == 3 and img_rgb.shape[0] == 3:  # (C, H, W)
        img_rgb = img_rgb.transpose(1, 2, 0)
    
    # 归一化图像到[0, 255]
    if img_rgb.max() <= 1.0:
        img_rgb = (img_rgb * 255).astype(np.uint8)
    else:
        img_rgb = img_rgb.astype(np.uint8)
    
    # 处理关键点
    if isinstance(kps, torch.Tensor):
        kps = kps.cpu().numpy()
    
    # 确保关键点是2D坐标
    if kps.ndim > 2:
        kps = kps.squeeze()
    
    if kps.shape[-1] != 2:
        print(f"Warning: Unexpected keypoint shape {kps.shape}, skipping plot")
        return
    
    H, W = img_rgb.shape[:2]
    
    # 显示RGB图像
    ax.imshow(img_rgb)
    
    # 确保关键点在图像范围内
    kps_clipped = np.clip(kps, [0, 0], [W-1, H-1])
    
    # 用高亮黄色×符号绘制关键点
    ax.scatter(kps_clipped[:, 0], kps_clipped[:, 1], 
               c='yellow', s=50, marker='x', linewidths=2,
               edgecolors='black', alpha=0.9, label='Keypoints')
    
    # 设置坐标轴属性
    ax.set_xlim(0, W)
    ax.set_ylim(H, 0)  # 翻转y轴以匹配图像坐标系
    ax.set_title(title)
    ax.legend(loc='upper right', fontsize=8)
    ax.set_xticks([])
    ax.set_yticks([])

def visualize_comparison(kp_3d_pred, rgb_vggt, save_dir="debug_runs/comparison_results"):
    """
    多模态比较预测关键点与原始点云
    适配新维度:
        kp_3d_pred['kp_1/2/3'].shape = (8, 1, 1, 300, 2)
        rgb_vggt.shape = torch.Size([8, 4, 3, 224, 224])
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # 只处理batch中的第一个样本
    batch_idx = 0
    
    # 从数据中提取第一个样本的关键点
    kps = {
        '1': kp_3d_pred['kp_1'][batch_idx, 0, 0] * 224 / 518,  # (300, 2)
        '2': kp_3d_pred['kp_2'][batch_idx, 0, 0] * 224 / 518,
        '3': kp_3d_pred['kp_3'][batch_idx, 0, 0] * 224 / 518
    }
    
    # 从RGB数据中提取前三个视图的图像（忽略第四个视图）
    view_images = {
        '1': rgb_vggt[batch_idx, 0],  # (3, 224, 224)
        '2': rgb_vggt[batch_idx, 1],
        '3': rgb_vggt[batch_idx, 2]
    }
    
    # 创建一个大图显示所有三个视图
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # 绘制三个视图
    for i, view_name in enumerate(['1', '2', '3']):
        plot_2d_projection_simple(
            ax=axes[i],
            img_rgb=view_images[view_name],
            kps=kps[view_name],
            title=f'View {view_name}'
        )
    
    # 保存图像
    plt.tight_layout()
    plt.savefig(f"{save_dir}/all_views.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 同时保存每个视图的单独图像
    for view_name in ['1', '2', '3']:
        fig, ax = plt.subplots(figsize=(8, 6))
        plot_2d_projection_simple(
            ax=ax,
            img_rgb=view_images[view_name],
            kps=kps[view_name],
            title=f'View {view_name}'
        )
        plt.savefig(f"{save_dir}/view_{view_name}.png", dpi=200, bbox_inches='tight')
        plt.close()

        

# def visualize_comparison(data_test, kp_3d_pred, rgb_vggt, save_dir="debug_runs/comparison_results"):
#     """
#     多模态比较预测关键点与原始点云
#     包含：3D点云对比、2D投影对比、深度图叠加、误差分析
#     """
#     os.makedirs(save_dir, exist_ok=True)
#     # ========== 0. 数据预处理（自动处理张量转换） ==========
#     def to_numpy(data):
#         if isinstance(data, torch.Tensor):
#             return data.detach().cpu().numpy()
#         return np.array(data)
    
#     # 转换所有输入数据
#     data_test = {k: to_numpy(v) for k, v in data_test.items()}
#     kp_3d_pred = {k: to_numpy(v) for k, v in kp_3d_pred.items()}

#     # kp_3d_pred['kp_1/2/3'].shape =(8, 1, 1, 300, 2)
#     # rgb_vggt.shape = torch.Size([8, 4, 3, 224, 224])
#     print(f"[DEBUG] {kp_3d_pred.keys()}, kp_3d_pred['kp_1'].max ={kp_3d_pred['kp_1'].max()}, \
#         kp_3d_pred['kp_2'].max ={kp_3d_pred['kp_2'].max()}, kp_3d_pred['kp_3'].max ={kp_3d_pred['kp_3'].max()}")

#     # ================= 1. 3D点云对比可视化 =================
#     # fig = plt.figure(figsize=(18, 10))
#     # # 3D视图1
#     # ax1 = fig.add_subplot(121, projection='3d')
#     # plot_3d_comparison(ax1, data_test, kp_3d_pred, view='overview')
#     # # 3D视图2（不同视角）
#     # ax2 = fig.add_subplot(122, projection='3d')
#     # plot_3d_comparison(ax2, data_test, kp_3d_pred, view='top_down')
#     # plt.tight_layout()
#     # plt.savefig(f"{save_dir}/3d_comparison.png", dpi=300)
#     # plt.close()
#     # ================= 2. 各相机视角2D投影对比 =================
#     view_mapping = {
#         'front': 'front',
#         'top': 'top',
#         'right': 'right'
#     }
#     original_keypoints = {
#         'front': data_test['kp_1'].squeeze() * 128 / 518,
#         'top': data_test['kp_2'].squeeze() * 128 / 518,
#         'right': data_test['kp_3'].squeeze() * 128 / 518,
#         # 'wrist': data_test['kp_4'].squeeze() * 128 / 518
#     }    
#     for view_name in view_mapping.keys():
#         fig = plt.figure(figsize=(15, 6))
#         # RGB图像
#         ax1 = fig.add_subplot(131)
#         plot_2d_projection(ax1, data_test, kp_3d_pred, rgb_vggt[0,1], overlay_type='rgb', original_kps=original_keypoints[view_name])
#         ax1.set_title(f'{view_name} RGB Projection')
#         # 深度图
#         ax2 = fig.add_subplot(132)
#         plot_2d_projection(ax2, data_test, kp_3d_pred, rgb_vggt, overlay_type='depth')
#         ax2.set_title(f'{view_name} Depth Projection')
#         # 误差热力图
#         ax3 = fig.add_subplot(133)
#         plot_error_heatmap(ax3, data_test, kp_3d_pred, view_name)
#         ax3.set_title(f'{view_name} Error Heatmap')
#         plt.tight_layout()
#         plt.savefig(f"{save_dir}/2d_{view_name}_comparison.png", dpi=300)
#         plt.close()


def visualize_time(save_dir):
    import matplotlib.font_manager as fm
    import matplotlib.pyplot as plt
   
    # 找到可用的替代字体
    available_fonts = [f.name for f in fm.fontManager.ttflist]
    # print("可用字体:", available_fonts)
    # 设置全局字体
    plt.rcParams['font.family'] = 'STIXGeneral'
    
    # 数据准备
    modules = ['Static-view', 'Dynamic-view', 'Action Head', '3D Supervision']
    components = ['Token', 'Attention', 'Decoder', 'VGGT Agg', 'VGGT Pred', 'Keypoints']

    # 双流和Action Head数据 (10^{-3}分钟)
    static = [0.079, 0.727, 0.568]  # 分项
    dynamic = [0.0243, 0.324, 2.14] # 分项
    action_head = [4.06]             # 总时间

    # 3D监督生成数据 (10^{-2}分钟)
    vggt_agg = 10.00
    vggt_pred = 8.80
    keypoints = 8.92
    supervision_total = 30.9

    # 创建画布
    fig, ax1 = plt.subplots(figsize=(8, 8))

    # 主Y轴：双流和Action Head (线性刻度)
    x_positions = np.array([0, 1, 2, 3])
    width = 0.7

    # Static-view 堆叠条形
    bottom = 0
    stream_colors = ['#0098B2', '#70CDBE', '#ADE7A8']  # 双流组件颜色

    for i, val in enumerate(static):
        ax1.bar(x_positions[0], val, width, bottom=bottom, color=stream_colors[i])
        if val > 0.1:
            ax1.text(x_positions[0], bottom + val/2, f'{val:.3f}', 
                    ha='center', va='center', color='darkgreen', fontsize=22, weight='bold')
        bottom += val

    # Dynamic-view 堆叠条形
    bottom = 0
    for i, val in enumerate(dynamic):
        ax1.bar(x_positions[1], val, width, bottom=bottom, color=stream_colors[i])
        if val > 0.1:
            ax1.text(x_positions[1], bottom + val/2, f'{val:.3f}', 
                    ha='center', va='center', color='darkgreen', fontsize=22, weight='bold')
        bottom += val

    # Action Head 单一条形
    action_head_color = '#FFDF97'
    ax1.bar(x_positions[2], action_head[0], width, color=action_head_color)
    ax1.text(x_positions[2], action_head[0]/2, f'{action_head[0]:.2f}', 
            ha='center', va='center', color='#CC5500', fontsize=22, weight='bold')

    # 3D监督生成 - 添加底色条
    ax1.bar(x_positions[3], 1.0, width, color='#F0F0F0', alpha=0.3)  # 仅占位用

    # 设置主Y轴
    ax1.set_ylabel('Time (x 1e-3 minutes)', fontsize=18)
    ax1.set_ylim(0, 5)  # 仅适用于前三个条形
    ax1.set_xticks(x_positions)
    ax1.set_xticklabels([
        'Static-view\nStream', 
        'Dynamic-view\nStream', 
        'RVT-2 Action\nHead',
        '3D Supervision\nGeneration'
    ], fontsize=18)
    ax1.grid(axis='y', linestyle='--', alpha=0.3)
    ax1.set_xlim(-0.5, 3.5)

    # 次Y轴：3D监督生成
    ax2 = ax1.twinx()
    bottom = 0
    supervision_colors = ['#0074B3', '#6CA3D4', '#98CFE6']  # 3D监督组件颜色

    # 绘制3D Supervision分项
    for i, (val, color) in enumerate(zip([vggt_agg, vggt_pred, keypoints], supervision_colors)):
        ax2.bar(x_positions[3], val, width, bottom=bottom, color=color)
        ax2.text(x_positions[3], bottom + val/2, f'{val:.1f}', 
                ha='center', va='center', color='white', fontsize=22, weight='bold')
        bottom += val
    
    # 设置次Y轴范围
    ax2.set_ylim(0, 30)
    ax2.set_ylabel('Time (x 1e-2 minutes)', fontsize=18, color='#555555', labelpad=10)
    ax2.tick_params(axis='y', colors='#555555')
    
    # 添加Y轴适用范围的说明
    # plt.text(3.5, 15, '→ Right Y-axis applies\nonly to this section', 
    #          ha='left', va='center', fontsize=9, color='#555555', 
    #          bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
    
    # 创建图例（严格按要求的列分组）
    from matplotlib.patches import Patch
    
    # 第一列：双流组件
    column1_elements = [
        Patch(facecolor='#0098B2', label='Token Embedding'),
        Patch(facecolor='#70CDBE', label='Attention'),
        Patch(facecolor='#ADE7A8', label='Decoder')
    ]
    
    # 第二列：3D监督组件
    column2_elements = [
        Patch(facecolor='#0074B3', label='VGGT Aggregator'),
        Patch(facecolor='#6CA3D4', label='VGGT Decoder'),
        Patch(facecolor='#98CFE6', label='Keypoints Extraction')
    ]
    
    # 创建两个独立的图例
    legend1 = ax1.legend(
        handles=column1_elements, 
        loc='upper left', 
        bbox_to_anchor=(0.00, 0.98),  # 左上角位置
        title="Dual Stream Submodules",
        fontsize=15,
        frameon=True,
        framealpha=0.8
    )
    
    legend2 = ax1.legend(
        handles=column2_elements, 
        loc='upper right', 
        bbox_to_anchor=(0.415, 0.75),  # 右上角位置
        title="3D Supervision Submodules",
        fontsize=15,
        frameon=True,
        framealpha=0.8
    )
    
    # 添加Action Head图例
    action_legend = ax1.legend(
        handles=[Patch(facecolor=action_head_color, label='Action Head')],
        loc='upper center', 
        bbox_to_anchor=(0.6, 0.98),  # 顶部中间位置
        fontsize=15,
        frameon=True,
        framealpha=0.8
    )
    
    # 确保所有图例都显示
    ax1.add_artist(legend1)
    ax1.add_artist(legend2)
    
    # 添加标题
    plt.title('Training Time Analysis of Cortical Policy Components', 
             fontsize=18, pad=15, weight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # 为顶部图例留空间
    plt.savefig(f"{save_dir}/time_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()

    # import matplotlib
    # matplotlib.use('Agg')
    # import matplotlib.pyplot as plt
    # import numpy as np
    # from matplotlib.patches import Patch

    # # 数据准备
    # streams = ['Static-view', 'Dynamic-view', 'Action Head']
    # stream_times = [1.37, 2.48, 4.06]  # ×10⁻³分钟

    # supervision = ['VGGT Agg', 'VGGT Pred', 'Keypoints']
    # supervision_times = [1.00, 0.88, 0.892]  # ×10⁻²分钟

    # # 创建画布
    # fig, ax = plt.subplots(figsize=(10, 8))

    # # 中心环形图：双流与Action Head
    # wedges1, texts1 = ax.pie(
    #     stream_times, 
    #     radius=1.2,
    #     colors=['#4C72B0', '#55A868', '#C44E52'],
    #     wedgeprops=dict(width=0.3, edgecolor='w'),
    #     labels=streams,  # 直接在pie函数中设置标签
    #     labeldistance=0.8,  # 调整标签距离
    #     textprops=dict(color="w", weight="bold", fontsize=10)  # 设置文本属性
    # )

    # # 外环：3D监督生成
    # wedges2, texts2 = ax.pie(
    #     [sum(supervision_times)] * 3,  # 分成三等份
    #     radius=1.5,
    #     colors=['#DDDDDD', '#AAAAAA', '#888888'],  # 不同灰度区分
    #     wedgeprops=dict(width=0.2, edgecolor='w'),
    #     startangle=90  # 从12点方向开始
    # )

    # # 添加标题
    # plt.title('Training Time Analysis\n(Area Proportional to Time)', pad=20, fontsize=14)

    # # 添加图例解释颜色
    # legend_elements = [
    #     Patch(facecolor='#4C72B0', label='Static-view Stream'),
    #     Patch(facecolor='#55A868', label='Dynamic-view Stream'),
    #     Patch(facecolor='#C44E52', label='Action Head'),
    #     Patch(facecolor='#DDDDDD', label='VGGT Aggregator'),
    #     Patch(facecolor='#AAAAAA', label='VGGT Prediction'),
    #     Patch(facecolor='#888888', label='Keypoints Selection')
    # ]

    # plt.legend(
    #     handles=legend_elements, 
    #     loc='upper center', 
    #     bbox_to_anchor=(0.5, -0.05),
    #     ncol=3,
    #     fontsize=9
    # )

    # # 添加时间值标签
    # for i, wedge in enumerate(wedges1):
    #     angle = (wedge.theta2 - wedge.theta1)/2. + wedge.theta1
    #     x = 1.1 * np.cos(np.deg2rad(angle))
    #     y = 1.1 * np.sin(np.deg2rad(angle))
    #     ax.text(x, y, f'{stream_times[i]}×10⁻³', 
    #             ha='center', va='center', fontsize=8)

    # # 添加3D监督总时间
    # ax.text(1.7, 0, f'Total: {sum(supervision_times):.2f}×10⁻²', 
    #         ha='center', va='center', fontsize=9, backgroundcolor='white')

    # plt.tight_layout()
    # plt.savefig('training_time_donut.png', dpi=300, bbox_inches='tight')

def visualize_projection(video_path, csv_path, frame_index=0, output_img_path="projection_visualization.png"):
    """
    在指定帧上可视化投影点（无pandas版本）
    :param video_path: 视频文件路径
    :param csv_path: 生成的CSV文件路径
    :param frame_index: 要显示的帧序号 (默认为第0帧)
    :param output_img_path: 输出图像保存路径
    """
    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("无法打开视频文件")
    
    # 跳转到指定帧
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ret, frame = cap.read()
    if not ret:
        cap.release()
        raise ValueError("无法读取指定帧")
    
    # 读取CSV数据（无pandas版本）
    timestamps = []
    x_coords = []
    y_coords = []
    
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            timestamps.append(float(row['timestamp']))
            x_coords.append(float(row['x_norm']))
            y_coords.append(float(row['y_norm']))
    
    if frame_index >= len(timestamps):
        cap.release()
        raise ValueError("帧索引超出CSV数据范围")
    
    # 获取归一化坐标
    x_norm = x_coords[frame_index]
    y_norm = y_coords[frame_index]
    
    # 转换为像素坐标
    height, width = frame.shape[:2]
    x_pixel = int(x_norm * width)
    y_pixel = int(y_norm * height)
    
    print(f"归一化坐标: ({x_norm}, {y_norm})")
    print(f"像素坐标: ({x_pixel}, {y_pixel})")
    print(f"视频帧尺寸: {width}x{height}")

    # 在帧上绘制标记
    marked_frame = cv2.circle(frame.copy(), 
                             (x_pixel, y_pixel), 
                             radius=3, 
                             color=(0, 255, 0),  # BGR格式：绿色
                             thickness=2)
    
    # 文字位置智能调整（避免超出边界）
    text = f"({x_norm:.2f}, {y_norm:.2f})"
    text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
    
    # 计算文字位置（自动避开边缘）
    text_x = max(10, min(x_pixel - text_size[0]//2, width - text_size[0] - 10))
    text_y = max(30, min(y_pixel - 20, height - 10))
    
    cv2.putText(marked_frame, text, (text_x, text_y), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    # 2. 添加帧信息（固定在左上角）
    cv2.putText(marked_frame, f"Frame: {frame_index}", (10, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
    
    # 保存结果图像（使用OpenCV）
    cv2.imwrite(output_img_path, marked_frame)
    print(f"可视化结果已保存到: {output_img_path}")
    
    # 或者使用matplotlib保存（可选）
    """
    plt.figure(figsize=(12, 6))
    plt.imshow(cv2.cvtColor(marked_frame, cv2.COLOR_BGR2RGB))
    plt.title(f"Frame {frame_index} | Projection at ({x_norm:.3f}, {y_norm:.3f})")
    plt.axis('off')
    plt.savefig(output_img_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    """
    
    cap.release()


if __name__ == "__main__":
    # visualize_time("/RVT/dual_stream/debug_runs")
    replay_root = "/dataset/rlbench/heuristic"
    task = "turn_tap"
    replay_name = "200.replay"

    replay_path = os.path.join(replay_root, task, replay_name)
    import pickle
    with open(replay_path, 'rb') as f:
        batch = pickle.load(f)
    
    visualize_root = "/dataset/rlbench/future"
    visualize_task = "insert_onto_square_peg"
    visualize_epoch = "0"
    video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}_st1.mp4"
    csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}_st1.csv"
    visualize_projection(video_path, csv_path, frame_index=125, \
        output_img_path=f"/RVT/dual_stream/debug_runs/{visualize_task}_ep{visualize_epoch}_st1.png")

    video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}_st2.mp4"
    csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}_st2.csv"
    visualize_projection(video_path, csv_path, frame_index=125, \
        output_img_path=f"/RVT/dual_stream/debug_runs/{visualize_task}_ep{visualize_epoch}_st2.png")