import argparse
import h5py
import numpy as np
import os
import matplotlib.pyplot as plt
import open3d as o3d
import time

# visualization utility

def color_pcd_vis(color_pcd):
    # visualize with open3D
    pcd = o3d.geometry.PointCloud()
    pcd.colors = o3d.utility.Vector3dVector(color_pcd[:, :3])
    pcd.points = o3d.utility.Vector3dVector(color_pcd[:,3:]) 
    axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3, origin=[0, 0, 0])
    o3d.visualization.draw_geometries([pcd, axis])
    print('number points', color_pcd.shape[0])


def plot_eef_traj_density(link_poses):
    left_eef = link_poses['left_eef']
    right_eef = link_poses['right_eef']
    
    # plot 3d
    fig = plt.figure(figsize=(15, 15))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(left_eef[:, 0], left_eef[:, 1], left_eef[:, 2], c='r', marker='o', s=0.1, label='left eef')
    ax.scatter(right_eef[:, 0], right_eef[:, 1], right_eef[:, 2], c='b', marker='o', s=0.1, label='right eef')
    ax.legend()
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    for i in range(left_eef.shape[0]):
        if i % 50 == 0:
            ax.text(left_eef[i, 0], left_eef[i, 1], left_eef[i, 2], str(i), color='red')
            ax.text(right_eef[i, 0], right_eef[i, 1], right_eef[i, 2], str(i), color='blue')
    
    plt.title('3D EEF Trajectory')
    plt.savefig('./brs_data_vis/eef_traj_density.png')
    plt.show()
    breakpoint()
    plt.close()


def plot_gripper_width(gripper_state):
    left_width = gripper_state['left_gripper']['gripper_position']
    right_width = gripper_state['right_gripper']['gripper_position']

    # plot 2d
    fig, ax = plt.subplots()
    ax.plot(left_width, label='left gripper')
    ax.plot(right_width, label='right gripper')
    ax.legend()
    ax.set_xlabel('Time step')
    ax.set_ylabel('Gripper width')
    plt.savefig('./brs_data_vis/gripper_width.png')
    plt.show()
    breakpoint()
    plt.close()


def plot_gripper_width_and_action(gripper_state, action):
    left_width = gripper_state['left_gripper']['gripper_position']
    right_width = gripper_state['right_gripper']['gripper_position']
    action_left = action['left_gripper']
    action_right = action['right_gripper']

    # plot 2d, left and right in two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1)
    fig.set_size_inches(15, 10)
    ax1.plot(left_width, label='left gripper width')
    ax1.plot(action_left, label='action')
    ax1.legend()
    ax1.set_xlabel('Time step')
    ax1.set_ylabel('Gripper width and action')
    ax1.set_title('left gripper width and action')

    ax2.plot(right_width, label='right gripper width')
    ax2.plot(action_right, label='action')
    ax2.legend()
    ax2.set_xlabel('Time step')
    ax2.set_ylabel('Gripper width and action')
    ax2.set_title('right gripper width and action')

    plt.title('Gripper width and action')
    plt.savefig('./brs_data_vis/gripper_width_and_action.png')
    plt.show()
    breakpoint()
    plt.close()


def plot_base_velocity(base_velocity):
    fig, ax = plt.subplots()
    plt.plot(base_velocity[:, 0], label='base velocity x')
    plt.plot(base_velocity[:, 1], label='base velocity y')
    plt.plot(base_velocity[:, 2], label='base velocity z')
    ax.legend()
    ax.set_xlabel('Time step')
    ax.set_ylabel('Base velocity')
    plt.savefig('base_velocity.png')
    plt.show()
    breakpoint()
    plt.close()


