import json
import numpy as np
import h5py, argparse, pdb
import matplotlib.pyplot as plt
import open3d as o3d
import torch as th
import fpsample
import time
from multiprocessing import Pool
from functools import partial
import sys
import time
import copy
import os
import datetime
from glob import glob
import shutil

from scipy.spatial.transform import Rotation as R
import plotly.graph_objects as go
from typing import Optional, Tuple
from prep_data_utils import pose2mat, pose_transform, mat2pose,\
 plot_frame_from_matrix, plot_pcd_with_matplotlib, write_to_hdf5_per_demo, color_pcd_vis, pcd_vis

# print with 3 decimal points
np.set_printoptions(precision=3)

# for the example mobile manipulation data
x_range = [0.0, 2.3]
y_range = [-1.5, 1.5]
z_range = [0.7, 1.5] # remove the table leg

# for the manipulation only data and the new mobile manipulation data
K = np.array([[174.08,   0.000, 128.000],
    [  0.000, 174.08, 128.000],
        [  0.000,   0.000,   1.000]])
table_mask_height = 1.08
cup_mask_height = 1.18


def process_seg_mask_nav_manip(demo_data):
    # customize for the tidy table task 
    # mobile mp, arm mp, arm replay, mobile mp, arm mp, arm replay
    # process segmentation mask 
    num_steps = demo_data['actions'].shape[0] - 1
    subtask_lengths = np.array(demo_data["subtask_lengths"])
    left_mp_ranges = np.array(demo_data["left_mp_ranges"])
    right_mp_ranges = np.array(demo_data["right_mp_ranges"])

    replay_seg_mask = np.zeros((num_steps, 1), dtype=bool)
    replay_seg_mask[left_mp_ranges[1,1]:left_mp_ranges[2,0], 0] = True
    replay_seg_mask[left_mp_ranges[3,1]:, 0] = True

    manip_seg_mask = np.zeros((num_steps, 1), dtype=bool)
    manip_seg_mask[left_mp_ranges[1,0]:left_mp_ranges[2,0], 0] = True
    manip_seg_mask[left_mp_ranges[3,0]:, 0] = True

    return replay_seg_mask, manip_seg_mask


def process_seg_mask_manip_only(demo_data):
    # process segmentation mask 
    num_steps = demo_data['actions'].shape[0] - 1
    left_mp_ranges = np.array(demo_data["left_mp_ranges"])
    seg_mask = np.zeros((num_steps, 1), dtype=bool)
    seg_mask[left_mp_ranges[0,0]:left_mp_ranges[0,1], 0] = False
    seg_mask[left_mp_ranges[0,1]:, 0] = True
    return seg_mask


def depth_to_pcd(
        depth_all, pose_all, base_link_pose_all, 
        step_index=0, 
        ref_frame='world', 
        max_depth = 2,
        ):

    depth = depth_all[step_index].squeeze()
    pose = pose_all[step_index]
    base_link_pose = base_link_pose_all[step_index]

    # get the homogeneous transformation matrix from quaternion
    pos = pose[:3]
    quat = pose[3:]
    from scipy.spatial.transform import Rotation as R
    rot = R.from_quat(quat)  # scipy expects [x, y, z, w]
    rot_add = R.from_euler('x', np.pi).as_matrix() # handle the cam_to_img transformation
    rot_matrix = rot.as_matrix() @ rot_add   # 3x3 rotation matrix
    world_to_cam_tf = np.eye(4)
    world_to_cam_tf[:3, :3] = rot_matrix
    world_to_cam_tf[:3, 3] = pos

    # filter depth
    mask = depth > max_depth
    depth[mask] = 0
    h, w = depth.shape
    y, x = np.meshgrid(np.arange(h), np.arange(w), indexing="ij", sparse=False)
    assert depth.min() >= 0
    u = x
    v = y
    uv = np.dstack((u, v, np.ones_like(u))) # (img_width, img_height, 3)

    Kinv = np.linalg.inv(K)
    
    pc = depth.reshape(-1, 1) * (uv.reshape(-1, 3) @ Kinv.T)
    pc = pc.reshape(h, w, 3)
    pc = np.concatenate([pc.reshape(-1, 3), np.ones((h * w, 1))], axis=-1)  # shape (H*W, 4)

    # world_to_cam_tf = np.eye(4)
    # pc : num_point x 4
    # pc_world (4x1) = world_to_cam_tf (4x4) @ pc_cam (4x1)
    # pc_world.T (1x4) = pc.T (1x4) @ world_to_cam_tf.T
    if ref_frame == 'world':
        pc = (pc @ world_to_cam_tf.T)[:, :3].reshape(h, w, 3)
    elif ref_frame == 'robot':
        world_to_robot_tf = pose2mat((th.from_numpy(base_link_pose[:3]), th.from_numpy(base_link_pose[3:]))).numpy()
        robot_to_world_tf = np.linalg.inv(world_to_robot_tf)
        pc = (pc @ world_to_cam_tf.T @ robot_to_world_tf.T)[:, :3].reshape(h, w, 3)
    else:
        raise ValueError("ref_frame must be 'world' or 'robot'")
    return pc


