import os
import sys
import pickle
from pathlib import Path
from typing import List, Tuple, Optional
import copy
import random
import numpy as np
import jax
import jax.numpy as jnp
import pybullet as p
from tqdm import tqdm
import trimesh  # Used to calculate mesh size and alignment
from scipy.spatial.transform import Rotation as sciR
from functools import partial
import pybullet_data

# Ensure the base path is in sys.path
if __name__ == '__main__':
    BASE_PATH = Path(__file__).parent
    if str(BASE_PATH.parent) not in sys.path:
        sys.path.append(str(BASE_PATH.parent))

# Import custom utilities (assuming these are available in your environment)
import util.latent_obj_util as loutil
import util.transform_util as tutil
import util.camera_util as cutil
# from dataset.scene_generation import SceneCls
# SceneData = SceneCls.SceneData
from util.structs import SceneData
import util.asset_util as asutil
import util.model_util as mutil
import modules.shakey_module as shakey_module

def compute_view_matrix(distance, yaw, pitch, target, up=[0, 0, 1]):
    # Convert degrees to radians.
    yaw_rad = np.deg2rad(yaw)
    pitch_rad = np.deg2rad(pitch)
    target = np.array(target, dtype=np.float32)
    
    # Compute camera (eye) position based on spherical coordinates.
    eye = np.array([
        target[0] - distance * np.cos(pitch_rad) * np.sin(yaw_rad),
        target[1] - distance * np.cos(pitch_rad) * np.cos(yaw_rad),
        target[2] + distance * np.sin(pitch_rad)
    ], dtype=np.float32)
    
    # Build coordinate axes.
    forward = target - eye
    forward /= np.linalg.norm(forward)
    
    up = np.array(up, dtype=np.float32)
    right = np.cross(forward, up)
    right /= np.linalg.norm(right)
    true_up = np.cross(right, forward)
    
    # Create the rotation part of the view matrix.
    R = np.eye(4, dtype=np.float32)
    R[0, :3] = right
    R[1, :3] = true_up
    R[2, :3] = -forward  # note: view matrix uses -forward
    
    # Create the translation part.
    T = np.eye(4, dtype=np.float32)
    T[:3, 3] = -eye
    
    # The view matrix is the product: R * T
    view = R @ T
    return view

def compute_projection_matrix(fov, aspect, near, far):
    # fov is in degrees
    f = 1.0 / np.tan(np.deg2rad(fov) / 2.0)
    proj = np.zeros((4, 4), dtype=np.float32)
    proj[0, 0] = f / aspect
    proj[1, 1] = f
    proj[2, 2] = (far + near) / (near - far)
    proj[2, 3] = (2 * far * near) / (near - far)
    proj[3, 2] = -1.0
    return proj


class SceneObject:
    """
    Contains mesh and pose information of an object.
    """

    def __init__(
        self,
        mesh_path: str,
        scale: List[float],
        base_pose: Tuple[List[float], List[float]],
        aabb: Tuple[List[float], List[float]],
        pb_uid: int = None,
        col_mesh_path: Optional[str] = None,
    ):  
        # if mesh_path is bype, convert to string
        if isinstance(mesh_path, bytes):
            mesh_path = mesh_path.decode('utf-8')
        self.mesh_path = mesh_path
        self.scale = scale
        self.base_pose = base_pose
        self.aabb = aabb
        self.pb_uid = pb_uid
        self.col_mesh_path = col_mesh_path

    def convert_to_oriCORN(self, models) -> loutil.LatentObjects:
        """
        Converts the object to an oriCORN representation.

        Args:
            models: An instance of the Models class.

        Returns:
            The oriCORN representation of the object.
        """

        try:
            dataset_name = os.path.basename(os.path.dirname(os.path.dirname(self.mesh_path)))
            if "cvx" in self.mesh_path:
                dataset_name = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(self.mesh_path))))
            asset_oriCORN_name = f'assets_oriCORNs/{models.pretrain_ckpt_id}/{dataset_name}/{os.path.basename(self.mesh_path).split(".")[0]}.pkl'
            with open(asset_oriCORN_name, 'rb') as f:
                canonical_oriCORN = pickle.load(f)

        except:
            # self.mesh_idx = identify_obj_id_from_mesh_name(self.mesh_path, models.canonical_latent_obj_filename_list)
            self.mesh_idx = models.asset_path_util.get_obj_id(self.mesh_path)
            if self.mesh_idx == -1:
                raise ValueError(f"Mesh {self.mesh_path} not found in models.")
            canonical_oriCORN = models.mesh_aligned_canonical_obj[self.mesh_idx]
            
        applied_scale = self.scale if isinstance(self.scale, float) else self.scale[0]
        obj_oriCORN = canonical_oriCORN.apply_scale(applied_scale, center=jnp.zeros(3))
        obj_oriCORN = obj_oriCORN.apply_pq_z(
            jnp.array(self.base_pose[0]), jnp.array(self.base_pose[1]), models.rot_configs
        )
        return obj_oriCORN

    def convert_to_o3d(self):
        import open3d as o3d
        mesh = o3d.io.read_triangle_mesh(self.mesh_path)
        mesh.compute_vertex_normals()
        mesh.scale(scale=self.scale[0], center=np.zeros(3,dtype=np.float32))
        mesh.rotate(sciR.from_quat(self.base_pose[1]).as_matrix(), center=np.zeros(3))
        mesh.translate(self.base_pose[0])
        return mesh

    def load_in_pb(self):
        visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=self.mesh_path,
                meshScale=self.scale
            )
        collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=self.mesh_path,
            meshScale=self.scale
        )
        self.pb_uid = p.createMultiBody(
            baseMass=1.0,  # Movable objects have mass
            baseCollisionShapeIndex=collision_shape_id,
            baseVisualShapeIndex=visual_shape_id,
            basePosition=self.base_pose[0],
            baseOrientation=self.base_pose[1]
        )

    def set_base_pose(self, base_pose):
        """
        Set the base pose of the object in PyBullet.

        Args:
            base_pose: A tuple containing the position and orientation.
        """
        p.resetBasePositionAndOrientation(self.pb_uid, base_pose[0], base_pose[1])
        self.base_pose = p.getBasePositionAndOrientation(self.pb_uid)

class Robots:
    def __init__(self, robot_urdf_path, robot_uid, q, robot_height):
        self.robot_urdf_path = robot_urdf_path
        self.robot_uid = robot_uid
        self.q = q
        self.robot_height = robot_height

    def convert_to_oriCORN(self, models) -> loutil.LatentObjects:
        """
        Converts the object to an oriCORN representation.

        Args:
            models: An instance of the Models class.

        Returns:
            The oriCORN representation of the object.
        """
        pass
        # # self.mesh_idx = identify_obj_id_from_mesh_name(self.mesh_path, models.canonical_latent_obj_filename_list)
        # self.mesh_idx = models.asset_path_util.get_obj_id(self.robot_urdf_path)
        # if self.mesh_idx == -1:
        #     raise ValueError(f"Mesh {self.robot_urdf_path} not found in models.")

        # canonical_oriCORN = models.mesh_aligned_canonical_obj[self.mesh_idx]
        # obj_oriCORN = canonical_oriCORN.apply_scale(1.0, center=jnp.zeros(3))
        # obj_oriCORN = obj_oriCORN.apply_pq_z(
        #     jnp.array([0, 0, 0]), jnp.array([0, 0, 0, 1]), models.rot_configs
        # )
        # return obj_oriCORN

    def convert_to_o3d(self):
        pass
        # import open3d as o3d
        # mesh = o3d.io.read_triangle_mesh(self.robot_urdf_path)
        # return mesh

    def load_in_pb(self):
        self.robot_uid = p.loadURDF(self.robot_urdf_path, useFixedBase=True)

        # reset robot base
        base_pos = np.zeros(3)
        base_pos[:2] = self.q[:2]
        base_pos[2] = self.robot_height
        base_quat = sciR.from_euler('z', self.q[2]).as_quat()
        print(base_quat)
        p.resetBasePositionAndOrientation(self.robot_uid, base_pos, base_quat)

        # reset robot joints
        robot_joints = self.q[3:]
        for i, joint_state in enumerate(robot_joints):
            p.resetJointState(self.robot_uid, i+1, joint_state)


class PybulletScene:
    """
    This class saves the scene information from PyBullet, which contains all necessary information
    to reconstruct the scene in PyBullet again. This includes poses, meshes, scales of both
    movable and fixed objects.
    """

    def __init__(
        self,
        movable_objects: List[SceneObject],
        fixed_objects: List[SceneObject],
        robot: Robots = None,
        # robot_uid: int = None
    ):
        self.movable_objects = movable_objects
        self.fixed_objects = fixed_objects
        self.robot = robot

    def reconstruct_scene_in_pybullet(self, visualize=True):
        """
        Reconstructs the scene in PyBullet using the saved information.

        Args:
            visualize (bool): Whether to use GUI visualization.
        """
        if visualize:
            p.connect(p.GUI)

            p.setGravity(0, 0, -9.81)

            CAMERA_DISTANCE        = 5.79
            CAMERA_PITCH           = -88.94
            CAMERA_YAW             = -270.0
            CAMERA_TARGET_POSITION = [0,0, 0]

            p.resetDebugVisualizerCamera(
                cameraDistance = CAMERA_DISTANCE,
                cameraYaw = CAMERA_YAW,
                cameraPitch = CAMERA_PITCH,
                cameraTargetPosition = CAMERA_TARGET_POSITION
            )
        else:
            p.connect(p.DIRECT)

        # Reconstruct fixed objects
        for obj in self.fixed_objects:
            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=obj.mesh_path,
                meshScale=obj.scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=obj.mesh_path,
                meshScale=obj.scale
            )
            obj.pb_uid = p.createMultiBody(
                baseMass=0.0,  # Fixed objects have zero mass
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=obj.base_pose[0],
                baseOrientation=obj.base_pose[1]
            )

        # Reconstruct movable objects
        for obj in self.movable_objects:
            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=obj.mesh_path,
                meshScale=obj.scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=obj.mesh_path,
                meshScale=obj.scale
            )
            obj.pb_uid = p.createMultiBody(
                baseMass=1.0,  # Movable objects have mass
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=obj.base_pose[0],
                baseOrientation=obj.base_pose[1]
            )

        # Optionally reconstruct the robot
        if self.robot is not None:
            self.robot.load_in_pb()

    def clear_scene(self, aux_pbids=None, visualize=False):
        """
        """
        # if visualize:
        #     p.connect(p.GUI)

        #     p.setGravity(0, 0, -9.81)

        #     CAMERA_DISTANCE        = 5.79
        #     CAMERA_PITCH           = -88.94
        #     CAMERA_YAW             = -270.0
        #     CAMERA_TARGET_POSITION = [0,0, 0]

        #     p.resetDebugVisualizerCamera(
        #         cameraDistance = CAMERA_DISTANCE,
        #         cameraYaw = CAMERA_YAW,
        #         cameraPitch = CAMERA_PITCH,
        #         cameraTargetPosition = CAMERA_TARGET_POSITION
        #     )
        # else:
        #     p.connect(p.DIRECT)

        # Reconstruct fixed objects
        for obj in self.fixed_objects:
            p.removeBody(obj.pb_uid)

        # Reconstruct movable objects
        for obj in self.movable_objects:
            p.removeBody(obj.pb_uid)

        if aux_pbids is not None:
            for pbid in aux_pbids:
                p.removeBody(pbid)

        # Optionally reconstruct the robot
        if self.robot is not None:
            self.robot.load_in_pb()

    def convert_to_oriCORNs(self, models):

        # Movable objects
        movable_oriCORNs = [obj.convert_to_oriCORN(models) for obj in self.movable_objects]
        self.movable_oriCORNs: loutil.LatentObjects = jax.tree_util.tree_map(
            lambda *x: jnp.stack(x), *movable_oriCORNs
        )

        # Fixed objects
        fixed_oriCORNs = [obj.convert_to_oriCORN(models) for obj in self.fixed_objects]
        if len(fixed_oriCORNs) > 0:
            self.fixed_oriCORNs: loutil.LatentObjects = jax.tree_util.tree_map(
                lambda *x: jnp.stack(x), *fixed_oriCORNs
            )
        else:
            self.fixed_oriCORNs = []
    
    def convert_to_o3d(self):
        # Movable objects
        movable_o3d = [obj.convert_to_o3d() for obj in self.movable_objects]
        # self.movable_o3d = movable_o3d

        # Fixed objects
        fixed_o3d = [obj.convert_to_o3d() for obj in self.fixed_objects]
        # self.fixed_o3d = fixed_o3d
        return movable_o3d, fixed_o3d

    def convert_to_o3d_aabbs(self):

        # get aabb in pybullet
        aabb_pb_list = []
        for obj in self.movable_objects:
            aabb = p.getAABB(obj.pb_uid)
            aabb_size = np.array(aabb[1]) - np.array(aabb[0])
            aabb_center = (np.array(aabb[1]) + np.array(aabb[0]))/2
            aabb_pb_list.append((aabb_center, aabb_size))
        
        for obj in self.fixed_objects:
            aabb = p.getAABB(obj.pb_uid)
            aabb_size = np.array(aabb[1]) - np.array(aabb[0])
            aabb_center = (np.array(aabb[1]) + np.array(aabb[0]))/2
            aabb_pb_list.append((aabb_center, aabb_size))

        import open3d as o3d
        # convert aabb_pb into o3d
        aabb_o3d_list = []
        for aabb in aabb_pb_list:
            aabb_ = o3d.geometry.AxisAlignedBoundingBox(min_bound=aabb[0]-aabb[1]/2, max_bound=aabb[0]+aabb[1]/2)
            aabb_.color = np.array([0.0, 0.0, 0.5])
            aabb_o3d_list.append(aabb_)

        return aabb_o3d_list

    def show_oriCORNs_o3d(self, models, visualize=True):
        self.convert_to_oriCORNs(models)
        from util.reconstruction_util import create_scene_mesh_from_oriCORNs
        dec = jax.jit(partial(models.apply, 'sdf_decoder'))
        return create_scene_mesh_from_oriCORNs(self.fixed_oriCORNs.concat(self.movable_oriCORNs, axis=0), dec=dec, level=0.1, visualize=visualize)

    def generate_scene_data(self, max_obj_no, pixel_size=[238, 420], view_base_target=[0.8, 0, 1.05], view_type='close', num_views=5, randomize_color=True, robot_pbuid=None)->SceneData:

        # define semi-sphere view points

        if randomize_color:
            color_pool = np.random.uniform(0, 1, size=(len(self.fixed_objects) + len(self.movable_objects)+5, 3))
            # randomize colors of fixed objects
            for obj in self.fixed_objects:
                color = random.choice(color_pool)
                p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist()+[1])

            # randomize colors of movable objects
            for obj in self.movable_objects:
                color = random.choice(color_pool)
                p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist()+[1])

            # randomize floow color - assume floor id is 0
            # floor_color = np.random.uniform(0, 1, 3)
            floor_color = random.choice(color_pool)
            p.changeVisualShape(0, -1, rgbaColor=floor_color.tolist()+[1])

            if robot_pbuid is not None:
                # randomize robot color for each links
                for i in range(0, p.getNumJoints(robot_pbuid)):
                    color = random.choice(color_pool)
                    p.changeVisualShape(robot_pbuid, i, rgbaColor=color.tolist()+[1])
            

        cam_posquats = []
        cam_intrinsics = []
        rgbs = []
        depths = []
        segs = []

        yaw_base = np.random.uniform(-np.pi/3, np.pi/3)
        yaw_range = np.random.uniform(np.pi/4, np.pi/2)
        for i in range(num_views):
            if np.array(view_base_target).ndim >= 2:
                view_base_target_apply = np.array(view_base_target)[np.random.randint(0, len(view_base_target))]
            else:
                view_base_target_apply = view_base_target
            if view_type=='close_side':
                view_target = view_base_target_apply + np.random.uniform(-0.2, 0.2, 3)
                dist = np.random.uniform(0.8, 1.5)
                yaw = np.random.uniform(yaw_base-yaw_range/2, yaw_base+yaw_range/2)  # Horizontal angle from base direction
                pitch = np.random.uniform(-np.pi*1/3, 0)  # Vertical angle
            elif view_type=='far':
                view_target = view_base_target_apply + np.random.uniform(-0.4, 0.4, 3)
                dist = np.random.uniform(1.0, 5.0)
                yaw = np.random.uniform(-np.pi, np.pi)  # Horizontal angle from base direction
                pitch = np.random.uniform(-np.pi*2/5, 0)  # Vertical angle
            elif view_type=='top':
                view_target = view_base_target_apply + np.random.uniform(-0.2, 0.2, 3)
                dist = np.random.uniform(0.5, 1.5)
                yaw = np.random.uniform(-np.pi, np.pi)  # Horizontal angle from base direction
                pitch = np.random.uniform(-np.pi*8/17, -np.pi/4)  # Vertical angle
            
            # Compute camera position using spherical coordinates
            # cam_pos = p.getMatrixFromQuaternion(p.getQuaternionFromEuler([pitch, yaw, 0]))
            cam_pos = sciR.from_euler('y', pitch).apply(np.array([1,0,0]))
            cam_pos = sciR.from_euler('z', yaw).apply(cam_pos)
            cam_pos = cam_pos * dist + view_base_target_apply
            # cam_pos = np.dot(cam_pos, cam_base_direction) * dist + view_target
            
            # Get the view matrix, and extract position + quaternion from it
            view_matrix = p.computeViewMatrix(cam_pos, view_target, np.array([0, 0, 1.])+np.random.normal(size=(3,))*0.1)  # Looking at the target
            cam_posquat = cutil.pb_viewmatrix_to_cam_posquat(view_matrix)
            cam_posquats.append(cam_posquat)

            # Camera intrinsics (assuming 640x480 resolution)
            fov = 60
            # prj_matrix = p.computeProjectionMatrixFOV(fov, pixel_size[0]/pixel_size[1], 0.1, 100.0)            
            cam_intrinsic = cutil.pbfov_to_intrinsic(pixel_size, fov)
            near = 0.01
            far=10.0
            prj_matrix = p.computeProjectionMatrix(
                *cutil.intrinsic_to_pb_lrbt(cam_intrinsic, near=near), 
                nearVal = near, 
                farVal = far)
            cam_intrinsics.append(cam_intrinsic)

            # pb_img_out = p.getCameraImage(pixel_size[1], pixel_size[0], view_matrix, prj_matrix, renderer=p.ER_BULLET_HARDWARE_OPENGL)
            pb_img_out = p.getCameraImage(pixel_size[1], pixel_size[0], view_matrix, prj_matrix, renderer=p.ER_TINY_RENDERER)
            rgb = np.reshape(pb_img_out[2], (pixel_size[0], pixel_size[1], 4))[:, :, :3]
            depth = np.reshape(pb_img_out[3], (pixel_size[0], pixel_size[1]))
            seg = np.reshape(pb_img_out[4], (pixel_size[0], pixel_size[1]))
            rgbs.append(rgb)
            depths.append(depth)
            segs.append(seg)

        cam_posquats = np.array(cam_posquats)  # Shape: [num_views, 7]
        cam_intrinsics = np.array(cam_intrinsics)  # Shape: [num_views, 6]

        # Gather object position and orientation data for the movable objects
        obj_posquats = np.zeros((max_obj_no, 7), dtype=np.float16)
        obj_scales = np.zeros((max_obj_no, 3), dtype=np.float16)
        mesh_names = [None for i in range(max_obj_no)]
        obj_uids = -np.ones((max_obj_no), dtype=np.int8)

        for j, obj in enumerate(self.movable_objects):
            obj_posquats[j] = obj.base_pose[0] + obj.base_pose[1]
            obj_scales[j] = copy.deepcopy(obj.scale)
            mesh_names[j] = copy.deepcopy(obj.mesh_path)
            obj_uids[j] = copy.deepcopy(obj.pb_uid)
        
        env_obj_no = 2
        env_obj_posquats = np.zeros((env_obj_no, 7), dtype=np.float16)
        env_obj_scales = np.zeros((env_obj_no, 3), dtype=np.float16)
        env_mesh_names = [None for i in range(env_obj_no)]
        env_obj_uids = -np.ones((env_obj_no), dtype=np.int8)
        for j, obj in enumerate(self.fixed_objects):
            env_obj_posquats[j] = obj.base_pose[0] + obj.base_pose[1]
            env_obj_scales[j] = copy.deepcopy(obj.scale)
            env_mesh_names[j] = copy.deepcopy(obj.mesh_path)
            env_obj_uids[j] = copy.deepcopy(obj.pb_uid)

        if robot_pbuid is not None:
            robot_posquat = p.getBasePositionAndOrientation(robot_pbuid)
            robot_posquat = np.concatenate([robot_posquat[0], robot_posquat[1]])
            robot_q = p.getJointStates(robot_pbuid, range(1,7)) # only for ur5
            robot_q = np.array([q[0] for q in robot_q])

        scene_data = SceneData(
            rgbs=np.array(rgbs).astype(np.uint8),
            depths=np.array(depths).astype(np.float16),
            seg=np.array(segs).astype(np.int8),
            cam_posquats=np.array(cam_posquats).astype(np.float16),
            cam_intrinsics=np.array(cam_intrinsics).astype(np.float16),
            ObjInfo_obj_posquats=np.array(obj_posquats).astype(np.float16),
            ObjInfo_scale=np.array(obj_scales).astype(np.float16),
            ObjInfo_uid_list=obj_uids.astype(np.int8),
            ObjInfo_mesh_name=np.array(mesh_names).astype(np.string_),
            EnvInfo_obj_posquats=np.array(env_obj_posquats).astype(np.float16),
            EnvInfo_scale=np.array(env_obj_scales).astype(np.float16),
            EnvInfo_uid_list=env_obj_uids.astype(np.int8),
            EnvInfo_mesh_name=np.array(env_mesh_names),
            RobotInfo_posquat=np.array(robot_posquat).astype(np.float16) if robot_pbuid is not None else None,
            RobotInfo_q=np.array(robot_q).astype(np.float16) if robot_pbuid is not None else None,
            table_params=None,
            robot_params=None,
            nvren_info=None
        )

        assert np.all(scene_data.ObjInfo_scale[...,0]==scene_data.ObjInfo_scale[...,1])
        assert np.all(scene_data.ObjInfo_scale[...,1]==scene_data.ObjInfo_scale[...,2])
        assert np.sum(np.abs(scene_data.ObjInfo_obj_posquats[scene_data.ObjInfo_scale[...,0] == 0])) < 1e-5

        valid_mask = False
        for i in range(scene_data.ObjInfo_obj_posquats.shape[0]):
            if scene_data.ObjInfo_scale[i,0] == 0:
                valid_mask = True
            if valid_mask:
                assert scene_data.ObjInfo_scale[i,0] == 0

        return scene_data

    def get_obj_by_pbuid(self, pb_uid):
        for obj in self.movable_objects:
            if obj.pb_uid == pb_uid:
                return obj
        for obj in self.fixed_objects:
            if obj.pb_uid == pb_uid:
                return obj
        return None
    
    def get_pbscene_by_pbuids(self, pb_uids):
        movable_objs = []
        fixed_objs = []
        for pb_uid in pb_uids:
            for obj in self.movable_objects:
                if obj.pb_uid == pb_uid:
                    movable_objs.append(obj)
            for obj in self.fixed_objects:
                if obj.pb_uid == pb_uid:
                    fixed_objs.append(obj)
        return PybulletScene(movable_objs, fixed_objs)