def robomimic_to_brs(robomimic_dataset_path, traj_vis=False):
    """
    reformat the data to be used in brs
    """

    with h5py.File(robomimic_dataset_path, "r") as f:
        data_group = f["data"]['demo_0'] # assume each file has one single demo

        # process actions
        brs_action = {
            'left_arm': [], 
            'left_gripper': [], 
            'mobile_base': [], 
            'right_arm': [], 
            'right_gripper': [], 
            'torso': [], 
        } # total 21 dim for R1
        rm_action = data_group["actions"]
        brs_action['mobile_base'] = np.array(rm_action[:, 0:3]) # mobile base 3
        brs_action['torso'] = np.array(rm_action[:, 3:7]) # torso 4
        brs_action['left_arm'] = np.array(rm_action[:, 7:13]) # left arm 6
        brs_action['left_gripper'] = np.array(rm_action[:, 13]) # left gripper 1
        brs_action['right_arm'] = np.array(rm_action[:, 14:20]) # right arm 6
        brs_action['right_gripper'] = np.array(rm_action[:, 20]) # right gripper 1

        # process obs
        brs_obs = {

            'depth': {
                'head': {
                    'depth': [], # num steps, 94, 168
                    'stamp': [], # num steps
                },
                'left_wrist': {
                    'depth': [], # num steps, 94, 168
                    'stamp': [], # num steps
                },
                'right_wrist': {
                    'depth': [], # num steps, 94, 168
                    'stamp': [], # num steps
                }
            },

            'gripper_state': {
                'left_gripper': {
                    'gripper_effort': [], # num steps, 1
                    'gripper_position': [], # num steps, 1
                    'gripper_velocity': [], # num steps, 1
                    'seq': [], # num steps, 1
                    'stamp': [], # num steps, 1
                },
                'right_gripper': {
                    'gripper_effort': [], # num steps, 1
                    'gripper_position': [], # num steps, 1
                    'gripper_velocity': [], # num steps, 1
                    'seq': [], # num steps, 1
                    'stamp': [], # num steps, 1
                },
            },

            'joint_state': {
                'left_arm': {
                    'joint_effort': [], # num steps, 7
                    'joint_position': [], # num steps, 7
                    'joint_velocity': [], # num steps, 7
                    'seq': [], # num steps, 1
                    'stamp': [], # num steps, 1
                },
                'right_arm': {
                    'joint_effort': [], # num steps, 7
                    'joint_position': [], # num steps, 7
                    'joint_velocity': [], # num steps, 7
                    'seq': [], # num steps, 1
                    'stamp': [], # num steps, 1
                },
                'torso':
                {
                    'joint_effort': [], # num steps, 4
                    'joint_position': [], # num steps, 4
                    'joint_velocity': [], # num steps, 4
                    'seq': [], # num steps, 1
                    'stamp': [], # num steps, 1
                }
            },

            'link_poses': {
                'head': [], # num steps, 7
                'left_eef': [], # num steps, 7
                'right_eef': [], # num steps, 7
            },

            'odom': {
                'angular_velocity': [], # num steps, 3
                'base_velocity': [], # num steps, 3
                'base_qpos': [], # num steps, 3
                'linear_velocity': [], # num steps, 3
                'orientation': [], # num steps, 4
                'position': [], # num steps, 3
                'stamp': [], # num steps, 1
            },

            'point_cloud': {
                'fused': {
                    'padding_mask': [], # num steps, 4096
                    'rgb': [], # num steps, 4096, 3
                    'xyz': [], # num steps, 4096, 3
                }
            },

            'rgb': {
                'head': {
                    'img': [], # num steps, 94, 168, 3
                    'stamp': [], # num steps
                },
                'left_wrist': {
                    'img': [], # num steps, 94, 168, 3
                    'stamp': [], # num steps
                },
                'right_wrist': {
                    'img': [], # num steps, 94, 168, 3
                    'stamp': [], # num steps
                }
            },

            'object_init_states': {},

            'object_states': {},
            
            'task_info': {}

        }

        rm_obs = data_group['obs']

        # prop:  ["odom/base_velocity", "qpos/torso", "qpos/left_arm", "qpos/left_gripper", "qpos/right_arm", "qpos/right_gripper"]
        # point_cloud: ['rgb', 'xyz, 'padding']
        brs_obs['odom']['base_velocity'] = np.array(rm_obs['prop_eef_state'][:, :3]) # base velocity 3
        brs_obs['joint_state']['torso']['joint_position'] = np.array(rm_obs['prop_eef_state'][:, 3:7]) # torso 4

        brs_obs['joint_state']['left_arm']['joint_position'] = np.array(rm_obs['prop_eef_state'][:, 7:13]) # left arm 6
        brs_obs['link_poses']['left_eef'] = np.array(rm_obs['prop_eef_state'][:, 13:20]) # head 7
        brs_obs['gripper_state']['left_gripper']['gripper_position'] = np.array(rm_obs['prop_eef_state'][:, 20]) # left gripper 1

        brs_obs['joint_state']['right_arm']['joint_position'] = np.array(rm_obs['prop_eef_state'][:, 21:27]) # right arm 6
        brs_obs['link_poses']['right_eef'] = np.array(rm_obs['prop_eef_state'][:, 27:34]) # left wrist 7
        brs_obs['gripper_state']['right_gripper']['gripper_position'] = np.array(rm_obs['prop_eef_state'][:, 34]) # right gripper 1

        if 'combined::color_point_cloud' in rm_obs.keys():
            brs_obs['point_cloud']['fused']['rgb'] = np.array(rm_obs['combined::color_point_cloud'][:,:, :3]) # num steps, 4096, 3
            brs_obs['point_cloud']['fused']['xyz'] = np.array(rm_obs['combined::color_point_cloud'][:,:, 3:6]) # num steps, 4096, 3
            brs_obs['point_cloud']['fused']['padding_mask'] = np.ones((brs_obs['point_cloud']['fused']['rgb'].shape[0], brs_obs['point_cloud']['fused']['rgb'].shape[1]), dtype=bool) # assuming all are valid points, num steps, 4096

        # the following may not be used in brs, but it is still queried when sample batch data        
        brs_obs['link_poses']['head'] = np.ones((brs_obs['link_poses']['right_eef'].shape[0], brs_obs['link_poses']['right_eef'].shape[1]))

        # TODO: this should be task relevant
        if 'object::teacup_601' in rm_obs.keys():
            # test_r1_cup task
            brs_obs['object_init_states']['object::teacup_601'] = np.array(rm_obs['object::teacup_601'][0,:]) # num steps, 7
            brs_obs['object_states']['teacup_601'] = np.array(rm_obs['object::teacup_601']) # num steps, 7
            
        if "object::drop_in_sink_awvzkn_0" in rm_obs.keys():
            brs_obs['object_init_states']['object::drop_in_sink_awvzkn_0'] = np.array(rm_obs['object::drop_in_sink_awvzkn_0'][0,:]) # num steps, 7
            brs_obs['object_states']['drop_in_sink_awvzkn_0'] = np.array(rm_obs['object::drop_in_sink_awvzkn_0']) # num steps, 7

        # add base position back to odom
        brs_obs['odom']['base_qpos'] = np.array(rm_obs['joint_qpos'][:, 0:6]) # base qpos 6

        brs_obs['task_info']['seg_mask'] = np.array(rm_obs['seg_mask']) # num steps, 1
        if "manip_seg_mask" in rm_obs.keys():
            brs_obs['task_info']['manip_seg_mask'] = np.array(rm_obs['manip_seg_mask'])
            episode_len = brs_obs['task_info']['manip_seg_mask'].shape[0]
            manip_portion = np.sum(brs_obs['task_info']['manip_seg_mask'])/brs_obs['task_info']['manip_seg_mask'].shape[0]
            print("")
            print('episode length', episode_len)
            print('manip portion', manip_portion)
        if "replay_seg_mask" in rm_obs.keys():
            brs_obs['task_info']['replay_seg_mask'] = np.array(rm_obs['replay_seg_mask'])
            replay_portion = np.sum(brs_obs['task_info']['replay_seg_mask'])/ brs_obs['task_info']['replay_seg_mask'].shape[0]
            print('replay portion', replay_portion)

    if traj_vis:
        plot_base_velocity(brs_obs['odom']['base_velocity'])
        # plot_eef_traj_density(brs_obs['link_poses'])
        # plot_gripper_width(brs_obs['gripper_state'])
        # plot_gripper_width_and_action(brs_obs['gripper_state'], brs_action)

    brs_data_dict = {
        'action': brs_action, 
        'obs': brs_obs,
    }
    return brs_data_dict, manip_portion, replay_portion


