import av
import os
import csv
import cv2
import torch
import pickle
import numpy as np
from tqdm import tqdm
# from moviepy.editor import *
import multiprocessing as mp
from functools import partial
import matplotlib.pyplot as plt
from moviepy.video.io.VideoFileClip import VideoFileClip
# from pytorch3d.transforms import euler_angles_to_matrix
from natsort import natsorted  # 自然排序（episode0, episode1,...）
from torch.cuda.amp import autocast

import sys
sys.path.append('/projects/colosseum/SAM2Act/sam2Act_COLOSSEUM')
import sam2act_colosseum.mvt.utils as mvt_utils
import sam2act_colosseum.rvt.rvt_utils as rvt_utils
from sam2act_colosseum.configs.config import SCENE_BOUNDS
from point_renderer.rvt_renderer import RVTBoxRenderer as BoxRenderer
from sam2act_colosseum.utils.tools import stack_on_channel, _norm_rgb, get_stored_demo, CAMERAS


def png_to_mp4(src_png_dir, dst_mp4_path, fps=30):
    """
    将PNG序列转换为MP4视频文件
    :param src_png_dir: 包含PNG文件的源目录（如.../episode0/wrist_rgb/）
    :param dst_mp4_path: 目标MP4路径（如.../close_jar/ep_0.mp4）
    :param fps: 输出视频帧率
    """
    # 确保目标目录存在
    os.makedirs(os.path.dirname(dst_mp4_path), exist_ok=True)
    
    # 获取排序后的PNG文件列表
    png_files = natsorted([
        f for f in os.listdir(src_png_dir) 
        if f.endswith('.png')
    ])
    
    if not png_files:
        raise ValueError(f"No PNG files found in {src_png_dir}")
    
    # 初始化视频写入器
    first_frame = cv2.imread(os.path.join(src_png_dir, png_files[0]))
    h, w, _ = first_frame.shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(dst_mp4_path, fourcc, fps, (w, h))
    
    # 逐帧写入视频
    for png in tqdm(png_files, desc=f"Processing {os.path.basename(dst_mp4_path)}"):
        frame = cv2.imread(os.path.join(src_png_dir, png))
        writer.write(frame)
    writer.release()

#==================================================================================================
# 下面是工具函数
#==================================================================================================
EndEffectorPts = [
    [0, 0, 0, 1],
    [0.1, 0, 0, 1],
    [0, 0.1, 0, 1],
    [0, 0, 0.1, 1]
]
Gripper2EEFCvt = [
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
]

def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
    """
    Return the rotation matrices for one of the rotations about an axis
    of which Euler angles describe, for each value of the angle given.

    Args:
        axis: Axis label "X" or "Y or "Z".
        angle: any shape tensor of Euler angles in radians

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """

    cos = torch.cos(angle)
    sin = torch.sin(angle)
    one = torch.ones_like(angle)
    zero = torch.zeros_like(angle)

    if axis == "X":
        R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
    elif axis == "Y":
        R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
    elif axis == "Z":
        R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
    else:
        raise ValueError("letter must be either X, Y or Z.")

    return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))