class SceneConverter:
    """
    Contains mesh and pose information of all objects in the scene.
    """

    def __init__(self):
        self.movable_objects: List[SceneObject] = []
        self.fixed_objects: List[SceneObject] = []
        self.robot_uid: int = None

    def create_pybullet_scene_from_meshes(
        self,
        mesh_names: List[str],
        object_poses: List[Tuple[List[float], List[float]]],
        object_scales: List[List[float]]
    ):
        """
        Creates a PyBullet scene from object meshes.

        Args:
            mesh_names (List[str]): List of mesh file names.
            object_poses (List[Tuple[List[float], List[float]]]): List of object poses (position, orientation).
            object_scales (List[List[float]]): List of object scales.
        """
        # Convert SDF filenames to mesh paths if necessary
        # for i, mesh_name in enumerate(mesh_names):
        #     if mesh_name.endswith('.sdf'):
        #         mesh_names[i], _ = extract_mesh_info_from_sdf_file(mesh_name)

        # Load the robot URDF
        robot_urdf_path = "assets/ur5/urdf/ur5.urdf"
        if not os.path.exists(robot_urdf_path):
            assert False, f"Robot URDF path {robot_urdf_path} does not exist. Please check the path."
        p.connect(p.GUI)
        self.robot_uid = p.loadURDF(robot_urdf_path, useFixedBase=True)

        num_objects = len(object_poses)
        assert num_objects == len(mesh_names) == len(object_scales)

        for i in range(num_objects):
            mesh_filename = mesh_names[i]
            obj_scale = object_scales[i]

            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            obj_uid = p.createMultiBody(
                baseMass=1.0,
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=object_poses[i][0],
                baseOrientation=object_poses[i][1]
            )
            self.movable_objects.append(SceneObject(mesh_filename, obj_scale, object_poses[i], obj_uid))

    def construct_scene_from_pybullet(
        self,
        movable_obj_uids: List[int],
        fixed_obj_uids: List[int],
        robot_uid: int = None,
        robot_urdf_path: str = None
    ) -> PybulletScene:
        """
        Constructs the scene by extracting object information from PyBullet.

        Args:
            movable_obj_uids (List[int]): List of movable object unique IDs.
            fixed_obj_uids (List[int]): List of fixed object unique IDs.
            robot_uid (int, optional): Unique ID of the robot.

        Returns:
            PybulletScene: An instance containing the scene information.
        """
        self.movable_objects = []
        self.fixed_objects = []
        self.robot_uid = robot_uid
        self.robot_urdf_path = robot_urdf_path


        def extract_col_path(mesh_path):
            for cvx_base_dir in  ['cvx/coacd', 'cvx/8_32_v4', 'cvx/0_coacd']:
                col_mesh_path = Path(mesh_path.decode('utf-8')).parent.parent / cvx_base_dir / (Path(mesh_path.decode('utf-8')).name)
                col_mesh_path = str(col_mesh_path)
                if os.path.exists(col_mesh_path):
                    return col_mesh_path
            return mesh_path.decode('utf-8')

        # Movable objects
        for obj_uid in movable_obj_uids:
            base_pose = p.getBasePositionAndOrientation(obj_uid)
            visual_data = p.getVisualShapeData(obj_uid)
            col_mesh_data = p.getCollisionShapeData(obj_uid, -1)
            aabb = p.getAABB(obj_uid)
            if not visual_data:
                continue
            mesh_path = visual_data[0][4]  # Mesh file path
            # col_mesh_path = col_mesh_data[0][4]
            col_mesh_path = extract_col_path(mesh_path)
            scale = visual_data[0][3]      # Scale of the object
            self.movable_objects.append(SceneObject(mesh_path, scale, base_pose, aabb, obj_uid, col_mesh_path=col_mesh_path))

        # Fixed objects
        for obj_uid in fixed_obj_uids:
            base_pose = p.getBasePositionAndOrientation(obj_uid)
            visual_data = p.getVisualShapeData(obj_uid)
            col_mesh_data = p.getCollisionShapeData(obj_uid, -1)
            aabb = p.getAABB(obj_uid)
            if not visual_data:
                continue
            mesh_path = visual_data[0][4]
            # col_mesh_path = col_mesh_data[0][4]
            col_mesh_path = extract_col_path(mesh_path)

            # if mesh_path.decode('utf-8').split("/")[-3] == 'sink':
            #     mesh_path = Path(mesh_path.decode('utf-8')).parent.parent / 'cvx/0_coacd' / ('cvx_'+Path(mesh_path.decode('utf-8')).name)
            #     mesh_path = str(mesh_path)
            scale = visual_data[0][3]
            self.fixed_objects.append(SceneObject(mesh_path, scale, base_pose, aabb, obj_uid, col_mesh_path=col_mesh_path))


        # robot
        if robot_uid is not None:
            self.robot_uid = robot_uid
            base_pose = p.getBasePositionAndOrientation(robot_uid)
            number_of_joints = p.getNumJoints(robot_uid)
            robot_joint_state = p.getJointStates(robot_uid, range(1, number_of_joints))
            robot_q = np.array([joint[0] for joint in robot_joint_state])
            base_ori = sciR.from_quat(base_pose[1]).as_rotvec()[2]
            robot_q = np.concatenate([base_pose[0][:2], [base_ori], robot_q])
            self.robot = Robots(robot_urdf_path=self.robot_urdf_path, robot_uid=robot_uid, q=robot_q, robot_height=base_pose[0][2])
        else:
            self.robot = None

        # Return the PybulletScene instance
        return PybulletScene(self.movable_objects, self.fixed_objects, self.robot)

    def convert_scene_to_oriCORNs(self, models):
        """
        Converts the entire scene to oriCORN representations.

        Args:
            models: An instance of the Models class.
        """
        # Movable objects
        movable_oriCORNs = [obj.convert_to_oriCORN(models) for obj in self.movable_objects]
        if len(movable_oriCORNs) > 0:
            self.movable_oriCORNs: loutil.LatentObjects = jax.tree_util.tree_map(
                lambda *x: jnp.stack(x), *movable_oriCORNs
            )
        else:
            self.movable_oriCORNs = None

        # Fixed objects
        fixed_oriCORNs = [obj.convert_to_oriCORN(models).reshape_outer_shape((-1,)) for obj in self.fixed_objects]
        if len(fixed_oriCORNs) > 0:
            self.fixed_oriCORNs: loutil.LatentObjects = jax.tree_util.tree_map(
                lambda *x: jnp.concat(x), *fixed_oriCORNs
            )
            self.fixed_oriCORNs = self.fixed_oriCORNs.reshape_outer_shape((-1,))
        else:
            self.fixed_oriCORNs = None


def quaternion_multiply(q, r):
    """
    Multiply two quaternions.

    Parameters:
    q, r -- quaternions in [x, y, z, w] format

    Returns:
    product quaternion in [x, y, z, w] format
    """
    x1, y1, z1, w1 = q
    x2, y2, z2, w2 = r

    w = w1*w2 - x1*x2 - y1*y2 - z1*z2
    x = w1*x2 + x1*w2 + y1*z2 - z1*y2
    y = w1*y2 - x1*z2 + y1*w2 + z1*x2
    z = w1*z2 + x1*y2 - y1*x2 + z1*w2

    return [x, y, z, w]

def create_random_obj(models, scene_center, xyz_range, table_uid=None, perform_stabilization=False, num_objects=2, seed=0, visualize=False)->List[SceneObject]:

    # if no pb connection, connect
    if p.getConnectionInfo()['isConnected'] == 0:
        p.connect(p.GUI if visualize else p.DIRECT)
    

    np_rng = np.random.default_rng(seed)
    # Create objects on the table, ensuring no collisions
    object_uids = []
    max_attempts = 10
    for i in range(num_objects):
        mesh_filename = models.asset_path_util.obj_paths[
            np_rng.integers(low=0, high=len(models.asset_path_util.obj_paths))
        ]
        # mesh_filename = extract_mesh_info_from_sdf_file(sdf_filename)[0]

        # Load the mesh to calculate size and alignment
        mesh = trimesh.load(mesh_filename)
        # Compute the bounding box size
        bounding_box = mesh.bounding_box.extents
        max_dimension = np.max(bounding_box)
        # Calculate normalization factor to normalize object size to 1
        normalization_factor = 1 / max_dimension

        # Random scale between 0.08 and 0.3
        random_scale = np_rng.uniform(0.08, 0.3)
        # Total scale is normalization factor times random scale
        total_scale_factor = normalization_factor * random_scale
        obj_scale = [total_scale_factor] * 3

        # Generate random position within the specified range
        for attempt in range(max_attempts):
            obj_position = [
                np_rng.uniform(-xyz_range[0], xyz_range[0])+scene_center[0],  # x
                np_rng.uniform(-xyz_range[1], xyz_range[1])+scene_center[1],  # y
                np_rng.uniform(0, xyz_range[2])+scene_center[2]    # z (above the table)
            ]

            # Align object so the largest dimension of AABB is along the z-axis
            max_dim_index = np.argmax(bounding_box)

            # Compute the rotation to align max_dim_index to z-axis
            if max_dim_index == 0:
                # Rotate from x-axis to z-axis (-90 degrees around y-axis)
                angle = -np.pi / 2
                axis = [0, 1, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            elif max_dim_index == 1:
                # Rotate from y-axis to z-axis (90 degrees around x-axis)
                angle = np.pi / 2
                axis = [1, 0, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            else:
                # max_dim_index == 2, already aligned
                q_align = [0, 0, 0, 1]  # Identity quaternion

            # Generate random rotation around z-axis
            random_yaw = np_rng.uniform(0, 2 * np.pi)
            q_random = p.getQuaternionFromEuler([0, 0, random_yaw])

            # Combine rotations: q_total = q_random * q_align
            q_total = quaternion_multiply(q_random, q_align)

            # Set object orientation to q_total
            obj_orientation = q_total

            # Create the object in PyBullet
            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            obj_uid = p.createMultiBody(
                baseMass=1.0,
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=obj_position,
                baseOrientation=obj_orientation
            )

            # Check for collisions with existing objects
            collision = False
            obj_pool = (object_uids + [table_uid]) if table_uid is not None else object_uids
            for existing_uid in obj_pool:
                contacts = p.getClosestPoints(obj_uid, existing_uid, distance=0)
                if contacts:
                    collision = True
                    break
            if collision:
                # Remove the object and try again
                p.removeBody(obj_uid)
                if attempt == max_attempts - 1:
                    print(f"Could not place object {mesh_filename} without collision after {max_attempts} attempts")
            else:
                # No collision, proceed
                object_uids.append(obj_uid)
                break

    # Stabilize objects on the table
    p.setGravity(0, 0, -9.81)
    time_step = p.getPhysicsEngineParameters()['fixedTimeStep']
    num_steps = int(4.0 / time_step)
    if perform_stabilization:
        for _ in range(num_steps):
            p.stepSimulation()

    # Remove objects that have fallen off the table
    drop_uids = []
    for obj_uid in object_uids:
        position, _ = p.getBasePositionAndOrientation(obj_uid)
        if position[2] < 0.1:  # Threshold to check if the object has dropped
            drop_uids.append(obj_uid)
    for obj_uid in drop_uids:
        p.removeBody(obj_uid)
        object_uids.remove(obj_uid)

    # colorize random color
    for obj_uid in object_uids:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj_uid, -1, rgbaColor=color.tolist()+[1])

    return object_uids

def create_multibody(mesh_path, position, quaternion, scale, collision_path = None):
    scales = [scale, scale, scale]
    visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=mesh_path,
        meshScale=scales
    )
    collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=mesh_path if collision_path is None else collision_path,
        meshScale=scales
    )
    position = position
    uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=collision_shape_id,
        baseVisualShapeIndex=visual_shape_id,
        basePosition=position,
        baseOrientation=quaternion
    )
    return uid


def create_table_sampled_scene(models=None, num_objects=15, seed=0, perform_stabilization=True, visualize=False):
    '''
    returns oriCORN_scene_fixed_obj, oriCORN_scene_movable_obj, pybullet_scene
    '''
    # scene_center = [0, 0.6, 0]  # Center of the scene
    scene_center = [0, 0.6, -0.4]  # Center of the scene
    np_rng = np.random.default_rng(seed)

    if models is None:
        import util.model_util as mutil
        models = mutil.Models().load_pretrained_models()
    # models.asset_path_util.enrol_or_download_assets()

    # Extract object filenames from models
    table_mesh_filename = models.asset_path_util.get_obj_path_from_rel_path('kitchen/table/big_table_1.obj')

    # Connect to PyBullet
    connected_inside = False
    if p.getConnectionInfo()['isConnected'] == 0:
        connected_inside = True
        p.connect(p.GUI if visualize else p.DIRECT)

    # Create the table in PyBullet
    table_scale = [1, 1, 1]
    table_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=table_mesh_filename,
        meshScale=table_scale
    )
    table_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=table_mesh_filename,
        meshScale=table_scale
    )
    table_position = scene_center+np.array([0, 0, 0.4])  # Adjusted table height to 0.4
    table_orientation = [0, 0, 0, 1]
    table_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=table_collision_shape_id,
        baseVisualShapeIndex=table_visual_shape_id,
        basePosition=table_position,
        baseOrientation=table_orientation
    )

    # Create objects on the table, ensuring no collisions
    object_uids = []
    max_attempts = 100
    # for i in range(num_objects):
    while len(object_uids) < (num_objects+3 if num_objects > 0 else 0):
        mesh_filename = models.asset_path_util.obj_paths[
            np_rng.integers(low=0, high=len(models.asset_path_util.obj_paths))
        ]
        # mesh_filename = extract_mesh_info_from_sdf_file(sdf_filename)[0]

        # Load the mesh to calculate size and alignment
        mesh = trimesh.load(mesh_filename)
        # Compute the bounding box size
        bounding_box = mesh.bounding_box.extents
        max_dimension = np.max(bounding_box)
        # Calculate normalization factor to normalize object size to 1
        normalization_factor = 1 / max_dimension

        # Random scale between 0.08 and 0.3
        # random_scale = np_rng.uniform(0.08, 0.3)
        random_scale = np_rng.uniform(0.15, 0.3)
        # Total scale is normalization factor times random scale
        total_scale_factor = normalization_factor * random_scale
        obj_scale = [total_scale_factor] * 3

        # Generate random position within the specified range
        for attempt in range(max_attempts):
            obj_position = [
                np_rng.uniform(-1.0, 1.0)+scene_center[0],  # x
                np_rng.uniform(-0.4, 0.4)+scene_center[1],  # y
                np_rng.uniform(0.8, 1.4)+scene_center[2]    # z (above the table)
            ]

            # Align object so the largest dimension of AABB is along the z-axis
            max_dim_index = np.argmax(bounding_box)

            # Compute the rotation to align max_dim_index to z-axis
            if max_dim_index == 0:
                # Rotate from x-axis to z-axis (-90 degrees around y-axis)
                angle = -np.pi / 2
                axis = [0, 1, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            elif max_dim_index == 1:
                # Rotate from y-axis to z-axis (90 degrees around x-axis)
                angle = np.pi / 2
                axis = [1, 0, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            else:
                # max_dim_index == 2, already aligned
                q_align = [0, 0, 0, 1]  # Identity quaternion

            # Generate random rotation around z-axis
            random_yaw = np_rng.uniform(0, 2 * np.pi)
            q_random = p.getQuaternionFromEuler([0, 0, random_yaw])

            # Combine rotations: q_total = q_random * q_align
            q_total = quaternion_multiply(q_random, q_align)

            # Set object orientation to q_total
            obj_orientation = q_total

            # Create the object in PyBullet
            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            obj_uid = p.createMultiBody(
                baseMass=1.0,
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=obj_position,
                baseOrientation=obj_orientation
            )

            # Check for collisions with existing objects
            collision = False
            for existing_uid in object_uids + [table_uid]:
                contacts = p.getClosestPoints(obj_uid, existing_uid, distance=0)
                if contacts:
                    collision = True
                    break
            if collision:
                # Remove the object and try again
                p.removeBody(obj_uid)
                if attempt == max_attempts - 1:
                    print(f"Could not place object {mesh_filename} without collision after {max_attempts} attempts")
            else:
                # No collision, proceed
                object_uids.append(obj_uid)
                break
        

        # Stabilize objects on the table
        p.setGravity(0, 0, -9.81)
        time_step = p.getPhysicsEngineParameters()['fixedTimeStep']
        num_steps = int(4.0 / time_step)
        if perform_stabilization:
            for _ in range(num_steps):
                p.stepSimulation()

        # Remove objects that have fallen off the table
        drop_uids = []
        for obj_uid in object_uids:
            position, _ = p.getBasePositionAndOrientation(obj_uid)
            if position[2] < 0.1:  # Threshold to check if the object has dropped
                drop_uids.append(obj_uid)
        for obj_uid in drop_uids:
            p.removeBody(obj_uid)
            object_uids.remove(obj_uid)

    # Create an upper table to cap the scene
    table_cover_scale = [1, 1, 1]
    table_cover_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=table_mesh_filename,
        meshScale=table_cover_scale
    )
    table_cover_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=table_mesh_filename,
        meshScale=table_cover_scale
    )
    table_cover_position = scene_center + np.array([0, 0, 0.9])  # Position of the upper table
    table_cover_orientation = [0, 0, 0, 1]
    table_cover_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=table_cover_collision_shape_id,
        baseVisualShapeIndex=table_cover_visual_shape_id,
        basePosition=table_cover_position,
        baseOrientation=table_cover_orientation
    )

    # Perform collision check with the upper table and remove colliding objects
    new_object_uids = []
    for obj_uid in object_uids:
        if len(new_object_uids) >= num_objects:
            p.removeBody(obj_uid)
            continue
        collision = False
        contacts = p.getClosestPoints(obj_uid, table_cover_uid, distance=0)
        if contacts:
            collision = True
        if collision:
            # Remove the object
            p.removeBody(obj_uid)
        else:
            new_object_uids.append(obj_uid)
    object_uids = new_object_uids[:num_objects]
    assert len(object_uids) == num_objects

    # Construct the scene
    movable_obj_uids = object_uids
    fixed_obj_uids = [table_uid, table_cover_uid]
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)
    if visualize:
        scene_objs = scene_converter.movable_oriCORNs.concat(scene_converter.fixed_oriCORNs, axis=0)
        from util.reconstruction_util import create_scene_mesh_from_oriCORNs
        dec = jax.jit(partial(models.apply, 'sdf_decoder'))
        create_scene_mesh_from_oriCORNs(scene_objs, dec=dec, visualize=True)
        if connected_inside:
            p.disconnect()
    else:
        if connected_inside:
            p.disconnect()
        return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene

