import os
import jax
import yaml

jax.config.update("jax_compilation_cache_dir", "__jaxcache__")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")

import jax.numpy as jnp
import numpy as np
import pickle
import glob
from tqdm import tqdm
from functools import partial
import time
import open3d as o3d
import optax
import jax.debug as jdb
import argparse

import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

import util.model_util as mutil
from train_pointfeat import ColDataset
import util.transform_util as tutil
import util.latent_obj_util as loutil
from util.reconstruction_util import create_fps_fcd_from_oriCORNs, FPS_padding
import pybullet as pb
from modules import shakey_module
import einops
import matplotlib.pyplot as plt
import util.structs as structs
import modules.traj_opt_module as traj_opt_module
import util.scene_util as scene_util
import logging
import datetime
import yaml
import util.mp_eval_util as mp_eval_util
from modules.ccd.curobo import decompose_mesh_to_spheres


def flatten_config(d, parent_key='', sep='_'):
    """
    Flatten a nested dictionary so that keys become 'parentkey_childkey'.
    This assumes your parser argument names are the leaf keys.
    """
    items = {}
    for k, v in d.items():
        new_key = k if parent_key == '' else f"{parent_key}_{k}"
        if isinstance(v, dict):
            items.update(flatten_config(v, parent_key=new_key, sep=sep))
        else:
            items[new_key] = v
    return items

def print_and_log(message):
    """
    Print and log a message.
    """
    print(message)
    logging.info(message)

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=None, help='Path to YAML configuration file')
parser.add_argument('--vel_coef', type=float, default=1.0)
parser.add_argument('--col_coef', type=float, default=5.0)
parser.add_argument('--col_threshold', type=float, default=6.0, help="collision threshold only for ours and stamp methods")
parser.add_argument('--acc_coef_factor', type=float, default=10)
parser.add_argument('--jerk_coef_factor', type=float, default=50)
parser.add_argument('--particle_itr_no', type=int, default=40)
parser.add_argument('--gradient_itr_no', type=int, default=10)
parser.add_argument('--num_mppi_samples', type=int, default=200)
parser.add_argument('--num_seeds', type=int, default=2)
parser.add_argument('--solver_seed', type=int, default=0)
parser.add_argument('--num_trajectory_points_particle', type=int, default=8)
parser.add_argument('--num_trajectory_points_gradient', type=int, default=8)
parser.add_argument('--interpolation_num_particle', type=int, default=3)
parser.add_argument('--interpolation_num_gradient', type=int, default=3)
parser.add_argument('--visualize', type=int, default=True)
parser.add_argument('--num_objects', type=int, default=0)
parser.add_argument('--bundled_samples_num', type=int, default=4)
parser.add_argument('--bundled_order', type=int, default=1)
parser.add_argument('--reduce_k', type=int, default=20)
parser.add_argument('--linesearch_batch_num', type=int, default=6)
parser.add_argument('--curobo_activation_distance', type=float, default=0.040)
parser.add_argument('--curobo_moving_obj_npnts', type=int, default=100)
parser.add_argument('--save_video', type=int, default=1)
parser.add_argument('--ccd_type', type=str, default='ours', choices=['ours', 'curobo', 'stamp'])
parser.add_argument('--broadphase_type', type=str, default='naivebf_independent', choices=['naivebf_independent', 
                                                                                    'naivebf_segment', 
                                                                                    'naivebf_traj', 
                                                                                    'naive_independent', 
                                                                                    'naive_segment', 
                                                                                    'naive_traj', 
                                                                                    'aabb_aabb', 
                                                                                    'timeopt_independent',
                                                                                    'timeopt_segment', 
                                                                                    'timeopt_traj',
                                                                                    'timeoptbf_independent',
                                                                                    'timeoptbf_segment',
                                                                                    'timeoptbf_traj',
                                                                                    ])
parser.add_argument('--env_type', type=str, default='bimanual', choices=['table', 'clean', 'dish', 'dish_multiple', 'room', 'room_pen', 'bimanual', 'construction_site', 'construction_site_hard'])
parser.add_argument('--robot_type', type=str, default='im2', choices=['ur5', 'ur5_rg6', 'shakey_rg2', 'shakey_robotiq', 'shakey_no_gripper', 'im2'])
parser.add_argument('--save_env', type=int, default=False)

# First parse to get the config file (if provided)
args, remaining = parser.parse_known_args()

if args.config is not None:
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    # Flatten the YAML config to match the argument names.
    # Adjust flattening if your keys in YAML and argument names differ.
    flat_config = flatten_config(config)
    
    # Optionally print the flattened config for debugging.
    # print("Flattened config:", flat_config)
    
    # Update parser defaults with values from the YAML config.
    parser.set_defaults(**flat_config)


args = parser.parse_args()
print("Final arguments:")
print(args)