def transform_eef_to_world(base_link_pose, eef_pos, eef_quat):
    # transform the eef pose to world frame
    # base_link_pose: (traj_length, 7)
    # eef_pos: (traj_length, 3)
    # eef_quat: (traj_length, 4)

    eef_pos = th.from_numpy(eef_pos)
    eef_quat = th.from_numpy(eef_quat)
    base_link_pose = th.from_numpy(base_link_pose)

    eef_pos_world = np.zeros_like(eef_pos)
    eef_quat_world = np.zeros_like(eef_quat)

    for i in range(eef_pos.shape[0]):
        pos, quat = pose_transform(
            base_link_pose[i,:3], base_link_pose[i,3:],
            eef_pos[i], eef_quat[i], 
            )
        eef_pos_world[i] = pos.numpy()
        eef_quat_world[i] = quat.numpy()

    # cocnatenate the eef pos and quat
    eef_pose = np.concatenate([eef_pos_world, eef_quat_world], axis=1)
    return eef_pose


def fps_downsample(color_pcd, num_points_to_sample):
    if color_pcd.shape[0] > num_points_to_sample:
        pc = color_pcd[:, 3:]
        color_img = color_pcd[:, :3]
        kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(pc, num_points_to_sample, h=5)
        pc = pc[kdline_fps_samples_idx]
        color_img = color_img[kdline_fps_samples_idx]
        color_pcd = np.concatenate([color_img, pc], axis=-1)
    else:
        # randomly sample points
        pad_number_of_points = num_points_to_sample - color_pcd.shape[0]
        try:
            # random sample padding points
            random_idx = np.random.choice(color_pcd.shape[0], pad_number_of_points, replace=True)
        except:
            print('the number of point is 0!!')
            # breakpoint()
        pad_pcd = color_pcd[random_idx]
        # concatenate the padding points
        color_pcd = np.concatenate([color_pcd, pad_pcd], axis=0)
    assert color_pcd.shape[0] == num_points_to_sample, "color_pcd shape is not equal to num_points_to_sample"
    return color_pcd


def get_pcd_portion(color_pcd):
    table_mask = color_pcd[:, 5] < table_mask_height
    cup_mask = (color_pcd[:, 5] > table_mask_height) & (color_pcd[:, 5] < cup_mask_height)
    other_mask = color_pcd[:, 5] > cup_mask_height
    table_pcd = color_pcd[table_mask]
    cup_pcd = color_pcd[cup_mask]
    other_pcd = color_pcd[other_mask]
    return table_pcd.shape[0], cup_pcd.shape[0], other_pcd.shape[0]

def get_instance_index(seg_instance, obj_name='teacup_601'):
    objs = seg_instance.values()
    if obj_name in objs:
        # find the corresponding key
        for key in seg_instance.keys():
            if seg_instance[key] == obj_name:
                obj_index = int(key)
                break
    else:
        obj_index = None
    return obj_index