def create_shelf_scene(models=None, num_objects=15, seed=0, perform_stabilization=True, visualize=False):
    '''
    returns oriCORN_scene_fixed_obj, oriCORN_scene_movable_obj, pybullet_scene
    '''
    # scene_center = [0, 0.6, 0]  # Center of the scene
    scene_center = [0, 0.6, -0.4]  # Center of the scene
    np_rng = np.random.default_rng(seed)

    if models is None:
        import util.model_util as mutil
        models = mutil.Models().load_pretrained_models()
    models.asset_path_util.enrol_or_download_assets()

    # Extract object filenames from models
    table_mesh_filename = models.asset_path_util.get_obj_path_from_rel_path('kitchen/table/big_table_1.obj')

    # Connect to PyBullet
    p.connect(p.GUI if visualize else p.DIRECT)

    # Create the table in PyBullet
    table_uid = create_multibody(
        mesh_path=table_mesh_filename,
        position=scene_center+np.array([0, 0, 0.4]),
        quaternion=[0, 0, 0, 1],
        scale=1
    )
    wall_uid = create_multibody(
        mesh_path=models.asset_path_util.get_obj_path_from_rel_path('GoogleScannedObjects/modified/Mario_Party_9_Wii_Game.obj'),
        position=scene_center+np.array([0, 0, 0.4]),
        quaternion=[0, 0, 0, 1],
        scale=1
    )

    # Create objects on the table, ensuring no collisions
    object_uids = []
    max_attempts = 10
    for i in range(num_objects):
        mesh_filename = models.asset_path_util.obj_paths[
            np_rng.integers(low=0, high=len(models.asset_path_util.obj_paths))
        ]
        # mesh_filename = extract_mesh_info_from_sdf_file(sdf_filename)[0]

        # Load the mesh to calculate size and alignment
        mesh = trimesh.load(mesh_filename)
        # Compute the bounding box size
        bounding_box = mesh.bounding_box.extents
        max_dimension = np.max(bounding_box)
        # Calculate normalization factor to normalize object size to 1
        normalization_factor = 1 / max_dimension

        # Random scale between 0.08 and 0.3
        random_scale = np_rng.uniform(0.08, 0.3)
        # Total scale is normalization factor times random scale
        total_scale_factor = normalization_factor * random_scale
        obj_scale = [total_scale_factor] * 3

        # Generate random position within the specified range
        for attempt in range(max_attempts):
            obj_position = [
                np_rng.uniform(-1.0, 1.0)+scene_center[0],  # x
                np_rng.uniform(-0.4, 0.4)+scene_center[1],  # y
                np_rng.uniform(0.8, 1.4)+scene_center[2]    # z (above the table)
            ]

            # Align object so the largest dimension of AABB is along the z-axis
            max_dim_index = np.argmax(bounding_box)

            # Compute the rotation to align max_dim_index to z-axis
            if max_dim_index == 0:
                # Rotate from x-axis to z-axis (-90 degrees around y-axis)
                angle = -np.pi / 2
                axis = [0, 1, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            elif max_dim_index == 1:
                # Rotate from y-axis to z-axis (90 degrees around x-axis)
                angle = np.pi / 2
                axis = [1, 0, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            else:
                # max_dim_index == 2, already aligned
                q_align = [0, 0, 0, 1]  # Identity quaternion

            # Generate random rotation around z-axis
            random_yaw = np_rng.uniform(0, 2 * np.pi)
            q_random = p.getQuaternionFromEuler([0, 0, random_yaw])

            # Combine rotations: q_total = q_random * q_align
            q_total = quaternion_multiply(q_random, q_align)

            # Set object orientation to q_total
            obj_orientation = q_total

            # Create the object in PyBullet
            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            obj_uid = p.createMultiBody(
                baseMass=1.0,
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=obj_position,
                baseOrientation=obj_orientation
            )

            # Check for collisions with existing objects
            collision = False
            for existing_uid in object_uids + [table_uid]:
                contacts = p.getClosestPoints(obj_uid, existing_uid, distance=0)
                if contacts:
                    collision = True
                    break
            if collision:
                # Remove the object and try again
                p.removeBody(obj_uid)
                if attempt == max_attempts - 1:
                    print(f"Could not place object {mesh_filename} without collision after {max_attempts} attempts")
            else:
                # No collision, proceed
                object_uids.append(obj_uid)
                break

    # Stabilize objects on the table
    p.setGravity(0, 0, -9.81)
    time_step = p.getPhysicsEngineParameters()['fixedTimeStep']
    num_steps = int(4.0 / time_step)
    if perform_stabilization:
        for _ in range(num_steps):
            p.stepSimulation()

    # Remove objects that have fallen off the table
    drop_uids = []
    for obj_uid in object_uids:
        position, _ = p.getBasePositionAndOrientation(obj_uid)
        if position[2] < 0.1:  # Threshold to check if the object has dropped
            drop_uids.append(obj_uid)
    for obj_uid in drop_uids:
        p.removeBody(obj_uid)
        object_uids.remove(obj_uid)

    # Create an upper table to cap the scene
    table_cover_scale = [1, 1, 1]
    table_cover_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=table_mesh_filename,
        meshScale=table_cover_scale
    )
    table_cover_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=table_mesh_filename,
        meshScale=table_cover_scale
    )
    table_cover_position = scene_center + np.array([0, 0, 0.9])  # Position of the upper table
    table_cover_orientation = [0, 0, 0, 1]
    table_cover_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=table_cover_collision_shape_id,
        baseVisualShapeIndex=table_cover_visual_shape_id,
        basePosition=table_cover_position,
        baseOrientation=table_cover_orientation
    )

    # Perform collision check with the upper table and remove colliding objects
    new_object_uids = []
    for obj_uid in object_uids:
        collision = False
        contacts = p.getClosestPoints(obj_uid, table_cover_uid, distance=0)
        if contacts:
            collision = True
        if collision:
            # Remove the object
            p.removeBody(obj_uid)
        else:
            new_object_uids.append(obj_uid)
    object_uids = new_object_uids

    # Construct the scene
    movable_obj_uids = object_uids
    fixed_obj_uids = [table_uid, table_cover_uid, wall_uid]
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    if visualize:
        scene_objs = scene_converter.movable_oriCORNs.concat(scene_converter.fixed_oriCORNs, axis=0)
        from util.reconstruction_util import create_scene_mesh_from_oriCORNs
        dec = jax.jit(partial(models.apply, 'sdf_decoder'))
        create_scene_mesh_from_oriCORNs(scene_objs, dec=dec, visualize=True)
        p.disconnect()
    else:
        p.disconnect()
        return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene


def create_pick_place_scene(models, visualize=False, num_objects=1):
    # scene_center = [0, 0.6, 0]  # Center of the scene
    scene_center = [0, 0.6, -0.4]  # Center of the scene

    # Extract object filenames from models
    table_mesh_filename = models.asset_path_util.get_obj_path_from_rel_path('kitchen/table/big_table_1.obj')

    # Connect to PyBullet
    p.connect(p.GUI if visualize else p.DIRECT)

    # Create the table in PyBullet
    table_scale = [1, 1, 1]
    table_positions = [
        scene_center + np.array([0, 0, 0.4]),
        scene_center + np.array([3, 0, 0.4]),
    ]
    table_orientations = [[0, 0, 0, 1], [0, 0, 0, 1]]
    table_uids = [
        p.createMultiBody(
            baseMass=0.0,
            baseCollisionShapeIndex=p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=table_mesh_filename,
                meshScale=table_scale
            ),
            baseVisualShapeIndex=p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=table_mesh_filename,
                meshScale=table_scale
            ),
            basePosition=position,
            baseOrientation=orientation,
        ) for position, orientation in zip(table_positions, table_orientations)
    ]

    start_region_uid = table_uids[0]
    goal_region_uid = table_uids[1]

    object_uids = []
    max_attempts = 10
    for i in range(num_objects):
        mesh_filename = models.asset_path_util.obj_paths[
            np.random.randint(len(models.asset_path_util.obj_paths))
        ]
        # mesh_filename = extract_mesh_info_from_sdf_file(sdf_filename)[0]

        # Load the mesh to calculate size and alignment
        mesh = trimesh.load(mesh_filename)
        # Compute the bounding box size
        bounding_box = mesh.bounding_box.extents
        max_dimension = np.max(bounding_box)
        # Calculate normalization factor to normalize object size to 1
        normalization_factor = 1 / max_dimension

        # Random scale between 0.08 and 0.3
        random_scale = np.random.uniform(0.08, 0.3)
        # Total scale is normalization factor times random scale
        total_scale_factor = normalization_factor * random_scale
        obj_scale = [total_scale_factor] * 3

        # Generate random position within the specified range
        for attempt in range(max_attempts):
            obj_position = [
                np.random.uniform(-1.0, 1.0)+scene_center[0],  # x
                np.random.uniform(-0.4, 0.4)+scene_center[1],  # y
                np.random.uniform(0.8, 1.4)+scene_center[2]    # z (above the table)
            ]

            # Align object so the largest dimension of AABB is along the z-axis
            max_dim_index = np.argmax(bounding_box)

            # Compute the rotation to align max_dim_index to z-axis
            if max_dim_index == 0:
                # Rotate from x-axis to z-axis (-90 degrees around y-axis)
                angle = -np.pi / 2
                axis = [0, 1, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            elif max_dim_index == 1:
                # Rotate from y-axis to z-axis (90 degrees around x-axis)
                angle = np.pi / 2
                axis = [1, 0, 0]
                q_align = p.getQuaternionFromAxisAngle(axis, angle)
            else:
                # max_dim_index == 2, already aligned
                q_align = [0, 0, 0, 1]  # Identity quaternion

            # Generate random rotation around z-axis
            random_yaw = np.random.uniform(0, 2 * np.pi)
            q_random = p.getQuaternionFromEuler([0, 0, random_yaw])

            # Combine rotations: q_total = q_random * q_align
            q_total = quaternion_multiply(q_random, q_align)

            # Set object orientation to q_total
            obj_orientation = q_total

            # Create the object in PyBullet
            visual_shape_id = p.createVisualShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            collision_shape_id = p.createCollisionShape(
                shapeType=p.GEOM_MESH,
                fileName=mesh_filename,
                meshScale=obj_scale
            )
            obj_uid = p.createMultiBody(
                baseMass=1.0,
                baseCollisionShapeIndex=collision_shape_id,
                baseVisualShapeIndex=visual_shape_id,
                basePosition=obj_position,
                baseOrientation=obj_orientation
            )

            # Check for collisions with existing objects
            collision = False
            for existing_uid in object_uids + table_uids:
                contacts = p.getClosestPoints(obj_uid, existing_uid, distance=0)
                if contacts:
                    collision = True
                    break
            if collision:
                # Remove the object and try again
                p.removeBody(obj_uid)
                if attempt == max_attempts - 1:
                    print(f"Could not place object {mesh_filename} without collision after {max_attempts} attempts")
            else:
                # No collision, proceed
                object_uids.append(obj_uid)
                break

    # Stabilize objects on the table
    p.setGravity(0, 0, -9.81)
    time_step = p.getPhysicsEngineParameters()['fixedTimeStep']
    num_steps = int(4.0 / time_step)
    for _ in range(num_steps):
        p.stepSimulation()

    # Remove objects that have fallen off the table
    drop_uids = []
    for obj_uid in object_uids:
        position, _ = p.getBasePositionAndOrientation(obj_uid)
        if position[2] < 0.1:  # Threshold to check if the object has dropped
            drop_uids.append(obj_uid)
    for obj_uid in drop_uids:
        p.removeBody(obj_uid)
        object_uids.remove(obj_uid)

    # Construct the scene
    movable_obj_uids = object_uids
    fixed_obj_uids = table_uids
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    if visualize:
        scene_objs = scene_converter.movable_oriCORNs.concat(scene_converter.fixed_oriCORNs, axis=0)
        from util.reconstruction_util import create_scene_mesh_from_oriCORNs
        create_scene_mesh_from_oriCORNs(scene_objs, None, visualize=True)
        p.disconnect()
    else:
        p.disconnect()
        return scene_converter.fixed_oriCORNs, scene_converter.movable_oriCORNs, pybullet_scene



def reset_shakey(q, shakey_pb_uid, shakey):
    z = shakey.robot_height
    x, y, theta, *joint_state = q
    quaternion = p.getQuaternionFromEuler([0, 0, theta])
    p.resetBasePositionAndOrientation(shakey_pb_uid, [x, y, z], quaternion)

    for i in range(len(joint_state)):
        p.resetJointState(shakey_pb_uid, i + 1, joint_state[i])


def simulate_pick_and_place(pybullet_scene:PybulletScene, trajs, obj_uid_in_hands, shakey_pb_uid, shakey, sleep=0.1):
    import time
    robot_height = shakey.robot_height
    def reset_q_pb(q):
        z = robot_height
        x, y, theta, *joint_state = q
        quaternion = p.getQuaternionFromEuler([0, 0, theta])
        p.resetBasePositionAndOrientation(shakey_pb_uid, [x, y, z], quaternion)

        shakey.set_q_pb(shakey_pb_uid, np.array(joint_state))
        # for i in range(6):
        #     p.resetJointState(shakey_pb_uid, i + 1, joint_state[i])

    movable_objects = copy.deepcopy(pybullet_scene.movable_objects)
    for movable_object in movable_objects:
        p.resetBasePositionAndOrientation(movable_object.pb_uid, movable_object.base_pose[0], movable_object.base_pose[1])

    traj_segments = np.split(trajs, len(obj_uid_in_hands), axis=0)
    # simulate attaching obj
    for traj_segment, obj_id in tqdm(zip(traj_segments, obj_uid_in_hands)):
        ee_pose = shakey.FK(traj_segment, False)[:, -1, :]
        initial_ee_pose = ee_pose[0]
        uid = None
        constraint_id = None
        if obj_id > -1:
            uid = movable_objects[obj_id].pb_uid
            obj_pose = p.getBasePositionAndOrientation(uid)
            constraint_id = p.createConstraint(parentBodyUniqueId=shakey_pb_uid, parentLinkIndex=shakey.ee_idx, 
                                childBodyUniqueId=uid, childLinkIndex=-1,
                                jointType=p.JOINT_FIXED, 
                                jointAxis=[0, 0, 0],
                            parentFramePosition=[0, 0, 0], 
                            childFramePosition=[0, 0, 0])
        for q, ee_pose_i in zip(traj_segment, ee_pose):
            reset_q_pb(q)
            if uid is not None:
                changed_obj_pose = tutil.pq_multi(
                    ee_pose_i[:3], ee_pose_i[3:],
                    *tutil.pq_multi(
                        *tutil.pq_inv(initial_ee_pose[:3], initial_ee_pose[3:]),
                        np.array(obj_pose[0]), np.array(obj_pose[1]),
                    )
                )
                p.resetBasePositionAndOrientation(
                    uid, changed_obj_pose[0], changed_obj_pose[1]
                )

            time.sleep(sleep)
        # After releasing, remove the constraint and store the object's last pose
        if constraint_id is not None:
            p.removeConstraint(constraint_id)
            obj_pose = p.getBasePositionAndOrientation(uid)
            movable_objects[obj_id].base_pose = obj_pose

        # if uid is not None:
        #     movable_objects[obj_id].base_pose = changed_obj_pose


def create_cleaning_scene(models=None, num_objects=15, seed=0, shakey=None, robot_pb_uid=None, fk=None, ik=None, visualize=False):
    assert shakey is not None, "Shakey is required for cleaning scene"

    connected_inside = False
    if p.getConnectionInfo()['isConnected'] == 0:
        connected_inside = True
        p.connect(p.GUI if visualize else p.DIRECT)
    
    np_rng = np.random.default_rng(seed)
    jkey = jax.random.PRNGKey(seed)
    scene_center = np.array([0.25, 0.6, -0.4])  # Center of the scene
    table_uid = create_multibody(
        mesh_path = models.asset_path_util.get_obj_path_from_rel_path('kitchen/table/big_table_1.obj'),
        position = scene_center + np.array([0, 0, 0.4]),
        quaternion = [0, 0, 0, 1],
        scale = 1.0,
    )
    chair_uids = [
        create_multibody(
            mesh_path = models.asset_path_util.get_obj_path_from_rel_path('kitchen/jokkmokk/JokkmokkChair.obj'),
            position = position,
            quaternion = quaternion,
            scale = 1.0,
        ) for position, quaternion in [
            (scene_center + np.array([-0.28, 0.3, 0.6]), sciR.from_euler('xyz', [0, 0, -np.pi/2]).as_quat()),
            (scene_center + np.array([0.28, 0.3, 0.6]), sciR.from_euler('xyz', [0, 0, -np.pi/2]).as_quat()),
            (scene_center + np.array([0.28, -0.4, 0.6]), sciR.from_euler('xyz', [0, 0, np.pi/2]).as_quat()),
            # ([0.28, 0.9, 0.6], sciR.from_euler('xyz', [0, 0, -np.pi/2]).as_quat()),
            # # ([-0.28, 0.2, 0.6], sciR.from_euler('xyz', [0, 0, np.pi/2]).as_quat()),
            # ([0.28, 0.2, 0.6], sciR.from_euler('xyz', [0, 0, np.pi/2]).as_quat()),
        ]
    ]
    # plane_uid = create_multibody(
    #     mesh_path = models.asset_path_util.get_obj_path_from_rel_path('GoogleScannedObjects/modified/Mario_Party_9_Wii_Game.obj'),
    #     position = scene_center + np.array([0, 0, 0.04]),
    #     quaternion = sciR.from_euler('xyz', [0, 0, np.pi/2]).as_quat(),
    #     scale = 10.0, 
    # )
    table_aabb = p.getAABB(table_uid)
    env_uids = [table_uid] + chair_uids

    # np_rng.uniform(np.array)
    lower_bound = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
    upper_bound = np.array([2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])

    ee_to_obj = (
        np.array([0.03, 0, 0.29]),
        sciR.from_euler('xyz', [-np.pi/2, 0, 0]).as_quat(),
    )
    while True:
        jkey, _ = jax.random.split(jkey)
        init_q = jax.random.uniform(jkey, (6,), minval=lower_bound, maxval=upper_bound)

        shakey.set_q_pb(robot_pb_uid, init_q)
        p.performCollisionDetection()
        col_res = p.getContactPoints(robot_pb_uid)
        if len(col_res) != 0:
            continue

        ee_pqc = fk(init_q)
        ee_position = ee_pqc[:3]
        if ee_position[0] < table_aabb[0][0] or ee_position[0] > table_aabb[1][0] or \
            ee_position[1] < table_aabb[0][1] or ee_position[1] > table_aabb[1][1] or \
            ee_position[2] < table_aabb[0][2] or ee_position[2] > table_aabb[1][2]:
            continue

        object_pqc = tutil.pq_multi(
            ee_pqc[:3],
            ee_pqc[3:],
            ee_to_obj[0],
            ee_to_obj[1],
        )
        moving_obj_uid = create_multibody(
            mesh_path = models.asset_path_util.get_obj_path_from_rel_path('NOCS/modified/mug-46ed9dad0440c043d33646b0990bb4a.obj'),
            position = object_pqc[0],
            quaternion = object_pqc[1],
            scale = 0.3,
        )
        collision = False
        p.performCollisionDetection()
        for existing_uid in env_uids:
            contacts = p.getClosestPoints(moving_obj_uid, existing_uid, distance=0)
            if contacts:
                collision = True
                break

        if collision:
            # Remove the object and try again
            p.removeBody(moving_obj_uid)
            continue
        
        while True:
            object_goal_pq = (
                np.array([
                    np_rng.uniform(table_aabb[0][0], table_aabb[1][0]),
                    np_rng.uniform(table_aabb[0][1], table_aabb[1][1]),
                    np_rng.uniform(table_aabb[1][2], table_aabb[1][2] + (table_aabb[1][2] - table_aabb[0][2])),
                ]),
                sciR.from_euler(
                    'xyz',
                    [
                        np_rng.uniform(-np.pi, np.pi),
                        np_rng.uniform(-np.pi, np.pi),
                        np_rng.uniform(-np.pi, np.pi),
                    ]
                ).as_quat(),
            )
            p.resetBasePositionAndOrientation(moving_obj_uid, object_goal_pq[0], object_goal_pq[1])
            ee_pq = tutil.pq_multi(
                object_goal_pq[0],
                object_goal_pq[1],
                *tutil.pq_inv(
                    ee_to_obj[0],
                    ee_to_obj[1],
                ),
            )
            ik_q = ik((lower_bound + upper_bound) / 2, ee_pq)
            goal_q = ik_q[3:]
            shakey.set_q_pb(robot_pb_uid, goal_q)
            converged_ee_pq = fk(goal_q)
            if np.linalg.norm(converged_ee_pq[:3] - ee_pq[0]) > 0.01 or np.linalg.norm(converged_ee_pq[3:] - ee_pq[1]) > 0.01:
                continue
            p.performCollisionDetection()
            collision = False
            for existing_uid in env_uids:
                contacts = p.getClosestPoints(robot_pb_uid, existing_uid, distance=0)
                if contacts:
                    collision = True
                    break
            if collision:
                continue
            break
        break

    # Construct the scene
    movable_obj_uids = [moving_obj_uid]
    fixed_obj_uids = env_uids
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)
    if connected_inside:
        p.disconnect()
    return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q, ee_to_obj



