import os
from pathlib import Path
BASE_PATH = Path(__file__).parent
import sys
if str(BASE_PATH.parent) not in sys.path:
    sys.path.append(str(BASE_PATH.parent))

import jax.numpy as jnp
import numpy as np
import jax
import pickle
from functools import partial
import matplotlib.pyplot as plt
import time
import optax
import einops
import jax.debug as jdb

import util.model_util as mutil
from util.model_util import aggregate_cost
import util.transform_util as tutil
import util.latent_obj_util as loutil
import util.render_util as rutil
import util.camera_util as cutil
import util.opt_util as optutil
import util.structs as structs
import modules.cost_module as cost_module

try:
    import modules.traj_search_module as traj_search_module
except:
    print('traj_search_module not imported')
import modules.shakey_module as shakey_module

from util.reconstruction_util import create_scene_mesh_from_oriCORNs
import util.reconstruction_util as reconstruction_util

def print_callback(inputs):
    print(inputs)

def vis_callback(inputs):
    link_canonical_oriCORN, scene_oriCORNs, robot_pqc, scene_oriCORNs_posquat, col_res_binary, models = inputs
    link_canonical_oriCORN:loutil.LatentObjects

    if col_res_binary.ndim > 1:
        col_res_binary = col_res_binary.squeeze(-1)

    args = (link_canonical_oriCORN, scene_oriCORNs, robot_pqc[jnp.logical_not(col_res_binary)], scene_oriCORNs_posquat)

    models.apply('col_decoder', *args, reduce_k=16*4, debug=True, merge=True)
    
    # linkA = link_canonical_oriCORN.apply_pq_z(robot_pqc[...,:3], robot_pqc[...,3:], rot_configs=models.rot_configs)
    # linkB = scene_oriCORNs.apply_pq_z(scene_oriCORNs_posquat[...,:3], scene_oriCORNs_posquat[...,3:], rot_configs=models.rot_configs)

    # for i in range(col_res_binary.shape[0]):
    #     if col_res_binary[i]:
    #         scene_obj = linkA[i].concat(linkB, axis=0)
    #         visualize.create_fps_fcd_from_oriCORNs(s, True)


