# train.py
# Script to train policies in Isaac Gym
#
# Copyright (c) 2018-2023, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
import os
# from datetime import datetime

# noinspection PyUnresolvedReferences
import isaacgym
import sys
sys.path.append('.')
sys.path.append('..')

# import hydra
# from isaacgymenvs.learning import calm_agent, calm_models, calm_network_builder, calm_players
# from isaacgymenvs.learning import encamp_network_builder, encamp_agent
# from isaacgymenvs.utils.rlgames_utils import multi_gpu_get_rank
# from isaacgymenvs.pbt.pbt import PbtAlgoObserver, initial_pbt_check
# from omegaconf import DictConfig, OmegaConf
# from hydra.utils import to_absolute_path
# from isaacgymenvs.tasks import isaacgym_task_map
# import gym

# from isaacgymenvs.utils.reformat import omegaconf_to_dict, print_dict
# from isaacgymenvs.utils.utils import set_np_formatting, set_seed

# import torch
# import numpy as np
# import random

from omegaconf import open_dict


from multiprocessing import Pool
from multiprocessing import Process
import argparse
import numpy as np

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    
    ##### pool settings ####
    parser.add_argument("--launch_type", type=str, default='trajectory')
    parser.add_argument("--tracking_data_sv_root", type=str, default='')
    parser.add_argument("--subj_nm", type=str, default='')
    parser.add_argument("--debug", action='store_true', default=False)
    parser.add_argument("--num_frames", type=int, default=150, help="number of vector envs")
    parser.add_argument("--base_dir", type=str, default='', help="Mocap save info file")
   
   
    ##### experiment settings #####
    # parser.add_argument("--additional_tag",  type=str, default='', help="path to the optimized qtars")
    parser.add_argument("--hand_type",  type=str, default='allegro', help="path to the optimized qtars")
    
    
    ##### isaacgym settings #####
    parser.add_argument("--numEnvs", type=int, default=8000)
    parser.add_argument("--minibatch_size", type=int, default=8000)
    parser.add_argument("--use_relative_control", type=str2bool,  default=False)
    parser.add_argument("--goal_cond", type=str2bool, default=False)
    # parser.add_argument("--object_name", type=str, default='')
    parser.add_argument("--obs_type", type=str, default='pure_state_wref_wdelta')
    parser.add_argument("--rigid_obj_density", type=float, default=500)
    parser.add_argument("--glb_trans_vel_scale", type=float, default=0.1)
    parser.add_argument("--glb_rot_vel_scale", type=float, default=0.1)
    # export additiona_tag="kinebais_wdelta_rewhandpos_dist_"
    parser.add_argument("--additional_tag", type=str, default='kinebais_wdelta_rewhandpos_dist_')
    parser.add_argument("--dt", type=float, default=0.0166)
    parser.add_argument("--test", type=str2bool,  default=False)
    parser.add_argument("--use_kinematics_bias", type=str2bool,  default=True)  
    parser.add_argument("--w_obj_ornt",  type=str2bool,   default=False)  
    parser.add_argument("--separate_stages",  type=str2bool,  default=False)  
    parser.add_argument("--kinematics_only",  type=str2bool,   default=False) 
    parser.add_argument("--use_fingertips",  type=str2bool,   default=True) 
    parser.add_argument("--use_kinematics_bias_wdelta",  type=str2bool,  default=True)  
    parser.add_argument("--hand_pose_guidance_glb_trans_coef", type=float, default=0.6)
    parser.add_argument("--hand_pose_guidance_glb_rot_coef", type=float, default=0.1)
    parser.add_argument("--hand_pose_guidance_fingerpose_coef", type=float, default=0.1)
    parser.add_argument("--rew_finger_obj_dist_coef", type=float, default=0.5)
    parser.add_argument("--rew_delta_hand_pose_coef", type=float, default=0.5)
    parser.add_argument("--nn_gpus", type=int, default=8)
    parser.add_argument("--st_idx", type=int, default=0)
    parser.add_argument("--dofSpeedScale", type=float, default=20)
    # use_twostage_rew
    parser.add_argument("--use_twostage_rew", type=str2bool,  default=False)
    parser.add_argument("--dataset_type", type=str, default='grab')
    
    ### sim steup ###
    parser.add_argument("--ground_distance", type=float, default=0.0)
    parser.add_argument("--use_canonical_state",  type=str2bool,   default=False) 
    parser.add_argument("--disable_gravity",  type=str2bool,   default=False) 
    parser.add_argument("--data_inst_flag", type=str, default='')
    # right_hand_dist_thres #
    parser.add_argument("--right_hand_dist_thres", type=float, default=0.12)
    parser.add_argument("--checkpoint", type=str, default='')
    parser.add_argument("--max_epochs", type=int, default=1000)
    # use_real_twostage_rew # use_real_twostage_rew # # two stage rewards # #
    parser.add_argument("--use_real_twostage_rew",  type=str2bool,   default=False) 
    # start_grasping_fr
    parser.add_argument("--start_grasping_fr",  type=str2bool,   default=False) 
    # 
    # controlFrequencyInv
    parser.add_argument("--controlFrequencyInv", type=int, default=1)
    parser.add_argument("--use_interpolated_data",  type=str2bool,   default=False) 
    # episodeLength
    parser.add_argument("--episodeLength", type=int, default=1000)
    # start_frame
    parser.add_argument("--start_frame", type=int, default=0)
    # rew_obj_pose_coef
    parser.add_argument("--rew_obj_pose_coef", type=float, default=1.0)
    # goal_dist_thres
    parser.add_argument("--goal_dist_thres", type=float, default=0.0)
    # lifting_separate_stages
    parser.add_argument("--lifting_separate_stages",  type=str2bool,   default=False) 
    
    # strict_lifting_separate_stages
    parser.add_argument("--strict_lifting_separate_stages",  type=str2bool,   default=False) 
    parser.add_argument("--add_table",  type=str2bool,   default=False) 
    # table_z_dim
    parser.add_argument("--table_z_dim", type=float, default=0.5)
    parser.add_argument("--headless", type=str2bool,   default=True)
    parser.add_argument("--target_object_name", type=str, default='')
    parser.add_argument("--target_mocap_sv_fn", type=str, default='')
    # use_taco_obj_traj
    parser.add_argument("--use_taco_obj_traj", type=str2bool,   default=True)
    # pre_optimized_traj
    parser.add_argument("--pre_optimized_traj", type=str, default='')
    # pre optimized traj 
    # pre_load_trajectories, obj_type_to_pre_optimized_traj
    ### TODO: add pre_load_trajectories, obj_type_to_pre_optimized_traj ###
    parser.add_argument("--pre_load_trajectories", type=str2bool,   default=False)
    parser.add_argument("--obj_type_to_pre_optimized_traj", type=str, default='')
    parser.add_argument("--subj_idx", type=int, default=0)
    # parser.add_argument("--hand_type", type=str, default='allegro')
    # use_vision
    parser.add_argument("--use_vision", type=str2bool,   default=False)
    parser.add_argument("--use_dagger", type=str2bool,   default=False)
    parser.add_argument("--use_generalist_policy", type=str2bool,   default=False)
    # 
    parser.add_argument("--use_hand_actions_rew", type=str2bool,   default=True)
    # supervised_training
    parser.add_argument("--supervised_training", type=str2bool,   default=False)
    # checkpoint
    parser.add_argument("--test_inst_tag", type=str, default='')
    parser.add_argument("--test_optimized_res", type=str, default='')
    parser.add_argument("--training_mode", type=str, default='regular')