def create_dish_scene(models:mutil.Models, seed, shakey:shakey_module.Shakey, ik_func, robot_pb_uid, visualize=True):
    '''
    https://app.gazebosim.org/GoogleResearch/fuel/collections/Scanned%20Objects%20by%20Google%20Research
    '''

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)
    

    # transform pybullet debug camera
    p.resetDebugVisualizerCamera(cameraDistance=1.00, cameraYaw=-546.00, cameraPitch=-18.94, cameraTargetPosition=[0.07, 0.00, 0.01])
    p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)

    np_rng = np.random.default_rng(seed)
    # env_center = np.array([0,0.4,0.]) + np_rng.uniform([-0.2, -0.1, -0.1], [0.2, 0.1, 0.1])
    env_center = np.array([0,0.4,0.]) + np_rng.uniform([-0.2, -0.1, -0.15], [0.2, 0.1, 0.02])

    # generate floor
    floor_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=p.createCollisionShape(
            shapeType=p.GEOM_PLANE,
        ),
        baseVisualShapeIndex=p.createVisualShape(
            shapeType=p.GEOM_PLANE,
            rgbaColor=[0.8, 0.8, 0.8, 1.0],
        ),
        basePosition=[0, 0, -0.55],
        baseOrientation=[0, 0, 0, 1],
    )

    dish_options = [
        ('GoogleScannedObjects/modified/Ecoforms_Plant_Saucer_SQ8COR.obj', np.array([0.087, 0, 0.18]), 1.1, 0.09),
        ('GoogleScannedObjects/modified/Ecoforms_Plant_Saucer_S20MOCHA.obj', np.array([0.087, 0, 0.18]), 1.1, 0.09),
        ('GoogleScannedObjects/modified/Ecoforms_Saucer_SQ3_Turquoise.obj', np.array([0.087, 0, 0.14]), 1.3, 0.040),
        ('GoogleScannedObjects/modified/Threshold_Tray_Rectangle_Porcelain.obj', np.array([0.087, 0, 0.25]), 1.1, 0.17)
    ]


    # dish_obj_filename, dish_translate, dish_scale, g2obj = random.choice(dish_options)
    
    dish_obj_filename, dish_translate, dish_scale, g2obj = dish_options[np_rng.integers(low=0, high=len(dish_options))]
    # 'Ecoforms_Plant_Saucer_S20MOCHA'
    # 'Chefmate_8_Frypan'
    dish_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dish_obj_filename)
    dish_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dish_obj_filename)), 'cvx/8_32_v4', os.path.basename(dish_obj_filename))

    dish_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=dish_obj_filename,
        meshScale=[dish_scale] * 3
    )
    dish_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=dish_cvx_filename,
        meshScale=[dish_scale] * 3
    )
    dish_position = env_center + dish_translate
    dish_orientation = (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()
    dish_uid = p.createMultiBody(
        baseMass=0.1,
        baseCollisionShapeIndex=dish_collision_shape_id,
        baseVisualShapeIndex=dish_visual_shape_id,
        basePosition=dish_position,
        baseOrientation=dish_orientation
    )

    dishholder_obj_filename = 'GoogleScannedObjects/modified/Poppin_File_Sorter_White.obj'
    dishholder_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dishholder_obj_filename)
    dishholder_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dishholder_obj_filename)), 'cvx/8_32_v4', os.path.basename(dishholder_obj_filename))

    dishholder_position = env_center + np.array([0,0,0.08])
    dishholder_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_obj_filename,
        meshScale=[1.0, 1.0, 1.0]
    )
    dishholder_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_cvx_filename,
        meshScale=[1.0, 1.0, 1.0]
    )
    dishholder_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=dishholder_collision_shape_id,
        baseVisualShapeIndex=dishholder_visual_shape_id,
        basePosition=dishholder_position,
        baseOrientation=[0, 0, 0, 1]
    )

    # Add a PyBullet box to support the dishholder
    dishholder_support_position = dishholder_position + np.array([0, 0, -0.07-0.4])
    dishholder_support_half_extents = [0.12, 0.08, 0.4]
    dishholder_support_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_BOX,
        halfExtents=dishholder_support_half_extents
    )
    # dishholder_support_collision_shape_id = p.createCollisionShape(
    #     shapeType=p.GEOM_BOX,
    #     halfExtents=dishholder_support_half_extents
    # )
    dishholder_support_uid = p.createMultiBody(
        baseMass=0.0,
        # baseCollisionShapeIndex=dishholder_support_collision_shape_id,
        baseVisualShapeIndex=dishholder_support_visual_shape_id,
        basePosition=dishholder_support_position,
        baseOrientation=[0, 0, 0, 1]
    )

    p.resetBasePositionAndOrientation(dish_uid, [0,0,0], [0,0,0,1.])

    movable_obj_uids = [dish_uid]
    fixed_obj_uids = [dishholder_uid]

    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    # calculate goal q
    gripper_to_obj_quat = (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()
    gripper_to_obj = (jnp.array([0,0,g2obj]), gripper_to_obj_quat)

    ee_to_obj = tutil.pq_multi(*shakey.gripper_tip_offset_from_ee, *gripper_to_obj)
    ee_to_obj_pqc = np.concat(ee_to_obj, axis=-1)
    
    ee_pq = tutil.pq_multi(
                dish_position,
                dish_orientation,
                *tutil.pq_inv(
                    ee_to_obj[0],
                    ee_to_obj[1],
                ),
            )
    
    while True:
        init_q = np.array([-np.pi/4, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, np.pi])
        init_q[0] = np_rng.choice([-np.pi/3, -np.pi/2, np.pi/2])
        init_q += np_rng.uniform(-0.1, 0.1, size=init_q.shape[-1])
        shakey.set_q_pb(robot_pb_uid, init_q, [dish_uid], ee_to_obj_pqc)
        p.performCollisionDetection()
        if len(p.getContactPoints(robot_pb_uid)) != 0 or len(p.getContactPoints(dish_uid)) != 0:
            continue

        ik_q = ik_func(init_q, ee_pq)
        goal_q = ik_q[3:]
        shakey.set_q_pb(robot_pb_uid, goal_q, [dish_uid], ee_to_obj_pqc)

        ee_pqc_ik = shakey.get_ee_pq_pb(robot_pb_uid)
        if np.linalg.norm(ee_pqc_ik[...,:3] - ee_pq[0]) > 1e-3 or np.linalg.norm(ee_pqc_ik[...,3:] - ee_pq[1]) > 1e-3:
            p.removeBody(dish_uid)
            p.removeBody(dishholder_uid)
            p.removeBody(dishholder_support_uid)
            p.removeBody(floor_uid)
            print('run again')
            return create_dish_scene(models, seed+1, shakey, ik_func, robot_pb_uid, visualize=visualize)
        
        p.performCollisionDetection()
        if len(p.getContactPoints(robot_pb_uid)) != 0 or len(p.getContactPoints(dish_uid)) != 0:
            continue
        break
    
    # dish_oriCORN = models.mesh_aligned_canonical_obj[models.asset_path_util.get_obj_id(dish_obj_filename)]
    # dish_oriCORN = dish_oriCORN.apply_scale(dish_scale, center=np.zeros(3,))

    aux_pbids = [floor_uid, dishholder_support_uid]

    # Assign random color to the floor
    floor_color = np_rng.uniform(0, 1, 3)
    p.changeVisualShape(floor_uid, -1, rgbaColor=floor_color.tolist() + [1])
    dishholder_color = np_rng.uniform(0, 1, 3)
    p.changeVisualShape(dishholder_uid, -1, rgbaColor=dishholder_color.tolist() + [1])
    p.changeVisualShape(dishholder_support_uid, -1, rgbaColor=dishholder_color.tolist() + [1])

    return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q, ee_to_obj, aux_pbids

def create_multiple_dish_scene(models:mutil.Models, seed, shakey:shakey_module.Shakey, ik_func, robot_pb_uid, visualize=True):
    '''
    https://app.gazebosim.org/GoogleResearch/fuel/collections/Scanned%20Objects%20by%20Google%20Research
    '''

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)
    

    # transform pybullet debug camera
    p.resetDebugVisualizerCamera(cameraDistance=1.00, cameraYaw=-546.00, cameraPitch=-18.94, cameraTargetPosition=[0.07, 0.00, 0.01])
    p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)

    np_rng = np.random.default_rng(seed)
    # env_center = np.array([0,0.4,0.]) + np_rng.uniform([-0.2, -0.1, -0.1], [0.2, 0.1, 0.1])
    env_center = np.array([0,0.5,0.]) + np_rng.uniform([-0.2, 0.1, -0.15], [0.1, 0.1, 0.02])

    # generate floor
    floor_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=p.createCollisionShape(
            shapeType=p.GEOM_PLANE,
        ),
        baseVisualShapeIndex=p.createVisualShape(
            shapeType=p.GEOM_PLANE,
            rgbaColor=[0.8, 0.8, 0.8, 1.0],
        ),
        basePosition=[0, 0, -0.55],
        baseOrientation=[0, 0, 0, 1],
    )

    gripper_offset = 0.01

    def y_axis_grasp_sampler(aabb):
        aabb_y_offset = (aabb[1][1] - aabb[0][1]) / 2
        gripper_to_obj_quat = (sciR.from_euler('x', np.pi/2)*sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()
        gripper_to_obj = (jnp.array([0, 0, aabb_y_offset - gripper_offset]), gripper_to_obj_quat)
        return gripper_to_obj

    def z_axis_grasp_sampler(aabb):
        aabb_z_offset = (aabb[1][2] - aabb[0][2]) / 2
        gripper_to_obj_quat = (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()
        gripper_to_obj = (jnp.array([0, 0, aabb_z_offset - gripper_offset]), gripper_to_obj_quat)
        return gripper_to_obj

    def rectangle_grasp_sampler(aabb):
        sample_axis = np_rng.choice(['y', 'z'])
        print(sample_axis)
        if sample_axis == 'y':
            return y_axis_grasp_sampler(aabb)
        elif sample_axis == 'z':
            return z_axis_grasp_sampler(aabb)

    def round_grasp_sampler(aabb):
        radius_offset = (aabb[1][1] - aabb[0][1]) / 2
        angle = np_rng.uniform(0, np.pi/2)
        # print(aabb_x_offset, aabb_y_offset)
        gripper_to_obj_quat = (sciR.from_euler('x', angle)*sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()
        gripper_to_obj = (jnp.array([0, 0, radius_offset - gripper_offset]), gripper_to_obj_quat)
        return gripper_to_obj

    dishholder_obj_filename = 'GoogleScannedObjects/modified/Poppin_File_Sorter_White.obj'
    "GoogleScannedObjects/modified/Rubbermaid_Large_Drainer.obj"
    "Markings_Letter_Holder"
    dishholder_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dishholder_obj_filename)
    dishholder_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dishholder_obj_filename)), 'cvx/8_32_v4', os.path.basename(dishholder_obj_filename))

    dishholder_position = env_center + np.array([0,0,0.08])
    dishholder_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_obj_filename,
        meshScale=[1.0, 1.0, 1.0]
    )
    dishholder_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_cvx_filename,
        meshScale=[1.0, 1.0, 1.0]
    )
    dishholder_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=dishholder_collision_shape_id,
        baseVisualShapeIndex=dishholder_visual_shape_id,
        basePosition=dishholder_position,
        baseOrientation=[0, 0, 0, 1]
    )

    shelf_obj_filename = "shelf_gen/cvx/shelf_assembled.obj"
    #  \
    # "own_assets/shelf/modified/wall-shelf-028.obj"
    shelf_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(shelf_obj_filename)
    # shelf_cvx_filename = shelf_obj_filename, # os.path.join(os.path.dirname(os.path.dirname(shelf_obj_filename)), 'cvx/0_coacd', os.path.basename(shelf_obj_filename))
    shelf_positions = [
        env_center + np.array([-1.2, -0.5, 0]),
        env_center + np.array([-1.2, -0.5, 0.2]),
        env_center + np.array([-1.2, -1.0, 0]),
        env_center + np.array([-1.2, -1.0, 0.2]),
    ]
    shelf_uids = [
        create_multibody(
            shelf_obj_filename,
            position = shelf_position,
            quaternion = np.array([0,0,0,1]), # (sciR.from_euler('y', np.pi/2)).as_quat(), # *sciR.from_euler('z', np.pi/2)
            scale = 3,
            # collision_path = shelf_cvx_filename,
        ) for shelf_position in shelf_positions
    ]

    # Add a PyBullet box to support the dishholder
    dishholder_support_position = dishholder_position + np.array([0, 0, -0.07-0.4])
    dishholder_support_half_extents = [0.12, 0.08, 0.4]
    dishholder_support_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_BOX,
        halfExtents=dishholder_support_half_extents
    )
    # dishholder_support_collision_shape_id = p.createCollisionShape(
    #     shapeType=p.GEOM_BOX,
    #     halfExtents=dishholder_support_half_extents
    # )
    dishholder_support_uid = p.createMultiBody(
        baseMass=0.0,
        # baseCollisionShapeIndex=dishholder_support_collision_shape_id,
        baseVisualShapeIndex=dishholder_support_visual_shape_id,
        basePosition=dishholder_support_position,
        baseOrientation=[0, 0, 0, 1]
    )

    outside_dish_options = [
        (
            'GoogleScannedObjects/modified/Ecoforms_Plant_Saucer_SQ8COR.obj',
            np.array([0.087, 0, 0.18]),
            1.1,
            z_axis_grasp_sampler,
        ),
        (
            'GoogleScannedObjects/modified/Ecoforms_Plant_Saucer_S20MOCHA.obj',
            np.array([0.087, 0, 0.18]),
            1.1,
            z_axis_grasp_sampler,
        ),
        (
            'GoogleScannedObjects/modified/Ecoforms_Saucer_SQ3_Turquoise.obj',
            np.array([0.087, 0, 0.16]),
            1.2,
            y_axis_grasp_sampler,
        ),
        (
            "GoogleScannedObjects/modified/Cole_Hardware_Saucer_Glazed_6.obj",
            np.array([0.087, 0, 0.18]),
            1.1,
            z_axis_grasp_sampler,
        ),
        (
            "GoogleScannedObjects/modified/Cole_Hardware_Saucer_Electric.obj",
            np.array([0.087, 0, 0.18]),
            1.1,
            z_axis_grasp_sampler,
        ),
        (
            "GoogleScannedObjects/modified/Corningware_CW_by_Corningware_3qt_Oblong_Casserole_Dish_Blue.obj",
            np.array([0.087, 0, 0.18]),
            0.5,
            z_axis_grasp_sampler,
        ),
        (
            "GoogleScannedObjects/modified/Cole_Hardware_Plant_Saucer_Brown_125.obj",
            np.array([0.087, 0, 0.18]),
            0.65,
            z_axis_grasp_sampler,
        ),
        (
            "GoogleScannedObjects/modified/Threshold_Salad_Plate_Square_Rim_Porcelain.obj",
            np.array([0.087, 0, 0.18]),
            0.85,
            z_axis_grasp_sampler,
        ),
    ]

    inside_dish_options = [
        (
            'GoogleScannedObjects/modified/Ecoforms_Saucer_SQ3_Turquoise.obj',
            np.array([0.087, 0, 0.16]),
            1.2,
            y_axis_grasp_sampler,
        ),
        (
            "GoogleScannedObjects/modified/Cole_Hardware_Plant_Saucer_Brown_125.obj",
            np.array([0.087, 0, 0.20]),
            0.65,
            z_axis_grasp_sampler,
        ),
    ]

    # 'Ecoforms_Plant_Saucer_S20MOCHA'
    # 'Chefmate_8_Frypan'


    dish_sequences = np.concat([
        np_rng.integers(low=0, high=len(outside_dish_options), size=2),
        np_rng.integers(low=0, high=len(inside_dish_options), size=2),
    ])
    orders = [0, 3, 1] # , 3, 1] # , 2] # np_rng.permutation(range(4))
    place_order_translates = [
        np.array([0., 0., 0]),
        np.array([-0.0575, 0, 0]),
        np.array([-0.115, 0, 0]),
        np.array([-0.1725, 0, 0]),
    ]

    print(dish_sequences, orders)

    dish_uids = []
    ee_to_objs = []

    init_qs = []
    goal_qs = []

    success = True
    poses_after_sequence = []
    p.setGravity(0, 0, -9.8)
    time_step = p.getPhysicsEngineParameters()['fixedTimeStep']
    num_steps = int(1.0 / time_step)
    initial_seed = np.array([-np.pi/4, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, np.pi])
    dish_aabbs = []
    
    for idx, (shelf_uid, order, dish_idx) in enumerate(zip(shelf_uids, orders, dish_sequences)):
        if order in [0, 3]:
            dish_options = outside_dish_options
        else:
            dish_options = inside_dish_options

        dish_obj_filename, dish_translate, dish_scale, gripper_to_obj_sampler = dish_options[dish_idx]
        dish_translate = dish_translate + place_order_translates[order]

        dish_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dish_obj_filename)
        dish_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dish_obj_filename)), 'cvx/8_32_v4', os.path.basename(dish_obj_filename))

        dish_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_obj_filename,
            meshScale=[dish_scale] * 3
        )
        dish_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_cvx_filename,
            meshScale=[dish_scale] * 3
        )
        dish_position = env_center + dish_translate
        dish_orientation = (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()
        dish_uid = p.createMultiBody(
            baseMass=0.1,
            baseCollisionShapeIndex=dish_collision_shape_id,
            baseVisualShapeIndex=dish_visual_shape_id,
            basePosition=dish_position,
            baseOrientation=dish_orientation
        )
        dish_uids.append(dish_uid)
        dish_aabb = p.getAABB(dish_uid)
        dish_aabbs.append(dish_aabb)

        for cnt in range(3):
            # check place pose
            gripper_to_obj = gripper_to_obj_sampler(dish_aabb)
            ee_to_obj = tutil.pq_multi(*shakey.gripper_tip_offset_from_ee, *gripper_to_obj)
            ee_to_obj_pqc = np.concat(ee_to_obj, axis=-1)
            
            ee_pq = tutil.pq_multi(
                dish_position,
                dish_orientation,
                *tutil.pq_inv(
                    ee_to_obj[0],
                    ee_to_obj[1],
                ),
            )

            # init_q = 
            # init_q[0] = np_rng.choice([-np.pi/3, -np.pi/2, np.pi/2])
            # init_q += np_rng.uniform(-0.1, 0.1, size=init_q.shape[-1])
            # initial_seed = np.array([
            #     1.2659647, -2.0600793, -1.834076, 0.75256187, 1.8756336, 1.5707886
            # ]) if order in [0, 3] else np.array([
            #     2.3174083, -2.06007, -1.834077, 0.75256246, 0.8241769, 1.5707868
            # ])
            ik_q = ik_func(initial_seed, ee_pq)
            goal_q = ik_q[3:]
            shakey.set_q_pb(robot_pb_uid, goal_q, [dish_uid], ee_to_obj_pqc)

            ee_pqc_ik = shakey.get_ee_pq_pb(robot_pb_uid)
            if np.linalg.norm(ee_pqc_ik[...,:3] - ee_pq[0]) > 1e-2 or np.linalg.norm(sciR.from_quat(ee_pqc_ik[...,3:]).as_matrix() - sciR.from_quat(ee_pq[1]).as_matrix()) > 1e-2:
                print('ik not converged')
                continue

            is_collision = False
            p.performCollisionDetection()
            for contact_dish_uid in dish_uids:
                if len(p.getContactPoints(robot_pb_uid, contact_dish_uid)) != 0:
                    is_collision = True
                    break
            if len(p.getContactPoints(dish_uid)) != 0:
                is_collision = True
            if is_collision:
                continue

            for _ in range(num_steps):
                p.stepSimulation()

            obj_pose_after_place = p.getBasePositionAndOrientation(dish_uid)
            place_goal_q = goal_q

            # check pick pose
            # pick_ee_pq_position = 
            shelf_position = p.getBasePositionAndOrientation(shelf_uid)[0]

            ee_pq = np.concatenate([
                np.array(shelf_position) + np.array([0.6, 0, 0.3]),
                (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi)).as_quat(),
            ])
            (sciR.from_euler('x', -np.pi/2)*sciR.from_euler('z', np.pi/2)*sciR.from_euler('y', -np.pi/2)).as_quat()
            pick_init_q = goal_qs[-1] if len(goal_qs) > 0 else np.array([-np.pi/4, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, np.pi])
            ik_q = ik_func(
                np.array([2.2749536, -2.0172122, -1.6899214, 0.56555176, 0.86664075, 1.5708045]),
                ee_pq
            )
            pick_goal_q = ik_q[3:]
            place_init_q = ik_q[3:]
            print(pick_goal_q)
            shakey.set_q_pb(robot_pb_uid, pick_goal_q, [dish_uid], ee_to_obj_pqc)
            obj_pose_when_pick =  p.getBasePositionAndOrientation(dish_uid)

            ee_pqc_ik = shakey.get_ee_pq_pb(robot_pb_uid)
            # array([[-5.45213759e-01,  1.95859502e-06, -4.30340990e-02,
            # -7.07105100e-01, -1.05200061e-05,  7.07108438e-01,
            # -1.80933102e-06]])
            if np.linalg.norm(ee_pqc_ik[...,:3] - ee_pq[:3]) > 1e-1 or np.linalg.norm(sciR.from_quat(ee_pqc_ik[...,3:]).as_matrix() - sciR.from_quat(ee_pq[3:]).as_matrix()) > 1e-1:
                print('ik not converged')
                continue

            is_collision = False
            p.performCollisionDetection()
            for contact_dish_uid in dish_uids:
                if len(p.getContactPoints(robot_pb_uid, contact_dish_uid)) != 0:
                    is_collision = True
                    break
            if len(p.getContactPoints(dish_uid)) != 0:
                is_collision = True
            if is_collision:
                continue


            # move dish to place after pose
            p.resetBasePositionAndOrientation(dish_uid, *obj_pose_after_place)
            obj_pose_when_pick = jnp.concatenate([jnp.array(obj_pose_when_pick[0]), jnp.array(obj_pose_when_pick[1])], axis=0)
            obj_pose_after_place = jnp.concatenate([jnp.array(obj_pose_after_place[0]), jnp.array(obj_pose_after_place[1])], axis=0)

            poses_after_sequence.extend([obj_pose_when_pick, obj_pose_after_place])
            ee_to_objs.extend([ee_to_obj])
            init_qs.extend([pick_init_q, place_init_q])
            goal_qs.extend([pick_goal_q, place_goal_q])

            break
        else:
            success = False
            break
    if not success:
        for dish_uid in dish_uids:
            p.removeBody(dish_uid)
        for shelf_uid in shelf_uids:
            p.removeBody(shelf_uid)
        p.removeBody(dishholder_uid)
        p.removeBody(dishholder_support_uid)
        p.removeBody(floor_uid)
        next_seed = np_rng.integers(low=0, high=1e6)
        return create_multiple_dish_scene(models, next_seed, shakey, ik_func, robot_pb_uid, visualize=visualize)

    for dish_uid in dish_uids:
        p.resetBasePositionAndOrientation(dish_uid, [0, 0, 0], [0, 0, 0, 1.])
    movable_obj_uids = dish_uids
    fixed_obj_uids = [dishholder_uid, *shelf_uids]
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    aux_pbids = [floor_uid, dishholder_support_uid]

    # Assign random color to the floor
    floor_color = np_rng.uniform(0, 1, 3)
    p.changeVisualShape(floor_uid, -1, rgbaColor=floor_color.tolist() + [1])
    dishholder_color = np_rng.uniform(0, 1, 3)
    p.changeVisualShape(dishholder_uid, -1, rgbaColor=dishholder_color.tolist() + [1])
    p.changeVisualShape(dishholder_support_uid, -1, rgbaColor=dishholder_color.tolist() + [1])
    init_qs = jnp.stack(init_qs)
    goal_qs = jnp.stack(goal_qs)
    poses_after_sequence = jnp.stack(poses_after_sequence)

    cameraDistance        = 2.20
    cameraYaw             = -594.40
    cameraPitch           = -38.54
    cameraTargetPosition  = [-0.38, -0.32, -0.43]
    p.resetDebugVisualizerCamera(cameraDistance=cameraDistance, cameraYaw=cameraYaw, cameraPitch=cameraPitch, cameraTargetPosition=cameraTargetPosition)

    view_matrix = p.computeViewMatrixFromYawPitchRoll(
        cameraTargetPosition, 
        cameraDistance, 
        cameraYaw, 
        cameraPitch, 
        roll=0,
        upAxisIndex=2
    )

    # import open3d as o3d
    # from util.reconstruction_util import create_scene_mesh_from_oriCORNs, create_fps_fcd_from_oriCORNs

    # fixed_obj = scene_converter.fixed_oriCORNs[-4:]
    # pcd_fps = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(fixed_obj.fps_tf.reshape(-1,3)))
    # # o3d.visualization.draw_geometries([pcd_fps])

    # dec = jax.jit(shakey.models.occ_prediction)
    # mesh = create_scene_mesh_from_oriCORNs(fixed_obj, dec, visualize=False, qp_bound=0.6, density=400)
    # # mesh = create_scene_mesh_from_oriCORNs(None, dec=jax.jit(dec_func), visualize=False, density=300, ndiv=10000, qp_bound=2.0)
    # mesh.compute_vertex_normals()
    # # mesh.paint_uniform_color((0.937, 0.278, 0.435))
    # # create_scene_mesh_from_oriCORNs(None, dec=dec_func, visualize=True, density=100, ndiv=800, qp_bound=0.2)
    # o3d.visualization.draw_geometries([mesh, pcd_fps])
    print(seed)
    return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, init_qs, goal_qs, ee_to_objs, poses_after_sequence, aux_pbids, view_matrix