def node_valid_cheker(models:mutil.Models, shakey:shakey_module.Shakey, robot_states, jkey, col_args:structs.LossArgs, merge_robot_state=False):
    '''
    one scene, maximum one grasped object, multiple robot states in batch
    if you want to apply multiple scene, robot states, grasped object, grasp pose, use vmap

    scene_oriCORNs: loutil.LatentObjects (NS, ...)
    robot_states: (NB, 6)
    grasped_oriCORN: (...)
    grasp_pq_wrt_ee: (7,)
    merge_robot_state: bool
        if false, merge robot state to one batch -> output (NB,)
        if true, output ()

    return 
        state_invalid_mask: (NB, )
        state_col_cost: (NB, )
    '''

    if col_args is not None:
        scene_oriCORNs, grasped_oriCORN, grasp_pq_wrt_ee = col_args.fixed_oriCORNs, col_args.moving_oriCORNs, col_args.ee_to_obj_pq
    else:
        scene_oriCORNs, grasped_oriCORN, grasp_pq_wrt_ee = None, None, None

    if scene_oriCORNs is not None:
        assert scene_oriCORNs.ndim == 1 or robot_states.ndim - scene_oriCORNs.ndim == 1
    grasped = False
    if grasped_oriCORN is not None:
        grasped = True
        # assert robot_states.ndim - grasped_oriCORN.ndim == 2
        assert grasp_pq_wrt_ee is not None
        # assert robot_states.ndim - grasp_pq_wrt_ee.ndim == 1
        if grasped_oriCORN.ndim == 0:
            grasped_oriCORN = grasped_oriCORN[None]
        if grasp_pq_wrt_ee.ndim == 1:
            grasp_pq_wrt_ee = grasp_pq_wrt_ee[None]

    scene_batch_shape = robot_states.shape[:-2]

    shakey_collision_check_link_idx = shakey.collision_check_link_idx

    robot_pqc = shakey.FK(robot_states, oriCORN_out=False) # (NB, NR, 7)
    ee_pqc = robot_pqc[...,shakey.ee_idx,:]
    robot_pqc = robot_pqc[...,shakey_collision_check_link_idx,:]

    # scene_oriCORNs_pos = scene_oriCORNs.pos # (NS, 3)

    link_canonical_oriCORN = shakey.link_canonical_oriCORN[shakey_collision_check_link_idx]
    
    if grasped:
        for dim in scene_batch_shape:
            link_canonical_oriCORN = link_canonical_oriCORN.extend_and_repeat_outer_shape(dim, axis=-2)
        link_canonical_oriCORN = link_canonical_oriCORN.concat(grasped_oriCORN, axis=-1)
        robot_pqc = jnp.concat([robot_pqc, tutil.pq_multi(ee_pqc, grasp_pq_wrt_ee)[...,None,:]], axis=-2)

    # scene_oriCORNs_posquat = jnp.c_[jnp.zeros_like(scene_oriCORNs_pos), tutil.qzero(scene_oriCORNs_pos.shape[:-1])] # (NS, 7) # add dummy query dim

    if merge_robot_state:
        # assert len(scene_batch_shape) == 0
        link_canonical_oriCORN = link_canonical_oriCORN.extend_and_repeat_outer_shape(robot_states.shape[-2], axis=-2) # (NB NR ...)
        link_canonical_oriCORN = link_canonical_oriCORN.reshape_outer_shape(scene_batch_shape + (-1,)) # (NB*NR ...)
        if len(scene_batch_shape) == 0:
            robot_pqc = robot_pqc.reshape(1, -1, *robot_pqc.shape[-1:]) # (1, NB*NR, 7) add dummy query dim
        else:
            robot_pqc = robot_pqc.reshape(*scene_batch_shape, -1, *robot_pqc.shape[-1:]) # (1, NB*NR, 7) add dummy query dim
    # (1, 400, 7)
    # (50, 8, 7)
    # dim for col decoder
    # Here, NB is for query dimensions
    # link_canonical_oriCORN : (NR, ...)
    # robot_pqc: (NB, NR, 7)
    # scene_oriCORNs: (NS, ...)
    # scene_oriCORNs_posquat: (NS, 7) or (NB, NS, 7)
    # if merge, scene dimensions are merged into one objects
    # col_cost = models.apply('col_decoder', link_canonical_oriCORN, scene_oriCORNs, robot_pqc, scene_oriCORNs_posquat, merge=True) # (NB, 1)
    reduce_k = 16*2
    # reduce_fps = 50
    # col_cost = models.apply('col_decoder', link_canonical_oriCORN.reduce_fps(reduce_fps, jax.random.PRNGKey(0)), scene_oriCORNs.reduce_fps(reduce_fps, jax.random.PRNGKey(0)), robot_pqc, scene_oriCORNs_posquat, reduce_k=reduce_k, merge=True) # (NB, 1)

    col_cost = models.apply('col_decoder', scene_oriCORNs, link_canonical_oriCORN,
                             robot_pqc, reduce_k=reduce_k, merge=True) # (NB, 1)
    col_cost = col_cost[0]
    # col_cost = pairwise_collision_check(models, link_canonical_oriCORN, scene_oriCORNs, robot_pqc, scene_oriCORNs_posquat)[...,None] # not perged

    if merge_robot_state and len(scene_batch_shape) == 0:
        col_cost = col_cost.squeeze(-2)
    col_res_binary = col_cost > -1 # True for collision
    state_invalid_mask = jnp.any(col_res_binary, axis=(-1))
    col_cost = jnp.max(col_cost, axis=(-1))

    # debug visualizer
    # jax.lax.cond(col_res_binary[0][0], lambda: jdb.callback(lambda args: models.apply('col_decoder', *args, reduce_k=reduce_k, debug=True, merge=True), (link_canonical_oriCORN, scene_oriCORNs, robot_pqc[:1], scene_oriCORNs_posquat)),
    #              lambda: None)
    # jax.lax.cond(col_res_binary[0][0], lambda: jdb.callback(vis_callback, (link_canonical_oriCORN, scene_oriCORNs, robot_pqc, scene_oriCORNs_posquat, state_invalid_mask, models)),
    #              lambda: None)

    # perform self collision
    self_col_logit = models.apply('self_collision_checker', robot_states[..., -shakey.num_act_joints:]).squeeze(-1)
    if merge_robot_state:
        state_invalid_mask = jnp.logical_or(state_invalid_mask, jnp.any(self_col_logit > 0.0, axis=(-1)))
        self_col_cost = aggregate_cost(self_col_logit, (-1,))
    else:
        state_invalid_mask = jnp.logical_or(state_invalid_mask, (self_col_logit > 0.0))
        self_col_cost = self_col_logit
    # self_col_cost = 0
    
    # jax.lax.cond(col_res_binary[0][0], lambda: jdb.callback(vis_callback, (link_canonical_oriCORN, scene_oriCORNs, robot_pqc, scene_oriCORNs_posquat, state_invalid_mask, models)),
    #              lambda: None)

    state_cost = col_cost + self_col_cost



    return state_invalid_mask, state_cost, {'self_col_cost':self_col_cost, 'col_cost':col_cost}