def save_brs_data(output_file_path, data):

    with h5py.File(output_file_path, "w") as f:
        
        demo_group = f.create_group("demo_0") # Create a group for the data
        
        # Save actions
        action_group = demo_group.create_group("action")
        action_dict = data['action']
        for action_key in action_dict.keys():
            action = action_dict[action_key]
            # Create a group for each action type
            action_group_sub = action_group.create_dataset(action_key, data=action)

        # Save obs
        obs_group = demo_group.create_group("obs")
        obs_dict = data['obs']
        for obs_key in obs_dict.keys():
            obs_sub_dict = obs_dict[obs_key]
            # Create a group for each observation type
            obs_group_sub = obs_group.create_group(obs_key)

            if obs_key in ['link_poses', 'odom', 'object_init_states', 'object_states', 'task_info']:
                for sub_key in obs_sub_dict.keys():
                    sub_data = obs_sub_dict[sub_key]
                    obs_sub_group = obs_group_sub.create_dataset(sub_key, data=sub_data)
            else:
                # Iterate through each sub-dictionary in the observation
                for sub_key in obs_sub_dict.keys():
                    sub_data = obs_sub_dict[sub_key]
                    # Create a group for each sub-key
                    obs_sub_group = obs_group_sub.create_group(sub_key)

                    # Save each dataset in the sub-dictionary
                    for data_key in sub_data.keys():
                        data_value = sub_data[data_key]
                        obs_sub_group.create_dataset(data_key, data=data_value)

    print('Finished writing the data to', output_file_path)