def to_yml(env_type, ccd_type, fixed_objects, movable_objects, base_se2, shakey: shakey_module.Shakey, loss_args: structs.LossArgs, x, aux_pbids, seed):
    to_np = lambda pos, quat: (np.array(pos), np.array(quat))
    pos, quat = to_np(*tutil.SE2h2pq(np.array(base_se2), np.array(shakey.robot_height)))
    i = 0
    data = {
        "mesh": {},
        "robot": {
            "urdf_path": shakey.urdf_dir,
            "robot_base_pose": [*pos.tolist(), quat[3].item(), *quat[:3].tolist()],
            "traj": np.array(x).tolist(),
            "moving_obj": []
        },
    }

    for object in fixed_objects:
        pos, quat = to_np(*object.base_pose)
        # tutil.pq_multi(*transformation, *to_np(*object.base_pose))
        mesh_path = object.mesh_path
        mesh_scale = object.scale[0]
        key = f"mesh_{i}"
        data["mesh"][key] = {
            "pose": [*pos.tolist(), quat[3].item(), *quat[:3].tolist()],
            "scale": mesh_scale,
            "file_path": mesh_path,
        }
        i += 1

    for aux_id in aux_pbids:
        pos, quat = pb.getBasePositionAndOrientation(aux_id)
        visual_data = pb.getVisualShapeData(aux_id)
        # pos, quat = to_np(*object[5])
        mesh_path = visual_data[0][4]
        mesh_path = mesh_path.decode('utf-8') if isinstance(mesh_path, bytes) else mesh_path
        if mesh_path == "":
            continue

        mesh_scale = visual_data[0][3][0]
        key = f"mesh_{i}"
        data["mesh"][key] = {
            "pose": [*pos, quat[3], *quat[:3]],
            "scale": mesh_scale,
            "file_path": mesh_path,
        }
        i += 1

    i = 0
    ee_indices = np.array(shakey.ee_idx).reshape(-1)
    ee_to_obj_pq = np.array(loss_args.ee_to_obj_pq).reshape(-1, 7)
    for object in movable_objects:
        pos, quat = to_np(*object.base_pose)
        # tutil.pq_multi(*transformation, *to_np(*object.base_pose))
        mesh_path = object.mesh_path
        mesh_scale = object.scale[0]
        ee_idx = ee_indices[i].item()
        ee_to_obj_pq_i = [*ee_to_obj_pq[i,:3].tolist(), ee_to_obj_pq[i,6].item(), *ee_to_obj_pq[i,3:6].tolist()]

        data["robot"]["moving_obj"].append(
            {
                "pose": [*pos.tolist(), quat[3].item(), *quat[:3].tolist()],
                "scale": mesh_scale,
                "file_path": mesh_path,
                "attached_to": ee_idx,
                "ee_to_obj_pq": ee_to_obj_pq_i,
            }
        )
        i += 1
    filename = f"temp/video/{env_type}/{ccd_type}/{seed}.yml"
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "w") as f:
        yaml.dump(data, f, default_flow_style=False)