def create_pen_scene(models:mutil.Models, seed, shakey:shakey_module.Shakey, ik_func, robot_pb_uid, visualize=True):
    '''
    'Granimals_20_Wooden_ABC_Blocks_Wagon_g2TinmUGGHI' # pen

    'Markings_Desk_Caddy' # pen holder
    'Markings_Letter_Holder' # pen holder
    https://app.gazebosim.org/GoogleResearch/fuel/collections/Scanned%20Objects%20by%20Google%20Research
    '''

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)

    env_center = np.array([0,0.4,0.]) + np.random.default_rng(seed).uniform([-0.2, -0.1, -0.1], [0.2, 0.1, 0.1])
    env_orientation = np.array([0,0,0,1.])
    
    dish_obj_filename = 'GoogleScannedObjects/modified/Granimals_20_Wooden_ABC_Blocks_Wagon_g2TinmUGGHI.obj'
    dish_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dish_obj_filename)
    dish_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dish_obj_filename)), 'cvx/8_32_v4', os.path.basename(dish_obj_filename))
    dish_scale = 0.8
    # dish_translate = np.array([0.03, 0, 0.14])
    dish_translate = np.array([0.0, -0.05, 0.14])

    dish_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=dish_obj_filename,
        meshScale=[dish_scale] * 3
    )
    dish_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=dish_cvx_filename,
        meshScale=[dish_scale] * 3
    )
    dish_position = env_center + dish_translate
    dish_orientation = tutil.aa2q(np.array([0, np.pi/2, 0]))
    dish_uid = p.createMultiBody(
        baseMass=0.1,
        baseCollisionShapeIndex=dish_collision_shape_id,
        baseVisualShapeIndex=dish_visual_shape_id,
        basePosition=dish_position,
        baseOrientation=dish_orientation
    )

    dishholder_obj_filename = 'GoogleScannedObjects/modified/Markings_Desk_Caddy.obj'
    dishholder_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dishholder_obj_filename)
    dishholder_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dishholder_obj_filename)), 'cvx/8_32_v4', os.path.basename(dishholder_obj_filename))

    dishholder_position = env_center + np.array([0,0,0.08])
    dishholder_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_obj_filename,
        meshScale=[1.0, 1.0, 1.0]
    )
    dishholder_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_cvx_filename,
        meshScale=[1.0, 1.0, 1.0]
    )
    dishholder_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=dishholder_collision_shape_id,
        baseVisualShapeIndex=dishholder_visual_shape_id,
        basePosition=dishholder_position,
        baseOrientation=[0, 0, 0, 1]
    )

    # dish_holder_oriCORN = models.mesh_aligned_canonical_obj[models.asset_path_util.get_obj_id(dishholder_obj_filename)]
    # dish_holder_oriCORN = dish_holder_oriCORN.translate(jnp.array([0, 0, 0.08]))

    movable_obj_uids = [dish_uid]
    fixed_obj_uids = [dishholder_uid]
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, fixed_obj_uids, robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    # calculate goal q
    gripper_to_obj = (jnp.array([0,0,0.050]), tutil.aa2q(np.array([0, -np.pi/2, 0])))

    ee_to_obj = tutil.pq_multi(*shakey.gripper_tip_offset_from_ee, *gripper_to_obj)
    
    ee_pq = tutil.pq_multi(
                dish_position,
                dish_orientation,
                *tutil.pq_inv(
                    ee_to_obj[0],
                    ee_to_obj[1],
                ),
            )
    # init_q = np.array([0, -np.pi/6, -np.pi/6, -np.pi/6, np.pi/2, np.pi])
    init_q = np.array([-np.pi/4, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, np.pi])
    shakey.set_q_pb(robot_pb_uid, init_q)
    ik_q = ik_func(init_q, ee_pq)
    goal_q = ik_q[3:]
    shakey.set_q_pb(robot_pb_uid, goal_q)
    
    # from util.reconstruction_util import create_scene_mesh_from_oriCORNs, create_fps_fcd_from_oriCORNs
    # # create_fps_fcd_from_oriCORNs(dish_holder_oriCORN.stack(dish_oriCORN, axis=0), visualize=True)
    # dec = jax.jit(partial(models.occ_prediction))
    # create_scene_mesh_from_oriCORNs(scene_converter.fixed_oriCORNs.stack(scene_converter.movable_oriCORNs, axis=0), dec=dec, qp_bound=0.2, visualize=True)

    dish_oriCORN = models.mesh_aligned_canonical_obj[models.asset_path_util.get_obj_id(dish_obj_filename)]
    dish_oriCORN = dish_oriCORN.apply_scale(0.8)

    return dish_oriCORN, scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q, ee_to_obj