def path_valid_cheker(models:mutil.Models, shakey:shakey_module.Shakey, robot_states, jkey, 
                      col_args:structs.LossArgs, collaps_ndim_in_robot_states=2, nitp=8, visualize=False):
    '''
    scene_oriCORNs : loutil.LatentObjects (NS, ...)
    robot_states: (..., 2, 6)
    return 
        state_invalid_mask: (NB, )
        state_col_cost: (NB, )
    '''

    if col_args is not None:
        scene_oriCORNs, grasped_oriCORN, grasp_pq_wrt_ee = col_args.fixed_oriCORNs, col_args.moving_oriCORNs, col_args.ee_to_obj_pq
    else:
        scene_oriCORNs, grasped_oriCORN, grasp_pq_wrt_ee = None, None, None

    grasped = False
    if grasped_oriCORN is not None:
        grasped = True
        assert grasp_pq_wrt_ee is not None
        if grasped_oriCORN.ndim == 0:
            grasped_oriCORN = grasped_oriCORN[None]
        if grasp_pq_wrt_ee.ndim == 1:
            grasp_pq_wrt_ee = grasp_pq_wrt_ee[None]

    shakey_collision_check_link_idx = shakey.collision_check_link_idx

    robot_states_itp = robot_states[...,0:1,:] + jnp.linspace(0, 1, nitp, endpoint=True)[...,None] * (robot_states[...,1:2,:] - robot_states[...,0:1,:]) # (NB, nitp, 6)
    robot_pqc = shakey.FK(robot_states_itp, oriCORN_out=False) # (NB, nitp, NR, 7)
    ee_pqc = robot_pqc[...,shakey.ee_idx,:]
    robot_pqc = robot_pqc[...,shakey_collision_check_link_idx,:] # (NB, nitp, NR, 7)
    link_canonical_oriCORN = shakey.link_canonical_oriCORN[shakey_collision_check_link_idx]

    if grasped:
        for dim in grasped_oriCORN.shape[:-1]:
            link_canonical_oriCORN = link_canonical_oriCORN.extend_and_repeat_outer_shape(dim, axis=-2)
        link_canonical_oriCORN = link_canonical_oriCORN.concat(grasped_oriCORN, axis=-1)
        robot_pqc = jnp.concat([robot_pqc, tutil.pq_multi(ee_pqc, grasp_pq_wrt_ee[...,None,:])[...,None,:]], axis=-2)

    # scene_oriCORNs_pos = scene_oriCORNs.pos # (NS, 3)

    # scene_oriCORNs_posquat = jnp.c_[jnp.zeros_like(scene_oriCORNs_pos), tutil.qzero(scene_oriCORNs_pos.shape[:-1])] # (NS, 7)

    # def col_check_body_fn(carry, i):
    #     col_cost = mutil.pairwise_collision_check(models, link_canonical_oriCORN, scene_oriCORNs, robot_pqc[i], scene_oriCORNs_posquat)
    #     return carry, col_cost
    
    # _, col_cost = jax.lax.scan(col_check_body_fn, None, jnp.arange(robot_pqc.shape[0]))


    # pairwise collision check
    # col_cost = mutil.pairwise_collision_check(models, link_canonical_oriCORN, scene_oriCORNs, robot_pqc, scene_oriCORNs_posquat) # (NB, nitp)


    original_shape = robot_pqc.shape
    robot_pqc = robot_pqc.reshape(-1, *robot_pqc.shape[-3:]) # (batch, NITP, NLINK, ND)

    # link_oriCORNs_tf = link_canonical_oriCORN.apply_pq_z(robot_pqc, models.rot_configs)
    # import open3d as o3d
    # for i in range(link_oriCORNs_tf.shape[0]):
    #     cur_link_oriCORNs = link_oriCORNs_tf[i]
    #     o3d.visualization.draw_geometries(cur_link_oriCORNs.get_fps_sphere_o3d())
            

    col_cost = models.apply('col_decoder', scene_oriCORNs, link_canonical_oriCORN, pq_transform_B=robot_pqc,
                                                                        reduce_k=40, path_check=True, jkey=jkey, debug=visualize)
    if visualize:
        return col_cost
    col_cost = mutil.aggregate_cost(col_cost, axes=(-1,-2,-3))
    col_cost = col_cost.reshape(original_shape[:-3]) # (NB, NEDGE)
    state_invalid_mask = col_cost > -1 # True for collision
    # state_invalid_mask = jnp.any(col_res_binary, axis=(-1))
    # col_cost = aggregate_cost(col_cost, (-1,))

    # add speed penalty
    # vel = robot_states[...,-1,:] - robot_states[...,0,:]
    # vel = jnp.linalg.norm(vel, axis=-1)
    # col_cost = vel*col_cost

    # perform self collision
    self_col_logit = models.apply('self_collision_checker', robot_states_itp[..., -shakey.num_act_joints:]).squeeze(-1)
    state_invalid_mask = jnp.logical_or(state_invalid_mask, jnp.any(self_col_logit > 0, axis=(-1,)))
    self_col_cost = aggregate_cost(self_col_logit, (-1,))
    # self_col_cost = 0

    state_cost = col_cost + self_col_cost

    return state_invalid_mask, state_cost, {'self_col_cost':self_col_cost, 'col_cost':col_cost}