def clip_traj_based_on_cup_pose(cup_pose):
    act_mask = np.ones(cup_pose.shape[0], dtype=bool)
    cup_z = cup_pose[:, 2]
    act_mask[cup_z < 0.9] = True
    print('clip based on cup')
    print('remaining steps', sum(act_mask))
    print('original length', len(act_mask))
    return act_mask

def remove_no_act_segment_tidy_table(traj, gripper_position, seg_mask, window_size=5, tol=1e-3):
    # construct a mask for the trajectory
    act_mask = np.ones(traj.shape[0], dtype=bool)
    for i in range(0, len(seg_mask)-window_size):
        # get the following window_size 
        start_pose = np.concatenate([traj[i], gripper_position[i][None]])
        end_pose = np.concatenate([traj[i+window_size], gripper_position[i+window_size][None]])
        if np.allclose(start_pose, end_pose, atol=tol):
            for i_i in range(i+1, i+window_size+1):
                if seg_mask[i_i] == 1:
                    # only remove the step if it is in the replay segment
                    act_mask[i_i] = False
                    # print('remove segment', i_i)

    return act_mask


def uniform_dowmsample(gripper_position, window_size=3):
    # downsample the trajectory 
    down_sample_mask = np.zeros(gripper_position.shape[0], dtype=bool)
    for i in range(0, len(gripper_position), window_size):
        down_sample_mask[i] = True
    # make sure the gripper close part is not downsampled
    gripper_pos_diff = np.abs(gripper_position[1:] - gripper_position[:-1])
    gripper_act_mask = np.zeros(gripper_position.shape[0], dtype=bool)
    for i in range(len(gripper_pos_diff)):
        if np.all(gripper_pos_diff[i] > 1e-5):
            gripper_act_mask[i] = True
    # combine the two masks
    mask = np.logical_or(down_sample_mask, gripper_act_mask)
    return mask


def get_episode_length_stats(args):
    traj_length_list = []
    with h5py.File(args.source_file, 'r') as f:
        demo_keys = list(f.keys())
        for key in demo_keys:
            demo = f[key]
            seg_mask = np.array(demo['obs/task_info/seg_mask']) # num_steps, 1
            traj_length = len(seg_mask)
            traj_length_list.append(traj_length)
    traj_length_list = np.array(traj_length_list)
    print('max traj length', np.max(traj_length_list))
    print('min traj length', np.min(traj_length_list))
    print('mean traj length', np.mean(traj_length_list))
    breakpoint()