if __name__ == "__main__":
    
    date_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    date_str = date_str + f"_{args.env_type}_{args.ccd_type}"
    log_dir = os.path.join('traj_opt_logs', date_str)
    os.makedirs(log_dir, exist_ok=True)
    # add log text file
    logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.INFO)

    # config logging
    logging.info(args.__dict__)

    seed = 0

    # load oriCORN model
    models = mutil.Models().load_pretrained_models()
    
    # robot
    use_collision_mesh = False
    if args.robot_type == 'ur5_rg6':
        urdf_dirs = "assets/ur5/urdf/shakey_open_rg6.urdf"
        models = models.load_self_collision_model('shakey')
    elif args.robot_type == 'shakey_no_gripper':
        urdf_dirs = "assets/ur5/urdf/shakey_no_gripper.urdf"
        models = models.load_self_collision_model('shakey')
    elif args.robot_type == 'shakey_robotiq':
        urdf_dirs = "assets/ur5/urdf/shakey_robotiq_open.urdf"
        models = models.load_self_collision_model('shakey_robotiq')
    elif args.robot_type == 'shakey_rg2_close':
        urdf_dirs = "assets/ur5/urdf/shakey_rg2_close.urdf"
        models = models.load_self_collision_model('shakey')
    elif args.robot_type == 'ur5':
        urdf_dirs = "assets/ur5/urdf/ur5.urdf"
        models = models.load_self_collision_model('ur5')
    elif args.robot_type == 'im2':
        urdf_dirs = "assets/RobotBimanualV4/urdf/RobotBimanualV4.urdf"
        models = models.load_self_collision_model('im2')
        use_collision_mesh = True
    elif args.robot_type == 'shakey_rg2':
        urdf_dirs = "assets/ur5/urdf/shakey_open.urdf"
        models = models.load_self_collision_model('shakey')
    else:
        raise ValueError(f'invalid robot type: {args.robot_type}')

    se2_bounds = None
    if args.env_type in ['room', 'room_pen']:
        robot_height = 0.55
    elif args.env_type in ['construction_site', 'construction_site_hard']:
        se2_bounds = np.array([[-1.0, -5.0, -1.2*np.pi], 
                               [1.0, 5.0, 1.2*np.pi]])
        robot_height = 0.34
    else:
        robot_height = 0.0


    shakey = shakey_module.load_urdf_kinematics(
        urdf_dirs=urdf_dirs,
        models=models,
        use_collision_mesh=use_collision_mesh,
        robot_height=robot_height,
    )

    if args.visualize:
        pb.connect(pb.GUI)

    if args.env_type in ['bimanual', 'bimanual_hard']:
        base_se2 = np.array([0, 0, 0.])
    else:
        base_se2 = np.array([0, -0.15, -np.pi/2.0])
    robot_pb_uid = shakey.create_pb(se2=base_se2)

    def search_node_visualizer(node):
        # global obj_in_hand, ee_to_obj_pq
        shakey.set_q_pb(robot_pb_uid, node, obj_in_hand, ee_to_obj_pq)

    base_pqc = tutil.SE2h2pq(base_se2, np.array(shakey.robot_height))
    robot_base_pqc = jnp.concat(base_pqc, axis=-1)
    traj_optimizer = traj_opt_module.TrajectoryOptimizer(models, shakey, robot_base_pqc=robot_base_pqc, 
                                                         col_coef=args.col_coef,
                                                         ccd_type=args.ccd_type,
                                                         broadphase_type=args.broadphase_type,
                                                         num_trajectory_points_particle=args.num_trajectory_points_particle,
                                                         num_trajectory_points_gradient=args.num_trajectory_points_gradient,
                                                         particle_itr_no=args.particle_itr_no,
                                                         gradient_itr_no=args.gradient_itr_no,
                                                         num_mppi_samples=args.num_mppi_samples,
                                                         interpolation_num_particle=args.interpolation_num_particle,
                                                         interpolation_num_gradient=args.interpolation_num_gradient,
                                                         reduce_k=args.reduce_k,
                                                         num_seeds=args.num_seeds,
                                                         vel_coef=args.vel_coef,
                                                         acc_coef_factor=args.acc_coef_factor,
                                                         jerk_coef_factor=args.jerk_coef_factor,
                                                         collision_threshold=args.col_threshold,
                                                         bundled_samples_num=args.bundled_samples_num,
                                                         bundled_order=args.bundled_order,
                                                         linesearch_batch_num=args.linesearch_batch_num,
                                                         curobo_activation_distance=args.curobo_activation_distance,
                                                         se2_bounds=se2_bounds)

    @jax.jit
    def fk(q):
        if q.shape[-1] == shakey.num_act_joints:
            return tutil.pq_multi(robot_base_pqc, shakey.FK(q, oriCORN_out=False)[shakey.ee_idx])
        else:
            return shakey.FK(q, oriCORN_out=False)[shakey.ee_idx]
    ik = shakey.get_IK_jit_func((robot_base_pqc[:3], robot_base_pqc[3:]), grasp_center_coordinate=False)
    ik_grasp = shakey.get_IK_jit_func((robot_base_pqc[:3], robot_base_pqc[3:]), grasp_center_coordinate=True)

    def sample_valid_q(z_range, jkey, robot_pb_uid):
        # turn on self collision
        lower_bound = np.array([-2.8973, -np.pi-np.pi/6.0, -2.8973, -3.0718, -2.8973, -3.0718])
        upper_bound = np.array([2.8973, -np.pi/6.0, 2.8973, 0.0698, 2.8973, 0.0698])
        while True:
            jkey, _ = jax.random.split(jkey)
            q_random = jax.random.uniform(jkey, (6,), minval=lower_bound, maxval=upper_bound)
            shakey.set_q_pb(robot_pb_uid, q_random)
            pb.performCollisionDetection()
            col_res = pb.getContactPoints(robot_pb_uid)
            if len(col_res) != 0:
                continue
            pqc = fk(q_random)
            if pqc[2] > z_range[0] and pqc[2] < z_range[1]:
                if pqc[1] > 0.2:
                    break
        return q_random

    pb.resetDebugVisualizerCamera(
        cameraDistance=1.79, 
        cameraYaw=-443.20,
        cameraPitch=0.26,
        cameraTargetPosition=[-0.05, 0.07, 0.30]
    )

    # if args.ccd_type == 'curobo' and args.env_type != 'bimanual':
    if args.ccd_type == 'curobo':
        voxel_intervals=[0.1] + [0.07]*shakey.num_act_joints
        num_surface_samples=[10] + [50]*shakey.num_act_joints
        # voxel_intervals=[0.3, 0.07, 0.07, 0.07, 0.07, 0.07, 0.03]
        # num_surface_samples=[10, 10, 30, 30, 20, 20, 40]
        grasping_objects=[]
        traj_optimizer.cost_module_cls.ccd_cls.enroll_robot(
            shakey,
            voxel_intervals=voxel_intervals,
            num_surface_samples=num_surface_samples,
            visualize=False,
        )

    # hyperparameter
    success_list = []
    success_particle_list = []
    max_pen_depth_list = []
    max_pen_depth_particle_list = []
    elapsed_time_list = []
    valid_mask_pred_list = []
    seeds = range(500)
    for itr, seed in enumerate(seeds):
        jkey = jax.random.PRNGKey(seed)

        ee_to_obj_pq = None
        moving_obj = None
        plane_params = None
        fixed_moving_idx_pair = None
        multiple_sequences = False
        mesh_ids = None
        view_matrix = None
        poses_after_sequence = None
        aux_pbids = []
        if itr%10==0:
            pb.resetSimulation()
            robot_pb_uid = shakey.create_pb(se2=base_se2)

        if args.env_type == 'table':
            random_obj, environment_obj, pybullet_scene = scene_util.create_table_sampled_scene(models=models, num_objects=args.num_objects, seed=seed)
            # init, goal sampling
            init_q = sample_valid_q([0.4, 1.0], jkey, robot_pb_uid)
            goal_q = sample_valid_q([0, 0.4], jkey, robot_pb_uid)
            if random_obj is None:
                fixed_obj = environment_obj
            else:
                fixed_obj = environment_obj.concat(random_obj, axis=0)
            moving_obj = None
            ee_to_obj_pq = None
        elif args.env_type == 'clean':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq = scene_util.create_cleaning_scene(
                models=models,
                num_objects=args.num_objects,
                seed=seed,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
                fk=fk,
                ik=ik,
                visualize=True
            )
            ee_to_obj_pq = jnp.concatenate(ee_to_obj_pq, axis=-1)
        elif args.env_type == 'dish':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, aux_pbids = scene_util.create_dish_scene(
                models=models,
                seed=seed,
                shakey=shakey,
                ik_func=ik,
                robot_pb_uid=robot_pb_uid,
            )
            ee_to_obj_pq = jnp.concatenate(ee_to_obj_pq, axis=-1)
            self_collision = False
        elif args.env_type == 'dish_multiple':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, poses_after_sequence, aux_pbids, view_matrix = scene_util.create_multiple_dish_scene(
                models=models,
                seed=seed,
                shakey=shakey,
                ik_func=ik,
                robot_pb_uid=robot_pb_uid,
            )
            ee_to_obj_pq = [
                jnp.concatenate(single_ee_to_obj_pq, axis=-1) for single_ee_to_obj_pq in ee_to_obj_pq
            ]
            self_collision = False
            multiple_sequences = True
            sequence_num = init_q.shape[0]
        elif args.env_type == 'room':
            fixed_obj, pybullet_scene, init_q, goal_q = scene_util.create_room_scene(
                models=models,
                seed=seed,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
            )
        elif args.env_type == 'room_pen':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, aux_pbids = scene_util.create_room_pen_scene(
                models=models,
                seed=seed,
                ik_func=ik,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
            )
        elif args.env_type == 'construction_site':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, aux_pbids, plane_params, fixed_moving_idx_pair = scene_util.create_construction_site(
                models=models,
                seed=seed,
                ik_func=ik,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
            )
            self_collision = False

        elif args.env_type == 'construction_site_hard':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, aux_pbids, plane_params, fixed_moving_idx_pair, view_matrix = scene_util.create_construction_site_hard(
                models=models,
                seed=seed,
                ik_func=ik,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
            )
            self_collision = False
            
        elif args.env_type == 'bimanual':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, fixed_moving_idx_pair = scene_util.create_bimanual_insertion(
                models=models,
                seed=seed,
                ik_func=ik_grasp,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
            )
            self_collision = True
        
        elif args.env_type == 'bimanual_hard':
            moving_obj, fixed_obj, pybullet_scene, init_q, goal_q, ee_to_obj_pq, fixed_moving_idx_pair, view_matrix =\
                  scene_util.create_bimanual_insertion_hard(
                models=models,
                seed=seed,
                ik_func=ik_grasp,
                shakey=shakey,
                robot_pb_uid=robot_pb_uid,
            )
            self_collision = True
        
        if pybullet_scene is not None and pybullet_scene.movable_objects is not None:
            obj_in_hand = [mo.pb_uid for mo in pybullet_scene.movable_objects]
        else:
            obj_in_hand = None

        moving_spheres = []

        if args.ccd_type == 'curobo' and not multiple_sequences:
            path_list = []
            scale_list = []
            pqc_list = []
            for obj in pybullet_scene.fixed_objects:
                # path_list.append(obj.col_mesh_path)
                path_list.append(obj.mesh_path)
                scale_list.append(obj.scale)
                pqc_list.append(np.concatenate(obj.base_pose, axis=-1))
            mesh_ids = traj_optimizer.cost_module_cls.ccd_cls.enroll_meshes(path_list, scale_list, pqc_list)

            if moving_obj is not None:
                for midx, obj in enumerate(pybullet_scene.movable_objects):
                    o3d_mesh = obj.convert_to_o3d()
                    moving_spheres.append(decompose_mesh_to_spheres(o3d_mesh, 0.05, None, num_of_pnts=args.curobo_moving_obj_npnts))

                    print(f"moving object {obj.mesh_path} spheres: {moving_spheres[-1].shape[0]}")

                    # surface_pcd = o3d_mesh.sample_points_uniformly(number_of_points=2000)
                    # surface_pnts = np.asarray(surface_pcd.points)
                    # fps_pnts = FPS_padding(surface_pnts, args.curobo_moving_obj_npnts, None)
                    # surface_spheres = jnp.concat([fps_pnts, np.full((fps_pnts.shape[0], 1), 0.0005)], axis=-1)
                    # moving_spheres.append(surface_spheres)

        if args.ccd_type == 'curobo' and fixed_moving_idx_pair is not None:
            assert len(pybullet_scene.fixed_objects) == 0
            mesh_names = [
                models.asset_path_util.obj_paths[i] for i in shakey.canonical_obj_idx
            ] + [
                obj.mesh_path for obj in pybullet_scene.movable_objects
            ]
            scales = [
                shakey.mesh_scale[i] for i in shakey.canonical_obj_idx
            ] + [
                obj.scale for obj in pybullet_scene.movable_objects
            ]
            pqcs = jnp.concatenate([
                jnp.concatenate([jnp.zeros((len(shakey.canonical_obj_idx),6)), jnp.ones((len(shakey.canonical_obj_idx),1))], axis=1),
                jnp.stack([np.concatenate(obj.base_pose, axis=-1) for obj in pybullet_scene.movable_objects], axis=0),
            ], axis=0)
            mesh_ids = traj_optimizer.cost_module_cls.ccd_cls.enroll_meshes_batch(mesh_names, scales, pqcs)
            print("enrol mesh ids with fixed_moving_idx_pair")

        if itr == 0:
            traj_opt_jit = jax.jit(lambda init_q, goal_q, loss_args, jkey: traj_optimizer.perform_multiple_seed(
                init_q,
                goal_q,
                loss_args=structs.LossArgs(
                    fixed_oriCORNs=loss_args.fixed_oriCORNs,
                    moving_oriCORNs=loss_args.moving_oriCORNs,
                    ee_to_obj_pq=loss_args.ee_to_obj_pq,
                    plane_params=loss_args.plane_params,
                    fixed_moving_idx_pair=fixed_moving_idx_pair, # fixed_moving_idx_pair to static
                    moving_spheres=loss_args.moving_spheres,
                    mesh_ids=loss_args.mesh_ids,
                ),
                jkey=jkey,
            ))
            loss_func_jit = jax.jit(lambda x, loss_args, jkey: traj_optimizer.cost_module_cls.traj_opt_cost(
                x,
                loss_args=structs.LossArgs(
                    fixed_oriCORNs=loss_args.fixed_oriCORNs,
                    moving_oriCORNs=loss_args.moving_oriCORNs,
                    ee_to_obj_pq=loss_args.ee_to_obj_pq,
                    plane_params=loss_args.plane_params,
                    fixed_moving_idx_pair=fixed_moving_idx_pair, # fixed_moving_idx_pair to static
                    moving_spheres=loss_args.moving_spheres,
                    mesh_ids=loss_args.mesh_ids,
                ),
                jkey=jkey,
                interpolation_num=args.interpolation_num_gradient
            ))
            if multiple_sequences:
                traj_opt_jits = [
                    jax.jit(lambda init_q, goal_q, loss_args, jkey: traj_optimizer.perform_multiple_seed(
                        init_q,
                        goal_q,
                        loss_args=structs.LossArgs(
                            fixed_oriCORNs=loss_args.fixed_oriCORNs,
                            moving_oriCORNs=loss_args.moving_oriCORNs,
                            ee_to_obj_pq=loss_args.ee_to_obj_pq,
                            plane_params=loss_args.plane_params,
                            fixed_moving_idx_pair=fixed_moving_idx_pair, # fixed_moving_idx_pair to static
                            moving_spheres=loss_args.moving_spheres,
                            mesh_ids=loss_args.mesh_ids,
                        ),
                        jkey=jkey,
                    )) for i in range(sequence_num)
                ]
                loss_func_jits = [
                    jax.jit(lambda x, loss_args, jkey: traj_optimizer.cost_module_cls.traj_opt_cost(
                        x,
                        loss_args=structs.LossArgs(
                            fixed_oriCORNs=loss_args.fixed_oriCORNs,
                            moving_oriCORNs=loss_args.moving_oriCORNs,
                            ee_to_obj_pq=loss_args.ee_to_obj_pq,
                            plane_params=loss_args.plane_params,
                            fixed_moving_idx_pair=fixed_moving_idx_pair, # fixed_moving_idx_pair to static
                            moving_spheres=loss_args.moving_spheres,
                            mesh_ids=loss_args.mesh_ids,
                        ),
                        jkey=jkey,
                        interpolation_num=args.interpolation_num_gradient
                    )) for i in range(sequence_num)
                ]

        # shakey.show_fps_sphere(init_q, scene_oriCORNs=fixed_obj, grasped_oriCORN=moving_obj, ee_to_obj_pqc=ee_to_obj_pq)
        # sphere_mask = shakey.show_fps_sphere(init_q, grasped_oriCORN=moving_obj, ee_to_obj_pqc=ee_to_obj_pq, visualize=False)

        loss_args = structs.LossArgs(
            fixed_obj,
            moving_obj,
            ee_to_obj_pq,
            plane_params,
            fixed_moving_idx_pair,
            moving_spheres,
            mesh_ids,
        )

        print('start traj opt')
        if multiple_sequences:
            elapsed_time = 0
            success_total = True
            success_particle_total = True
            max_pen_depth_total = 0
            max_pen_depth_particle_total = 0
            valid_mask_pred_total = True
            moving_obj_number = loss_args.moving_oriCORNs.shape[0]
            for i, (traj_opt_jit, loss_func_jit) in enumerate(zip(traj_opt_jits, loss_func_jits)):
                is_grasping_obj = (i % 2 == 1)
                grasping_obj_idx = i // 2
                fixed_obj:loutil.LatentObjects = loss_args.fixed_oriCORNs
                moving_obj:loutil.LatentObjects = loss_args.moving_oriCORNs
                moving_spheres = []

                moving_obj_pose = jnp.stack([
                    poses_after_sequence[2 * j + 1] if 2 * j + 2 <= i else poses_after_sequence[2 * j]
                    for j in range(moving_obj_number)
                ])
                if is_grasping_obj:
                    if grasping_obj_idx > 0:
                        fixed_obj = fixed_obj.concat(
                            moving_obj[:grasping_obj_idx].apply_pq_z(moving_obj_pose[:grasping_obj_idx, :3], moving_obj_pose[:grasping_obj_idx, 3:], models.rot_configs),
                            axis = 0,
                        )
                    if grasping_obj_idx < moving_obj_number:
                        fixed_obj = fixed_obj.concat(
                            moving_obj[grasping_obj_idx+1:].apply_pq_z(moving_obj_pose[grasping_obj_idx+1:, :3], moving_obj_pose[grasping_obj_idx+1:, 3:], models.rot_configs),
                            axis = 0,
                        )
                    moving_obj = moving_obj[grasping_obj_idx:grasping_obj_idx+1]
                    ee_to_obj_pq = loss_args.ee_to_obj_pq[grasping_obj_idx]
                else:
                    fixed_obj = fixed_obj.concat(
                        moving_obj.apply_pq_z(moving_obj_pose[:, :3], moving_obj_pose[:, 3:], models.rot_configs),
                        axis = 0,
                    )
                    moving_obj = None
                    ee_to_obj_pq = None

                for j in range(moving_obj_number):
                    pybullet_scene.movable_objects[j].set_base_pose((np.array(moving_obj_pose[j, :3]), np.array(moving_obj_pose[j, 3:])))

                if args.ccd_type == 'ours' and args.broadphase_type.split('_')[0][-2:] != 'bf':
                    traj_optimizer.cost_module_cls.ccd_cls.broad_phase_cls.enroll_bvh(fixed_obj)
                elif args.ccd_type == 'curobo':
                    path_list = []
                    scale_list = []
                    pqc_list = []
                    for obj in pybullet_scene.fixed_objects:
                        # path_list.append(obj.col_mesh_path)
                        path_list.append(obj.mesh_path)
                        scale_list.append(obj.scale)
                        pqc_list.append(np.concatenate(obj.base_pose, axis=-1))
                    for midx, obj in enumerate(pybullet_scene.movable_objects):
                        if (not is_grasping_obj) or (is_grasping_obj and midx is not grasping_obj_idx):
                            path_list.append(obj.mesh_path)
                            scale_list.append(obj.scale)
                            pqc_list.append(moving_obj_pose[midx])

                    mesh_ids = traj_optimizer.cost_module_cls.ccd_cls.enroll_meshes(path_list, scale_list, pqc_list)
                    if is_grasping_obj:
                        obj = pybullet_scene.movable_objects[grasping_obj_idx]
                        obj.set_base_pose((np.array([0,0,0]), np.array([0,0,0,1])))
                        o3d_mesh = obj.convert_to_o3d()
                        surface_spheres = decompose_mesh_to_spheres(o3d_mesh, 0.05, None, num_of_pnts=args.curobo_moving_obj_npnts)
                        moving_spheres = [surface_spheres]
                        print(f"moving object {obj.mesh_path} spheres: {moving_spheres[-1].shape[0]}")
                        obj.set_base_pose((moving_obj_pose[grasping_obj_idx, :3], moving_obj_pose[grasping_obj_idx, 3:]))

                start_time = time.time()
                loss_args_i = structs.LossArgs(
                    fixed_oriCORNs=fixed_obj,
                    moving_oriCORNs=moving_obj,
                    ee_to_obj_pq=ee_to_obj_pq,
                    plane_params=loss_args.plane_params,
                    fixed_moving_idx_pair=loss_args.fixed_moving_idx_pair,
                    moving_spheres=moving_spheres,
                    mesh_ids=mesh_ids,
                )
                x, traj_opt_state = traj_opt_jit(init_q[i], goal_q[i], loss_args_i, jax.random.PRNGKey(args.solver_seed+seed))
                x = jax.block_until_ready(x)
                elapsed_time += time.time() - start_time

                _, loss_aux = loss_func_jit(x, loss_args_i, jkey)
                video_path = None # os.path.join(log_dir, 'video', f'video_{args.env_type}_{args.ccd_type}_{seed}_{i}.mp4')
                current_obj_in_hand = obj_in_hand[grasping_obj_idx:grasping_obj_idx+1] if is_grasping_obj else None
                ee_to_obj_pq = np.array(ee_to_obj_pq) if is_grasping_obj else None
                success, max_pen_depth = mp_eval_util.simulate_in_pb(
                    x,
                    robot_pb_uid,
                    shakey,
                    sleep_time=0.010/x.shape[-2],
                    evaluate=not args.visualize,
                    obj_in_hand=current_obj_in_hand,
                    ee_to_obj_pq=ee_to_obj_pq,
                    self_collision=self_collision,
                    # ignore_collision=obj_in_hand[i+1:],
                    video_dir=video_path,
                    view_matrix=view_matrix,
                )
                success_particle, max_pen_depth_particle = False, 0
                mp_eval_util.simulate_in_pb(
                    traj_opt_state['opt_aux_info']['x_particle'],
                    robot_pb_uid,
                    shakey,
                    sleep_time=0.010/x.shape[-2],
                    evaluate=not args.visualize,
                    obj_in_hand=current_obj_in_hand,
                    ee_to_obj_pq=ee_to_obj_pq,
                    self_collision=self_collision,
                    # ignore_collision=obj_in_hand[i+1:],
                    video_dir=None,
                    view_matrix=view_matrix,
                )
                print(success)

                success_total = success_total and success
                max_pen_depth_total = min(max_pen_depth_total, max_pen_depth)
                success_particle_total = success_particle_total and success_particle
                max_pen_depth_particle_total = min(max_pen_depth_particle_total, max_pen_depth_particle)

                opt_aux = {}
                for k in traj_opt_state['opt_aux_info']:
                    if traj_opt_state['opt_aux_info'][k].ndim < 2:
                        opt_aux[k] = traj_opt_state['opt_aux_info'][k]

                valid_mask_pred_total = valid_mask_pred_total or np.logical_not(opt_aux['min_invalid_mask'])
                try:
                    # print(f"CUR_INFO: num itr: {traj_opt_state['i']}, col_loss: {traj_opt_state['col_loss']}, loss:{traj_opt_state['min_loss']}, invalid_mask: {traj_opt_state['invalid_mask']}, elapsed_time: {elapsed_time}")
                    print_and_log(f"CUR_INFO: num itr: {traj_opt_state['i']}, {opt_aux}, elapsed_time: {elapsed_time}")
                except:
                    pass
                if args.save_env:
                    if is_grasping_obj:
                        to_yml(args.env_type, args.ccd_type, pybullet_scene.fixed_objects + pybullet_scene.movable_objects[:grasping_obj_idx] + pybullet_scene.movable_objects[grasping_obj_idx+1:], pybullet_scene.movable_objects[grasping_obj_idx:grasping_obj_idx+1], base_se2, shakey, loss_args, x, aux_pbids,  2 * moving_obj_number * seed + i)
                    else:
                        to_yml(args.env_type, args.ccd_type, pybullet_scene.fixed_objects + pybullet_scene.movable_objects, [], base_se2, shakey, loss_args, x, aux_pbids,  2 * moving_obj_number * seed + i)

            success = success_total
            success_particle = success_particle_total
            max_pen_depth = max_pen_depth_total
            max_pen_depth_particle = max_pen_depth_particle_total
            valid_mask_pred = valid_mask_pred_total
            print_and_log(loss_aux)

        else:
            if args.ccd_type == 'ours' and args.broadphase_type.split('_')[0][-2:] != 'bf':
                traj_optimizer.cost_module_cls.ccd_cls.broad_phase_cls.enroll_bvh(fixed_obj)
            start_time = time.time()
            x, traj_opt_state = traj_opt_jit(init_q, goal_q, loss_args, jax.random.PRNGKey(args.solver_seed+seed))
            x = jax.block_until_ready(x)
            elapsed_time = time.time() - start_time
            _, loss_aux = loss_func_jit(x, loss_args, jkey)

            if args.save_video:
                video_path = os.path.join(log_dir, 'video', f'video_{args.env_type}_{args.ccd_type}_{seed}.mp4')
            else:
                video_path = None
            success, max_pen_depth = mp_eval_util.simulate_in_pb(
                x,
                robot_pb_uid,
                shakey,
                sleep_time=0.010/x.shape[-2],
                evaluate=not args.visualize,
                obj_in_hand=obj_in_hand,
                ee_to_obj_pq=np.array(ee_to_obj_pq),
                self_collision=self_collision,
                video_dir=video_path,
                view_matrix=view_matrix,
            )
            success_particle, max_pen_depth_particle = mp_eval_util.simulate_in_pb(
                traj_opt_state['opt_aux_info']['x_particle'],
                robot_pb_uid,
                shakey,
                sleep_time=0.010/x.shape[-2],
                evaluate=not args.visualize,
                obj_in_hand=obj_in_hand,
                ee_to_obj_pq=np.array(ee_to_obj_pq),
                self_collision=self_collision,
            )

            opt_aux = {}
            for k in traj_opt_state['opt_aux_info']:
                if traj_opt_state['opt_aux_info'][k].ndim < 2:
                    opt_aux[k] = traj_opt_state['opt_aux_info'][k]

            valid_mask_pred = np.logical_not(opt_aux['min_invalid_mask'])

            try:
                # print(f"CUR_INFO: num itr: {traj_opt_state['i']}, col_loss: {traj_opt_state['col_loss']}, loss:{traj_opt_state['min_loss']}, invalid_mask: {traj_opt_state['invalid_mask']}, elapsed_time: {elapsed_time}")
                print_and_log(f"CUR_INFO: num itr: {traj_opt_state['i']}, {opt_aux}, elapsed_time: {elapsed_time}")
            except:
                pass
            print_and_log(loss_aux)

        valid_mask_pred_list.append(valid_mask_pred)
        elapsed_time_list.append(elapsed_time)
        success_list.append(success)
        max_pen_depth_list.append(max_pen_depth)
        success_particle_list.append(success_particle)
        max_pen_depth_particle_list.append(max_pen_depth_particle)

        # calculate precition and recall with valid_mask_pred_list and success_list
        succ_precision = np.sum(np.logical_and(success_list, valid_mask_pred_list)) / np.sum(valid_mask_pred_list)
        succ_recall = np.sum(np.logical_and(success_list, valid_mask_pred_list)) / np.sum(success_list)
        print_and_log(f'seed: {seed}, valid_precision: {succ_precision:.4f}, valid_recall: {succ_recall:.4f}')

        print_and_log(f'seed: {seed}, success rate: {np.mean(success_list):.2f} , succ rate particle {np.mean(success_particle_list):.2f} / {len(success_list)}, max_pen_depth: {np.mean(max_pen_depth_list)},  elapsed time: {np.mean(elapsed_time_list[1:]):.2f}s')


        # if (success_list[-1]) and not loss_aux['collision_binary']:
        #     print(1)

        # if loss_aux['collision_binary']:
        #     with open('collision_log.pkl', 'wb') as f:
        #         pickle.dump((x, loss_args, pybullet_scene), f)
        #     traj_optimizer.cost_module_cls.traj_opt_cost(x, loss_args, jkey, interpolation_num=args.interpolation_num_gradient, visualize=True)

        # if (success_list[-1]) and loss_aux['collision_binary']:
        #     with open('collision_log.pkl', 'wb') as f:
        #         pickle.dump((x, loss_args, pybullet_scene), f)
        #     traj_optimizer.cost_module_cls.traj_opt_cost(x, loss_args, jkey, interpolation_num=args.interpolation_num_gradient, visualize=True)

        # if (not success_list[-1]) and not loss_aux['collision_binary']:
        #     with open('collision_log.pkl', 'wb') as f:
        #         pickle.dump((x, loss_args, pybullet_scene), f)
        #     print('collision not detected but success is False')
        #     traj_optimizer.cost_module_cls.traj_opt_cost(x, loss_args, jkey, interpolation_num=args.interpolation_num_gradient, visualize=True)


        if args.save_env and not multiple_sequences:
            to_yml(args.env_type, args.ccd_type, pybullet_scene.fixed_objects, pybullet_scene.movable_objects, base_se2, shakey, loss_args, x, aux_pbids, seed)

        if pybullet_scene is not None:
            pybullet_scene.clear_scene(aux_pbids)

    print_and_log(f'{success_list}')
    print_and_log(f'{success_particle_list}')
    print_and_log(f'{valid_mask_pred_list}')
    print_and_log(f'{max_pen_depth_list}')
    print_and_log(f'{max_pen_depth_particle_list}')
    print_and_log(f'{elapsed_time_list[1:]}')
    print_and_log(f'success rate: {np.mean(success_list):.4f} / {len(success_list)}, elapsed time: {np.mean(elapsed_time_list[1:]):.4f}s, max_pen_depth: {np.mean(max_pen_depth_list)}')
    print_and_log(f'success rate particle: {np.mean(success_particle_list):.4f} / {len(success_particle_list)}, max_pen_depth: {np.mean(max_pen_depth_list):.4f}, max_pen_depth_particle: {np.mean(max_pen_depth_particle_list):.4f}')

    # logging.info(loss_aux_particle)
    # logging.info(loss_aux)