def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
    """
    Convert rotations given as Euler angles in radians to rotation matrices.

    Args:
        euler_angles: Euler angles in radians as tensor of shape (..., 3).
        convention: Convention string of three uppercase letters from
            {"X", "Y", and "Z"}.

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
        raise ValueError("Invalid input euler angles.")
    if len(convention) != 3:
        raise ValueError("Convention must have 3 letters.")
    if convention[1] in (convention[0], convention[2]):
        raise ValueError(f"Invalid convention {convention}.")
    for letter in convention:
        if letter not in ("X", "Y", "Z"):
            raise ValueError(f"Invalid letter {letter} in convention string.")
    matrices = [
        _axis_angle_rotation(c, e)
        for c, e in zip(convention, torch.unbind(euler_angles, -1))
    ]
    # return functools.reduce(torch.matmul, matrices)
    return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])

def get_extrinsic_matrices(extrinsics):
    """
    从相机外参矩阵推导w2c和c2w变换矩阵
    
    参数:
        extrinsics: 相机外参矩阵 [num_views, 4, 4]
                    描述从世界坐标系到相机坐标系的变换
    
    返回:
        w2c: 世界坐标系→相机坐标系变换矩阵 [num_views, 4, 4]
        c2w: 相机坐标系→世界坐标系变换矩阵 [num_views, 4, 4]
    """
    # 外参矩阵就是w2c变换 (世界到相机)
    w2c = extrinsics.clone()
    
    # c2w是w2c的逆矩阵 (即相机到世界坐标系)
    # 注意: 对于齐次坐标矩阵，求逆有特殊优化
    c2w = torch.empty_like(w2c)
    for i in range(w2c.shape[0]):
        # 提取旋转和平移分量
        R = w2c[i, :3, :3]
        T = w2c[i, :3, 3]
        
        # 计算逆变换: [R^T | -R^T * T]
        R_inv = R.T
        T_inv = -R_inv @ T
        
        # 构造齐次变换矩阵
        c2w[i, :3, :3] = R_inv
        c2w[i, :3, 3] = T_inv
        c2w[i, 3, :3] = 0
        c2w[i, 3, 3] = 1
    
    return w2c, c2w

def get_transformation_matrix_from_euler(euler_angles):
    """
    使用 PyTorch3D 将欧拉角转换为 4x4 齐次变换矩阵
    
    参数:
        euler_angles: [batch_size, 6] 张量
            前3个元素: 位置 (x, y, z)
            后3个元素: 欧拉角 (弧度) (rot_x, rot_y, rot_z)
    
    返回:
        4x4 变换矩阵 [batch_size, 4, 4]
    """
    batch_size = euler_angles.shape[0]
    device = euler_angles.device
    
    # 分离位置和旋转
    position = euler_angles[:, :3]
    rotation_angles = euler_angles[:, 3:]
    
    # 使用 PyTorch3D 将欧拉角转换为旋转矩阵
    # 注意: PyTorch3D 使用 XYZ 顺序，角度单位为弧度
    rotation_matrix = euler_angles_to_matrix(rotation_angles, convention="XYZ")
    
    # 创建齐次变换矩阵
    transform_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
    transform_matrix[:, :3, :3] = rotation_matrix
    transform_matrix[:, :3, 3] = position
    
    return transform_matrix

def _preprocess_inputs(replay_sample, cameras):
    obs, pcds = [], []
    for n in cameras:
        rgb = torch.from_numpy(getattr(replay_sample, f"{n}_rgb")).unsqueeze(0)  # [H,W,3] -> [1,H,W,3]
        rgb = rgb.permute(0, 3, 1, 2)  # [1,H,W,3] -> [1,3,H,W]
        rgb = _norm_rgb(rgb)

        # 处理点云数据
        pcd = torch.from_numpy(getattr(replay_sample, f"{n}_point_cloud"))  # [H,W,3]
        pcd = pcd.permute(2, 0, 1).unsqueeze(0)  # [H,W,3] -> [3,H,W] -> [1,3,H,W]

        # 存储数据
        obs.append([rgb, pcd])
        pcds.append(pcd)

    misc = getattr(replay_sample, "misc")
    dynamic_extrinsic = torch.from_numpy(misc["wrist_camera_extrinsics"]).unsqueeze(0) # [1, 4, 4]
    dynamic_intrinsic = torch.from_numpy(misc["wrist_camera_intrinsics"]).unsqueeze(0) # [1, 3, 3]

    pc_min = pcd.amin(dim=(-1, -2))  # [B,C]
    pc_max = pcd.amax(dim=(-1, -2))  # [B,C]
    
    # 确保pc_min是2维张量 [B,C]
    if pc_min.dim() == 1:
        pc_min = pc_min.unsqueeze(0)  # [C] -> [1,C]
        pc_max = pc_max.unsqueeze(0)
    
    # 确保有3个通道(x,y,z)
    if pc_min.size(-1) < 3:  # 使用size(-1)更安全
        pc_min = pc_min.repeat(1, 3)  # 复制到3个通道
        pc_max = pc_max.repeat(1, 3)
    
    dyn_cam_info = []
    for sample in range(dynamic_extrinsic.shape[0]):
        extrinsic = dynamic_extrinsic[sample]
        intrinsic = dynamic_intrinsic[sample]

        # 提取当前样本的点云范围
        extrinsic = extrinsic.unsqueeze(0)
        intrinsic = intrinsic.unsqueeze(0)
        current_pc_min = pc_min[sample, :3]
        current_pc_max = pc_max[sample, :3]

        R = ensure_float32(extrinsic[:, :3, :3])                # (num_dyn_cam = 1, 3, 3)
        T = ensure_float32(extrinsic[:, :3, 3])                 # (num_dyn_cam = 1, 3)
        
        scale = torch.ones(1, 3, device=extrinsic.device)
        x_range = (current_pc_max[0] - current_pc_min[0]).clamp(min=0.1)
        y_range = (current_pc_max[1] - current_pc_min[1]).clamp(min=0.1)
        scale[:, 0] = 1.0 / x_range
        scale[:, 1] = 1.0 / y_range

        static_range = 2.0
        aspect_ratio = x_range / y_range
        if aspect_ratio > 1:
            dyn_img_sizes_w = [ensure_float32(static_range), ensure_float32(static_range / aspect_ratio)]
        else:
            dyn_img_sizes_w = [ensure_float32(static_range * aspect_ratio), ensure_float32(static_range)]
        K = None  
        dyn_cam_info.append((R, T, dyn_img_sizes_w, K))
    return obs, pcds, dyn_cam_info

def ensure_float32(x):
    if torch.is_tensor(x):
        return x.float()
    elif isinstance(x, (list, tuple)):
        return type(x)(ensure_float32(item) for item in x)
    elif isinstance(x, (int, float)):
        return torch.tensor(x, dtype=torch.float32)
    return x

def batch_convert(root_train_dir, dst_gaze_dir):
    """
    批量处理所有任务和episode
    :param root_train_dir: 训练数据根目录（.../rlbench/train/）
    :param dst_gaze_dir: 目标目录（.../rlbench/future/full_scale/gaze/）
    """
    for task_name in os.listdir(root_train_dir):
        task_dir = os.path.join(root_train_dir, task_name)
        if not os.path.isdir(task_dir):
            continue
        os.makedirs(f"{dst_gaze_dir}/{task_name}", exist_ok=True)
            
        episodes_dir = os.path.join(task_dir, 'all_variations', 'episodes')
        for ep_dir in natsorted(os.listdir(episodes_dir)):
            if not ep_dir.startswith('episode'):
                continue
                
            # 构造路径
            src_png_dir = os.path.join(episodes_dir, ep_dir, 'wrist_rgb')
            ep_num = ep_dir.replace('episode', '')
            dst_mp4 = os.path.join(dst_gaze_dir, f"{task_name}/ep_{ep_num}.mp4")
            
            # 执行转换
            if os.path.exists(src_png_dir):
                png_to_mp4(src_png_dir, dst_mp4)


def batch_convert_render(root_train_dir, dst_gaze_dir):
    """
    批量处理所有任务和episode
    :param root_train_dir: 训练数据根目录（.../rlbench/train/）
    :param dst_gaze_dir: 目标目录（.../rlbench/future/full_scale/gaze/）
    """
    def render(pc, img_feat, dyn_cam_info):
        renderer = BoxRenderer(
            device="cuda:0",
            img_size=(224, 224),
            three_views=True,
            with_depth=True,
        )
        
        with torch.no_grad():  
            with autocast(enabled=False):
                img = []
                for _pc, _img_feat, _dyn_cam_info in zip(
                    pc, img_feat, dyn_cam_info
                ):
                    _pc = ensure_float32(_pc).view(-1, 3)
                    _img_feat = ensure_float32(_img_feat)
                    
                    # 处理相机参数
                    R = ensure_float32(_dyn_cam_info[0])
                    T = ensure_float32(_dyn_cam_info[1])
                    dyn_img_sizes_w = [ensure_float32(x) for x in _dyn_cam_info[2]]
                    K = _dyn_cam_info[3]
                    
                    _dyn_cam_info = (R, T, dyn_img_sizes_w, K)

                    max_pc = 1.0 if len(_pc) == 0 else torch.max(torch.abs(_pc))
                    img.append(
                        renderer(
                            _pc,
                            torch.cat((_pc / max_pc, _img_feat), dim=-1),
                            fix_cam=True,
                            dyn_cam_info=(_dyn_cam_info,)
                        ).unsqueeze(0)
                    )
        img = torch.cat(img, 0)
        img = img.permute(0, 1, 4, 2, 3)
        return img
    
    for task_name in os.listdir(root_train_dir):
        task_dir = os.path.join(root_train_dir, task_name)
        if not os.path.isdir(task_dir):
            continue
        os.makedirs(f"{dst_gaze_dir}/{task_name}", exist_ok=True)
            
        episodes_dir = os.path.join(task_dir, 'all_variations', 'episodes')
        for ep_dir in natsorted(os.listdir(episodes_dir)):
            if not ep_dir.startswith('episode'):
                continue
                
            index = int(ep_dir[7:])
            dst_mp4_st1 = os.path.join(dst_gaze_dir, f"{task_name}/ep_{index}_st1.mp4")
            dst_mp4_st2 = os.path.join(dst_gaze_dir, f"{task_name}/ep_{index}_st2.mp4")
            st1_png_dir = os.path.join(dst_gaze_dir, f"render_st1/{task_name}/ep_{index}")
            st2_png_dir = os.path.join(dst_gaze_dir, f"render_st2/{task_name}/ep_{index}")
            os.makedirs(st1_png_dir, exist_ok=True)
            os.makedirs(st2_png_dir, exist_ok=True)
            
            demo = get_stored_demo(episodes_dir, index)
            for i in range(len(demo)):
                obs, pcd, dyn_cam_info = _preprocess_inputs(demo[i], CAMERAS)
                pc, img_feat = rvt_utils.get_pc_img_feat(obs, pcd)
                pc, img_feat = rvt_utils.move_pc_in_bound(
                        pc, img_feat, SCENE_BOUNDS, no_op=False
                    )
                action_gripper_pose = torch.tensor(demo[i].gripper_pose).unsqueeze(0)
                action_trans_con = action_gripper_pose[:, 0:3]
                wpt = [x[:3] for x in action_trans_con]
                wpt_local = []
                rev_trans = []
                for _pc, _wpt in zip(pc, wpt):
                    a, b = mvt_utils.place_pc_in_cube(_pc, _wpt, with_mean_or_bounds=False, scene_bounds=SCENE_BOUNDS)
                    wpt_local.append(a.unsqueeze(0))
                    rev_trans.append(b)
                wpt_local = torch.cat(wpt_local, axis=0)
                pc = [mvt_utils.place_pc_in_cube(_pc, with_mean_or_bounds=False,
                                                 scene_bounds=SCENE_BOUNDS)[0] for _pc in pc]
                img_st1 = render(pc=pc, img_feat=img_feat, dyn_cam_info=dyn_cam_info)
                img_st1 = img_st1[:, 3, 3:6, :, :].squeeze(0)
                img_st1_np = img_st1.permute(1, 2, 0).cpu().numpy()  # CHW -> HWC
                img_st1_np = (img_st1_np * 255).astype(np.uint8)  # 转换为0-255范围
                img_st1_np = cv2.cvtColor(img_st1_np, cv2.COLOR_RGB2BGR)
                cv2.imwrite(os.path.join(st1_png_dir, f"frame_{i:04d}.png"), img_st1_np)

                pc, rev_trans = mvt_utils.trans_pc(pc, loc=wpt_local, sca=4)
                img_st2 = render(pc=pc, img_feat=img_feat, dyn_cam_info=dyn_cam_info)
                img_st2 = img_st2[:, 3, 3:6, :, :].squeeze(0)
                img_st2_np = img_st2.permute(1, 2, 0).cpu().numpy()
                img_st2_np = (img_st2_np * 255).astype(np.uint8)
                img_st2_np = cv2.cvtColor(img_st2_np, cv2.COLOR_RGB2BGR)
                cv2.imwrite(os.path.join(st2_png_dir, f"frame_{i:04d}.png"), img_st2_np)
                
            png_to_mp4(st1_png_dir, dst_mp4_st1)
            png_to_mp4(st2_png_dir, dst_mp4_st2)


def get_ee_projection(pose, extrinsics, intrinsic):
    """
    计算末端执行器中心点的像素坐标
    参数:
        pose: [batch_size, 7] (x,y,z + 欧拉角 + 开合度)
        extrinsics: [num_views, 4, 4] 相机外参
        intrinsic: [num_views, 3, 3] 相机内参
    返回:
        uvs: [num_views, batch_size, 2] 像素坐标（无效点标记为-1）
    """
    device = pose.device
    batch_size = pose.shape[0]
    num_views = extrinsics.shape[0]
    # 1. 提取世界坐标系下的3D点 [batch_size, 3]
    points_world = pose[:, :3]  # [batch_size, 3]
    # 2. 转换为齐次坐标 [batch_size, 4]
    points_hom = torch.cat([
        points_world, 
        torch.ones(batch_size, 1, device=device)
    ], dim=-1)  # [batch_size, 4]
    # 3. 世界→相机坐标系 [num_views, batch_size, 4]
    w2c = extrinsics  # 直接使用外参（世界→相机）
    points_cam = torch.matmul(
        points_hom.unsqueeze(0),  # [1, batch_size, 4]
        w2c.transpose(-1, -2)     # [num_views, 4, 4] -> [num_views, 4, 4]
    )  # 结果: [num_views, batch_size, 4]
    # 4. 检查深度有效性 [num_views, batch_size]
    valid_mask = points_cam[:, :, 2] > 0
    # 5. 投影到像素坐标系
    # 准备内参矩阵 [num_views, 3, 3]
    intrinsic_exp = intrinsic
    # 有效点的相机坐标 [num_views, batch_size, 3]
    points_cam_xyz = points_cam[:, :, :3]  # 取出XYZ
    # 计算投影坐标（批量处理）
    uv_hom = torch.matmul(
        intrinsic_exp.unsqueeze(1),         # [num_views, 1, 3, 3]
        points_cam_xyz.unsqueeze(-1)        # [num_views, batch_size, 3, 1]
    ).squeeze(-1)  # [num_views, batch_size, 3]
    # 透视除法 [num_views, batch_size, 2]
    uv = (uv_hom[:, :, :2] / uv_hom[:, :, 2:]).round().to(torch.int64)
    # 6. 标记无效点
    result = torch.full((num_views, batch_size, 2), -1, device=device, dtype=torch.int64)
    result[valid_mask] = uv[valid_mask].to(torch.int64)
    return result.squeeze(0)  # 如果batch_size=1或num_views=1，压缩对应维度



def world_to_pixel(pose_3d, extrinsics, intrinsic, img_size=(128, 128)):
    """
    将世界坐标系3D点投影到像素坐标系
    参数:
        pose_3d: 世界坐标系下的3D点 [7] (Tensor或ndarray)
        extrinsics: 相机外参矩阵 [1,4,4] (世界→相机)
        intrinsic: 相机内参矩阵 [1,3,3]
    返回:
        (u, v): 像素坐标 (整数)
    """
    return torch.tensor([[img_size[0] // 2, img_size[1] - 25]], 
                        device=pose_3d.device, dtype=torch.float32)
    # 转换为PyTorch张量并确保float32类型
    def safe_to_tensor(data):
        if isinstance(data, torch.Tensor):
            return data.clone().float()
        elif isinstance(data, np.ndarray):
            return torch.from_numpy(data).float()
        else:
            return torch.tensor(data).float()
    pose_3d = safe_to_tensor(pose_3d[:3]).unsqueeze(0)  # [3] -> [1,3]
    extrinsics = safe_to_tensor(extrinsics)             
    intrinsic = safe_to_tensor(intrinsic)

    batch_size = pose_3d.shape[0]
    num_views = extrinsics.shape[0]

    # 1. 转换为齐次坐标 [batch_size, 4]
    points_hom = torch.cat([
        pose_3d, 
        torch.ones(batch_size, 1, device=pose_3d.device)
    ], dim=-1)

    # 2. 世界→相机坐标系 [num_views, batch_size, 4]
    points_cam = torch.matmul(
        points_hom.unsqueeze(0),  # [1, batch_size, 4]
        extrinsics.transpose(-1, -2)  # [num_views, 4,4] -> [num_views,4,4]
    )

    # 3. 检查深度有效性 [num_views, batch_size]
    valid_mask = points_cam[:, :, 2] > 0

    # 4. 投影计算 [num_views, batch_size, 2]
    uv_hom = torch.matmul(
        intrinsic.unsqueeze(1),  # [num_views, 1, 3,3]
        points_cam[:, :, :3].unsqueeze(-1)  # [num_views, batch_size, 3,1]
    ).squeeze(-1)
    
    # 透视除法 + 边界裁剪
    uv = (uv_hom[:, :, :2] / uv_hom[:, :, 2:])
    uv[:, :, 0] = uv[:, :, 0].clamp(0, img_size[0]-1)
    uv[:, :, 1] = uv[:, :, 1].clamp(0, img_size[1]-1)
    
    # 5. 处理无效点（关键修复：统一数据类型）
    result = torch.full(
        (extrinsics.shape[0], pose_3d.shape[0], 2), 
        -1, 
        device=pose_3d.device, 
        dtype=torch.float32  # 改为float32类型
    )
    result[valid_mask] = uv[valid_mask]  
    
    # 返回第一个视角的结果
    return result[0]






def save_gaze_csv(vid, timestamps, projections, output_dir):
    """
    保存注视数据到CSV文件（安全版本）
    """
    try:
        # 强化数据验证
        if len(timestamps) != len(projections):
            raise ValueError(f"数据长度不匹配: {len(timestamps)}时间戳 vs {len(projections)}坐标")
        if not all(len(p) == 2 for p in projections):
            raise ValueError("坐标格式错误，应为[[x,y],...]")
            
        # 创建目录（确保上级目录存在）
        os.makedirs(output_dir, exist_ok=True)
        file_path = os.path.join(output_dir, f"{vid}.csv")
        
        # 写入数据（添加异常捕获）
        with open(file_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['timestamp', 'x_norm', 'y_norm', 'confidence'])
            for t, (x, y) in zip(timestamps, projections):
                writer.writerow([
                    round(float(t), 6),  # 限制小数位数
                    round(float(x), 6),
                    round(float(y), 6),
                    1.0
                ])
        return True
    
    except Exception as e:
        print(f"保存 {vid}.csv 失败: {type(e).__name__}: {str(e)}")
        try:  # 尝试删除可能生成的不完整文件
            if 'file_path' in locals() and os.path.exists(file_path):
                os.remove(file_path)
        except:
            pass
        return False
    

def batch_process_poses(data_root, output_dir, dst_gaze):
    """
    处理所有视频的位姿数据
    :param data_root: 包含pose/extrinsics/intrinsic数据的根目录
    :param output_dir: CSV输出目录
    :param dst_gaze: MP4文件所在的根目录
    """
    for task_dir in sorted(os.listdir(data_root)):
        task_path = os.path.join(data_root, task_dir)
        episodes_dir = os.path.join(task_path, 'all_variations', 'episodes')
        if not os.path.exists(episodes_dir):
            print(f"目录不存在: {episodes_dir}")
            continue
        task_output_dir = os.path.join(output_dir, task_dir)
        os.makedirs(task_output_dir, exist_ok=True)
        if not os.access(task_output_dir, os.W_OK):
            print(f"错误: 无写入权限 {task_output_dir}")
            continue
        for episode in sorted(os.listdir(episodes_dir)):
            episode_path = os.path.join(episodes_dir, episode)
            vid = f"ep_{episode.replace('episode', '')}"
            video_path = os.path.join(dst_gaze, f"{task_dir}/{vid}.mp4")
            if not os.path.exists(video_path):
                print(f"视频文件不存在: {video_path}")
                continue
            # 初始化错误跟踪变量
            has_error = False
            error_msgs = []
            missing_attr_printed = False
            missing_misc_printed = False
            try:
                cap = cv2.VideoCapture(video_path)
                if not cap.isOpened():
                    print(f"无法打开视频: {video_path}")
                    continue
                w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                fps = cap.get(cv2.CAP_PROP_FPS)  # 动态获取FPS
                cap.release()
                # 读取low_dim_obs.pkl
                pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
                if not os.path.exists(pkl_path):
                    print(f"跳过 {vid} - pkl文件不存在")
                    continue
                    
                with open(pkl_path, "rb") as f:
                    obs_data = pickle.load(f)
                
                if not hasattr(obs_data, '_observations'):
                    print(f"跳过 {vid} - 观测数据格式异常")
                    continue
                    
                num_frames = len(obs_data._observations)
                # print(f"\n处理任务 {task_dir} {vid} | 总帧数: {num_frames}")
                
                # 逐帧处理
                timestamps = []
                projections = []
                valid_frames = 0
                for obs_number in range(num_frames):
                    obs = obs_data._observations[obs_number]
                    if not all(hasattr(obs, attr) for attr in ['gripper_pose', 'misc']):
                        if not missing_attr_printed:
                            print(f"警告: {task_dir}/{vid} 缺少基础属性")
                            missing_attr_printed = True
                        continue
                    misc = obs.misc  # 直接访问misc属性
                    # print(misc)
                    if not all(attr in misc.keys() for attr in ['wrist_camera_extrinsics', 'wrist_camera_intrinsics']):
                        if not missing_misc_printed:
                            print(f"警告: {task_dir}/{vid} misc缺少必要属性 (实际存在: {list(misc.keys())})")
                            missing_misc_printed = True
                        continue
                    try:    
                        extrinsics = misc.get('wrist_camera_extrinsics')
                        intrinsic = misc.get('wrist_camera_intrinsics')

                        def safe_to_tensor(data):
                            if isinstance(data, torch.Tensor):
                                return data.clone()
                            elif isinstance(data, np.ndarray):
                                return torch.from_numpy(data).clone()
                            else:  # 处理列表或其他可迭代对象
                                return torch.tensor(data).clone()

                        extrinsics = safe_to_tensor(extrinsics).float().unsqueeze(0)
                        intrinsic = safe_to_tensor(intrinsic).float().unsqueeze(0)
                        gripper_pose = safe_to_tensor(obs.gripper_pose).float()
                        # 计算投影坐标
                        uvs = world_to_pixel(gripper_pose, extrinsics, intrinsic, img_size=(w, h))
                        if uvs.dim() == 2 and uvs.shape == torch.Size([1, 2]):
                            # 先裁剪到有效范围
                            u = float(torch.clamp(uvs[0, 0], 0, w-1))
                            v = float(torch.clamp(uvs[0, 1], 0, h-1))
                            # 再归一化
                            u_norm = u / w
                            v_norm = v / h
                            
                            if not (0 <= u_norm <= 1) or not (0 <= v_norm <= 1):
                                print(f"帧 {obs_number} 坐标超出范围: ({u_norm:.3f}, {v_norm:.3f})")
                                continue
                            if torch.isnan(uvs).any():
                                print(f"帧 {obs_number} 投影结果包含NaN值")
                                continue
                            
                            valid_frames += 1
                            timestamps.append(obs_number / fps)
                            projections.append([u_norm, v_norm])  # 直接使用计算好的归一化坐标
                        else:
                            print(f"帧 {obs_number} 异常输出形状: {uvs.shape}")

                    except Exception as e:
                        has_error = True
                        error_msgs.append(f"帧 {obs_number} 处理异常: {str(e)}")
                        continue 
                # 只在有错误或有效帧数不匹配时打印
                if has_error or valid_frames != num_frames:
                    print(f"\n处理任务 {task_dir} {vid} | 总帧数: {num_frames}")
                    
                    # 打印所有错误信息
                    for msg in error_msgs:
                        print(msg)
                        
                    # 打印属性缺失警告（如果存在）
                    if missing_attr_printed:
                        print(f"警告: {task_dir}/{vid} 缺少基础属性")
                    if missing_misc_printed:
                        print(f"警告: {task_dir}/{vid} misc缺少必要属性")
                        
                    # 打印有效帧数信息
                    print(f"有效帧数: {valid_frames}/{num_frames}")
                    
                if valid_frames > 0:
                    success = save_gaze_csv(vid, timestamps, projections, task_output_dir)
                    if success:
                        # 验证文件是否真实存在
                        csv_path = os.path.join(task_output_dir, f"{vid}.csv")
                        if not os.path.exists(csv_path):
                            print(f"错误: 文件未生成 {csv_path}")
                    else:
                        print(f"保存失败: {vid}")
                    
            except Exception as e:
                print(f"处理失败 {task_dir}/{vid} | 错误: {str(e)}")


def process_single_task(task_dir, data_root, output_dir, dst_gaze, renderer, preloaded_data):
    task_path = os.path.join(data_root, task_dir)
    task_output_dir = os.path.join(output_dir, task_dir)
    os.makedirs(task_output_dir, exist_ok=True)

    try:
        for episode_data in preloaded_data['episodes']:
            episode_name = episode_data['name']
            episode_path = os.path.join(task_path, 'all_variations', 'episodes', episode_name)
            for st in ['st1', 'st2']:
                vid = f"ep_{episode_name.replace('episode', '')}_{st}"
                csv_path = os.path.join(task_output_dir, f"{vid}.csv")
                
                # 检查CSV是否已存在
                if os.path.exists(csv_path):
                    print(f"跳过 {vid} - CSV文件已存在")
                    continue

                video_path = os.path.join(dst_gaze, task_dir, f"{vid}.mp4")
                if not os.path.exists(video_path):
                    print(f"跳过 {vid} - 视频文件不存在: {video_path}")
                    continue

                try:
                    cap = cv2.VideoCapture(video_path)
                    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    fps = cap.get(cv2.CAP_PROP_FPS)  # 动态获取FPS
                finally:
                    cap.release()

                pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
                if not os.path.exists(pkl_path):
                    print(f"跳过 {vid} - pkl文件不存在")
                    continue
                with open(pkl_path, "rb") as f:
                    obs_data = pickle.load(f)
                num_frames = len(obs_data._observations)

                # 逐帧处理
                timestamps = []
                projections = []
                valid_frames = 0
                demo = get_stored_demo(os.path.dirname(episode_path), int(episode_name[7:]))
                for obs_number in range(num_frames):             
                    obs, pcd, dyn_cam_info = _preprocess_inputs(demo[obs_number], CAMERAS)
                    pc, img_feat = rvt_utils.get_pc_img_feat(obs, pcd)
                    pc, img_feat = rvt_utils.move_pc_in_bound(
                            pc, img_feat, SCENE_BOUNDS, no_op=False
                        )
                    action_gripper_pose = torch.tensor(demo[obs_number].gripper_pose).unsqueeze(0)
                    action_trans_con = action_gripper_pose[:, 0:3]
                    wpt = [x[:3] for x in action_trans_con]
                    wpt_local = []
                    for _pc, _wpt in zip(pc, wpt):
                        a, b = mvt_utils.place_pc_in_cube(_pc, _wpt, with_mean_or_bounds=False, 
                                                            scene_bounds=SCENE_BOUNDS)
                        wpt_local.append(a.unsqueeze(0))
                    wpt_local = torch.cat(wpt_local, axis=0)

                    if st =='st2':
                        wpt_local, _ = mvt_utils.trans_pc(wpt_local, loc=wpt_local, sca=4)

                    wpt_local = ensure_float32(wpt_local).to(renderer.device) # [1, 3]
                    # projection (bs, 1, num_img, 2)
                    wpt_img = renderer.get_pt_loc_on_img( # [1, 1, 4, 2]
                        wpt_local.unsqueeze(1), fix_cam=True, dyn_cam_info=dyn_cam_info
                    ).squeeze()
                    uvs = wpt_img[3].unsqueeze(0)
                    print("[DEBUG] uvs = ", uvs)
                    if uvs.dim() == 2 and uvs.shape == torch.Size([1, 2]):
                        # 先裁剪到有效范围
                        u = float(torch.clamp(uvs[0, 0], 0, w-1))
                        v = float(torch.clamp(uvs[0, 1], 0, h-1))
                        # 再归一化
                        u_norm = u / w
                        v_norm = v / h
                        
                        if not (0 <= u_norm <= 1) or not (0 <= v_norm <= 1):
                            print(f"帧 {obs_number} 坐标超出范围: ({u_norm:.3f}, {v_norm:.3f})")
                            continue
                        if torch.isnan(uvs).any():
                            print(f"帧 {obs_number} 投影结果包含NaN值")
                            continue
                        
                        valid_frames += 1
                        timestamps.append(obs_number / fps)
                        projections.append([u_norm, v_norm])  # 直接使用计算好的归一化坐标
                    else:
                        print(f"帧 {obs_number} 异常输出形状: {uvs.shape}")

                if valid_frames > 0:
                    success = save_gaze_csv(vid, timestamps, projections, task_output_dir)
                    if not success:
                        print(f"保存失败: {vid}")
                    
        return (task_dir, True, None)
    except Exception as e:
        return (task_dir, False, str(e))
    
# 在进程初始化时预加载所有PKL文件
def load_task_data(task_dir, data_root, dst_gaze):
    task_path = os.path.join(data_root, task_dir)
    episodes_dir = os.path.join(task_path, 'all_variations', 'episodes')
    
    episodes = []
    # 过滤掉.DS_Store等系统文件
    valid_episodes = [ep for ep in sorted(os.listdir(episodes_dir)) 
                      if not ep.startswith('.') and os.path.isdir(os.path.join(episodes_dir, ep))]
    
    for episode in valid_episodes:
        pkl_path = os.path.join(episodes_dir, episode, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            print(f"跳过 {episode} - pkl文件不存在: {pkl_path}")
            continue
        try:
            with open(pkl_path, "rb") as f:
                episodes.append({
                    'name': episode,
                    'data': pickle.load(f),
                    'video_path': os.path.join(dst_gaze, f"{task_dir}/ep_{episode[7:]}_st1.mp4")
                })
        except Exception as e:
            print(f"加载 {pkl_path} 失败: {str(e)}")
            continue
            
    return {'task_dir': task_dir, 'episodes': episodes}

def _run_tasks_for_gpu(task_list, data_root, output_dir, dst_gaze, gpu_id, renderer):
    """ 单个GPU上运行所有分配的任务 """
    # 设置当前进程使用的GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    import torch
    torch.cuda.set_device(0)  # 必须用0，因为上面已经隔离了设备
    
    # 预加载共享数据到内存
    shared_cache = {}
    for task_dir in task_list:
        try:
            print(f"Processing {task_dir} on GPU {gpu_id}")
            # 使用内存缓存避免重复I/O
            if task_dir not in shared_cache:
                shared_cache[task_dir] = load_task_data(task_dir, data_root, dst_gaze)
            process_single_task(task_dir, data_root, output_dir, dst_gaze, renderer, shared_cache[task_dir])
        except Exception as e:
            print(f"Error processing {task_dir} on GPU {gpu_id}: {str(e)}")

# GPU工作进程
def worker(gpu_id, queue, data_root, output_dir, dst_gaze):
    renderer = BoxRenderer(
        device=f"cuda:{gpu_id}",
        img_size=(224, 224),
        three_views=True,
        with_depth=True,
    )
    while not queue.empty():
        tasks = queue.get()
        _run_tasks_for_gpu(tasks, data_root, output_dir, dst_gaze, gpu_id, renderer)

def batch_process_poses_parallel(data_root, output_dir, dst_gaze, chunk_size = 8):
    # 获取所有任务目录
    task_dirs = [d for d in sorted(os.listdir(data_root)) 
                if os.path.isdir(os.path.join(data_root, d))]

    num_gpus = torch.cuda.device_count()
    print(f"Available GPUs: {num_gpus}")

    # 任务分块（提高缓存命中率）
    task_chunks = [task_dirs[i:i+chunk_size] 
                  for i in range(0, len(task_dirs), chunk_size)]
    
    if mp.get_start_method() != "spawn":
        mp.set_start_method("spawn", force=True)
    ctx = mp.get_context("spawn")
    task_queue = ctx.Queue()
    
    # 填充任务队列
    for chunk in task_chunks:
        task_queue.put(chunk)
    
    processes = []
    for gpu_id in range(num_gpus):
        p = ctx.Process(target=worker, args=(gpu_id, task_queue, data_root, output_dir, dst_gaze))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()


def clip_videos(source_path, save_path, untrack_csv):
    """
    Clip long videos into clips.

    :param source_path: Long video path
    :param save_path:
    :param untrack_csv: Used to remove untracked frames.
    :return: None
    """
    os.makedirs(save_path, exist_ok=True)

    if untrack_csv is not None:
        with open(untrack_csv, 'r') as f:
            lines = [item for item in csv.reader(f)]
        untracked = dict()
        for line in lines:
            start_hr, start_min, start_sec = line[1].split(':')
            end_hr, end_min, end_sec = line[2].split(':')
            start = int(start_hr) * 3600 + int(start_min) * 60 + int(start_sec)
            end = int(end_hr) * 3600 + int(end_min) * 60 + int(end_sec)
            if line[0] in untracked.keys():
                untracked[line[0]].append([start, end, int(line[-1])])
            else:
                untracked[line[0]] = [[start, end, int(line[-1])]]

    for root, _, files in os.walk(source_path):
        for item in files:
            if not item.endswith('.mp4'):
                continue

            video_path = os.path.join(root, item)
            vid = os.path.splitext(item)[0]
            rel_path = os.path.relpath(root, source_path)
            output_dir = os.path.join(save_path, rel_path, vid)
            os.makedirs(output_dir, exist_ok=True)
            try:
                with VideoFileClip(video_path) as video:
                    duration = video.duration   
                    # 生成所有需要处理的时间段
                    segments = []
                    for start in range(0, int(duration), 5):
                        end = min(start + 5, duration)
                        if end > duration:
                            break
                        if untrack_csv is not None:
                            if os.path.splitext(item)[0] in untracked.keys():
                                skip = False
                                for interval in untracked[vid]:
                                    if not (end < interval[0] or start > interval[1]):
                                        skip = True
                                        break
                                if skip:
                                    continue
                        segments.append((start, end))
                    for start, end in tqdm(segments, desc=f"Processing {vid}"):
                        output_path = os.path.join(
                            output_dir, 
                            f"{os.path.basename(rel_path)}_{vid}_t{start}_t{end}.mp4"
                        )
                        with video.subclip(start, end) as clip:
                            clip.write_videofile(output_path)
            except Exception as e:
                print(f"Failed on {video_path}: {str(e)}")

def get_rlbench_frame_label(data_path, save_path):
    all_frames_num = 0
    all_saccade_num = 0
    all_trimmed_num = 0
    all_untracked_num = 0
    
    gaze_dir = os.path.join(data_path, 'gaze')  # CSV来源目录
    video_dir = os.path.join(data_path, 'full_scale.gaze')  # MP4来源目录
    for task_dir in os.listdir(gaze_dir):
        task_csv_dir = os.path.join(gaze_dir, task_dir)
        if not os.path.isdir(task_csv_dir):
            continue
            
        for csv_file in os.listdir(task_csv_dir):  # 遍历CSV文件
            if not csv_file.endswith('.csv'):
                continue
            vid = os.path.splitext(csv_file)[0]
            mp4_path = os.path.join(video_dir, task_dir, f'{vid}.mp4')
            try:
                with open(os.path.join(task_csv_dir, csv_file), 'r') as f:
                    reader = csv.reader(f)
                    header = next(reader)  # 读取表头
                    lines = [line for line in reader if len(line) >= 3]  # 至少需要timestamp,x,y
                    
                    if not lines:
                        print(f"警告: {csv_file} 无有效数据")
                        continue
                    
                container = av.open(mp4_path)
                fps = float(container.streams.video[0].average_rate)
                frames_length = container.streams.video[0].frames

                gaze_loc = []
                j = 0
                for i in tqdm(range(frames_length), leave=False, desc=f"Processing {vid}"):
                    time_stamp = i / fps
                    # 寻找最近时间戳
                    while j < len(lines)-1 and float(lines[j][0]) < time_stamp:
                        j += 1
                    # 选择更接近的帧
                    if j > 0 and abs(float(lines[j-1][0])-time_stamp) < abs(float(lines[j][0])-time_stamp):
                        row = lines[j-1]
                    else:
                        row = lines[j]

                    # 获取归一化坐标（新格式第1、2列是x,y）
                    try:
                        x = float(row[1])
                        y = float(row[2])
                    except (IndexError, ValueError):
                        x, y = 0.5, 0.5  # 默认中心位置
                    
                    # 运动检测
                    if i == 0:
                        gaze_type = 0
                    else:
                        dx = (x - gaze_loc[-1][1]) * 1088
                        dy = (y - gaze_loc[-1][2]) * 1080
                        movement = np.sqrt(dx**2 + dy**2)
                        gaze_type = 0 if movement <= 40 else 1
                    # 边界处理
                    if not (0 <= x <= 1 and 0 <= y <= 1):
                        gaze_type = 2
                        x = np.clip(x, 0, 1)
                        y = np.clip(y, 0, 1)
                        
                    gaze_loc.append([i, x, y, gaze_type])

                if frames_length > len(gaze_loc):
                    gaze_loc.extend([[k, 0, 0, 3] for k in range(len(gaze_loc), frames_length)])

                all_frames_num += len(gaze_loc)
                for item in gaze_loc:
                    if item[3] == 1:
                        all_saccade_num += 1
                    elif item[3] == 2:
                        all_trimmed_num += 1
                    elif item[3] == 3:
                        all_untracked_num += 1

                os.makedirs(os.path.join(save_path, task_dir), exist_ok=True)
                with open(os.path.join(save_path, task_dir, f'{vid}_frame_label.csv'), 'w') as f:
                    csv_writer = csv.writer(f)
                    csv_writer.writerow(['frame', 'x', 'y', 'gaze_type'])
                    csv_writer.writerows(gaze_loc)
                    
            except Exception as e:
                print(f"处理 {task_dir}/{csv_file} 失败: {str(e)}")
                continue             
                    
    print('All saccade rate:', all_saccade_num / all_frames_num,
          'All trimmed rate:', all_trimmed_num / all_frames_num,
          'All untracked rate:', all_untracked_num / all_frames_num)

import os
from pathlib import Path
def generate_train_csv(root_dir, output_file="train_rlbench_gaze.csv"):
    """
    遍历目录并生成训练集CSV文件
    
    参数:
        root_dir: 要遍历的根目录 (如 clips.gaze 目录)
        output_file: 输出的CSV文件名
    """
    root_dir = Path(root_dir)
    video_paths = []
    
    # 遍历所有子目录
    for task_dir in root_dir.iterdir():
        if not task_dir.is_dir():
            continue
            
        # 匹配 {task}/ep_{0~99}/ 目录
        for ep_dir in task_dir.iterdir():
            if not ep_dir.is_dir() or not ep_dir.name.startswith("ep_"):
                continue
                
            # 验证episode编号是否在0-99范围内
            try:
                ep_num = int(ep_dir.name.split("_")[1])
                if not 0 <= ep_num <= 99:
                    continue
            except (IndexError, ValueError):
                continue
                
            # 匹配 {task}_ep_{0~99}_t*_t*.mp4 文件
            for video_file in ep_dir.glob(f"{task_dir.name}_ep_{ep_num}_st*_t*_t*.mp4"):
                # 验证文件名格式 (例如 close_jar_ep_0_t0_t5.mp4)
                parts = video_file.stem.split("_")
                if len(parts) >= 5 and parts[-5] == "ep" and parts[-2].startswith("t") and parts[-1].startswith("t"):
                    # 使用相对于root_dir的路径
                    rel_path = os.path.relpath(video_file, start=root_dir.parent)
                    video_paths.append(rel_path)
    
    # 写入CSV文件
    with open(output_file, "w") as f:
        for path in sorted(video_paths):
            f.write(path + "\n")
    
    print(f"生成完成！共找到 {len(video_paths)} 个视频文件，已写入 {output_file}")

def visualize_projection(video_path, csv_path, frame_index=0, output_img_path="projection_visualization.png"):
    """
    在指定帧上可视化投影点（无pandas版本）
    :param video_path: 视频文件路径
    :param csv_path: 生成的CSV文件路径
    :param frame_index: 要显示的帧序号 (默认为第0帧)
    :param output_img_path: 输出图像保存路径
    """
    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("无法打开视频文件")
    
    # 跳转到指定帧
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ret, frame = cap.read()
    if not ret:
        cap.release()
        raise ValueError("无法读取指定帧")
    
    # 读取CSV数据（无pandas版本）
    timestamps = []
    x_coords = []
    y_coords = []
    
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            timestamps.append(float(row['timestamp']))
            x_coords.append(float(row['x_norm']))
            y_coords.append(float(row['y_norm']))
    
    if frame_index >= len(timestamps):
        cap.release()
        raise ValueError("帧索引超出CSV数据范围")
    
    # 获取归一化坐标
    x_norm = x_coords[frame_index]
    y_norm = y_coords[frame_index]
    
    # 转换为像素坐标
    height, width = frame.shape[:2]
    x_pixel = int(x_norm * width)
    y_pixel = int(y_norm * height)
    
    print(f"归一化坐标: ({x_norm}, {y_norm})")
    print(f"像素坐标: ({x_pixel}, {y_pixel})")
    print(f"视频帧尺寸: {width}x{height}")

    # 在帧上绘制标记
    marked_frame = cv2.circle(frame.copy(), 
                             (x_pixel, y_pixel), 
                             radius=3, 
                             color=(0, 255, 0),  # BGR格式：绿色
                             thickness=2)
    
    # 文字位置智能调整（避免超出边界）
    text = f"({x_norm:.2f}, {y_norm:.2f})"
    text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
    
    # 计算文字位置（自动避开边缘）
    text_x = max(10, min(x_pixel - text_size[0]//2, width - text_size[0] - 10))
    text_y = max(30, min(y_pixel - 20, height - 10))
    
    cv2.putText(marked_frame, text, (text_x, text_y), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    # 2. 添加帧信息（固定在左上角）
    cv2.putText(marked_frame, f"Frame: {frame_index}", (10, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
    
    # 保存结果图像（使用OpenCV）
    cv2.imwrite(output_img_path, marked_frame)
    print(f"可视化结果已保存到: {output_img_path}")
    
    # 或者使用matplotlib保存（可选）
    """
    plt.figure(figsize=(12, 6))
    plt.imshow(cv2.cvtColor(marked_frame, cv2.COLOR_BGR2RGB))
    plt.title(f"Frame {frame_index} | Projection at ({x_norm:.3f}, {y_norm:.3f})")
    plt.axis('off')
    plt.savefig(output_img_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    """
    
    cap.release()

def visualize_projection_video(video_path, csv_path, output_video_path="output_tracking.mp4", fps=None):
    """
    生成完整视频，在每一帧上同步显示坐标点轨迹
    
    :param video_path: 输入视频文件路径
    :param csv_path: 包含坐标数据的CSV文件路径
    :param output_video_path: 输出视频路径
    :param fps: 输出视频帧率（默认与输入视频相同）
    """
    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("无法打开视频文件")
    
    # 获取视频属性
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    input_fps = cap.get(cv2.CAP_PROP_FPS)
    
    # 设置输出视频
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(
        output_video_path, 
        fourcc, 
        fps if fps else input_fps, 
        (width, height),
        isColor=True
    )
    
    # 读取CSV数据
    timestamps = []
    x_coords = []
    y_coords = []
    
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            timestamps.append(float(row['timestamp']))
            x_coords.append(float(row['x_norm']))
            y_coords.append(float(row['y_norm']))
    
    # 检查数据长度是否匹配
    min_length = min(frame_count, len(timestamps))
    print(f"视频总帧数: {frame_count} | CSV数据点数: {len(timestamps)}")
    print(f"将处理前 {min_length} 帧")
    
    # 轨迹可视化参数
    trajectory = []
    max_trajectory_length = 20  # 显示最近20个点的轨迹
    
    # 处理每一帧
    for frame_idx in range(min_length):
        ret, frame = cap.read()
        if not ret:
            break
        
        # 获取当前帧坐标
        x_norm = x_coords[frame_idx]
        y_norm = y_coords[frame_idx]
        x_pixel = int(x_norm * width)
        y_pixel = int(y_norm * height)
        
        # 更新轨迹
        trajectory.append((x_pixel, y_pixel))
        if len(trajectory) > max_trajectory_length:
            trajectory.pop(0)
        
        # 绘制轨迹线（渐变色）
        for i in range(1, len(trajectory)):
            alpha = i / len(trajectory)  # 透明度渐变
            color = (0, int(255 * alpha), int(255 * (1-alpha)))  # 绿->黄->红
            cv2.line(frame.copy(), trajectory[i-1], trajectory[i], color, 2)
        
        # 绘制当前点 绿色
        marked_frame = cv2.circle(frame.copy(), (x_pixel, y_pixel), radius=3, color=(0, 255, 0), thickness=2)
        
        # 显示坐标信息
        info_text = f"Frame: {frame_idx}" #  | Pos: ({x_norm:.2f}, {y_norm:.2f})
        cv2.putText(marked_frame, info_text, (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) # 帧信息（固定在左上角）
        
        # 文字位置智能调整（避免超出边界）
        text = f"Pos: ({x_norm:.2f}, {y_norm:.2f})"
        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
        
        # 计算文字位置（自动避开边缘）
        text_x = max(10, min(x_pixel - text_size[0]//2, width - text_size[0] - 10))
        text_y = max(30, min(y_pixel - 20, height - 10))
        cv2.putText(marked_frame, text, (text_x, text_y), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        
        # 写入输出视频
        out.write(marked_frame)
        
        # 进度显示
        if frame_idx % 100 == 0:
            print(f"已处理 {frame_idx}/{min_length} 帧")
    
    # 释放资源
    cap.release()
    out.release()
    cv2.destroyAllWindows()
    
    print(f"视频生成完成，保存至: {output_video_path}")


def main():
    RENDER = True
    path_to_rlbench = '/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future'  

    save_path = f'{path_to_rlbench}/clips.gaze'
    untracked_csv = None

    # pre-save rendered dynamic images/videos
    root_train = "/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/train"
    dst_gaze = f'{path_to_rlbench}/full_scale.gaze'
    # os.makedirs(dst_gaze, exist_ok=True)# DONE
    # if RENDER:
    #     batch_convert_render(root_train, dst_gaze)
    # else:
    #     batch_convert(root_train, dst_gaze) # DONE

    # clip_videos(source_path=dst_gaze, save_path=save_path, untrack_csv=untracked_csv) # DONE

    # output_dir = f'{path_to_rlbench}/gaze'  # 用于保存CSV
    # os.makedirs(output_dir, exist_ok=True) # DONE
    # if RENDER:
    #     batch_process_poses_parallel(root_train, output_dir, dst_gaze)
    # else:
    #     batch_process_poses(root_train, output_dir, dst_gaze) # DONE
    # # 可视化
    # visualize_task = "insert_onto_square_peg"
    # visualize_epoch = "0"
    # visualize_root = "/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future"
    # if RENDER:
    #     video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}_st1.mp4"
    #     csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}_st1.csv"
    #     visualize_projection(video_path, csv_path, frame_index=158, output_img_path=f"{visualize_task}_ep{visualize_epoch}_st1.png")
    #     visualize_projection_video(video_path, csv_path,
    #         output_video_path=f"{visualize_task}_ep{visualize_epoch}_st1.mp4", fps=30  # 可选参数
    #     )
    #     video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}_st2.mp4"
    #     csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}_st2.csv"
    #     visualize_projection(video_path, csv_path, frame_index=158, output_img_path=f"{visualize_task}_ep{visualize_epoch}_st2.png")
    #     visualize_projection_video(video_path, csv_path,
    #         output_video_path=f"{visualize_task}_ep{visualize_epoch}_st2.mp4", fps=30  # 可选参数
    #     )
    # else:
    #     video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}.mp4"
    #     csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}.csv"
    #     visualize_projection(video_path, csv_path, frame_index=158, output_img_path=f"{visualize_task}_ep{visualize_epoch}.png")
    #     visualize_projection_video(video_path, csv_path,
    #         output_video_path=f"{visualize_task}_ep{visualize_epoch}.mp4", fps=30  # 可选参数
    #     )

    # get_rlbench_frame_label(data_path=path_to_rlbench, save_path=f'{path_to_rlbench}/gaze_frame_label')
    generate_train_csv("/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future/clips.gaze", output_file="data/train_rlbench_gaze.csv")


if __name__ == '__main__':
    main()