def get_color_pcd_info(demo_data, stream_vis_sign=False, single_frame_vis_sign=False, table_heuristic=False, seg_heuristic=True, random_type='D0'):
    assert not (table_heuristic and seg_heuristic), "only one of table_heuristic and seg_heuristic can be True"

    if random_type == 'D0':
        # for the example mobile manipulation data
        # for D0
        x_range = [0.0, 2.3]
        y_range = [-1.5, 1.5]
        z_range = [0.7, 1.5] # remove the table leg and remove the ceiling light
    elif random_type == 'D1':
        # for D1
        x_range = [0.0, 4]
        y_range = [-3, 3]
        z_range = [0.7, 1.5]
    else:
        print('need to speficy the pcd range for the random type')
        breakpoint()

    base_link_pose = np.array(demo_data["obs"]['base_link_pose']) # (traj_length, 7)

    # get camera and eef information
    eye_rgb = np.array(demo_data["obs"]['robot_r1::robot_r1:eyes:Camera:0::rgb'])[:, :, :, :3]
    eye_depth = np.array(demo_data["obs"]['robot_r1::robot_r1:eyes:Camera:0::depth_linear'])[:, :, :, None]
    eye_pose = np.array(demo_data["obs"]['robot_r1:eyes:Camera:0_pose']) # (traj_length, 7)

    left_cam_rgb = np.array(demo_data["obs"]['robot_r1::robot_r1:left_eef_link:Camera:0::rgb'])[:, :, :, :3]
    left_cam_depth = np.array(demo_data["obs"]['robot_r1::robot_r1:left_eef_link:Camera:0::depth_linear'])[:, :, :, None]
    left_cam_pose = np.array(demo_data["obs"]['robot_r1:left_eef_link:Camera:0_pose']) # (traj_length, 7)

    right_cam_rgb = np.array(demo_data["obs"]['robot_r1::robot_r1:right_eef_link:Camera:0::rgb'])[:, :, :, :3]
    right_cam_depth = np.array(demo_data["obs"]['robot_r1::robot_r1:right_eef_link:Camera:0::depth_linear'])[:, :, :, None]
    right_cam_pose = np.array(demo_data["obs"]['robot_r1:right_eef_link:Camera:0_pose']) # (traj_length, 7)

    left_obj_seg_mask = np.array(demo_data["obs"]['robot_r1::robot_r1:left_eef_link:Camera:0::seg_instance'])[:, :, :, None]
    right_obj_seg_mask = np.array(demo_data["obs"]['robot_r1::robot_r1:right_eef_link:Camera:0::seg_instance'])[:, :, :, None]
    eye_obj_seg_mask = np.array(demo_data["obs"]['robot_r1::robot_r1:eyes:Camera:0::seg_instance'])[:, :, :, None]

    # need to recalculate the cup_index for each frame,
    # there can be situations where the cup is not in the pcd
    demo_cup_index = None
    obs_info = demo_data['obs_info']
    # read string from the hdf5 file
    for step_index in range(eye_depth.shape[0]):
        obs_step = json.loads(obs_info[step_index].decode('utf-8'))['robot_r1']
        left_cam_seg_instance = obs_step['robot_r1:left_eef_link:Camera:0']['seg_instance']
        left_cup_index = get_instance_index(left_cam_seg_instance, obj_name='teacup_601')
    
        right_cam_seg_instance = obs_step["robot_r1:right_eef_link:Camera:0"]['seg_instance']
        right_cup_index = get_instance_index(right_cam_seg_instance, obj_name='teacup_601')
        eye_cam_seg_instance = obs_step["robot_r1:eyes:Camera:0"]['seg_instance']
        eye_cup_index = get_instance_index(eye_cam_seg_instance, obj_name='teacup_601')

        print("")
        print('step', step_index)
        print('left cup index', left_cup_index)
        print('right cup index', right_cup_index)
        print('eye cup index', eye_cup_index)

        if eye_cup_index is not None and demo_cup_index is None:
            demo_cup_index = eye_cup_index
            break
    
    cup_index = demo_cup_index

    # start processing and visualizing the point cloud
    assert not (stream_vis_sign and single_frame_vis_sign), "only one of stream_vis_sign and single_frame_vis_sign can be True"

    if stream_vis_sign:
        vis = o3d.visualization.Visualizer()
        vis.create_window()
        pcd_vis = o3d.geometry.PointCloud()
        # o3d visualizer add axis
        axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3, origin=[0, 0, 0])
        vis.add_geometry(axis)
        firstfirst = True
    
    pcd_max = []
    pcd_min = []
    input_pcd_portion = []
    output_pcd_portion = []
    color_pcd_demo = []
    for i in range(eye_depth.shape[0]):

        ref_frame = 'world'
        ref_frame = 'robot'
        eye_cam_pcd = depth_to_pcd(eye_depth, eye_pose, base_link_pose, step_index=i, ref_frame=ref_frame)
        left_cam_pcd = depth_to_pcd(left_cam_depth, left_cam_pose, base_link_pose, step_index=i, ref_frame=ref_frame)
        right_cam_pcd = depth_to_pcd(right_cam_depth, right_cam_pose, base_link_pose, step_index=i, ref_frame=ref_frame)

        if seg_heuristic:
            eye_cam_rgbd = np.concatenate([eye_rgb[i]/255.0, eye_cam_pcd, eye_obj_seg_mask[i]], axis=-1).reshape(-1,7)
            left_cam_rgbd = np.concatenate([left_cam_rgb[i]/255.0, left_cam_pcd, left_obj_seg_mask[i]], axis=-1).reshape(-1,7)
            right_cam_rgbd = np.concatenate([right_cam_rgb[i]/255.0, right_cam_pcd, right_obj_seg_mask[i]], axis=-1).reshape(-1,7)
        else:
            eye_cam_rgbd = np.concatenate([eye_rgb[i]/255.0, eye_cam_pcd], axis=-1).reshape(-1,6)
            left_cam_rgbd = np.concatenate([left_cam_rgb[i]/255.0, left_cam_pcd], axis=-1).reshape(-1,6)
            right_cam_rgbd = np.concatenate([right_cam_rgb[i]/255.0, right_cam_pcd], axis=-1).reshape(-1,6)
        
        color_pcd = np.concatenate([eye_cam_rgbd, left_cam_rgbd, right_cam_rgbd], axis=0)

        # apply fps sampling
        sample_type = 'fps'
        num_points_to_sample = 4096

        # clip point cloud with a bounding box
        mask = (color_pcd[:, 3] > x_range[0]) & (color_pcd[:, 3] < x_range[1]) & (color_pcd[:, 4] > y_range[0]) & (color_pcd[:, 4] < y_range[1]) & (color_pcd[:, 5] > z_range[0]) & (color_pcd[:, 5] < z_range[1])
        color_pcd = color_pcd[mask]

        print('step index', i, 'seg_heuristic', seg_heuristic, 'table_heuristic', table_heuristic)
        if seg_heuristic:
            # upweight the cup points 
            teacup_mask = color_pcd[:, 6] == cup_index

            teacup_pcd = color_pcd[teacup_mask][: , :6]
            other_pcd = color_pcd[~teacup_mask][: , :6]
            color_pcd = color_pcd[:,:6]

            tea_cup_ratio = 0.3
            cup_samples = int(num_points_to_sample * tea_cup_ratio)
            other_samples = num_points_to_sample - cup_samples
            if teacup_pcd.shape[0] == 0 or other_pcd.shape[0] == 0:
                if teacup_pcd.shape[0] == 0:
                    print('teacup pcd shape is 0')
                else:
                    print('other pcd shape is 0')
                color_pcd = fps_downsample(color_pcd, num_points_to_sample)
            else:
                teacup_pcd_ds = fps_downsample(teacup_pcd, cup_samples)
                other_pcd_ds = fps_downsample(other_pcd, other_samples)
                color_pcd = np.concatenate([teacup_pcd_ds, other_pcd_ds], axis=0)
            assert color_pcd.shape[0] == num_points_to_sample, "color_pcd shape is not equal to num_points_to_sample"
            assert color_pcd.shape[1] == 6, "color_pcd shape 1 is not equal to 6"

        elif table_heuristic:
            table_pt_num, cup_pt_num, other_num = get_pcd_portion(color_pcd)
            input_pcd_portion.append([table_pt_num, cup_pt_num, other_num])

            # split the pcd based on table and not table and then do down sample differently
            table_mask = color_pcd[:, 5] < table_mask_height
            table_pcd = color_pcd[table_mask]
            not_table_pcd = color_pcd[~table_mask]
            table_ratio = 0.5
            table_samples = int(num_points_to_sample * table_ratio)
            not_table_samples = num_points_to_sample - table_samples

            if table_pcd.shape[0] == 0 or not_table_pcd.shape[0] == 0:
                print('step index', i)
                if table_pcd.shape[0] == 0:
                    print('table pcd shape is 0')
                else:
                    print('not table pcd shape is 0')
                color_pcd = fps_downsample(color_pcd, num_points_to_sample)
            else:
                table_pcd_ds = fps_downsample(table_pcd, table_samples)
                not_table_pcd_ds = fps_downsample(not_table_pcd, not_table_samples)
                color_pcd = np.concatenate([table_pcd_ds, not_table_pcd_ds], axis=0)
            assert color_pcd.shape[0] == num_points_to_sample, "color_pcd shape is not equal to num_points_to_sample"

        else:
            color_pcd = fps_downsample(color_pcd, num_points_to_sample)
            assert color_pcd.shape[0] == num_points_to_sample, "color_pcd shape is not equal to num_points_to_sample"

        # get post processing portion 
        table_pt_num, cup_pt_num, other_num = get_pcd_portion(color_pcd)
        output_pcd_portion.append([table_pt_num, cup_pt_num, other_num])
        
        if stream_vis_sign or single_frame_vis_sign:
            print('step index', i)
            print("")

        pcd_max.append(np.max(color_pcd[:, 3:], axis=0))
        pcd_min.append(np.min(color_pcd[:, 3:], axis=0))
        color_pcd_demo.append(color_pcd)

        if stream_vis_sign:
            pcd_vis.colors = o3d.utility.Vector3dVector(color_pcd[:, :3])
            pcd_vis.points = o3d.utility.Vector3dVector(color_pcd[:, 3:])
            if firstfirst:
                vis.add_geometry(pcd_vis)
                firstfirst = False
            else:
                vis.update_geometry(pcd_vis)
            
            vis.poll_events()
            vis.update_renderer() 
            # time.sleep(0.1)
        if single_frame_vis_sign:
            color_pcd_vis(color_pcd)


    # post processing some information
    pcd_max = np.max(np.array(pcd_max), axis=0)
    pcd_min = np.min(np.array(pcd_min), axis=0)

    return {
        'combined::color_point_cloud': np.array(color_pcd_demo),
        'pcd_max': pcd_max,
        'pcd_min': pcd_min,
        'input_pcd_portion': input_pcd_portion,
        'output_pcd_portion': output_pcd_portion,
    }