def create_room_scene(models:mutil.Models, seed, shakey:shakey_module.Shakey, robot_pb_uid, visualize=True):
    '''
    '''

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)

    # env_center = np.array([0,0.4,0.]) + np.random.default_rng(seed).uniform([-0.2, -0.1, -0.1], [0.2, 0.1, 0.1])
    env_center = np.zeros(3)
    env_orientation = np.array([0,0,0,1.])

    # create floor
    
    
    pybullet_data_path = pybullet_data.getDataPath()
    floor_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=os.path.join(pybullet_data_path, 'plane.obj'),
        meshScale=[2, 2, 1]
    )
    floor_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=os.path.join(pybullet_data_path, 'plane.obj'),
        meshScale=[2, 2, 1]
    )
    floor_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=floor_collision_shape_id,
        baseVisualShapeIndex=floor_visual_shape_id,
        basePosition=env_center,
        baseOrientation=[0, 0, 0, 1]
    )

    # dishholder_obj_filename = "assets/room/raw/wall.obj"
    # dishholder_cvx_filename = "assets/room/raw/wall.obj"
    # dishholder_obj_filename = "assets/room/raw/wall_no_floor.obj"
    # dishholder_cvx_filename = "assets/room/raw/wall_no_floor.obj"
    dishholder_obj_filename = "assets/room/raw/room_level2.obj"
    dishholder_cvx_filename = "assets/room/raw/room_level2.obj"
    

    env_scale = 2.0

    dishholder_position = env_center
    dishholder_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    dishholder_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=dishholder_cvx_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    dishholder_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=dishholder_collision_shape_id,
        baseVisualShapeIndex=dishholder_visual_shape_id,
        basePosition=dishholder_position,
        baseOrientation=[0, 0, 0, 1]
    )


    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        [], [dishholder_uid], robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)
    scene_converter.fixed_oriCORNs

    np_rng = np.random.default_rng(seed)
    def valid_state_sample(np_rng, pos_bound_lower:np.ndarray, pos_bound_upper:np.ndarray):
        lower_bound = np.array([*pos_bound_lower.tolist(), -np.pi, -2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
        upper_bound = np.array([*pos_bound_upper.tolist(), np.pi, 2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])

        while True:
            init_q = np_rng.uniform(lower_bound, upper_bound)
            shakey.set_q_pb(robot_pb_uid, init_q)
            p.performCollisionDetection()
            col_res = p.getContactPoints(robot_pb_uid)
            if len(col_res) == 0:
                break
        return init_q

    init_q = valid_state_sample(np_rng, np.array([-5, -2.0]), np.array([-1, 2.0]))
    shakey.set_q_pb(robot_pb_uid, init_q)
    goal_q = valid_state_sample(np_rng, np.array([1, -2.0]), np.array([5, 2.0]))
    shakey.set_q_pb(robot_pb_uid, goal_q)

    return scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q



def create_room_pen_scene(models:mutil.Models, seed, ik_func, shakey:shakey_module.Shakey, robot_pb_uid, visualize=True):
    '''
    '''

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)

    env_center = np.zeros(3)

    # create floor
    pybullet_data_path = pybullet_data.getDataPath()
    floor_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=os.path.join(pybullet_data_path, 'plane.obj'),
        meshScale=[1, 1, 1]
    )
    floor_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=os.path.join(pybullet_data_path, 'plane.obj'),
        meshScale=[1, 1, 1]
    )
    floor_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=floor_collision_shape_id,
        baseVisualShapeIndex=floor_visual_shape_id,
        basePosition=env_center,
        baseOrientation=[0, 0, 0, 1]
    )
    # Remove texture from the floor and assign a random color
    floor_color = np.random.uniform(0, 1, 3)
    p.changeVisualShape(floor_uid, -1, rgbaColor=floor_color.tolist() + [1])

    # dishholder_obj_filename = "assets/room/raw/wall.obj"
    # dishholder_cvx_filename = "assets/room/raw/wall.obj"
    # wall_obj_filename = "assets/room/raw/wall_no_floor.obj"
    # wall_cvx_filename = "assets/room/raw/wall_no_floor.obj"
    # dishholder_obj_filename = "assets/room/raw/room_level2.obj"
    # dishholder_cvx_filename = "assets/room/raw/room_level2.obj"
    # wall_obj_filename = "assets/room/raw/room_no_floor_v3.obj"
    # wall_cvx_filename = "assets/room/raw/room_no_floor_v3.obj"
    wall_obj_filename = "assets/room/raw/room_no_floor_v2.obj"
    wall_cvx_filename = "assets/room/raw/room_no_floor_v2.obj"
    env_scale = 2.0

    wall_position = env_center
    wall_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=wall_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    wall_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=wall_cvx_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    wall_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=wall_collision_shape_id,
        baseVisualShapeIndex=wall_visual_shape_id,
        basePosition=wall_position,
        baseOrientation=[0, 0, 0, 1]
    )


    wall2_obj_filename = "assets/room/raw/room_no_floor_wall.obj"
    wall2_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=wall2_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    wall2_uid = p.createMultiBody(
        baseMass=0.0,
        baseVisualShapeIndex=wall2_visual_shape_id,
        basePosition=wall_position,
        baseOrientation=[0, 0, 0, 1]
    )

    np_rng = np.random.default_rng(seed)
    for i in range(100):
        penholder_obj_filename = 'GoogleScannedObjects/modified/Markings_Desk_Caddy.obj'
        penholder_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(penholder_obj_filename)
        penholder_obj_filename = os.path.join(os.path.dirname(os.path.dirname(penholder_obj_filename)), 'cvx/8_32_v4', os.path.basename(penholder_obj_filename))
        penholder_scale = 2.0
        penholder_position = env_center + np.array([2.0,0,0.43]) + np_rng.uniform([-0.1, -1.5, -0.05], [0.1, 1.5, 0.05])
        penholder_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=penholder_obj_filename,
            meshScale=[penholder_scale]*3
        )
        penholder_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=penholder_obj_filename,
            meshScale=[penholder_scale]*3
        )
        penholder_uid = p.createMultiBody(
            baseMass=0.0,
            baseCollisionShapeIndex=penholder_collision_shape_id,
            baseVisualShapeIndex=penholder_visual_shape_id,
            basePosition=penholder_position,
            baseOrientation=[0, 0, 0, 1]
        )

        penholder_table_position = penholder_position + np.array([0, 0, -0.25])
        # penholder_table_scale = penholder_scale * 0.5
        penholder_table_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_BOX,
            halfExtents=[0.05, 0.22, 0.2]
        )
        penholder_table_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_BOX,
            halfExtents=[0.05, 0.22, 0.2]
        )
        penholder_table_uid = p.createMultiBody(
            baseMass=0.0,
            baseCollisionShapeIndex=penholder_table_collision_shape_id,
            baseVisualShapeIndex=penholder_table_visual_shape_id,
            basePosition=penholder_table_position,
            baseOrientation=[0, 0, 0, 1]
        )

        penholder_color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(penholder_uid, -1, rgbaColor=penholder_color.tolist() + [1])
        p.changeVisualShape(penholder_table_uid, -1, rgbaColor=penholder_color.tolist() + [1])
        wall_color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(wall_uid, -1, rgbaColor=wall_color.tolist() + [1])
        p.changeVisualShape(wall2_uid, -1, rgbaColor=wall_color.tolist() + [1])

        dish_obj_filename = 'GoogleScannedObjects/modified/Granimals_20_Wooden_ABC_Blocks_Wagon_g2TinmUGGHI.obj'
        dish_obj_filename = models.asset_path_util.get_obj_path_from_rel_path(dish_obj_filename)
        dish_cvx_filename = os.path.join(os.path.dirname(os.path.dirname(dish_obj_filename)), 'cvx/8_32_v4', os.path.basename(dish_obj_filename))
        dish_scale = 0.7*penholder_scale
        dish_translate = np.array([0.0, -0.09*penholder_scale, 0.06*penholder_scale])

        dish_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_obj_filename,
            meshScale=[dish_scale] * 3
        )
        dish_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_cvx_filename,
            meshScale=[dish_scale] * 3
        )
        dish_position = penholder_position + dish_translate
        dish_orientation = tutil.aa2q(np.array([0, np.pi/2, 0]))
        dish_uid = p.createMultiBody(
            baseMass=0.1,
            baseCollisionShapeIndex=dish_collision_shape_id,
            baseVisualShapeIndex=dish_visual_shape_id,
            basePosition=dish_position,
            baseOrientation=dish_orientation
        )

        # calculate goal q
        gripper_to_obj = (jnp.array([0,0,0.05*penholder_scale]), (sciR.from_euler('z', np.pi/2)*sciR.from_euler('y', -np.pi/2)).as_quat())

        ee_to_obj = tutil.pq_multi(*shakey.gripper_tip_offset_from_ee, *gripper_to_obj)
        
        ee_pq = tutil.pq_multi(
                    dish_position,
                    dish_orientation,
                    *tutil.pq_inv(
                        ee_to_obj[0],
                        ee_to_obj[1],
                    ),
                )
        
        for i in range(100):
            ik_init_q = np.array([0, np_rng.uniform(-1.5, 1.5), np.pi, 0, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, 0.0])
            # shakey.set_q_pb(robot_pb_uid, ik_init_q)
            goal_q = ik_func(ik_init_q, ee_pq)
            shakey.set_q_pb(robot_pb_uid, goal_q)

            p.performCollisionDetection()
            contacts = p.getContactPoints(robot_pb_uid)
            if len(contacts) == 0 or np.all([ct[2]==dish_uid for ct in contacts]):
                break
            
        p.performCollisionDetection()
        contacts = p.getContactPoints(robot_pb_uid)
        if len(contacts) == 0 or np.all([ct[2]==dish_uid for ct in contacts]):
            break
        else:
            p.removeBody(penholder_uid)
            p.removeBody(penholder_table_uid)
            p.removeBody(dish_uid)

    aux_pbids = [penholder_table_uid, floor_uid, wall2_uid]

    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        [dish_uid], 
        [wall_uid, penholder_uid], 
        robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    def valid_state_sample(np_rng, pos_bound_lower:np.ndarray, pos_bound_upper:np.ndarray):
        lower_bound = np.array([*pos_bound_lower.tolist(), np.pi/2, -2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
        upper_bound = np.array([*pos_bound_upper.tolist(), np.pi*3/2, 2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])

        while True:
            init_q = np_rng.uniform(lower_bound, upper_bound)
            shakey.set_q_pb(robot_pb_uid, init_q)
            p.performCollisionDetection()
            col_res = p.getContactPoints(robot_pb_uid)
            if len(col_res) == 0:
                break
        return init_q

    init_q = valid_state_sample(np_rng, np.array([-3.0, -1.6]), np.array([-1, 1.6]))
    ee_to_obj = jnp.concat(ee_to_obj)

    movable_canonical_oriCORN = models.mesh_aligned_canonical_obj[models.asset_path_util.get_obj_id(dish_obj_filename)]
    movable_canonical_oriCORN = movable_canonical_oriCORN.apply_scale(dish_scale)

    return movable_canonical_oriCORN, scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q, ee_to_obj, aux_pbids



def create_bimanual_insertion(models:mutil.Models, seed, ik_func, shakey:shakey_module.Shakey, robot_pb_uid, visualize=False):

    # transform pybullet debug camera
    p.resetDebugVisualizerCamera(cameraDistance=1.00, cameraYaw=-261.2, cameraPitch=-34.8, cameraTargetPosition=[0.01, 0.05, 0.4])
    p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)

    
    # jkey = jax.random.PRNGKey(seed)
    np_rng = np.random.default_rng(seed)

    hole_center = np.array([0.3, 0, 0.8])
    # hole_orientation = np.array([0,0,0,1.])
    hole_orientation = sciR.from_euler('z', np.pi/2).as_quat()
    env_scale = 0.19
    hole_obj_filename = 'assets/assembly/raw/hole_v5.obj'
    # hole_obj_filename = 'assets/assembly/raw/hole_v3.obj'
    hole_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=hole_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    hole_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=hole_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    hole_uid = p.createMultiBody(
        baseMass=1.0,
        baseVisualShapeIndex=hole_visual_shape_id,
        baseCollisionShapeIndex=hole_collision_shape_id,
        basePosition=hole_center,
        baseOrientation=hole_orientation
    )

    hold_offset = 0.045
    # peg_offset = -0.02
    # peg_offset = 0.010
    # peg_offset = 0.030
    peg_offset = np_rng.uniform(0.020, 0.040)

    peg_center = hole_center + np. array([0,0,peg_offset])
    peg_orientation = hole_orientation
    peg_scale = env_scale * 1.0

    peg_obj_filename = 'assets/assembly/raw/peg_v5.obj'
    # peg_obj_filename = 'assets/assembly/raw/peg_v3.obj'
    peg_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=peg_obj_filename,
        meshScale=[peg_scale]*3
    )
    peg_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=peg_obj_filename,
        meshScale=[peg_scale]*3
    )
    peg_uid = p.createMultiBody(
        baseMass=1.0,
        baseVisualShapeIndex=peg_visual_shape_id,
        baseCollisionShapeIndex=peg_collision_shape_id,
        basePosition=peg_center,
        baseOrientation=peg_orientation
    )

    movable_obj_uids = [peg_uid, hole_uid]


    # collision filter
    p.setCollisionFilterPair(robot_pb_uid, peg_uid, 6, -1, 0)
    # p.setCollisionFilterPair(robot_pb_uid, peg_uid, 5, -1, 0)
    p.setCollisionFilterPair(robot_pb_uid, hole_uid, 12, -1, 0)
    # p.setCollisionFilterPair(robot_pb_uid, hole_uid, 11, -1, 0)


    # obj_to_gripper_offset = np.stack([np.concat([np.array([-0.02, 0, 0.03]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
    #                   np.concat([np.array([-0.11, 0.0, 0.0]), (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
    
    obj_to_gripper_offset = np.stack([np.concat([np.array([-0.02, 0, hold_offset-peg_offset]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
                      np.concat([np.array([-0.11, 0.0, 0.0]), (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
    gripper_to_obj_pqc = tutil.pq_inv(obj_to_gripper_offset)

    ee_to_obj = tutil.pq_multi(jnp.concat(shakey.gripper_tip_offset_from_ee), gripper_to_obj_pqc)

    def sample_valid_q():
        lower_bound = np.array(copy.deepcopy(shakey.q_lower_bound))
        lower_bound = lower_bound*0.3
        lower_bound[[4,10]] = np.array([-1.065, -2.2996]) - np.pi/6
        upper_bound = np.array(copy.deepcopy(shakey.q_upper_bound))
        upper_bound = upper_bound*0.3
        upper_bound[[4,10]] = np.array([-1.065, -2.2996]) + np.pi/6
        q_sample = np_rng.uniform(lower_bound, upper_bound)
        shakey.set_q_pb(robot_pb_uid, q_sample, movable_obj_uids, ee_to_obj)
        p.performCollisionDetection()
        col_res = p.getContactPoints(robot_pb_uid)
        if len(col_res) != 0:
            return sample_valid_q()
        return q_sample

    init_q = sample_valid_q()

    pq_ee = np.stack([np.concat([peg_center, peg_orientation]), 
                      np.concat([hole_center, hole_orientation])], axis=0)
    pq_ee = tutil.pq_multi(pq_ee, obj_to_gripper_offset)
    goal_q = ik_func(np.zeros_like(init_q), pq_ee)

    shakey.set_q_pb(robot_pb_uid, goal_q, movable_obj_uids, ee_to_obj)

    p.performCollisionDetection()
    col_res = p.getContactPoints(robot_pb_uid)
    assert len(col_res) == 0

    p.resetBasePositionAndOrientation(peg_uid, (0,0,0), (0,0,0,1))
    p.resetBasePositionAndOrientation(hole_uid, (0,0,0), (0,0,0,1))
    scene_converter = SceneConverter()

    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, 
        [],
        robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)
    
    shakey.set_q_pb(robot_pb_uid, goal_q, movable_obj_uids, ee_to_obj)

    fixed_moving_idx_pair = [(np.array([11, 12, -1]), np.array([0, 5, 6, -2])),
                             (np.array([0]), np.array([5, 6, -2]))]
    
    # give colors to the robot and objects
    # Assign random colors to movable objects
    for obj in pybullet_scene.movable_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to fixed objects
    for obj in pybullet_scene.fixed_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to the robot's links
    colored_link_idx = [np.concat(fimi) for fimi in fixed_moving_idx_pair]
    colored_link_idx = np.concat(colored_link_idx, axis=0)
    colored_link_idx = np.unique(colored_link_idx)
    colored_link_idx = colored_link_idx[colored_link_idx >= 0]
    for i in colored_link_idx:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(robot_pb_uid, i, rgbaColor=color.tolist() + [1])

    # fixed_moving_idx_pair = None
    return (scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, 
            init_q[...,-shakey.num_act_joints:], goal_q[...,-shakey.num_act_joints:], ee_to_obj, fixed_moving_idx_pair)




def create_bimanual_insertion_hard(models:mutil.Models, seed, ik_func, shakey:shakey_module.Shakey, robot_pb_uid, visualize=False):

    # transform pybullet debug camera
    cameraTargetPosition = [0.01, 0.05, 0.5]
    cameraDistance = 1.2
    cameraYaw = -261.2
    cameraPitch = -34.8
    p.resetDebugVisualizerCamera(cameraDistance=cameraDistance, cameraYaw=cameraYaw, 
                                 cameraPitch=cameraPitch, cameraTargetPosition=cameraTargetPosition)
    p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)

    view_matrix = p.computeViewMatrixFromYawPitchRoll(
        cameraTargetPosition, 
        cameraDistance, 
        cameraYaw, 
        cameraPitch, 
        roll=0,
        upAxisIndex=2
    )

    
    # jkey = jax.random.PRNGKey(seed)
    np_rng = np.random.default_rng(seed)
    obj_to_gripper_offset = None
    peg_scale = None

    env_idx = np_rng.integers(0, 4)

    if env_idx == 0:
        # kinfe hole
        # env_scale = 0.19
        env_scale = 0.25
        peg_scale = 0.22
        hold_offset = 0.05
        peg_offset = 0.0
        hole_obj_filename = 'assets/assembly/raw/hole_v5.obj'
        # peg_obj_filename = 'assets/assembly/raw/peg_v5.obj'
        peg_obj_filename = 'assets/assembly/cvx/coacd/peg_v5.obj'
        # peg_obj_filename = 'assets/assembly/raw/knife.obj'
        # peg_obj_filename = 'assets/assembly/cvx/coacd/knife.obj'
        env_center = np.array([0.3, 0.1, 0.75])
        scale_factor = env_scale/0.19
        hole_obj_to_gripper_position = scale_factor*np.array([-0.08, 0.0, 0.01])
        env_orientation = sciR.from_euler('z', np.pi/2).as_quat()
        obj_to_gripper_offset = np.stack([np.concat([scale_factor*np.array([-0.02, 0, hold_offset-peg_offset]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
                            np.concat([hole_obj_to_gripper_position, (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
        peg_offset *= scale_factor
    elif env_idx == 1:
        # dual peg
        # env_scale = 0.3
        env_scale = 0.45
        hold_offset = 0.08
        peg_offset = 0.00
        hole_obj_filename = 'assets/assembly/cvx/coacd/hole_var4.obj'
        # hole_obj_filename = 'assets/assembly/raw/hole_var4.obj'
        peg_obj_filename = 'assets/assembly/raw/peg_var4.obj'
        env_center = np.array([0.3, 0.1, 0.75])
        scale_factor = env_scale/0.3
        hole_obj_to_gripper_position = scale_factor*np.array([-0.11, 0.0, 0.0])
        env_orientation = (sciR.from_euler('z', np.pi/2)).as_quat()
        obj_to_gripper_offset = np.stack([np.concat([scale_factor*np.array([-0.03, 0, hold_offset-peg_offset]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
                            np.concat([hole_obj_to_gripper_position, (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
        peg_offset *= scale_factor
    elif env_idx == 2:
        # chair
        env_scale = 0.19
        # env_scale = 0.15
        peg_offset = 0.02
        hole_obj_filename = 'assets/assembly/raw/hole_var5.obj'
        peg_obj_filename = 'assets/assembly/cvx/coacd/peg_var5.obj'
        # peg_obj_filename = 'assets/assembly/raw/peg_var5.obj'
        env_center = np.array([0.46, 0.0, 0.75])
        scale_factor = env_scale/0.15
        env_orientation = (sciR.from_euler('z', -np.pi/6)*sciR.from_euler('z', np.pi/2)).as_quat()
        obj_to_gripper_offset = np.stack([np.concat([scale_factor*np.array([0.10, 0.10, 0.020]), (sciR.from_euler('z', np.pi/4)*sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
                            np.concat([scale_factor*np.array([-0.15, 0.0, 0.01]), (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
        peg_offset *= scale_factor
    elif env_idx == 3:
        # bookshelf
        # env_scale = 0.25
        env_scale = 0.35
        peg_offset = 0.00
        # hole_obj_filename = 'assets/assembly/cvx/coacd/hole_var6.obj'
        hole_obj_filename = 'assets/assembly/raw/hole_var6.obj'
        peg_obj_filename = 'assets/assembly/raw/peg_var6.obj'
        env_center = np.array([0.35, 0.0, 0.75])
        env_orientation = (sciR.from_euler('z', np.pi/2)).as_quat()
        scale_factor = env_scale/0.25
        obj_to_gripper_offset = np.stack([np.concat([scale_factor*np.array([0.04, 0.01, 0.10]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
                            np.concat([scale_factor*np.array([-0.03, -0.05, 0.04]), (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
        peg_offset *= scale_factor

    # elif env_idx == 4:
    #     env_scale = 0.25
    #     peg_offset = 0.00
    #     # hole_obj_filename = 'assets/assembly/raw/peg_var6.obj'
    #     hole_obj_filename = 'assets/assembly/cvx/coacd/hole_var7.obj'
    #     peg_obj_filename = 'assets/assembly/raw/peg_var7.obj'
    #     env_center = np.array([0.35, 0.0, 0.75])
    #     env_orientation = (sciR.from_euler('z', np.pi/2)).as_quat()
    #     obj_to_gripper_offset = np.stack([np.concat([np.array([0.07, 0.01, 0.020]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
    #                         np.concat([np.array([-0.07, -0.06, 0.04]), (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)


    hole_center = env_center
    hole_orientation = env_orientation
    hole_col_filename = os.path.join(os.path.dirname(os.path.dirname(hole_obj_filename)), 'cvx/coacd', os.path.basename(hole_obj_filename))
    if not os.path.exists(hole_col_filename):
        hole_col_filename = hole_obj_filename
    hole_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=hole_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    hole_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=hole_col_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    hole_uid = p.createMultiBody(
        baseMass=1.0,
        baseVisualShapeIndex=hole_visual_shape_id,
        baseCollisionShapeIndex=hole_collision_shape_id,
        basePosition=hole_center,
        baseOrientation=hole_orientation
    )


    peg_center = hole_center + np. array([0,0,peg_offset])
    peg_orientation = hole_orientation
    if peg_scale is None:
        peg_scale = env_scale

    # peg_obj_filename = 'assets/assembly/raw/peg_v3.obj'
    peg_col_filename = os.path.join(os.path.dirname(os.path.dirname(peg_obj_filename)), 'cvx/coacd', os.path.basename(peg_obj_filename))
    if not os.path.exists(peg_col_filename):
        peg_col_filename = peg_obj_filename
    peg_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=peg_obj_filename,
        meshScale=[peg_scale]*3
    )
    peg_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=peg_col_filename,
        meshScale=[peg_scale]*3
    )
    peg_uid = p.createMultiBody(
        baseMass=1.0,
        baseVisualShapeIndex=peg_visual_shape_id,
        baseCollisionShapeIndex=peg_collision_shape_id,
        basePosition=peg_center,
        baseOrientation=peg_orientation
    )

    movable_obj_uids = [peg_uid, hole_uid]

    # collision filter
    p.setCollisionFilterPair(robot_pb_uid, peg_uid, 6, -1, 0)
    # p.setCollisionFilterPair(robot_pb_uid, peg_uid, 5, -1, 0)
    p.setCollisionFilterPair(robot_pb_uid, hole_uid, 12, -1, 0)
    # p.setCollisionFilterPair(robot_pb_uid, hole_uid, 11, -1, 0)


    # obj_to_gripper_offset = np.stack([np.concat([np.array([-0.02, 0, 0.03]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
    #                   np.concat([np.array([-0.11, 0.0, 0.0]), (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)

    # if obj_to_gripper_offset is None:
    #     obj_to_gripper_offset = np.stack([np.concat([np.array([-0.02, 0, hold_offset-peg_offset]), (sciR.from_euler('y', -np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()]), 
    #                     np.concat([hole_obj_to_gripper_position, (sciR.from_euler('y', np.pi/2)*sciR.from_euler('z', np.pi/2)).as_quat()])], axis=0)
    gripper_to_obj_pqc = tutil.pq_inv(obj_to_gripper_offset)

    ee_to_obj = tutil.pq_multi(jnp.concat(shakey.gripper_tip_offset_from_ee), gripper_to_obj_pqc)

    def sample_valid_q():
        lower_bound = np.array(copy.deepcopy(shakey.q_lower_bound))
        lower_bound = lower_bound*0.3
        lower_bound[[4,10]] = np.array([-1.065, -2.2996]) - np.pi/6
        upper_bound = np.array(copy.deepcopy(shakey.q_upper_bound))
        upper_bound = upper_bound*0.3
        upper_bound[[4,10]] = np.array([-1.065, -2.2996]) + np.pi/6
        q_sample = np_rng.uniform(lower_bound, upper_bound)
        shakey.set_q_pb(robot_pb_uid, q_sample, movable_obj_uids, ee_to_obj)
        p.performCollisionDetection()
        col_res = p.getContactPoints(robot_pb_uid)
        if len(col_res) != 0:
            return sample_valid_q()
        col_res = p.getContactPoints(peg_uid)
        # for cr in col_res:
        if hole_uid in [cr[2] for cr in col_res]:
            return sample_valid_q()
        return q_sample


    obj_pq_base = np.stack([np.concat([peg_center, peg_orientation]), 
                      np.concat([hole_center, hole_orientation])], axis=0)
    init_q = sample_valid_q()
    
    def sample_valid_obj_pq():

        valid = False
        while not valid:
            rand_pos = np_rng.uniform([-0.08, -0.08, -0.08], [0.08, 0.08, 0.08])
            rand_quat = np_rng.normal(0, 1, 3)*0.2
            rand_quat = tutil.qExp(rand_quat)
            # rand_quat /= np.linalg.norm(rand_quat)
            # rand_pos = np.array([0,0,0])
            # rand_quat = sciR.from_euler('y', np.pi/10).as_quat()

            obj_pq = tutil.pq_multi(tutil.pq_multi(obj_pq_base[0], jnp.concat([rand_pos, rand_quat])),
                            tutil.pq_multi(tutil.pq_inv(obj_pq_base[0]), obj_pq_base))

            pq_ee = tutil.pq_multi(obj_pq, obj_to_gripper_offset)
            goal_q = ik_func(np.zeros_like(init_q), pq_ee)

            shakey.set_q_pb(robot_pb_uid, goal_q, movable_obj_uids, ee_to_obj)
            ee_pqs_ik = shakey.get_ee_pq_pb(robot_pb_uid, True)
            valid = True
            if not (np.linalg.norm(pq_ee[...,:3] - ee_pqs_ik[...,:3]) < 1e-3 and np.linalg.norm(pq_ee[...,3:] - ee_pqs_ik[...,3:]) < 4e-3):
                valid = False

            p.performCollisionDetection()
            col_res = p.getContactPoints(robot_pb_uid)
            if len(col_res) != 0:
                valid = False

            # p.resetBasePositionAndOrientation(peg_uid, obj_pq[0,:3], obj_pq[0,3:])
            # p.resetBasePositionAndOrientation(hole_uid, obj_pq[1,:3], obj_pq[1,3:])
            
        return obj_pq, goal_q
        




    # init_q = sample_valid_q()

    # pq_ee = tutil.pq_multi(obj_pq, obj_to_gripper_offset)
    # goal_q = ik_func(np.zeros_like(init_q), pq_ee)

    obj_pq, goal_q = sample_valid_obj_pq()

    # shakey.set_q_pb(robot_pb_uid, goal_q, movable_obj_uids, ee_to_obj)

    # ee_pqs_ik = shakey.get_ee_pq_pb(robot_pb_uid, True)

    # assert np.linalg.norm(pq_ee[...,:3] - ee_pqs_ik[...,:3]) < 5e-3 and np.linalg.norm(pq_ee[...,3:] - ee_pqs_ik[...,3:]) < 2e-2

    # init_q =  goal_q - np.array([0,0,0,
    #                              -0.3, 0, -0.3, 0, 0, 0,
    #                              0, 0.2, 0.3, 0, 0, 0])

    # shakey.set_q_pb(robot_pb_uid, init_q, movable_obj_uids, ee_to_obj)

    p.performCollisionDetection()
    col_res = p.getContactPoints(robot_pb_uid)
    assert len(col_res) == 0

    p.resetBasePositionAndOrientation(peg_uid, (0,0,0), (0,0,0,1))
    p.resetBasePositionAndOrientation(hole_uid, (0,0,0), (0,0,0,1))
    scene_converter = SceneConverter()

    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        movable_obj_uids, 
        [],
        robot_uid=None
    )

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)
    
    shakey.set_q_pb(robot_pb_uid, goal_q, movable_obj_uids, ee_to_obj)

    # fixed_moving_idx_pair = [(np.array([11, 12, -1]), np.array([0, 5, 6, -2])),
    #                          (np.array([0]), np.array([5, 6, -2]))]
    
    fixed_moving_idx_pair = [(np.array([-1]), np.array([-2])),
                            # (np.array([11, 12, -1]), np.array([5, 6, -2])),
                            (np.array([-1]), np.array([5, 6])),
                            (np.array([-2]), np.array([11, 12])),
                             (np.array([0]), np.array([5, 6, 11, 12, -1, -2]))]
    
    # give colors to the robot and objects
    # Assign random colors to movable objects
    for obj in pybullet_scene.movable_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to fixed objects
    for obj in pybullet_scene.fixed_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to the robot's links
    colored_link_idx = [np.concat(fimi) for fimi in fixed_moving_idx_pair]
    colored_link_idx = np.concat(colored_link_idx, axis=0)
    colored_link_idx = np.unique(colored_link_idx)
    colored_link_idx = colored_link_idx[colored_link_idx >= 0]
    for i in colored_link_idx:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(robot_pb_uid, i, rgbaColor=color.tolist() + [1])
    
    # print pb ids
    print(f"peg_uid: {peg_uid}, hole_uid: {hole_uid}, robot_pb_uid: {robot_pb_uid}")

    # fixed_moving_idx_pair = None
    return (scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, 
            init_q[...,-shakey.num_act_joints:], goal_q[...,-shakey.num_act_joints:], ee_to_obj, fixed_moving_idx_pair, view_matrix)


def create_construction_site(models:mutil.Models, seed, ik_func, shakey:shakey_module.Shakey, robot_pb_uid, visualize=True):
    '''
    '''

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)
    

    # transform pybullet debug camera
    p.resetDebugVisualizerCamera(cameraDistance=5.39, cameraYaw=-324.40, cameraPitch=-56.94, cameraTargetPosition=[-0.80, -1.05, -1.17])
    p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)

    env_center = np.zeros(3)

    # create floor
    pybullet_data_path = pybullet_data.getDataPath()
    floor_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=os.path.join(pybullet_data_path, 'plane.obj'),
        meshScale=[1, 1, 1]
    )
    # floor_collision_shape_id = p.createCollisionShape(
    #     shapeType=p.GEOM_MESH,
    #     fileName=os.path.join(pybullet_data_path, 'plane.obj'),
    #     meshScale=[1, 1, 1]
    # )
    # create giant box for floor collision
    floor_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_BOX,
        halfExtents=[10, 10, 0.1],
        collisionFramePosition=[0, 0, -0.1],
    )
    floor_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=floor_collision_shape_id,
        baseVisualShapeIndex=floor_visual_shape_id,
        basePosition=env_center,
        baseOrientation=[0, 0, 0, 1]
    )
    # Remove texture from the floor and assign a random color
    floor_color = np.random.uniform(0, 1, 3)
    p.changeVisualShape(floor_uid, -1, rgbaColor=floor_color.tolist() + [1])

    # wall_obj_filenames = [
    #                     "assets/construction_site/raw/construction_site_obstacles_v4.obj",
    #                      "assets/construction_site/raw/construction_site_shelf_v4.obj",
    #                      ]
    # wall_obj_filenames = [
    #                     "assets/construction_site/raw/construction_site_obstacle_beam.obj",
    #                     "assets/construction_site/raw/construction_site_obstacle_cart.obj",
    #                      "assets/construction_site/raw/construction_site_shelf_v4.obj",
    #                      ]
    wall_obj_filenames = [
                        # "assets/construction_site/raw/construction_site_obstacle_cart_v5.obj",
                        "assets/construction_site/raw/construction_site_obstacle_cart_v6.obj",
                        "assets/construction_site/raw/construction_site_obstacle_pipe_v5.obj",
                        "assets/construction_site/raw/construction_site_obstacle_beam_v5.obj",
                         "assets/construction_site/raw/construction_site_shelf_v4.obj",
                         ]
    env_scale = 2.5
    wall_uids = []
    for env_itr, wall_obj_filename in enumerate(wall_obj_filenames):
        cvx_obj_filename = os.path.join(os.path.dirname(os.path.dirname(wall_obj_filename)), 'cvx/coacd', os.path.basename(wall_obj_filename))
        if not os.path.exists(cvx_obj_filename):
            cvx_obj_filename = wall_obj_filename
        wall_position = env_center
        wall_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=wall_obj_filename,
            meshScale=[env_scale, env_scale, env_scale]
        )
        wall_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=cvx_obj_filename,
            meshScale=[env_scale, env_scale, env_scale]
        )
        wall_uid = p.createMultiBody(
            baseMass=0.0,
            baseCollisionShapeIndex=wall_collision_shape_id,
            baseVisualShapeIndex=wall_visual_shape_id,
            basePosition=wall_position,
            baseOrientation=[0, 0, 0, 1]
        )
        wall_uids.append(wall_uid)


    wall2_obj_filename = "assets/construction_site/raw/construction_site_moving_wall.obj"
    wall2_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=wall2_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    wall2_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=wall2_obj_filename,
        meshScale=[env_scale, env_scale, 3*env_scale]
    )
    wall2_uid = p.createMultiBody(
        baseMass=0.0,
        baseVisualShapeIndex=wall2_visual_shape_id,
        baseCollisionShapeIndex=wall2_collision_shape_id,
        basePosition=wall_position,
        baseOrientation=[0, 0, 0, 1]
    )
    # Change the texture of wall2_uid
    # texture_path = "assets/texture/cave/cave_texture.png"
    # texture_id = p.loadTexture(texture_path)
    # p.changeVisualShape(wall2_uid, -1, textureUniqueId=texture_id)


    np_rng = np.random.default_rng(seed)
    for i in range(100):

        dish_obj_filename = 'assets/construction_site/raw/construction_site_moving_obj_v4.obj'
        dish_scale = env_scale*1.1
        # dish_scale = env_scale*0.8

        dish_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_obj_filename,
            meshScale=[dish_scale] * 3
        )
        dish_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_obj_filename,
            meshScale=[dish_scale] * 3
        )
        # dish_position = np.array([0.0, env_scale*1.7, env_scale*0.42]) # goal position
        # dish_position = np.array([0.0, env_scale*1.69, env_scale*0.35]) # goal position
        dish_position = np.array([0.0, env_scale*1.69, env_scale*0.40]) # goal position
        dish_orientation = tutil.aa2q(np.array([0, 0, 0]))
        dish_uid = p.createMultiBody(
            baseMass=0.1,
            baseCollisionShapeIndex=dish_collision_shape_id,
            baseVisualShapeIndex=dish_visual_shape_id,
            basePosition=dish_position,
            baseOrientation=dish_orientation
        )

        # p.setCollisionFilterPair(robot_pb_uid, dish_uid, 6, -1, 0)
        # p.setCollisionFilterPair(robot_pb_uid, dish_uid, 7, -1, 0)

        # calculate goal q
        # gripper_to_obj = np.array([0, 0.1, -0.01])*env_scale, np.array([0,0,0,1])
        gripper_to_obj = np.array([0, 0.14, -0.01])*env_scale, np.array([0,0,0,1])
        # gripper_to_obj = np.array([0, 0.16, -0.015])*env_scale, np.array([0,0,0,1])
        gripper_to_obj = p.multiplyTransforms([0,0,0], sciR.from_euler('x', 2*np.pi/3).as_quat(), *gripper_to_obj)
        gripper_to_obj = (np.array(gripper_to_obj[0]), np.array(gripper_to_obj[1]))

        ee_to_obj = tutil.pq_multi(*shakey.gripper_tip_offset_from_ee, *gripper_to_obj)
        
        ee_pq = tutil.pq_multi(
                    dish_position,
                    dish_orientation,
                    *tutil.pq_inv(
                        ee_to_obj[0],
                        ee_to_obj[1],
                    ),
                )
        
        # for i in range(100):
        ik_init_q = np.array([0, 0, -np.pi/2, 0, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, 0.0])
        # shakey.set_q_pb(robot_pb_uid, ik_init_q)
        goal_q = ik_func(ik_init_q, ee_pq)
        shakey.set_q_pb(robot_pb_uid, goal_q)

        # p.performCollisionDetection()
        # contacts = p.getContactPoints(robot_pb_uid)
        # if len(contacts) == 0 or np.all([ct[2]==dish_uid for ct in contacts]):
        #     break
            
        p.performCollisionDetection()
        contacts = p.getContactPoints(robot_pb_uid)
        if len(contacts) == 0 or np.all([ct[2]==dish_uid for ct in contacts]):
            break
        else:
            print('retry')
            p.removeBody(dish_uid)

    aux_pbids = [floor_uid, wall2_uid]

    moving_obj_uids = [dish_uid]

    p.resetBasePositionAndOrientation(dish_uid, (0,0,0), (0,0,0,1))
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        moving_obj_uids, 
        wall_uids,
        robot_uid=None
    )

    print(f'pybullet uids, moving_obj_uids: {moving_obj_uids}, fixed_uids: {wall_uids}, floor_uid: {floor_uid}, wall_uid: {wall2_uid}, robot_uid: {robot_pb_uid}')

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    ee_to_obj = jnp.concat(ee_to_obj)

    
    def valid_state_sample(np_rng, pos_bound_lower:np.ndarray, pos_bound_upper:np.ndarray):
        lower_bound = np.array([*pos_bound_lower.tolist(), -np.pi, 
                                -np.pi/2, -2.1, -1.8973, -2.0718, 0.2, -0.5])
        upper_bound = np.array([*pos_bound_upper.tolist(), 0, 
                                np.pi/2, -0.3, 1.8973, 0.0698, 2.5, 0.5])
        check_uids = aux_pbids + wall_uids
        while True:
            init_q = np_rng.uniform(lower_bound, upper_bound)
            shakey.set_q_pb(robot_pb_uid, init_q, moving_obj_uids, ee_to_obj)
            p.performCollisionDetection()
            valid=True
            for env_uid in check_uids:
                threshold = 0.010 if env_uid == floor_uid else 0.150
                col_res1 = p.getClosestPoints(robot_pb_uid, env_uid, distance=0.1)
                for cr in col_res1:
                    if cr[8] < threshold:
                        valid = False
                col_res2 = p.getClosestPoints(dish_uid, env_uid, distance=0.1)
                for cr in col_res2:
                    if cr[8] < threshold:
                        valid = False
            # if valid:
            #     break
            col_res1 = p.getContactPoints(robot_pb_uid)
            col_res2 = p.getContactPoints(dish_uid)
            if valid and len(col_res1) == 0 and len(col_res2) == 0:
                break
        return init_q

    init_q = valid_state_sample(np_rng, np.array([-0.8, -5.0]), np.array([0.8, -3.5]))

    # movable_canonical_oriCORN = models.mesh_aligned_canonical_obj[models.asset_path_util.get_obj_id(dish_obj_filename)]
    # movable_canonical_oriCORN = movable_canonical_oriCORN.apply_scale(dish_scale)

    shakey.set_q_pb(robot_pb_uid, init_q, moving_obj_uids, ee_to_obj)

    plane_params = np.array([[0,0,1,0.00],
                                     [1,0,0,-0.4*env_scale],
                                     [-1,0,0,-0.4*env_scale],
                                     [0,1,0,-2.*env_scale],
                                     [0,-1,0,-2.*env_scale],
                                     ])
    
    # give colors to walls, fixed objects, and moving objects
    # Assign random colors to movable objects
    for obj in pybullet_scene.movable_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to fixed objects
    for obj in pybullet_scene.fixed_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to auxiliary PyBullet IDs
    for aux_pbid in aux_pbids:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(aux_pbid, -1, rgbaColor=color.tolist() + [1])
    
    # fixed_moving_idx_pair = [(np.array([-1]), np.array([0, 1, 2, 3]))]
    fixed_moving_idx_pair = None

    # shakey.set_q_pb(robot_pb_uid, init_q, moving_obj_uids, ee_to_obj)
    # shakey.set_q_pb(robot_pb_uid, goal_q, moving_obj_uids, ee_to_obj)

    
    return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q, ee_to_obj, aux_pbids, plane_params, fixed_moving_idx_pair




def create_construction_site_hard(models:mutil.Models, seed, ik_func, shakey:shakey_module.Shakey, robot_pb_uid, visualize=True):
    '''
    '''

    np_rng = np.random.default_rng(seed)

    if p.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            p.connect(p.GUI)
        else:
            p.connect(p.DIRECT)
    

    # transform pybullet debug camera
    
    # Example parameters from debug visualizer:
    cameraDistance        = 5.79
    cameraYaw             = -348.40
    cameraPitch           = -32.94
    cameraTargetPosition  = [-0.80, -1.05, -1.17]
    p.resetDebugVisualizerCamera(cameraDistance=cameraDistance, cameraYaw=cameraYaw, cameraPitch=cameraPitch, cameraTargetPosition=cameraTargetPosition)

    view_matrix = p.computeViewMatrixFromYawPitchRoll(
        cameraTargetPosition, 
        cameraDistance, 
        cameraYaw, 
        cameraPitch, 
        roll=0,
        upAxisIndex=2
    )



    # Calculate view matrix.
    # view_matrix = compute_view_matrix(
    #     distance=cameraDistance,
    #     yaw=cameraYaw,
    #     pitch=cameraPitch,
    #     target=cameraTargetPosition
    # )

    # # Define projection parameters.
    # fov     = 60        # degrees
    # aspect  = 420 / 238 # aspect ratio: width/height
    # near    = 0.1
    # far     = 100.0

    # proj_matrix = compute_projection_matrix(fov, aspect, near, far)

    p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)

    env_center = np.zeros(3)
    ceil_offset = 0.20

    # create floor
    pybullet_data_path = pybullet_data.getDataPath()
    floor_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=os.path.join(pybullet_data_path, 'plane.obj'),
        meshScale=[1, 1, 1]
    )
    # create giant box for floor collision
    floor_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_BOX,
        halfExtents=[10, 10, 0.1],
        collisionFramePosition=[0, 0, -0.1],
    )
    floor_uid = p.createMultiBody(
        baseMass=0.0,
        baseCollisionShapeIndex=floor_collision_shape_id,
        baseVisualShapeIndex=floor_visual_shape_id,
        basePosition=env_center,
        baseOrientation=[0, 0, 0, 1]
    )
    # Remove texture from the floor and assign a random color
    floor_color = np.random.uniform(0, 1, 3)
    p.changeVisualShape(floor_uid, -1, rgbaColor=floor_color.tolist() + [1])

    obstacle1_meshs = [
        "assets/construction_site/raw/construction_site_obstacle1_roadcone.obj",
        "assets/construction_site/raw/construction_site_obstacle1_beam.obj",
    ]
    obstacle2_meshs = [
        "assets/construction_site/raw/construction_site_obstacle2_drum.obj",
        "assets/construction_site/raw/construction_site_obstacle2_roadblock.obj",
        "assets/construction_site/raw/construction_site_obstacle_cart_v6.obj",
    ]
    obstacle3_meshs = [
        "assets/construction_site/raw/construction_site_obstacle_beam_v5.obj",
        "assets/construction_site/raw/construction_site_obstacle3_roadblock.obj",
    ]

    # random select obstacle123 from meshes
    obstacle1_mesh = np_rng.choice(obstacle1_meshs)
    obstacle2_mesh = np_rng.choice(obstacle2_meshs)
    obstacle3_mesh = np_rng.choice(obstacle3_meshs)

    # pipe_random_offset = np_rng.uniform([0,-0.15,-0.10], [0,0.15,0.20])
    # obstacle1_offset = np_rng.uniform([-0.3, -0.05, 0], [0.3, 0.05, 0.])
    # obstacle2_offset = np_rng.uniform([-0.1, -0.15, 0], [0.1, 0.15, 0.030])
    # obstacle3_offset = np_rng.uniform([-0.05, -0.15, 0], [0.05, 0.15, 0.030])
    pipe_random_offset = np_rng.uniform([0,-0.1,-0.03], [0,0.1,0.10])
    obstacle1_offset = np_rng.uniform([-0.7, -0.05, 0], [0.7, 0.05, 0.])
    obstacle2_offset = np_rng.uniform([-0.1, -0.1, 0], [0.1, 0.1, 0.030])
    obstacle3_offset = np_rng.uniform([-0.05, -0.1, 0], [0.05, 0.1, 0.030])


    wall_obj_filenames = [
        "assets/construction_site/raw/construction_site_obstacle_pipe_v5.obj",
        obstacle1_mesh,
        obstacle2_mesh,
        obstacle3_mesh,
    ]

    obstacle_offsets = [
        pipe_random_offset,
        obstacle1_offset,
        obstacle2_offset,
        obstacle3_offset,]
    
    flip_mask = np_rng.uniform(0.0, 1.0) > 0.5

    env_scale = 2.5
    wall_uids = []
    for env_itr, (wall_obj_filename, obstacle_offset) in enumerate(zip(wall_obj_filenames, obstacle_offsets)):
        cvx_obj_filename = os.path.join(os.path.dirname(os.path.dirname(wall_obj_filename)), 'cvx/coacd', os.path.basename(wall_obj_filename))
        if not os.path.exists(cvx_obj_filename):
            cvx_obj_filename = wall_obj_filename
        wall_position = env_center + obstacle_offset
        if flip_mask:
            wall_position[1] -= 1.0
        wall_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=wall_obj_filename,
            meshScale=[env_scale, env_scale, env_scale]
        )
        wall_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=cvx_obj_filename,
            meshScale=[env_scale, env_scale, env_scale]
        )
        wall_uid = p.createMultiBody(
            baseMass=0.0,
            baseCollisionShapeIndex=wall_collision_shape_id,
            baseVisualShapeIndex=wall_visual_shape_id,
            basePosition=wall_position,
            baseOrientation=sciR.from_euler('z', np.pi).as_quat() if flip_mask else [0, 0, 0, 1]
        )
        wall_uids.append(wall_uid)


    # wall2_obj_filename = "assets/construction_site/raw/construction_site_wall.obj"
    wall2_obj_filename = "assets/construction_site/raw/construction_site_cave_wall.obj"
    wall2_visual_shape_id = p.createVisualShape(
        shapeType=p.GEOM_MESH,
        fileName=wall2_obj_filename,
        meshScale=[env_scale, env_scale, env_scale]
    )
    wall2_collision_shape_id = p.createCollisionShape(
        shapeType=p.GEOM_MESH,
        fileName=wall2_obj_filename,
        # meshScale=[env_scale, env_scale, 3*env_scale]
        meshScale=[env_scale, env_scale, env_scale]
    )
    wall2_uid = p.createMultiBody(
        baseMass=0.0,
        baseVisualShapeIndex=wall2_visual_shape_id,
        baseCollisionShapeIndex=wall2_collision_shape_id,
        basePosition=[0,0.09*env_scale,ceil_offset],
        baseOrientation=[0, 0, 0, 1]
    )
    # Change the texture of wall2_uid
    # texture_path = "assets/texture/cave/cave_texture.png"
    # texture_id = p.loadTexture(texture_path)
    # p.changeVisualShape(wall2_uid, -1, textureUniqueId=texture_id)


    for i in range(100):

        # dish_obj_filename = 'assets/construction_site/raw/construction_site_moving_obj_v4.obj'
        dish_obj_filename = 'assets/construction_site/raw/construction_site_pickaxe_v3.obj'
        dish_scale = env_scale*1.5
        # dish_scale = env_scale*1.8

        dish_visual_shape_id = p.createVisualShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_obj_filename,
            meshScale=[dish_scale] * 3
        )
        dish_collision_shape_id = p.createCollisionShape(
            shapeType=p.GEOM_MESH,
            fileName=dish_obj_filename,
            meshScale=[dish_scale] * 3
        )
        dish_position = np.array([0.0, env_scale*1.69, env_scale*0.40]) # goal position
        # dish_position = np.array([0.0, env_scale*1.6, env_scale*0.40]) # goal position
        dish_orientation = tutil.aa2q(np.array([0, 0, 0]))
        dish_uid = p.createMultiBody(
            baseMass=0.1,
            baseCollisionShapeIndex=dish_collision_shape_id,
            baseVisualShapeIndex=dish_visual_shape_id,
            basePosition=dish_position,
            baseOrientation=dish_orientation
        )


        # calculate goal q
        gripper_to_obj = np.array([0, 0.03, -0.01])*env_scale, np.array([0,0,0,1])
        # gripper_to_obj = np.array([0, 0.00, -0.01])*env_scale, np.array([0,0,0,1])
        gripper_to_obj = p.multiplyTransforms([0,0,0], sciR.from_euler('x', 2*np.pi/3).as_quat(), *gripper_to_obj)
        gripper_to_obj = (np.array(gripper_to_obj[0]), np.array(gripper_to_obj[1]))

        ee_to_obj = tutil.pq_multi(*shakey.gripper_tip_offset_from_ee, *gripper_to_obj)
        
        ee_pq = tutil.pq_multi(
                    dish_position,
                    dish_orientation,
                    *tutil.pq_inv(
                        ee_to_obj[0],
                        ee_to_obj[1],
                    ),
                )
        
        # for i in range(100):
        ik_init_q = np.array([0, 0, -np.pi/2, 0, -np.pi/2, -np.pi/2, -np.pi/2, np.pi/2, 0.0])
        # shakey.set_q_pb(robot_pb_uid, ik_init_q)
        goal_q = ik_func(ik_init_q, ee_pq)
        shakey.set_q_pb(robot_pb_uid, goal_q)

        # p.performCollisionDetection()
        # contacts = p.getContactPoints(robot_pb_uid)
        # if len(contacts) == 0 or np.all([ct[2]==dish_uid for ct in contacts]):
        #     break
            
        p.performCollisionDetection()
        contacts = p.getContactPoints(robot_pb_uid)
        if len(contacts) == 0 or np.all([ct[2]==dish_uid for ct in contacts]):
            break
        else:
            print('retry')
            p.removeBody(dish_uid)

    aux_pbids = [floor_uid, wall2_uid]

    moving_obj_uids = [dish_uid]

    p.resetBasePositionAndOrientation(dish_uid, (0,0,0), (0,0,0,1))
    scene_converter = SceneConverter()
    pybullet_scene = scene_converter.construct_scene_from_pybullet(
        moving_obj_uids, 
        wall_uids,
        robot_uid=None
    )

    print(f'pybullet uids, moving_obj_uids: {moving_obj_uids}, fixed_uids: {wall_uids}, floor_uid: {floor_uid}, wall_uid: {wall2_uid}, robot_uid: {robot_pb_uid}')

    # Convert the scene to oriCORNs
    scene_converter.convert_scene_to_oriCORNs(models)

    ee_to_obj = jnp.concat(ee_to_obj)

    
    def valid_state_sample(np_rng, pos_bound_lower:np.ndarray, pos_bound_upper:np.ndarray):
        lower_bound = np.array([*pos_bound_lower.tolist(), -np.pi, 
                                -np.pi/2, -2.1, -1.8973, -2.0718, 0.2, -0.5])
        upper_bound = np.array([*pos_bound_upper.tolist(), 0, 
                                np.pi/2, -0.3, 1.8973, 0.0698, 2.5, 0.5])
        check_uids = aux_pbids + wall_uids
        while True:
            init_q = np_rng.uniform(lower_bound, upper_bound)
            shakey.set_q_pb(robot_pb_uid, init_q, moving_obj_uids, ee_to_obj)
            p.performCollisionDetection()
            valid=True
            for env_uid in check_uids:
                threshold = 0.010 if env_uid == floor_uid else 0.150
                col_res1 = p.getClosestPoints(robot_pb_uid, env_uid, distance=0.1)
                for cr in col_res1:
                    if cr[8] < threshold:
                        valid = False
                col_res2 = p.getClosestPoints(dish_uid, env_uid, distance=0.1)
                for cr in col_res2:
                    if cr[8] < threshold:
                        valid = False
            # if valid:
            #     break
            col_res1 = p.getContactPoints(robot_pb_uid)
            col_res2 = p.getContactPoints(dish_uid)
            if valid and len(col_res1) == 0 and len(col_res2) == 0:
                break
        return init_q

    # init_q = valid_state_sample(np_rng, np.array([-0.8, -5.0]), np.array([0.8, -3.5]))
    init_q = valid_state_sample(np_rng, np.array([-0.8, -5.5]), np.array([0.8, -4]))

    # movable_canonical_oriCORN = models.mesh_aligned_canonical_obj[models.asset_path_util.get_obj_id(dish_obj_filename)]
    # movable_canonical_oriCORN = movable_canonical_oriCORN.apply_scale(dish_scale)

    shakey.set_q_pb(robot_pb_uid, init_q, moving_obj_uids, ee_to_obj)

    inclined_pnt = np.array((0, 0.32, 0.75+ceil_offset/env_scale))
    inclined_dist = np.sum(inclined_pnt*np.array([0,1/np.sqrt(2),1/np.sqrt(2)]))
    inclined_dist -= 0.020 # thickness

    plane_params = np.array([[0,0,1,0.0],
                            [1,0,0,-0.4*env_scale],
                            [-1,0,0,-0.4*env_scale],
                            [0,0,-1,-0.80*env_scale-ceil_offset],
                            [-1/np.sqrt(2),0,-1/np.sqrt(2),-inclined_dist*env_scale],
                            [1/np.sqrt(2),0,-1/np.sqrt(2),-inclined_dist*env_scale],
                        #  [0,-1,0,-2.*env_scale],
                                     ])
    
    # give colors to walls, fixed objects, and moving objects
    # Assign random colors to movable objects
    for obj in pybullet_scene.movable_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to fixed objects
    for obj in pybullet_scene.fixed_objects:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(obj.pb_uid, -1, rgbaColor=color.tolist() + [1])

    # Assign random colors to auxiliary PyBullet IDs
    for aux_pbid in aux_pbids:
        color = np.random.uniform(0, 1, 3)
        p.changeVisualShape(aux_pbid, -1, rgbaColor=color.tolist() + [1])
    
    # fixed_moving_idx_pair = [(np.array([-1]), np.array([0, 1, 2, 3]))]
    fixed_moving_idx_pair = None

    # shakey.set_q_pb(robot_pb_uid, init_q, moving_obj_uids, ee_to_obj)
    # shakey.set_q_pb(robot_pb_uid, goal_q, moving_obj_uids, ee_to_obj)

    
    return scene_converter.movable_oriCORNs, scene_converter.fixed_oriCORNs, pybullet_scene, init_q, goal_q, ee_to_obj, aux_pbids, plane_params, fixed_moving_idx_pair, view_matrix






def visualize_shape_locality():
    import open3d as o3d

    mesh_paths = [
        (
            "/home/rogga/research/efficient_planning/dataset/ur5/meshes/rg6_gripper/modified/ee_rg6_gripper.obj",
            "/home/rogga/research/efficient_planning/dataset/objaverse_v1/modified/98651ed9be254682afd785489f0f27f8.obj",
        ),
        (
            # "/home/rogga/research/efficient_planning/dataset/GoogleScannedObjects/modified/ASICS_GEL1140V_WhiteRoyalSilver.obj",
            "/home/rogga/research/efficient_planning/dataset/objaverse_v1/modified/4914f3e9e79c4117bdad44c2ab8f12ae.obj",
            "/home/rogga/research/efficient_planning/dataset/GoogleScannedObjects/modified/Calphalon_Kitchen_Essentials_12_Cast_Iron_Fry_Pan_Black.obj",
        )
    ]
    scales = (
        (
            1, 0.1,
        ),
        (
            0.1, 1,
        ),
    )
    translations = (
        (
            (-0.1, 0, -0.2),
            (0, 0, 0),
        ),
        (
            (-0.4,-0.29, 0.15),
            (0, 0, 0),
        ),
    )
    rotations = (
        (
            (0, 0, 0),
            (0, -0.3, 1.57),
        ),
        (
            (0, 0, 0),
            (225 - 1.6, 0, 0),
        ),
    )
    colors = (
        # (1, 0.82, 0.4),
        (0.937, 0.278, 0.435),
        (0.067, 0.541, 0.698),
    )
    for i, (meshes, _translations, _rotations, _scales) in enumerate(zip(mesh_paths, translations, rotations, scales)):
        if i in []:
            continue
    coordinate = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
    geometries = [] # coordinate
    for mesh, translation, rotation, scale, color in zip(meshes, _translations, _rotations, _scales, colors):
        # Load the mesh
        mesh = o3d.io.read_triangle_mesh(mesh)
        # Scale the mesh
        mesh = mesh.scale(scale, center=(0,0,0))
        # Translate the mesh
        mesh.translate(translation)
        # Rotate the mesh
        mesh.rotate(mesh.get_rotation_matrix_from_xyz(rotation), center=(0,0,0))
        # Compute the vertex normals
        mesh.compute_vertex_normals()
        # Add the mesh to the list of geometries
        mesh.paint_uniform_color(np.array(color))
        geometries.append(mesh)

    o3d.visualization.draw_geometries(geometries)
    
def visualize_time_locality():
    import open3d as o3d
    import numpy as np
    import copy

    def rotation_matrix_from_vectors(vec1, vec2):
        """
        Returns the rotation matrix that aligns vec1 to vec2.
        
        Parameters:
        vec1: A 3d "source" vector
        vec2: A 3d "destination" vector
        
        The method uses the cross product and dot product to form the rotation.
        """
        a = vec1 / np.linalg.norm(vec1)
        b = vec2 / np.linalg.norm(vec2)
        v = np.cross(a, b)
        c = np.dot(a, b)
        s = np.linalg.norm(v)
        if s < 1e-8:
            # Vectors are almost parallel. Handle the opposite direction case.
            if c < 0:
                # 180 degree rotation around any axis perpendicular to vec1.
                axis = np.array([1, 0, 0])
                if np.abs(a[0]) > 0.9:
                    axis = np.array([0, 1, 0])
                v = np.cross(a, axis)
                v = v / np.linalg.norm(v)
                K = np.array([[0, -v[2], v[1]],
                            [v[2], 0, -v[0]],
                            [-v[1], v[0], 0]])
                return np.eye(3) + 2 * K @ K
            else:
                return np.eye(3)
        # Skew-symmetric cross-product matrix of v
        K = np.array([[0, -v[2], v[1]],
                    [v[2], 0, -v[0]],
                    [-v[1], v[0], 0]])
        R = np.eye(3) + K + K @ K * ((1 - c) / (s ** 2))
        return R

    def visualize_trajectory_and_place_mesh(trajectory_points, mesh):
        """
        Visualizes a trajectory as a series of connected linear segments and places copies
        of a given mesh at each trajectory point. Each mesh copy is rotated so that its 
        forward direction (assumed to be +x) aligns with the trajectory's tangent.
        
        Parameters:
            trajectory_points (list or np.ndarray):
                An ordered collection of 3D points (each as [x, y, z]) that define the trajectory.
            mesh (o3d.geometry.TriangleMesh):
                An Open3D mesh to be copied, rotated, and placed at each trajectory point.
        """
        # Convert trajectory_points to a NumPy array for consistency.
        traj_points = np.array(trajectory_points)

        # Create a LineSet to show the trajectory.
        lines = [[i, i + 1] for i in range(len(traj_points) - 1)]
        line_set = o3d.geometry.LineSet(
            points=o3d.utility.Vector3dVector(traj_points),
            lines=o3d.utility.Vector2iVector(lines)
        )
        # Color the trajectory line red.
        red_color = [1.0, 0.0, 0.0]
        line_set.colors = o3d.utility.Vector3dVector([red_color for _ in lines])
        
        # List to collect mesh copies.
        mesh_copies = []
        # Define the reference forward vector of the mesh (assumed +x direction).
        ref_vector = np.array([1, 0, 0])
        
        for i, pt in enumerate(traj_points):
            # Compute the tangent direction along the trajectory.
            if i < len(traj_points) - 1 and i > 0:
                tangent = traj_points[i + 1] - traj_points[i - 1]
            elif i == len(traj_points) - 1:
                tangent = pt - traj_points[i - 1]
            else:
                tangent = traj_points[i + 1] - pt
            tangent_norm = np.linalg.norm(tangent)
            if tangent_norm < 1e-6:
                tangent_norm = 1.0
            tangent = tangent / tangent_norm
            
            # Compute the rotation matrix that aligns the mesh's forward direction to the tangent.
            R = rotation_matrix_from_vectors(ref_vector, tangent)
            
            # Create a deep copy of the mesh.
            mesh_copy = copy.deepcopy(mesh)
            # Get the mesh center.
            mesh_center = mesh_copy.get_center()
            # Rotate the mesh about its center.
            mesh_copy.rotate(R, center=mesh_center)
            # Translate the mesh so that its center is at the trajectory point.
            mesh_copy.translate(pt - mesh_center, relative=True)
            mesh_copies.append(mesh_copy)

        fixed_mesh = o3d.geometry.TriangleMesh.create_box(width=2, height=0.2, depth=0.2)
        fixed_mesh.compute_vertex_normals()
        fixed_mesh.paint_uniform_color([0.067, 0.541, 0.698])
        fixed_mesh.translate((-1, -0.3, 0))

        geometries = mesh_copies + [line_set] + [fixed_mesh]
        o3d.visualization.draw_geometries(geometries)

    # Define a sample list of trajectory points.
    x_values = np.linspace(-1.5, 1.5, 7)
    # y_values = np.sin(x_values * np.pi * 3) * 0.5 + x_values ** 2
    # y_values = x_values ** 2
    # y_values = -np.cos(x_values) + 1

    # y_values = -np.cos(x_values) * 3 + 3

    y_values = -np.cos(x_values) + 1
    threshold = 0.3
    y_values[y_values > threshold] = threshold

    # y_values = -np.cos(x_values) * 2 + 2

    z_values = np.zeros_like(x_values)
    trajectory_points = np.array([x_values, y_values, z_values]).T
    print(trajectory_points)

    # trajectory_points = [
    #     [0.0, 0.0, 0.0],
    #     [1.0, 0.0, 0.0],
    #     [1.0, 1.0, 0.0],
    #     [2.0, 1.0, 0.0]
    # ]
    
    # Create a sample mesh, here a simple box is used.
    mesh = o3d.io.read_triangle_mesh("/home/rogga/research/efficient_planning/dataset/objaverse_v1/modified/4914f3e9e79c4117bdad44c2ab8f12ae.obj")
    mesh = mesh.scale(0.1, center=(0, 0, 0))
    mesh.paint_uniform_color([0.937, 0.278, 0.435])  # Example color
    # Rotate the mesh (example rotation).
    mesh.rotate(mesh.get_rotation_matrix_from_xyz((1.57, 0, 0)), center=(0, 0, 0))
    mesh.compute_vertex_normals()

    # Call the function to visualize the trajectory with the mesh copies placed along it.
    visualize_trajectory_and_place_mesh(trajectory_points, mesh)


# if __name__ == '__main__':
    # create_table_sampled_scene(visualize=True)
    # _, _, pybullet_scene = create_table_sampled_scene(asset_base_dir='/home/dongwon/research/object_set', visualize=False)
    # pybullet_scene.reconstruct_scene_in_pybullet(visualize=True)
    # # _, _, pybullet_scene = create_table_sampled_scene(asset_base_dir='/home/dongwon/research/object_set', visualize=False)
    # # pybullet_scene.reconstruct_scene_in_pybullet(visualize=True)
    # asset_base_dir = '/home/rogga/research/efficient_planning/dataset'
    # _, _, pybullet_scene = create_pick_place_scene(asset_base_dir=asset_base_dir, visualize=True)
    # pybullet_scene.reconstruct_scene_in_pybullet(visualize=True)