def Bezier_curve_3points(p1, p2, p3, t):
    '''
    p1, p2, p3 -> (f,)
    t -> (s,)
    '''
    t_extended = t[...,None]
    res = (1-t_extended)**2*p1[None] + 2*(1-t_extended)*t_extended*p2[None] + t_extended**2*p3[None]
    return res

def interval_based_interpolations(waypnts, gap):
    wp_len = jnp.linalg.norm(waypnts[1:] - waypnts[:-1], axis=-1)
    entire_len = jnp.sum(wp_len)
    assert entire_len > gap
    int_no = int(entire_len/gap)+1
    return way_points_to_trajectory(waypnts, int_no, cos_transition=True)



# @partial(jax.jit, static_argnums=[1,2])
def way_points_to_trajectory(waypnts, resolution, cos_transition=True):
    """???"""
    wp_len = jnp.linalg.norm(waypnts[1:] - waypnts[:-1], axis=-1)
    wp_len = wp_len/jnp.sum(wp_len).clip(1e-5)
    wp_len = jnp.where(wp_len<1e-4, 0, wp_len)
    wp_len = wp_len/jnp.sum(wp_len)
    wp_len_cumsum = jnp.cumsum(wp_len)
    wp_len_cumsum = jnp.concatenate([jnp.array([0]),wp_len_cumsum], 0)
    wp_len_cumsum = wp_len_cumsum.at[-1].set(1.0)
    indicator = jnp.linspace(0, 1, resolution)
    if cos_transition:
        indicator = (-jnp.cos(indicator*jnp.pi)+1)/2.
    included_idx = jnp.sum(indicator[...,None] > wp_len_cumsum[1:], axis=-1)
    
    upper_residual = (wp_len_cumsum[included_idx+1] - indicator)/wp_len[included_idx].clip(1e-5)
    upper_residual = upper_residual.clip(0.,1.)
    bottom_residual = 1.-upper_residual
    
    traj = waypnts[included_idx] * upper_residual[...,None] + waypnts[included_idx+1] * bottom_residual[...,None]
    traj = jnp.where(wp_len[included_idx][...,None] < 1e-4, waypnts[included_idx], traj)
    traj = traj.at[0].set(waypnts[0])
    traj = traj.at[-1].set(waypnts[-1])
    
    return traj

def inequality_cost_fn(x, margin=0.0):
    '''
    x - positive: bad
    '''
    return jnp.maximum(x+margin, 0)**2