def data_clean(args):

    f_new = h5py.File(args.target_file, "w")

    # args.downsample = True

    # load the source file
    num_demos_cleaned = 0
    with h5py.File(args.source_file, 'r') as f:
        demo_keys = list(f.keys())
        for key in demo_keys:
            print("")
            print(key)
            demo = f[key]

            # construct mask for the trajectory
            seg_mask = np.array(demo['obs/task_info/replay_seg_mask']) # num_steps, 1
            cup_pose = np.array(demo['obs/object_states/teacup_601']) # num_steps, 7
            left_eef_traj = np.array(demo['obs/link_poses/left_eef']) # num_steps, 7
            left_gripper_position = np.array(demo['obs/gripper_state/left_gripper/gripper_position']) # num_steps, 7
            act_mask = remove_no_act_segment_tidy_table(left_eef_traj, left_gripper_position, seg_mask, window_size=5, tol=1e-3)
            cup_mask = clip_traj_based_on_cup_pose(cup_pose)
            # combine the two masks
            mask = np.logical_and(act_mask, cup_mask)

            if args.downsample:
                downsample_mask = uniform_dowmsample(left_gripper_position, window_size=args.downsample_window_size)
                mask = np.logical_and(mask, downsample_mask)
                
            # print('manipulation start index', np.where(seg_mask)[0][0])
            print('original length', len(mask))
            print('reamaining steps', sum(mask))

            demo_group = f_new.create_group(key)
            action_group = demo_group.create_group("action")
            action_dict = demo['action']
            for action_key in action_dict.keys():
                action = np.array(action_dict[action_key])
                action = action[mask] # apply mask to the action
                action_group.create_dataset(action_key, data=action)

            obs_group = demo_group.create_group("obs")
            obs_dict = demo['obs']
            for obs_key in obs_dict.keys():
                obs_sub_dict = obs_dict[obs_key]
                obs_group_sub = obs_group.create_group(obs_key)
                if obs_key in ['link_poses', 'odom', 'object_init_states', 'object_states', 'task_info']:
                    for sub_key in obs_sub_dict.keys():
                        sub_data = np.array(obs_sub_dict[sub_key])
                        if len(sub_data) == len(mask):
                            sub_data = sub_data[mask] # apply mask
                        else:
                            pass
                        obs_sub_group = obs_group_sub.create_dataset(sub_key, data=sub_data)
                else:
                    # Iterate through each sub-dictionary in the observation
                    for sub_key in obs_sub_dict.keys():
                        sub_data = obs_sub_dict[sub_key]
                        # Create a group for each sub-key
                        obs_sub_group = obs_group_sub.create_group(sub_key)

                        # Save each dataset in the sub-dictionary
                        for data_key in sub_data.keys():
                            data_value = np.array(sub_data[data_key])
                            if len(data_value) == len(mask):
                                data_value = data_value[mask] # apply mask
                            else:
                                pass
                            obs_sub_group.create_dataset(data_key, data=data_value)
            # check the segmentation mask
            seg_mask = np.array(obs_group['task_info/seg_mask'])
            print('replay portion', sum(seg_mask)/len(seg_mask))
            num_demos_cleaned += 1
        print('num demos saved after data clean', num_demos_cleaned)
        print('num demos in the source file', len(demo_keys))
        
        f_new.close()
        print('data saved to', args.target_file)