def process_rm_r1_tidy_table(file_path, obs_type, sample_type='fps', 
                             with_color=True, vis_sign=False, concat_all=False, output_path=None, 
                             table_heuristic=False, seg_heuristic=True,
                             random_type='D0',
                             force_output_path=False):
    # Open the file

    with h5py.File(file_path, "r") as hdf:
        # Access a group or dataset
        group = hdf["data"]
        print('num demos', len(group.keys()))
        # breakpoint()
        # process data for each demo
        for demo_key in group.keys():
            demo_data = group[demo_key]
            print("")
            print('Start processing', demo_key) 

            obs_dict = {}
            next_obs_dict = {}

            # change coffee cup matrix to pose and quat
            cup_pose_matrix = np.array(demo_data["datagen_info"]['object_poses']['teacup_601'])
            cup_pose = []
            for i in range(cup_pose_matrix.shape[0]):
                pos, quat = mat2pose(th.from_numpy(cup_pose_matrix[i]))
                cup_pose_i = np.concatenate([pos.numpy(), quat.numpy()])
                cup_pose.append(cup_pose_i)
            cup_pose = np.array(cup_pose)

            sink_pose_matrix = np.array(demo_data["datagen_info"]['object_poses']['drop_in_sink_awvzkn_0'])
            sink_pose = []
            for i in range(sink_pose_matrix.shape[0]):
                pos, quat = mat2pose(th.from_numpy(sink_pose_matrix[i]))
                sink_pose_i = np.concatenate([pos.numpy(), quat.numpy()])
                sink_pose.append(sink_pose_i)
            sink_pose = np.array(sink_pose)
            print('cup pose', cup_pose.shape)
            print('sink pose', sink_pose.shape)


            obs_key_list = [
                'joint_qpos', 
                'prop_state',
                'prop_eef_state',
                'prop_eef_basepose',
                'combined::color_point_cloud',
                'object::drop_in_sink_awvzkn_0',
                'object::teacup_601',
            ]
            try:
                if "combined::color_point_cloud" in obs_key_list:
                    pcd_start_time = time.time()
                    color_pcd_dict = get_color_pcd_info(demo_data, stream_vis_sign=vis_sign, 
                                                        table_heuristic=table_heuristic, 
                                                        seg_heuristic=seg_heuristic,
                                                        random_type=random_type
                                                        )
                    pcd_demo = color_pcd_dict['combined::color_point_cloud']
                    print('get color pcd time', time.time() - pcd_start_time)
                    obs_dict['pcd_max'] = color_pcd_dict['pcd_max']
                    obs_dict['pcd_min'] = color_pcd_dict['pcd_min']
                    obs_dict['input_pcd_portion'] = color_pcd_dict['input_pcd_portion']
                    obs_dict['output_pcd_portion'] = color_pcd_dict['output_pcd_portion']
            except:
                print('file cannot be processed: ', file_path, 'with demo key', demo_key)
                print("")
                continue
            
            num_steps = demo_data['actions'].shape[0] - 1
            actions = demo_data["actions"][:-1] # actions already in range [-1, 1]

            # get rewards and dones
            # assume the data are expert demonstrations and only the last step is the success step
            rewards = np.zeros(num_steps)
            rewards[-1] = 1
            dones = np.zeros(num_steps)
            dones[-1] = 1

            # get the obs info
            for obs_key in obs_key_list:
                if "point_cloud" in obs_key:
                    obs_dict[obs_key] = pcd_demo[:-1]
                    next_obs_dict[obs_key] = pcd_demo[1:]
                elif obs_key == "object::teacup_601":
                    obs_dict[obs_key] = cup_pose[:-1]
                    next_obs_dict[obs_key] = cup_pose[1:]
                elif obs_key == "object::drop_in_sink_awvzkn_0":
                    obs_dict[obs_key] = sink_pose[:-1]
                    next_obs_dict[obs_key] = sink_pose[1:]
                else:
                    obs_dict[obs_key] = demo_data['obs'][obs_key][:-1]
                    next_obs_dict[obs_key] = demo_data['obs'][obs_key][1:]
                
                # assert obs_dict[obs_key].shape[0] == next_obs_dict[obs_key].shape[0] == num_steps
            demo_dict = {
                "obs": obs_dict,
                "next_obs": next_obs_dict,
                "actions": actions,
                "rewards": rewards,
                "dones": dones
            }

            episode_length = demo_dict["obs"]["joint_qpos"].shape[0]
            print('episode length: ', episode_length)
            # breakpoint()
            write_to_hdf5_per_demo({demo_key: demo_dict}, output_path, force_output_path=force_output_path)