class MotionPlanner(object):
    
    # def __init__(self, models:mutil.Models, shakey=None, robot_height=0, robot_name='shakey'):
    def __init__(self, cost_module_cls:cost_module.CostModules):

        self.cost_module_cls = cost_module_cls

        self.col_checker = jax.jit(self.col_checker_nojit)

        self.path_valid_cheker = jax.jit(self.path_valid_cheker_nojit)

        self.shakey = self.cost_module_cls.shakey

        self.se2_bounds =  self.cost_module_cls.se2_bounds
    
    def col_checker_nojit(self, traj, col_args:structs.LossArgs, jkey):
        '''
        '''
        loss, loss_aux = self.cost_module_cls.traj_opt_cost(traj, col_args, jkey, interpolation_num=None, ccd=False)
        col_mask = loss_aux['collision_binary']

        return col_mask, loss

    def path_valid_cheker_nojit(self, traj, col_args:structs.LossArgs, jkey):
        '''
        '''
        loss, loss_aux = self.cost_module_cls.traj_opt_cost(traj, col_args, jkey, interpolation_num=8, ccd=True, visualize=False)
        col_mask = loss_aux['collision_binary']

        return col_mask, loss


    def plan(self, jkey, init_state, goal_state, loss_args:structs.LossArgs, 
             node_size=6000, num_neighbors=8, node_one_batch_size=2000, path_one_batch_size=2000, 
             node_visualize_func=None):

        entire_start_t = time.time()
        if init_state.shape[-1] == self.shakey.num_act_joints:
            joint_min = self.shakey.q_lower_bound
            joint_max = self.shakey.q_upper_bound
            node_scale = None
        else:
            assert self.se2_bounds is not None
            joint_min = jnp.concat([self.se2_bounds[0], self.shakey.q_lower_bound])
            joint_max = jnp.concat([self.se2_bounds[1], self.shakey.q_upper_bound])
            node_scale = np.concat([np.array([2,2,1.]), np.ones(self.shakey.q_upper_bound.shape)])
        path_res, aux = traj_search_module.PRM_node_only(jkey, init_state, goal_state, node_size, num_neighbors, 
                                                joint_max, joint_min, self.col_checker, self.path_valid_cheker,
                                                state_scale=node_scale,
                                                col_args=loss_args, 
                                                node_one_batch_size=node_one_batch_size,
                                                path_one_batch_size=path_one_batch_size,
                                                node_visualize_func=node_visualize_func)
        if path_res is None:
            return None, None
        path_res = jnp.array(path_res)
        # path_res = way_points_to_trajectory(path_res, resolution=16, cos_transition=False)

        # rf_start_t = time.time()
        # path_rf = self.refine_traj_jit(path_res, (scene_oriCORNs,grasped_oriCORN,grasp_pq_wrt_ee))
        # path_rf = path_rf[0]
        # rf_end_t = time.time()
        # entire_end_t = time.time()
        # print(f'Entire time: {entire_end_t-entire_start_t} / RF time: {rf_end_t-rf_start_t}')

        # _, origin_loss_aux = traj_cost(path_res, (scene_oriCORNs,), self.col_checker_notjit)
        # _, rf_loss_aux = traj_cost(path_rf, (scene_oriCORNs,), self.col_checker_notjit)
        # print({k:origin_loss_aux[k] for k in origin_loss_aux if k[-4:]=='cost'})
        # print({k:rf_loss_aux[k] for k in rf_loss_aux if k[-4:]=='cost'})

        return path_res, {'traj_non_rf':path_res}