def main_cluster(args, data_folder_name, data_name='r1_tidy_table_full'):
    args.save_init_pose = False
    if args.save_init_pose:
        file_name = os.path.join(data_folder_name, data_name+'.hdf5')
        init_cup_pose_list = []
        with h5py.File(file_name, "r") as f:
            demo_keys = f.keys()
            for demo_key in demo_keys:
                demo = f[demo_key]
                cup_init_pose = np.array(demo['obs/object_init_states/object::teacup_601'])
                print('tea cup init pose', cup_init_pose)
                # save the cup init pose to a list
                init_cup_pose_list.append(cup_init_pose)
        # save the cup init pose to a file
        print('init_cup_pose_list', len(init_cup_pose_list))
        init_cup_pose_list = np.array(init_cup_pose_list)
        save_name = os.path.join(data_folder_name, data_name+'.npy')
        np.save(save_name, init_cup_pose_list)
        breakpoint()

    all_file_path = []
    # rank the folder name based on the worder id
    worker_folder_list = os.listdir(args.file_folder)

    worker_folder_list.sort(key=lambda x: int(x.split('_')[-1])) # sorted from 0
    print('worker_folder_list', worker_folder_list)
    for worker_folder_name in worker_folder_list:
        print('worker_folder_name', worker_folder_name)
        worker_folder_path = os.path.join(args.file_folder, worker_folder_name, 'tmp')
        if not os.path.exists(worker_folder_path):
            worker_folder_path = os.path.join(args.file_folder, worker_folder_name)
        print('len saved files', len(os.listdir(worker_folder_path)))
        # count_ptr = 0
        for file_name in os.listdir(worker_folder_path):
            # check if the file is an hdf5 file
            if not file_name.endswith('.hdf5'):
                continue
            # construct the full file path
            robomimic_dataset_path = os.path.join(worker_folder_path, file_name)

            # ##############################
            # # directly process the file here
            # ##############################
            # print('robomimic_dataset_path', robomimic_dataset_path)
            # brs_data_dict, manip_portion, replay_portion = robomimic_to_brs(robomimic_dataset_path, traj_vis=args.traj_vis)
            # if not os.path.exists(args.output_folder):
            #     os.makedirs(args.output_folder)
            # output_file_name = os.path.join(args.output_folder, f"{worker_folder_name}_demo_{int(count_ptr)}_brs.hdf5")
            # print('output_file_name', output_file_name)
            # save_brs_data(output_file_name, brs_data_dict)
            # count_ptr += 1

            all_file_path.append(robomimic_dataset_path)
    print('the number of demos in the processed data', len(all_file_path))

    args.loop_over_files = False
    if args.loop_over_files:
        count_ptr = 0
        manip_portion_list = []
        replay_portion_list = []
        nav_portion_list = []
        # loop over the file in the folder
        for file_name in all_file_path:
            try:
                count_ptr += 1
                print('count_ptr', count_ptr)
                # check if the file is an hdf5 file
                if not file_name.endswith('.hdf5'):
                    print('file name is not hdf5 file', file_name)
                    breakpoint()
                    continue
                # construct the full file path
                robomimic_dataset_path = os.path.join(args.file_folder, file_name)
                print("Processing file:", robomimic_dataset_path)
            
                brs_data_dict, manip_portion, replay_portion = robomimic_to_brs(robomimic_dataset_path, traj_vis=args.traj_vis)
                manip_portion_list.append(manip_portion)
                replay_portion_list.append(replay_portion)
                nav_portion_list.append(1-manip_portion-replay_portion)

                if not os.path.exists(args.output_folder):
                    os.makedirs(args.output_folder)
                output_file_name = args.output_folder + f"/demo_{int(count_ptr)}_brs.hdf5" 
                time.sleep(1)
                print('sleep for 1 second')
                
                save_brs_data(output_file_name, brs_data_dict)
            except:
                print('error in processing file', robomimic_dataset_path)
                continue

            manip_portion_list = np.array(manip_portion_list)
            replay_portion_list = np.array(replay_portion_list)
            nav_portion_list = np.array(nav_portion_list)
            print('total number of processed demos', len(manip_portion_list))
            print('manip portion', 'mean', np.mean(manip_portion_list), 'std', np.std(manip_portion_list), 'max', np.max(manip_portion_list), 'min', np.min(manip_portion_list))
            print('replay portion', 'mean', np.mean(replay_portion_list), 'std', np.std(replay_portion_list), 'max', np.max(replay_portion_list), 'min', np.min(replay_portion_list))
            print('nav portion', 'mean', np.mean(nav_portion_list), 'std', np.std(nav_portion_list), 'max', np.max(nav_portion_list), 'min', np.min(nav_portion_list))


    args.merge_sign = False
    if args.merge_sign:
        # write demos in order to new file
        f_new = h5py.File(args.new_hdf5_path, "w")
        base_vel_info = []
        all_file_names = os.listdir(args.output_folder)
        for i, file_name in enumerate(all_file_names):
            file_name = os.path.join(args.output_folder, file_name)
            print(i, 'file name', file_name, 'in total number of processed demos', len(all_file_names))
            with h5py.File(file_name, "r") as f:
                demo_str = "demo_{}".format(i)
                # get data
                demo_data = f['demo_0']
                # get base velocity
                base_velocity = np.array(demo_data['obs/odom/base_velocity'])
                base_vel_info.append(base_velocity)
                f.copy("demo_0", f_new, name=demo_str)
            
            if i % 10 == 0:
                # get per dim average
                base_vel_mean = np.mean(np.concatenate(base_vel_info), axis=0)
                base_vel_max = np.max(np.concatenate(base_vel_info), axis=0)
                base_vel_min = np.min(np.concatenate(base_vel_info), axis=0)
                print("")
                print('base_vel_mean', base_vel_mean)
                print('base_vel_max', base_vel_max)
                print('base_vel_min', base_vel_min)
                print('concat', np.concatenate([base_vel_mean[None],base_vel_max[None],base_vel_min[None]], axis=0))
                print("")
                # save the base velocity info to a file
                prefix = args.new_hdf5_path.split('/')[-1].split('.')[0]
                np.concatenate([base_vel_mean[None],base_vel_max[None],base_vel_min[None]], axis=0)
        f_new.close()
    
        print('data processed and merged successfully')
        print('total number of processed demos', len(manip_portion_list))
        print('manip portion', 'mean', np.mean(manip_portion_list), 'std', np.std(manip_portion_list), 'max', np.max(manip_portion_list), 'min', np.min(manip_portion_list))
        print('replay portion', 'mean', np.mean(replay_portion_list), 'std', np.std(replay_portion_list), 'max', np.max(replay_portion_list), 'min', np.min(replay_portion_list))
        print('nav portion', 'mean', np.mean(nav_portion_list), 'std', np.std(nav_portion_list), 'max', np.max(nav_portion_list), 'min', np.min(nav_portion_list))

        print('base_vel_info')
        print('base_vel_mean', base_vel_mean)
        print('base_vel_max', base_vel_max)
        print('base_vel_min', base_vel_min)

    data_clean_sign = True
    if data_clean_sign:
        # start data cleaning
        args.source_file = args.new_hdf5_path
        data_clean(args)