def process_rm_r1_tidy_table_cluster(folder_path, obs_type, sample_type='fps', with_color=True, 
                                     vis_sign=False, concat_all=False, output_path=None, start_worker_index=0,
                                     table_heuristic=False, seg_heuristic=True,
                                     random_type='D0', baseline='momagen'):
    # get all the demo.hdf5 file path
    worker_folder_list = [
        'r1_tidy_table_worker_{}'.format(int(start_worker_index)), 
        ]

    tmp_folder_path = os.path.join(folder_path, worker_folder_list[0], f"demo_src_r1_tidy_table_task_{random_type}", "tmp")
    if os.path.exists(tmp_folder_path):
        # the tmp folder exists, meaning the data is not merged
        print('the tmp folder exists, meaning the data is not merged')
        return 
    
    file_list_cannot_be_processed = []
    # process each file
    for i, worker_folder_path in enumerate(worker_folder_list):
        file_path = os.path.join(folder_path, worker_folder_path, f"demo_src_r1_tidy_table_task_{random_type}", "demo.hdf5")
        if args.baseline == 'mimicgen' or args.baseline == 'skillgen':
            file_path = os.path.join(folder_path, worker_folder_path, f"demo_src_r1_tidy_table_{baseline}_task_{random_type}", "demo.hdf5")
        print("")
        print('Start processing', file_path)
        print("")
        # breakpoint()

        output_folder = os.path.join(output_path, worker_folder_path)
        print('output folder: ', output_folder)
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
        new_output_path = os.path.join(output_folder, "rm_demo.hdf5")
        try:
            process_rm_r1_tidy_table(
                file_path=file_path,
                obs_type=obs_type,
                sample_type=sample_type,
                with_color=with_color,
                vis_sign=vis_sign,
                concat_all=concat_all,
                output_path=new_output_path,
                table_heuristic=table_heuristic,
                seg_heuristic=seg_heuristic,
                random_type=random_type,
            )
        except:
            print('file cannot be processed: ', file_path)
            file_list_cannot_be_processed.append(file_path)
            print("")
            print('file list cannot be processed: ', file_list_cannot_be_processed)
            print('total number of files cannot be processed: ', len(file_list_cannot_be_processed))
            print("")
            continue


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # add file path argument
    parser.add_argument("--file_path", 
                        type=str, 
                        default="", 
                        help="hdf5 file path")
    parser.add_argument("--folder_path",
                        type=str, 
                        default="", 
                        help="hdf5 folder path")
    # add output file path argument
    parser.add_argument("--output_path", 
                        type=str, 
                        default="", 
                        help="output hdf5 file path")
    # add observation type
    parser.add_argument("--obs_type", 
                        type=str, 
                        default="point_cloud", 
                        help="observation key type", 
                        choices=["low_dim", "rgb", "depth", "point_cloud"])
    # trian val split ratio
    parser.add_argument(
        "--split_ratio",
        type=float,
        default=0.1,
        help="validation ratio, in (0, 1)"
    )
    # pcd number of samples
    parser.add_argument(
        "--num_pcd_samples",
        type=int,
        default=1024,
        help="number of samples after processing pcd"
    )
    parser.add_argument(
        "--vis_sign",
        action="store_true",
        help="whether visualize the pcd when processing"
    )
    parser.add_argument(
        "--fps",
        action="store_true",
        help="use farthest point sampling to sample the point cloud"
    )
    parser.add_argument(
        "--random",
        action="store_true",
        help="use random sampling to sample the point cloud"
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="debug mode: only save the first 50 steps of each demo"
    )
    parser.add_argument(
        "--with_color",
        action="store_true",
        help="debug mode: only save the first 50 steps of each demo"
    )
    parser.add_argument(
        "--process_subtasks",
        action="store_true",
        help="process the subtaks"
    )

    parser.add_argument(
        "--start_worker_index",
        type=int,
        default=0,
        help="start data worker index, will process this one and the next one"
    )
    parser.add_argument(
        "--table_heuristic",
        action="store_true",
        help="whether need to use table heuristic to process the pcd"
    )
    parser.add_argument(
        "--seg_heuristic",
        action="store_true",
        help="whether need to use segmentation heuristic to process the pcd"
    )
    parser.add_argument(
        "--random_type",
        type=str,
        default='D0',
        help="random type for the data, D0, D1, D2"
    )
    parser.add_argument(
        "--baseline",
        default='momagen',
        type=str,
        help="baseline for the data, momagen, mimicgen, skillgen"
    )

    args = parser.parse_args()
    assert args.baseline in ['momagen', 'mimicgen', 'skillgen'], "baseline must be momagen, mimicgen or skillgen"

    file_path = args.file_path
    output_path = args.output_path
    if args.baseline == 'mimicgen' or args.baseline == 'skillgen':
        args.output_path = args.output_path + "_" + args.baseline
        output_path = args.output_path

    global NUM_POINTS_TO_SAMPLE
    NUM_POINTS_TO_SAMPLE = args.num_pcd_samples
    
    sample_type='fps'

    # change the processed file name accordingly
    if args.obs_type == "point_cloud":
        output_path = output_path.replace(".hdf5", "_{}_{}.hdf5".format(sample_type, args.num_pcd_samples))
        if args.with_color:
            output_path = output_path.replace(".hdf5", "_color.hdf5")
    if args.debug:
        output_path = output_path.replace(".hdf5", "_debug.hdf5")
    if args.obs_type == "rgb":
        output_path = output_path.replace(".hdf5", "_rgb.hdf5")

    on_cluster = True
    if on_cluster:
        output_path = args.output_path

    # hacky for mobile manipulation
    folder_path = args.folder_path
    args.vis_sign = True
    args.with_color = True

    assert args.random_type in ['D0', 'D1', 'D2'], "random type must be D0, D1 or D2"


    args.seg_heuristic = True
    args.table_heuristic = False

    robomimic_dataset = process_rm_r1_tidy_table_cluster(
        folder_path=folder_path,
        obs_type=args.obs_type,
        sample_type=sample_type,
        with_color=args.with_color,
        vis_sign=args.vis_sign,
        concat_all=False,
        output_path=output_path,
        start_worker_index=args.start_worker_index,
        table_heuristic=args.table_heuristic,
        seg_heuristic=args.seg_heuristic,
        random_type=args.random_type,
        baseline = args.baseline,
    )