#     export preload_experiences_tf=True
# export preload_experiences_path=''
    # preload_experiences_tf, preload_experiences_path
    parser.add_argument("--preload_experiences_tf", type=str2bool,   default=False)
    parser.add_argument("--preload_experiences_path", type=str,   default=None)
    # single_instance_training
    parser.add_argument("--single_instance_training", type=str2bool,   default=False)
    # parser.add_argument("--checkpoint", type=str, default='') # 
    parser.add_argument("--generalist_tune_all_instnaces", type=str2bool,   default=False)
    # sampleds_with_object_code_fn
    parser.add_argument("--sampleds_with_object_code_fn", type=str, default='')
    # log_path
    parser.add_argument("--log_path", type=str, default='./runs')
    # grab_inst_tag_to_optimized_res_fn: '/root/diffsim/softzoo/softzoo/diffusion/assets/data_inst_tag_to_optimized_res.npy'
    # grab_inst_tag_to_optimized_res_fn, taco_inst_tag_to_optimized_res_fn
    parser.add_argument("--grab_inst_tag_to_optimized_res_fn", type=str, default='/root/diffsim/softzoo/softzoo/diffusion/assets/data_inst_tag_to_optimized_res.npy')
    parser.add_argument("--taco_inst_tag_to_optimized_res_fn", type=str, default='')
    # single_instance_tag # obj_type_to_optimized_res_fn #
    parser.add_argument("--single_instance_tag", type=str, default='')
    parser.add_argument("--obj_type_to_optimized_res_fn", type=str, default='')
    # supervised_loss_coef
    parser.add_argument("--supervised_loss_coef", type=float, default=0.0005)
    # pure_supervised_training
    parser.add_argument("--pure_supervised_training", type=str2bool,   default=False)
    # inst_tag_to_latent_feature_fn
    parser.add_argument("--inst_tag_to_latent_feature_fn", type=str, default='')
    
    
    parser.add_argument("--object_type_to_latent_feature_fn", type=str, default='./uni_manip/tds_diffusion_exp/allegro_tracking_kine_diff_AE_Diff_trainAE_vnew_/obj_type_to_obj_feat.npy')
    # --grab_obj_type_to_opt_res_fn=${grab_obj_type_to_opt_res_fn} --taco_obj_type_to_opt_res_fn=${taco_obj_type_to_opt_res_fn} 
    parser.add_argument("--grab_obj_type_to_opt_res_fn", type=str, default='')
    parser.add_argument("--taco_obj_type_to_opt_res_fn", type=str, default='')
    parser.add_argument("--maxx_inst_nn", type=int, default=10000)
    # tracking_save_info_fn, tracking_info_st_tag
    parser.add_argument("--tracking_save_info_fn", type=str, default="./data/GRAB_Tracking_PK_reduced/data")
    parser.add_argument("--tracking_info_st_tag", type=str, default='passive_active_info_')
    parser.add_argument("--only_training_on_succ_samples", type=str2bool,   default=False)
    parser.add_argument("--exclude_inst_tag_to_opt_res_fn", type=str, default='')
    parser.add_argument("--rew_filter", type=str2bool,   default=False)
    parser.add_argument("--rew_low_threshold", type=float, default=0.0)
    # use_teacher_model
    parser.add_argument("--use_teacher_model", type=str2bool,   default=False)
    # use_strict_maxx_nn_ts
    parser.add_argument("--use_strict_maxx_nn_ts", type=str2bool,   default=False)
    # taco_interped_data_sv_additional_tag
    parser.add_argument("--taco_interped_data_sv_additional_tag", type=str, default='')
    # strict_maxx_nn_ts
    parser.add_argument("--strict_maxx_nn_ts", type=int, default=150)
    # parser.add_argument("--inst_tag_to_latent_feature_fn", type=str, default='')
    parser.add_argument("--grab_train_test_setting", type=str2bool,   default=False)
    # use_local_canonical_state
    parser.add_argument("--use_local_canonical_state", type=str2bool,   default=False)
    # bound_loss_coef
    parser.add_argument("--bound_loss_coef", type=float, default=0.0001)
    # rew_grab_thres, rew_taco_thres
    parser.add_argument("--rew_grab_thres", type=float, default=50.0)
    parser.add_argument("--rew_taco_thres", type=float, default=200.0)
    # rew_smoothness_coef
    parser.add_argument("--rew_smoothness_coef", type=float, default=0.0)
    # obj_type_to_base_traj_fn
    parser.add_argument("--obj_type_to_base_traj_fn", type=str, default='')
    parser.add_argument("--use_base_traj", type=str2bool,   default=False)

    parser.add_argument("--rew_thres_with_selected_insts", type=str2bool,   default=False)
    parser.add_argument("--selected_inst_idxes_dict", type=str, default='')
    # customize_damping
    parser.add_argument("--customize_damping", type=str2bool,   default=False)
    # customize_global_damping
    parser.add_argument("--customize_global_damping", type=str2bool,   default=False)
    parser.add_argument("--train_on_all_trajs", type=str2bool,   default=False)
    parser.add_argument("--eval_split_trajs", type=str2bool,   default=False)
    # single_instance_state_based_train
    parser.add_argument("--single_instance_state_based_train", type=str2bool,   default=False)
    # data_selection_ratio
    parser.add_argument("--data_selection_ratio", type=float,   default=1.0)
    # test_taco_tag = 'taco_20231024_'
    # export test_taco_tag='taco_20231024_
    parser.add_argument("--test_taco_tag", type=str, default='taco_20231024_')
    # wo_vel_obs
    parser.add_argument("--wo_vel_obs", type=str2bool,   default=False)

    
    args = parser.parse_args()
    
    if len(args.subj_nm) > 0:
        args.subj_idx = int(args.subj_nm[1:])
    
    if len(args.obj_type_to_optimized_res_fn) > 0 and os.path.exists(args.obj_type_to_optimized_res_fn):
        obj_type_to_optimized_res = np.load(args.obj_type_to_optimized_res_fn, allow_pickle=True).item()
        pure_obj_type_to_optimized_res = {}
        for key in obj_type_to_optimized_res:
            if isinstance(key, tuple):
                pure_obj_type_to_optimized_res[key[0]] = obj_type_to_optimized_res[key]
            else:
                pure_obj_type_to_optimized_res[key] = obj_type_to_optimized_res[key]
        # print(f"obj_type_to_optimized_res: {pure_obj_type_to_optimized_res.keys()}")
    else:
        pure_obj_type_to_optimized_res = None 
    
    if len(args.exclude_inst_tag_to_opt_res_fn) > 0 and os.path.exists(args.exclude_inst_tag_to_opt_res_fn):
        exclude_inst_tag_to_opt_res_raw = np.load(args.exclude_inst_tag_to_opt_res_fn, allow_pickle=True).item()
        exclude_inst_tag_to_opt_res = {}
        for key in exclude_inst_tag_to_opt_res_raw:
            if isinstance(key, tuple):
                exclude_inst_tag_to_opt_res[key[0]] = exclude_inst_tag_to_opt_res_raw[key]
            else:
                exclude_inst_tag_to_opt_res[key] = exclude_inst_tag_to_opt_res_raw[key]
        print(f"exclude_inst_tag_to_opt_res: {exclude_inst_tag_to_opt_res.keys()}")
    else:
        exclude_inst_tag_to_opt_res = None    
        
            

    def launch_one_process(cur_grab_data_tag, traj_grab_data_tag, cuda_idx, pre_optimized_traj=None):
        print(f"pre_optimized_traj: {pre_optimized_traj}")
        obs_type = args.obs_type
        # use_small_sigmas = args.use_small_sigmas
        # finger_urdf_template = args.finger_urdf_template
        # finger_near_palm_joint_idx = args.finger_near_palm_joint_idx
        # constraint_level = args.constraint_level
        # object_type = cur_grab_data_tag # cur_grab_dta_tag #
        object_name = cur_grab_data_tag
        # task_type = "mocap_tracking"
        if args.hand_type == 'allegro':
            if args.dataset_type == 'grab':
                mocap_sv_info_fn = f"{args.tracking_data_sv_root}/passive_active_info_{traj_grab_data_tag}.npy"
            elif args.dataset_type == 'taco':

                if len(args.taco_interped_data_sv_additional_tag) > 0:
                    mocap_sv_info_fn = f"passive_active_info_ori_grab_s2_phone_call_1_interped_{traj_grab_data_tag}_v2_{args.taco_interped_data_sv_additional_tag}.npy"
                else:
                    mocap_sv_info_fn = f"passive_active_info_ori_grab_s2_phone_call_1_interped_{traj_grab_data_tag}_v2.npy"
                # mocap_sv_info_fn = f"passive_active_info_ori_grab_s2_phone_call_1_interped_{traj_grab_data_tag}_v2.npy"
                mocap_sv_info_fn = os.path.join(args.tracking_data_sv_root, mocap_sv_info_fn)
            
            else:
                raise ValueError
        
        elif args.hand_type == 'leap':
            mocap_sv_info_fn = f"./data/TACO_Tracking_PK_LEAP/data/passive_active_info_ori_grab_s2_phone_call_1_interped_{traj_grab_data_tag}_v2_interpfr_60_interpfr2_60_nntrans_40.npy"
        else:
            raise ValueError
        # #
        
        print(f"mocap_sv_info_fn: {mocap_sv_info_fn}")
        
        checkpoint = ''
        tag = f"tracking_{object_name}"
        
        if args.launch_type == 'trajectory':
            if args.hand_type == 'allegro': ## modify the traing name
                train_name = f"tracking_{object_name}_obs_{obs_type}_density_{args.rigid_obj_density}_trans_{args.glb_trans_vel_scale}_rot_{args.glb_rot_vel_scale}_goalcond_{args.goal_cond}_{args.additional_tag}"
            elif args.hand_type == 'leap':
                train_name = f"tracking_{object_name}_{args.hand_type}_obs_{obs_type}_density_{args.rigid_obj_density}_trans_{args.glb_trans_vel_scale}_rot_{args.glb_rot_vel_scale}_goalcond_{args.goal_cond}_{args.additional_tag}"
            else:
                raise ValueError
        elif args.launch_type == 'object_type':
            if args.hand_type == 'allegro':
                train_name = f"tracking_{object_name}_traj_{traj_grab_data_tag}_obs_{obs_type}_density_{args.rigid_obj_density}_trans_{args.glb_trans_vel_scale}_rot_{args.glb_rot_vel_scale}_goalcond_{args.goal_cond}_{args.additional_tag}"
            elif args.hand_type == 'leap':
                train_name = f"tracking_{object_name}_traj_{traj_grab_data_tag}_{args.hand_type}_obs_{obs_type}_density_{args.rigid_obj_density}_trans_{args.glb_trans_vel_scale}_rot_{args.glb_rot_vel_scale}_goalcond_{args.goal_cond}_{args.additional_tag}"
            else:
                raise ValueError
        else:
            raise ValueError
        
        
        full_experiment_name = train_name
        
        if args.headless:
            capture_video = False
            force_render = False
        else:
            capture_video = True
            force_render = True
        
        # 
        # use_small_sigmas = "--use_small_sigmas" if args.use_small_sigmas else ""
        # use_relaxed_model = "--use_relaxed_model" if args.use_relaxed_model else ""
        # w_hand_table_collision = "--w_hand_table_collision" if args.w_hand_table_collision else ""
        
        
        print(f"test: {args.test}")
        
        if args.headless:
            cuda_visible_text = f"CUDA_VISIBLE_DEVICES={cuda_idx} "
        else:
            cuda_visible_text = ''
            
        exp_dir = './exp/IsaacGymEnvs/isaacgymenvs'
        if not os.path.exists(exp_dir):
            exp_dir = '.'
        
        if len(args.target_object_name) > 0:
            object_name=args.target_object_name
        
        if len(args.target_mocap_sv_fn) > 0:
            mocap_sv_info_fn = args.target_mocap_sv_fn
            
        if pre_optimized_traj is not None and len(pre_optimized_traj) > 0:
            cur_pre_optimized_traj = pre_optimized_traj
        else:
            cur_pre_optimized_traj = args.pre_optimized_traj
            
        if args.use_vision:
            task_type = "AllegroHandTrackingVision"
            train_type = "HumanoidPPOVision"
            
            enableCameraSensors = True
            if args.use_dagger:
                # /root/diffsim/IsaacGymEnvs2/isaacgymenvs/cfg/train/HumanoidPPOSupervised.yaml
                task_type = "AllegroHandTrackingVision"
                train_type = "HumanoidPPOVisionDAgger"
            print(f"task_type: {task_type}, train_type: {train_type}")
        else:
            # 
            if args.use_generalist_policy:
                task_type = "AllegroHandTrackingGeneralist"
                train_type = "HumanoidPPO"
                # test_inst_tag, test_optimized_res
                if args.supervised_training:
                    task_type = "AllegroHandTrackingGeneralist"
                    train_type = "HumanoidPPOSupervised"
                    
                    if args.single_instance_state_based_train:
                        train_type = "HumanoidPPOSupervisedSN"
                        print(f"using SN")
                    
                    training_mode_config = f"train.params.config.training_mode={args.training_mode}"
                    test_inst_config = f"task.env.test_inst_tag={args.test_inst_tag} task.env.test_optimized_res={args.test_optimized_res}"
                    # # preload_experiences_tf, preload_experiences_path
                    preload_experience_config = f"train.params.config.preload_experiences_tf={args.preload_experiences_tf} train.params.config.preload_experiences_path={args.preload_experiences_path}"
                    single_instance_training_config = f"train.params.config.single_instance_training={args.single_instance_training}"
                    
                    # if len(args.sampleds_with_object_code_fn) == 0 and pre_optimized_traj is not None and len(pre_optimized_traj) > 0 and os.path.exists(pre_optimized_traj):
                    #     sampleds_with_object_code_fn_config = f"task.env.sampleds_with_object_code_fn={pre_optimized_traj}"
                    # else:
                    #     sampleds_with_object_code_fn_config = f"task.env.sampleds_with_object_code_fn={args.sampleds_with_object_code_fn}"
                    sampleds_with_object_code_fn_config = f"task.env.sampleds_with_object_code_fn={args.sampleds_with_object_code_fn}"
                    
                    # # ori grab s2 apple lift #
                    if args.generalist_tune_all_instnaces:
                        test_inst_config = f"task.env.test_inst_tag={cur_grab_data_tag} task.env.test_optimized_res={pre_optimized_traj}"
                        single_instance_training_config = f"train.params.config.single_instance_training={True}"
                        preload_experience_config = f"train.params.config.preload_experiences_tf={False} train.params.config.preload_experiences_path={''}"

                    # log_path
                    log_path_config = f"train.params.config.log_path={args.log_path}"
                    train_dir_config = f"train.params.config.train_dir={args.log_path}"
                    single_instance_tag_config = f"train.params.config.single_instance_tag={args.single_instance_tag}"
                    obj_type_to_optimized_res_fn_config = f"train.params.config.obj_type_to_optimized_res_fn={args.obj_type_to_optimized_res_fn}"
                    supervised_loss_coef_config = f"train.params.config.supervised_loss_coef={args.supervised_loss_coef}"
                    pure_supervised_training_config = f"train.params.config.pure_supervised_training={args.pure_supervised_training}"
                    inst_tag_to_latent_feature_fn_config = f"task.env.inst_tag_to_latent_feature_fn={args.inst_tag_to_latent_feature_fn}"
                    object_type_to_latent_feature_fn_config = f"task.env.object_type_to_latent_feature_fn={args.object_type_to_latent_feature_fn}"
                    # --grab_obj_type_to_opt_res_fn=${grab_obj_type_to_opt_res_fn} --taco_obj_type_to_opt_res_fn=${taco_obj_type_to_opt_res_fn} 
                    obj_type_to_opt_res_config = f"task.env.grab_obj_type_to_opt_res_fn={args.grab_obj_type_to_opt_res_fn} task.env.taco_obj_type_to_opt_res_fn={args.taco_obj_type_to_opt_res_fn} train.params.config.grab_obj_type_to_opt_res_fn={args.grab_obj_type_to_opt_res_fn} train.params.config.taco_obj_type_to_opt_res_fn={args.taco_obj_type_to_opt_res_fn}"
                    use_teacher_model_config = f"train.params.config.use_teacher_model={args.use_teacher_model}"
                    bound_loss_coef_config = f"train.params.config.bounds_loss_coef={args.bound_loss_coef}"
                    data_selection_ratio_config = f"task.env.data_selection_ratio={args.data_selection_ratio}"
                    
                else:
                    training_mode_config  = ""
                    test_inst_config = ""
                    single_instance_training_config = ""
                    sampleds_with_object_code_fn_config = ""
                    log_path_config = ""
                    single_instance_tag_config = ""
                    obj_type_to_optimized_res_fn_config = ""
                    supervised_loss_coef_config = ""
                    pure_supervised_training_config = ""
                    inst_tag_to_latent_feature_fn_config = ""
                    object_type_to_latent_feature_fn_config = ""
                    obj_type_to_opt_res_config = ""
                    use_teacher_model_config = ""
                    bound_loss_coef_config = ""
                    data_selection_ratio_config = ""
                
                # grab_opt_res_config  taco_opt_res_config
                grab_opt_res_config = f"task.env.grab_inst_tag_to_optimized_res_fn={args.grab_inst_tag_to_optimized_res_fn}"
                taco_opt_res_config = f"task.env.taco_inst_tag_to_optimized_res_fn={args.taco_inst_tag_to_optimized_res_fn}"
                maxx_inst_nn_config= f"task.env.maxx_inst_nn={args.maxx_inst_nn}"
                # tracking_save_info_fn, tracking_info_st_tag
                tracking_folder_info_config = f"task.env.tracking_save_info_fn={args.tracking_save_info_fn} task.env.tracking_info_st_tag={args.tracking_info_st_tag}"
                # obj_type_to_opt_res_config
                # task.env.taco_inst_tag_to_optimized_res_fn=${taco_inst_tag_to_optimized_res_fn}
                only_training_on_succ_samples_config = f"task.env.only_training_on_succ_samples={args.only_training_on_succ_samples}"
                
                use_strict_maxx_nn_ts_config = f"task.env.use_strict_maxx_nn_ts={args.use_strict_maxx_nn_ts}"
                taco_interped_data_sv_additional_tag_config = f"task.env.taco_interped_data_sv_additional_tag={args.taco_interped_data_sv_additional_tag}"
                strict_maxx_nn_ts_config=  f"task.env.strict_maxx_nn_ts={args.strict_maxx_nn_ts}"
                grab_train_test_setting_config = f"task.env.grab_train_test_setting={args.grab_train_test_setting}"
                use_local_canonical_state_config = f"task.env.use_local_canonical_state={args.use_local_canonical_state}"
                 # rew_grab_thres, rew_taco_thres
                rew_thres_config = f"task.env.rew_grab_thres={args.rew_grab_thres} task.env.rew_taco_thres={args.rew_taco_thres}"
                rew_smoothness_coef_config = f"task.env.rew_smoothness_coef={args.rew_smoothness_coef}"
                obj_type_to_base_traj_fn_config = f"task.env.obj_type_to_base_traj_fn={args.obj_type_to_base_traj_fn}"
                use_base_traj_config = f"task.env.use_base_traj={args.use_base_traj}"
                # rew_thres_with_selected_insts, selected_inst_idxes_dict
                rew_thres_with_selected_insts_config = f"task.env.rew_thres_with_selected_insts={args.rew_thres_with_selected_insts}"
                selected_inst_idxes_dict_config = f"task.env.selected_inst_idxes_dict={args.selected_inst_idxes_dict}"
                customize_damping_config = f"task.env.customize_damping={args.customize_damping}"
                customize_global_damping_config = f"task.env.customize_global_damping={args.customize_global_damping}"
                train_on_all_trajs_config = f"task.env.train_on_all_trajs={args.train_on_all_trajs}"
                single_instance_state_based_train_config = f"task.env.single_instance_state_based_train={args.single_instance_state_based_train}"
                wo_vel_obs_config  = f"task.env.wo_vel_obs=${args.wo_vel_obs}"
                
            else:
                task_type = "AllegroHandTracking"
                train_type = "HumanoidPPO"
                training_mode_config = ""
                test_inst_config = ""

                single_instance_training_config = ""
                sampleds_with_object_code_fn_config = ""
                log_path_config = ""
                
                grab_opt_res_config = f""
                taco_opt_res_config = f""
                
                single_instance_tag_config = ""
                obj_type_to_optimized_res_fn_config = "" #
                
                supervised_loss_coef_config = ""
                pure_supervised_training_config = ""

                inst_tag_to_latent_feature_fn_config = ""
                object_type_to_latent_feature_fn_config = ""
                obj_type_to_opt_res_config = ""
                
                maxx_inst_nn_config = ""
                tracking_folder_info_config= ""
                
                only_training_on_succ_samples_config = ""
                use_teacher_model_config = ""
                
                use_strict_maxx_nn_ts_config = ""
                taco_interped_data_sv_additional_tag_config = ""
                strict_maxx_nn_ts_config = ""
                grab_train_test_setting_config= ""
                
                use_local_canonical_state_config = ""
                bound_loss_coef_config = ""
                rew_thres_config = ""
                rew_smoothness_coef_config = ""
                obj_type_to_base_traj_fn_config = ""
                use_base_traj_config = ""
                
                rew_thres_with_selected_insts_config = ""
                selected_inst_idxes_dict_config = ""
                customize_damping_config = ""
                customize_global_damping_config = ""
                train_on_all_trajs_config = ""
                single_instance_state_based_train_config = ""
                data_selection_ratio_config = ""
                wo_vel_obs_config = ""

            enableCameraSensors = False
        # if args.use_generalist_policy: #
        # grab_inst_tag_to_optimized_res_fn, taco_inst_tag_to_optimized_res_fn
        
        if args.use_vision:
            cuda_visible_text = f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  " 
            #
            cmd = f"{cuda_visible_text} python train.py task={task_type} train={train_type} sim_device='cuda:{cuda_idx}' rl_device='cuda:{cuda_idx}'  capture_video={capture_video} force_render={force_render} headless={args.headless}   task.env.numEnvs={args.numEnvs} train.params.config.minibatch_size={args.minibatch_size}  task.env.useRelativeControl={args.use_relative_control}  train.params.config.max_epochs={args.max_epochs} task.env.mocap_sv_info_fn={mocap_sv_info_fn} task.env.goal_cond={args.goal_cond} task.env.object_name={object_name} tag={tag} train.params.config.name={train_name} train.params.config.full_experiment_name={full_experiment_name} task.sim.dt={args.dt} test={args.test} task.env.use_kinematics_bias={args.use_kinematics_bias} task.env.w_obj_ornt={args.w_obj_ornt} task.env.observationType={obs_type}  task.env.separate_stages={args.separate_stages} task.env.rigid_obj_density={args.rigid_obj_density}   task.env.kinematics_only={args.kinematics_only}  task.env.use_fingertips={args.use_fingertips}  task.env.glb_trans_vel_scale={args.glb_trans_vel_scale} task.env.glb_rot_vel_scale={args.glb_rot_vel_scale} task.env.use_kinematics_bias_wdelta={args.use_kinematics_bias_wdelta} task.env.hand_pose_guidance_glb_trans_coef={args.hand_pose_guidance_glb_trans_coef} task.env.hand_pose_guidance_glb_rot_coef={args.hand_pose_guidance_glb_rot_coef} task.env.hand_pose_guidance_fingerpose_coef={args.hand_pose_guidance_fingerpose_coef} task.env.rew_finger_obj_dist_coef={args.rew_finger_obj_dist_coef} task.env.rew_delta_hand_pose_coef={args.rew_delta_hand_pose_coef} task.env.dofSpeedScale={args.dofSpeedScale} task.env.use_twostage_rew={args.use_twostage_rew} task.env.ground_distance={args.ground_distance} task.env.use_canonical_state={args.use_canonical_state} task.env.disable_obj_gravity={args.disable_gravity} train.params.config.save_best_after=50 task.env.right_hand_dist_thres={args.right_hand_dist_thres} checkpoint={args.checkpoint} task.env.use_real_twostage_rew={args.use_real_twostage_rew} task.env.start_grasping_fr={args.start_grasping_fr} task.env.controlFrequencyInv={args.controlFrequencyInv} task.env.episodeLength={args.episodeLength} task.env.start_frame={args.start_frame} task.env.rew_obj_pose_coef={args.rew_obj_pose_coef} task.env.goal_dist_thres={args.goal_dist_thres} task.env.lifting_separate_stages={args.lifting_separate_stages} task.env.strict_lifting_separate_stages={args.strict_lifting_separate_stages} task.env.table_z_dim={args.table_z_dim} task.env.add_table={args.add_table} exp_dir={exp_dir} task.env.use_taco_obj_traj={args.use_taco_obj_traj} task.env.pre_optimized_traj={ cur_pre_optimized_traj } task.env.hand_type={ args.hand_type } enableCameraSensors={enableCameraSensors} graphics_device_id={cuda_idx}"
        else:
            cmd = f"{cuda_visible_text} python train.py task={task_type} train={train_type} sim_device='cuda:0' rl_device='cuda:0'  capture_video={capture_video} force_render={force_render} headless={args.headless}   task.env.numEnvs={args.numEnvs} train.params.config.minibatch_size={args.minibatch_size}  task.env.useRelativeControl={args.use_relative_control}  train.params.config.max_epochs={args.max_epochs} task.env.mocap_sv_info_fn={mocap_sv_info_fn} task.env.goal_cond={args.goal_cond} task.env.object_name={object_name} tag={tag} train.params.config.name={train_name} train.params.config.full_experiment_name={full_experiment_name} task.sim.dt={args.dt} test={args.test} task.env.use_kinematics_bias={args.use_kinematics_bias} task.env.w_obj_ornt={args.w_obj_ornt} task.env.observationType={obs_type}  task.env.separate_stages={args.separate_stages} task.env.rigid_obj_density={args.rigid_obj_density}   task.env.kinematics_only={args.kinematics_only}  task.env.use_fingertips={args.use_fingertips}  task.env.glb_trans_vel_scale={args.glb_trans_vel_scale} task.env.glb_rot_vel_scale={args.glb_rot_vel_scale} task.env.use_kinematics_bias_wdelta={args.use_kinematics_bias_wdelta} task.env.hand_pose_guidance_glb_trans_coef={args.hand_pose_guidance_glb_trans_coef} task.env.hand_pose_guidance_glb_rot_coef={args.hand_pose_guidance_glb_rot_coef} task.env.hand_pose_guidance_fingerpose_coef={args.hand_pose_guidance_fingerpose_coef} task.env.rew_finger_obj_dist_coef={args.rew_finger_obj_dist_coef} task.env.rew_delta_hand_pose_coef={args.rew_delta_hand_pose_coef} task.env.dofSpeedScale={args.dofSpeedScale} task.env.use_twostage_rew={args.use_twostage_rew} task.env.ground_distance={args.ground_distance} task.env.use_canonical_state={args.use_canonical_state} task.env.disable_obj_gravity={args.disable_gravity} train.params.config.save_best_after=50 task.env.right_hand_dist_thres={args.right_hand_dist_thres} checkpoint={args.checkpoint} task.env.use_real_twostage_rew={args.use_real_twostage_rew} task.env.start_grasping_fr={args.start_grasping_fr} task.env.controlFrequencyInv={args.controlFrequencyInv} task.env.episodeLength={args.episodeLength} task.env.start_frame={args.start_frame} task.env.rew_obj_pose_coef={args.rew_obj_pose_coef} task.env.goal_dist_thres={args.goal_dist_thres} task.env.lifting_separate_stages={args.lifting_separate_stages} task.env.strict_lifting_separate_stages={args.strict_lifting_separate_stages} task.env.table_z_dim={args.table_z_dim} task.env.add_table={args.add_table} exp_dir={exp_dir} task.env.use_taco_obj_traj={args.use_taco_obj_traj} task.env.pre_optimized_traj={ cur_pre_optimized_traj } task.env.hand_type={ args.hand_type } enableCameraSensors={enableCameraSensors} graphics_device_id={cuda_idx} task.env.use_hand_actions_rew={args.use_hand_actions_rew} task.env.supervised_training={args.supervised_training} {training_mode_config} {test_inst_config} {preload_experience_config} {single_instance_training_config} {sampleds_with_object_code_fn_config} {log_path_config} {train_dir_config} {grab_opt_res_config}  {taco_opt_res_config} {single_instance_tag_config} {obj_type_to_optimized_res_fn_config} {supervised_loss_coef_config} {pure_supervised_training_config} {inst_tag_to_latent_feature_fn_config} {object_type_to_latent_feature_fn_config} {obj_type_to_opt_res_config} {maxx_inst_nn_config} {tracking_folder_info_config} {only_training_on_succ_samples_config} {use_teacher_model_config} {use_strict_maxx_nn_ts_config} {taco_interped_data_sv_additional_tag_config} {strict_maxx_nn_ts_config} {grab_train_test_setting_config} {use_local_canonical_state_config} {bound_loss_coef_config} {rew_thres_config} {rew_smoothness_coef_config} {obj_type_to_base_traj_fn_config} {use_base_traj_config} {rew_thres_with_selected_insts_config} {selected_inst_idxes_dict_config} {customize_damping_config} {customize_global_damping_config} {train_on_all_trajs_config} {single_instance_state_based_train_config} {data_selection_ratio_config} {wo_vel_obs_config}"   
            
            
        print(cmd)
        os.system(cmd)
    
    
    
    tracking_data_sv_root = args.tracking_data_sv_root
    
    
    if args.launch_type != 'trajectory':
        # grab_data_nm_idx_dict #
        # grab_tracking_data_root = args.tracking_data_sv_root
        if args.subj_idx == 2 or args.subj_idx < 1:
            data_nm_idx_dict_sv_fn = "grab_data_nm_idx_dict.npy"
        else:
            data_nm_idx_dict_sv_fn = f"grab_data_nm_idx_dict_s{args.subj_idx}.npy"
        data_nm_idx_dict_sv_fn = os.path.join(tracking_data_sv_root, data_nm_idx_dict_sv_fn)
        data_nm_idx_dict = np.load(data_nm_idx_dict_sv_fn, allow_pickle=True).item()
        data_nm_to_idx = data_nm_idx_dict['data_nm_to_idx']
        idx_to_data_nm = data_nm_idx_dict['idx_to_data_nm'] # data nm to idx # 
        
    # 
    def find_similar_objs(obj_index):
        grab_cross_obj_diff_arr_fn = "../assets/grab_cross_obj_verts_diff.npy"
        grab_cross_obj_diff_arr = np.load(grab_cross_obj_diff_arr_fn)
        cur_obj_diff_arr = grab_cross_obj_diff_arr[obj_index]
        cur_obj_sorted_nei_idxes = np.argsort(cur_obj_diff_arr, axis=0)
        cur_obj_sorted_nei_idxes = cur_obj_sorted_nei_idxes[1: 1 + 10]
        cur_obj_sorted_nei_idxes = cur_obj_sorted_nei_idxes.tolist()
        return cur_obj_sorted_nei_idxes

    

    # ./data/GRAB_Tracking_LEAP_PK/data/leap_passive_active_info_ori_grab_s1_alarmclock_lift.npy
    if args.dataset_type == 'grab':
        # passive_active_info_ori_grab_s2_pyramidlarge_lift.npy
        starting_str = "passive_active_info_ori_grab_"
        passive_active_info_tag = "passive_active_info_"
        
        if args.hand_type == 'leap':
            starting_str = "leap_" + starting_str
            passive_active_info_tag = "leap_" + passive_active_info_tag
        
        tot_tracking_data = os.listdir(tracking_data_sv_root)
        if args.num_frames == 150:
            tot_tracking_data = [fn for fn in tot_tracking_data if fn[: len(starting_str)] == starting_str and fn.endswith(".npy") and "_nf_" not in fn]
        else:
            nf_tag = f"_nf_{args.num_frames}"
            tot_tracking_data = [fn for fn in tot_tracking_data if fn[: len(starting_str)] == starting_str and fn.endswith(".npy") and nf_tag in fn]
        
        if len(args.subj_nm) > 0:
            subj_tag = f"_{args.subj_nm}_"
            tot_tracking_data = [fn for fn in tot_tracking_data if subj_tag in fn]
    elif args.dataset_type == 'taco':
        taso_inst_st_flag = 'taco_'
        mesh_sv_root = "/root/diffsim/UniDexGrasp/dexgrasp_policy/assets/meshdatav3_scaled/sem"
        tot_mesh_folders = os.listdir(mesh_sv_root)
        tot_mesh_folders = [fn for fn in tot_mesh_folders if fn[: len(taso_inst_st_flag)] == taso_inst_st_flag]
        
        
        if args.eval_split_trajs:
            # test_taco_tag = 'taco_20231024_'
            test_taco_tag = args.test_taco_tag # get the test taco tag #
            tot_mesh_folders = [fn for fn in tot_mesh_folders if test_taco_tag in fn]
        
        # modified_tag = "_modifed"
        # interped_tag = "_interped"
        # find tracking data
        tot_tracking_data = tot_mesh_folders
        passive_active_info_tag = ''
    
    else:
        raise ValueError(f"Unrecognized dataset_type: {args.dataset_type}")
    
    nn_gpus = args.nn_gpus
    
    ### and also add the grab instance to the optimized res dict
    # pre_load_trajectories, obj_type_to_pre_optimized_traj
    pre_load_trajectories = args.pre_load_trajectories
    print(f"pre_load_trajectories: {pre_load_trajectories}")
    if pre_load_trajectories: # load pre optimized trajectories ## # load pre optimized trajs #
        obj_type_to_pre_optimized_traj = args.obj_type_to_pre_optimized_traj 
        assert len(obj_type_to_pre_optimized_traj) > 0 and os.path.exists(obj_type_to_pre_optimized_traj)
        obj_type_to_pre_optimized_traj = np.load(obj_type_to_pre_optimized_traj, allow_pickle=True).item()
        #### obj type to pre optimized traj ####
    else:
        obj_type_to_pre_optimized_traj = None
    
    print(f"launch_type: {args.launch_type}")
    if args.launch_type == 'trajectory':
        # not using them 
        tot_grab_data_tag = []
        for cur_tracking_data in tot_tracking_data:
            cur_grab_data_tag = cur_tracking_data.split(".")[0][len(passive_active_info_tag):]
            print(f"cur_grab_data_tag: {cur_grab_data_tag}")
            if exclude_inst_tag_to_opt_res is not None and cur_grab_data_tag in exclude_inst_tag_to_opt_res:
                continue
            
            traj_grab_data_tag = cur_grab_data_tag
            
            if pure_obj_type_to_optimized_res and args.rew_filter:
                if cur_grab_data_tag not in pure_obj_type_to_optimized_res:
                    continue
                
                # obj rew and the grab data tag #
                cur_obj_rew = pure_obj_type_to_optimized_res[cur_grab_data_tag][0]
                print(f"cur_grab_data_tag: {cur_grab_data_tag}, cur_obj_rew: {cur_obj_rew}")
                if cur_obj_rew > args.rew_low_threshold:
                    continue
                print(f"cur_grab_data_tag: {cur_grab_data_tag}, cur_obj_rew: {cur_obj_rew}")
            
            
            if obj_type_to_pre_optimized_traj is not None:
                key_of_opt_traj = list(obj_type_to_pre_optimized_traj.keys())[0]
                
                if isinstance(key_of_opt_traj, tuple):
                    if 'taco' in cur_grab_data_tag:
                        cur_grab_data_tag_key = (cur_grab_data_tag, 'ori_grab_s2_phone_call_1')
                    else:
                        cur_grab_data_tag_key = ( cur_grab_data_tag, cur_grab_data_tag )
                    # cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[(cur_grab_data_tag, cur_grab_data_tag)]
                else: # grab data kye #
                    cur_grab_data_tag_key = cur_grab_data_tag
                    # cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ cur_grab_data_tag ]
                    
                
                if cur_grab_data_tag_key not in obj_type_to_pre_optimized_traj:
                    if args.train_on_all_trajs:
                        cur_pre_optimized_traj = ["./uni_manip/isaacgym_rl_exp_taco_grab_interpseq_interpfr_60_interpfr2_60_nntrans_40_eval/tracking_TACO_taco_20230930_037_INTERPSEQ_ori_grab_s2_phone_call_1_obs_pure_state_wref_wdelta_density_500.0_trans_0.5_rot_0.5_goalcond_False_kinebias_t0.5r0.5f20_rfd_0.3_rh_0.5_interpfr_60_interpfr2_60_nntrans_40_04-20-05-36/ts_to_hand_obj_obs_reset_1.npy"]
                    else:
                        continue
                else:
                    cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ cur_grab_data_tag_key ]
                # if isinstance(key_of_opt_traj, tuple): # grab data tag key #
                #     cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[(cur_grab_data_tag, cur_grab_data_tag)]
                # else:
                #     cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ cur_grab_data_tag ]
                cur_pre_optimized_traj = cur_pre_optimized_traj[0]
                cur_pre_optimized_traj_sorted = cur_pre_optimized_traj.replace(".npy", "_sorted.npy")
                cur_pre_optimized_traj_sorted_best = cur_pre_optimized_traj_sorted.replace(".npy", "_best.npy")
                if not os.path.exists(cur_pre_optimized_traj_sorted_best):
                    continue
                cur_pre_optimized_traj = cur_pre_optimized_traj_sorted_best
            else:
                cur_pre_optimized_traj = None
            print(f"cur_grab_data_tag: {cur_grab_data_tag}, cur_pre_optimized_traj: {cur_pre_optimized_traj}")
            cur_cuda_idx = len(tot_grab_data_tag) % nn_gpus
            tot_grab_data_tag.append(
                [cur_grab_data_tag, traj_grab_data_tag, cur_cuda_idx, cur_pre_optimized_traj]
            )
            
            if args.debug:
                break
            
    elif args.launch_type == 'object_type':
        tot_grab_data_tag = []
        for cur_tracking_data in tot_tracking_data:
            cur_grab_data_tag = cur_tracking_data.split(".")[0][len(passive_active_info_tag):]
            traj_grab_data_tag = cur_grab_data_tag
            
            if '_nf_' in cur_grab_data_tag:
                pure_obj_type = cur_grab_data_tag.split('_nf_')[0] #
            else:
                pure_obj_type = cur_grab_data_tag
            cur_idx = data_nm_to_idx[pure_obj_type]
            # ori_grab_sx_xxx # 
            tot_data_names = list(data_nm_to_idx.keys())
            print(f"pure_obj_type: {pure_obj_type}")
            cur_obj_name = pure_obj_type.split("_")[3]
            grab_obj_idx_dict_fn = f"../assets/grab_obj_name_idx_dict.npy"
            grab_obj_idx_dict = np.load(grab_obj_idx_dict_fn, allow_pickle=True).item()
            grab_obj_nm_to_idx = grab_obj_idx_dict['grab_obj_name_to_idx']
            cru_obj_idx = grab_obj_nm_to_idx[cur_obj_name] # get the current object index #
            cur_obj_sorted_nei_idxes = find_similar_objs(cru_obj_idx)
            cur_obj_sorted_nei_names = [grab_obj_idx_dict['grab_idx_to_obj_name'][idx] for idx in cur_obj_sorted_nei_idxes]
            for i_obj, nei_obj_name in enumerate(cur_obj_sorted_nei_names):
                pure_nei_obj_name = None 
                for cur_candi_pure_obj_name in tot_data_names:
                    if nei_obj_name in cur_candi_pure_obj_name:
                        pure_nei_obj_name = cur_candi_pure_obj_name
                        break
                if pure_nei_obj_name is None:
                    continue
                cur_cuda_idx = len(tot_grab_data_tag) % nn_gpus
                tot_grab_data_tag.append(
                    [pure_nei_obj_name, traj_grab_data_tag, cur_cuda_idx, None]
                )
    else:
        raise ValueError(f"Launch type {args.launch_type} not supported")
    
    print(f"tot_tracking_data : {tot_tracking_data}")
    
    ### tot grab data tag ###
    ### tot grab data tag ###
    tot_grab_data_tag = tot_grab_data_tag[args.st_idx: ]
    
    if args.debug:
        tot_grab_data_tag = tot_grab_data_tag[:1]
    
    
    # generalist_tune_all_instnaces  # data inst flag is not None #
    if (not args.generalist_tune_all_instnaces) and (args.data_inst_flag is not None) and len(args.data_inst_flag) > 0:
        data_inst_flag = args.data_inst_flag
        cur_cuda_idx = args.st_idx
        if obj_type_to_pre_optimized_traj is not None:
            # key_of_opt_traj = obj_type_to_pre_optimized_traj.keys()[0]
            key_of_opt_traj = list(obj_type_to_pre_optimized_traj.keys())[0]
            if isinstance(key_of_opt_traj, tuple):
                cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ (data_inst_flag, data_inst_flag) ]
            else:
                cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ data_inst_flag ]
            # cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ (data_inst_flag, data_inst_flag) ]
            # cur_pre_optimized_traj = obj_type_to_pre_optimized_traj[ data_inst_flag ]
            cur_pre_optimized_traj = cur_pre_optimized_traj[0] 
        else:
            cur_pre_optimized_traj = None
        tot_grab_data_tag = [
            [data_inst_flag, data_inst_flag, cur_cuda_idx, cur_pre_optimized_traj]
        ]
    #
    max_pool_size = nn_gpus * 1
    
    for i_st in range(0, len(tot_grab_data_tag), max_pool_size):
        i_ed = i_st + max_pool_size
        i_ed = min(i_ed, len(tot_grab_data_tag))
        cur_batch_grab_data_tags = tot_grab_data_tag[i_st: i_ed]
        
        cur_thread_processes = []
        
        for cur_grab_data_tag in cur_batch_grab_data_tags:
            # existing = judge_whether_trained(tot_tracking_logs, cur_grab_data_tag)
            # if existing:
            #     print(f" cur_grab_data_tag: {cur_grab_data_tag} has been trained")
            #     continue
            cur_thread_processes.append(
                Process(target=launch_one_process, args=(cur_grab_data_tag))
            )
            # 
            cur_thread_processes[-1].start()
        for p in cur_thread_processes:
            p.join()
    
    
    # launch_rlg_hydra()