if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    # add file path argument
    parser.add_argument("--file_folder", 
                        type=str, 
                        help="hdf5 file folder")
    
    # add output file path argument
    parser.add_argument("--output_folder", 
                        type=str, 
                        help="output hdf5 file path")
    
    parser.add_argument("--new_hdf5_path",
                        type=str,
                        help="new hdf5 file path")
    
    parser.add_argument("--loop_over_files",
                        action='store_true', 
                        help="whether need to loop over files")
    parser.add_argument("--merge_sign",
                        action='store_true', 
                        help="whether need to merge the files")
    parser.add_argument("--new_hdf5_name",
                        type=str, 
                        default="merged_brs_data.hdf5", 
                        help="new hdf5 file path")

    parser.add_argument("--traj_vis",
                        action='store_true', 
                        help="whether need to visualize the trajectory")
    
    parser.add_argument("--save_init_pose",
                        action='store_true', 
                        help="whether need to save the init pose")

    parser.add_argument('--source_file', 
                        type=str, 
                        default="",
                        help='source file to be processed'
                        )
    parser.add_argument('--target_file',
                        type=str, 
                        default="",
                        help='target file to be processed'
                        )
    parser.add_argument('--window_size',
                        type=int, 
                        default=5,
                        help='window size for downsampling'
                        )
    parser.add_argument('--clip_traj',
                        action='store_true',
                        help='clip the trajectory based on cup pose'
                        )
    parser.add_argument("--downsample",
                        action='store_true',
                        help="downsample the trajectory"
                        )
    parser.add_argument("--downsample_window_size",
                        type=int,
                        default=3,
                        help="window size for downsampling"
                        )
    parser.add_argument("--limit_episode_length",
                        type=int,
                        default=650,
                        help="limit the episode length"
                        )
    parser.add_argument("--worker_id",
                        type=int,
                        default=0,
                        help="worker id for the file folder"
                        )

    args = parser.parse_args()

    # main(args)
    main_cluster(args)