if __name__ == '__main__':

    import modules.traj_opt_module as traj_opt_module

    models = mutil.Models().load_pretrained_models()
    models = models.load_self_collision_model()

    urdf_dirs = "assets/ur5/urdf/shakey_open.urdf"
    shakey = shakey_module.load_urdf_kinematics(urdf_dirs, models=models)

    traj_optimizer = traj_opt_module.TrajectoryOptimizer(models, shakey)
    traj_opt_jit = jax.jit(traj_optimizer.perform_multiple_seed)

    motion_planner = MotionPlanner(models, shakey)

    init_state = np.array([1.5, 0, 0, 0, np.pi/2, 0])
    goal_state = np.array([-1.5, 0, 0, 0, np.pi/2, 0])
    # init_state = np.random.uniform(-1, 1, size=(6,))
    # goal_state = np.random.uniform(-1, 1, size=(6,))

    # pybullet visualizer
    import pybullet as p
    p.connect(p.GUI)
    # p.connect(p.DIRECT)
    shakey_pb_uid = p.loadURDF(urdf_dirs, useFixedBase=True)
    def reset_q_pb(q):
        for i in range(6):
            p.resetJointState(shakey_pb_uid, i+1, q[i])
    reset_q_pb(init_state)


    num_scene_obj = 10
    scene_oriCORNs_list = []
    while len(scene_oriCORNs_list) < num_scene_obj:
        random_pos = np.random.uniform(np.array([0.1,-0.8,-1]), np.array([0.8,0.8,1.0]), size=(3,))
        random_quat = tutil.qrand((), jax.random.PRNGKey(0))
        random_obj_idx = np.random.choice(models.canonical_latent_obj.shape[0])
        scale = models.scale_to_origin[random_obj_idx] * 0.20

        obj_filename = models.asset_path_util.obj_paths[random_obj_idx]
        vis_uid = p.createVisualShape(shapeType=p.GEOM_MESH, fileName=obj_filename, meshScale=[scale,scale,scale])
        col_uid = p.createCollisionShape(shapeType=p.GEOM_MESH, fileName=obj_filename, meshScale=[scale,scale,scale])
        p.createMultiBody(baseVisualShapeIndex=vis_uid, baseCollisionShapeIndex=col_uid, basePosition=random_pos, baseOrientation=random_quat)

        p.performCollisionDetection()
        col = p.getContactPoints()
        if len(col) != 0:
            p.removeBody(vis_uid)
            continue

        scene_oriCORNs = models.mesh_aligned_canonical_obj[random_obj_idx].apply_scale(scale, center=jnp.zeros(3))
        scene_oriCORNs = scene_oriCORNs.apply_pq_z(random_pos, random_quat, models.rot_configs)
        scene_oriCORNs_list.append(scene_oriCORNs)

    scene_oriCORNs = jax.tree_util.tree_map(lambda *args: jnp.stack(args, axis=0), *scene_oriCORNs_list)

    random_obj_idx = np.random.choice(models.canonical_latent_obj.shape[0])
    scale = models.scale_to_origin[random_obj_idx] * 0.20
    grasped_oriCORN = models.mesh_aligned_canonical_obj[random_obj_idx].apply_scale(scale, center=jnp.zeros(3))
    
    # grasped obj
    obj_filename = models.asset_path_util.obj_paths[random_obj_idx]
    vis_uid = p.createVisualShape(shapeType=p.GEOM_MESH, fileName=obj_filename, meshScale=[scale,scale,scale])
    grasped_obj_pbid = p.createMultiBody(baseVisualShapeIndex=vis_uid)

    # grasp_pqc = jnp.array([0,0,1.0,0,0,0,1.0])
    grasp_pqc = jnp.array([0,0,0.2,0,0,0,1.0])

    def reset_grasped_obj():
        epos, equat = p.getLinkState(shakey_pb_uid, motion_planner.shakey.ee_idx, computeForwardKinematics=True)[4:6]
        new_epos, new_equat = p.multiplyTransforms(epos, equat, grasp_pqc[:3], grasp_pqc[3:])
        p.resetBasePositionAndOrientation(grasped_obj_pbid, new_epos, new_equat)
    reset_grasped_obj()

    jkey = jax.random.PRNGKey(0)
    # traj, aux = motion_planner.plan(jkey, init_state, goal_state, scene_oriCORNs, 
    #                                 grasped_oriCORN=grasped_oriCORN, 
    #                                 grasp_pq_wrt_ee=grasp_pqc,
    #                                 node_size=8000, num_neighbors=8,
    #                                 node_one_batch_size=4000, path_one_batch_size=1000)
    import util.structs as structs
    loss_args = structs.LossArgs(scene_oriCORNs)
    traj = traj_opt_jit(init_state, goal_state, loss_args, jkey)

    from tqdm import tqdm

    # traj_interpolated = way_points_to_trajectory(aux['traj_non_rf'], resolution=2000, cos_transition=True)
    # for q in tqdm(traj_interpolated):
    #     reset_q_pb(q)
    #     reset_grasped_obj()
    #     time.sleep(0.001)

    traj_interpolated = way_points_to_trajectory(traj, resolution=2000, cos_transition=True)
    for q in tqdm(traj_interpolated):
        reset_q_pb(q)
        reset_grasped_obj()
        time.sleep(0.001)

    print(1)
    # node_valid_cheker(models, shakey, scene_oriCORNs, robot_states)

