from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
import optax
from typing import Tuple
from functools import partial
from scipy.spatial.transform import Rotation as sciR
import pybullet as p
import os
from typing import List
import jax.debug as jdb

if __name__ == '__main__':
    from pathlib import Path
    BASE_PATH = Path(__file__).parent
    import sys
    if str(BASE_PATH.parent) not in sys.path:
        sys.path.append(str(BASE_PATH.parent))

import util.latent_obj_util as loutil
import util.model_util as mutil
import util.bp_matching_util as bputil
import util.transform_util as tutil

def print_callback(inputs):
    print(inputs)

def line_search(x, cost_func, grad, intp_no=10):
    grad_intp = jnp.linspace(1e-5, 1, intp_no)[...,None]*grad[...,None,:]
    x_intp = x[...,None,:] + grad_intp
    # cost_intp = jax.vmap(lambda x: cost_func(x, *operands), -2)(x_intp)[0]
    cost_intp = jax.vmap(cost_func, 0)(x_intp)[0]
    min_idx = jnp.argmin(cost_intp, axis=0)
    return jnp.take_along_axis(x_intp, min_idx[...,None,None], axis=-2).squeeze(-2)


def joint_limit_cost_func(q_t, q_l, q_u, eta):
    return jnp.where(
        q_t < q_l, q_l - q_t + 0.5 * eta,
        jnp.where(
            (q_l <= q_t) & (q_t < q_l + eta),
            0.5 / eta * (q_l - q_t + eta)**2,
            jnp.where(
                q_t > q_u, q_t - q_u + 0.5 * eta,
                jnp.where(
                    (q_u - eta < q_t) & (q_t <= q_u),
                    0.5 / eta * (q_t - q_u + eta)**2,
                    0.0
                )
            )
        )
    )

@struct.dataclass
class Shakey:
    link_canonical_oriCORN:loutil.LatentObjects = None
    canonical_obj_idx:np.array = None
    act_joint_mask:np.array = None
    joint_to_link_pq:jnp.array = None
    link_to_joint_pq:jnp.array = None
    mesh_scale:jnp.array = None
    joint_axis:jnp.array = None
    parent_link_idx:np.array = None
    link_to_mesh_pq:jnp.array = None
    rot_configs:dict = None
    urdf_dir:str = None
    models:mutil.Models = None
    arm_joint_idx:np.array = None
    robot_type:str = None
    link_idx_to_pb_idx:np.array = None
    collision_check_link_idx:np.array = None
    ee_idx:np.array = None
    robot_height:float = 0.0
    gripper_tip_offset_from_ee:Tuple[np.array] = None
    q_lower_bound:np.array = None
    q_upper_bound:np.array = None
    robot_fixed_pqc:np.array = None

    def FK(self, q, oriCORN_out=True, link_pq_out=False):
        numjoints = len(self.act_joint_mask)
        base_pose = jnp.array([0., 0., 0., 0., 0., 0., 1.])

        # handle base pose
        if q.shape[-1] == self.num_act_joints + 3:
            base_se2 = q[..., :3]  # x, y, theta
            base_pose = jnp.concatenate([
                base_se2[..., :2],
                self.robot_height*jnp.ones_like(base_se2[..., :1]),
                tutil.aa2q(base_se2[..., 2:3] * jnp.array([0, 0, 1.])),
            ], axis=-1)
            q = q[..., 3:]
        elif self.robot_fixed_pqc is not None:
            base_pose = self.robot_fixed_pqc
            # base_pose = jnp.concatenate([
            #     self.robot_fixed_SE2[..., :2],
            #     self.robot_height*jnp.ones_like(base_se2[..., :1]),
            #     tutil.aa2q(self.robot_fixed_SE2[..., 2:3] * jnp.array([0, 0, 1.])),
            # ], axis=-1)
        else:
            assert q.shape[-1] == self.num_act_joints

        # handle active joint mask
        if q.shape[-1] == self.num_act_joints:
            entire_q = jnp.zeros(q.shape[:-1] + (self.numjoints,))
            entire_q = entire_q.at[..., self.arm_joint_idx].set(q)
        else:
            entire_q = q

        # Initialize an array to hold link poses.
        # We'll store all link poses plus the base at the end.
        # The final shape of link_pq_array will be [..., numjoints+1, 7]
        leading_shape = q.shape[:-1]
        link_pq_array = jnp.zeros(leading_shape + (numjoints + 1, 7))
        joint_pq_array = jnp.zeros(leading_shape + (numjoints, 7))

        # Set the base pose at index numjoints (the last one)
        # This matches the original code which placed the base pose at link_pq_list[-1]
        link_pq_array = link_pq_array.at[..., numjoints, :].set(base_pose)
        parent_link_idx = jnp.array(self.parent_link_idx)
        act_joint_mask = jnp.array(self.act_joint_mask)
        joint_axis = jnp.array(self.joint_axis)
        def body_fun(carry, i):
            entire_q, link_pq_array, joint_pq_array = carry
            parent_idx = parent_link_idx[i]

            parent_link_pq = link_pq_array[..., parent_idx, :]
            joint_pq = tutil.pq_multi(parent_link_pq, self.link_to_joint_pq[i])

            # Condition on whether this joint is active
            joint_pq_next = jax.lax.cond(
                act_joint_mask[i],
                lambda _: tutil.pq_multi(
                    joint_pq,
                    jnp.concatenate([
                        jnp.zeros_like(entire_q[..., :3]),
                        tutil.aa2q(joint_axis[i] * entire_q[..., i, None])
                    ], axis=-1)
                ),
                lambda _: joint_pq,
                operand=None
            )

            link_pq = tutil.pq_multi(joint_pq_next, self.joint_to_link_pq[i])
            # Store the link pose
            link_pq_array = link_pq_array.at[..., i, :].set(link_pq)
            joint_pq_array = joint_pq_array.at[..., i, :].set(joint_pq_next)

            return (entire_q, link_pq_array, joint_pq_array), (0, 0)

        # Run the scan over all joints
        (entire_q, link_pq_array, joint_pq_array), (joint_pq_list, link_pq_list) = jax.lax.scan(
            body_fun,
            (entire_q, link_pq_array, joint_pq_array),
            jnp.arange(numjoints),

        )
        # jax.lax.fori_loop(0, numjoints, body_fun, (entire_q, link_pq_array, 0))

        # After the scan, link_pq_array contains:
        #   link_pq_array[..., :numjoints, :] = the poses for each joint link
        #   link_pq_array[..., numjoints, :] = the base pose

        # The original code did: link_pq_list = jnp.stack(link_pq_list[:-1], axis=-2)
        # We had placed the base pose as the last element, so we just take the first numjoints elements.
        link_pq_final = link_pq_array[..., :numjoints, :]

        
        # Compute mesh poses
        mesh_pqcs = tutil.pq_multi(link_pq_final, self.link_to_mesh_pq)
        
        if link_pq_out:
            return mesh_pqcs, link_pq_final

        if oriCORN_out:
            oriCORNs_tf = self.link_canonical_oriCORN.apply_pq_z(mesh_pqcs, self.rot_configs)
            return oriCORNs_tf, mesh_pqcs
        else:
            return mesh_pqcs


    
    def IK_cost_func(self, q, goal_pq, self_collision=False, damped=False, fix_base=False, scene_oriCORNs=None,
                     grasped_oriCORN_wrt_ee:Optional[loutil.LatentObjects]=None,
                     place_oriCORN:Optional[loutil.LatentObjects]=None):
        '''
        goal_pq is based on grasp target
        '''
        if fix_base and q.shape[-1] != self.num_act_joints:
            q = q.at[...,:3].set(jax.lax.stop_gradient(q[...,:3]))
        
        if goal_pq is not None:
            if isinstance(goal_pq, Tuple):
                goal_pq = jnp.concatenate(goal_pq, axis=-1)
            link_pqcs = self.FK(q, oriCORN_out=False)
            ee_pq = link_pqcs[...,self.ee_idx,:]
            pq_dif = tutil.pq_multi(tutil.pq_inv(ee_pq), goal_pq)
            pos_dif = pq_dif[...,:3]
            ang_dif = tutil.q2R(pq_dif[...,3:]) - jnp.eye(3)
            pos_loss = jnp.sum(pos_dif**2)
            quat_loss = jnp.sum(ang_dif**2)
        else:
            assert grasped_oriCORN_wrt_ee is not None
            assert place_oriCORN is not None
            link_pqcs = self.FK(q, oriCORN_out=False)
            ee_pq = link_pqcs[...,self.ee_idx,:]
            grasped_oriCORNs_world = grasped_oriCORN_wrt_ee.apply_pq_z(ee_pq, self.rot_configs)

            pw_fps_dif = jnp.sum((grasped_oriCORNs_world.fps_tf[...,None,:] - place_oriCORN.fps_tf[...,None,:,:])**2, axis=(-1))
            pw_z_dif = jnp.sum((grasped_oriCORNs_world.z[...,None,:,:] - place_oriCORN.z[...,None,:,:,:])**2, axis=(-1,-2))
            pw_loss = 1000*pw_fps_dif + pw_z_dif
            pos_loss = jnp.sum(jnp.min(pw_loss, axis=-1) + jnp.min(pw_loss, axis=-2))
            quat_loss = 0.0

            # fps_matching_pair, (chloss, ch_fps, ch_z) = bputil.fps_matching(grasped_oriCORNs_world, place_obj_oriCORNs, dc_pos_loss_coef=40, dif_type='hg')

            # pos_loss = jnp.sum((grasped_oriCORNs_world.fps_tf - place_oriCORN.fps_tf)**2)
            # quat_loss = jnp.sum((grasped_oriCORNs_world.z - place_oriCORN.z)**2)

        self_col_loss = 0.0
        if self_collision:
            self_col_cost = self.models.apply('self_collision_checker', q[...,3:]).squeeze(-1)
            self_col_loss = jnp.maximum(self_col_cost, 0).sum()

        scene_col_cost = jnp.array(0.0)
        if scene_oriCORNs is not None:
            # col_logit = models.apply('col_decoder', shakey.link_canonical_oriCORN, scene_oriCORNs, link_pqcs, None, reduce_k=80, merge=True)[0].squeeze(-1)
            # col_cost = col_logit
            scene_pqc = jnp.zeros_like(scene_oriCORNs.pos)
            scene_pqc = jnp.concat([scene_pqc, tutil.aa2q(scene_pqc)], axis=-1)
            scene_col_cost = mutil.pairwise_collision_check(self.models, self.link_canonical_oriCORN, scene_oriCORNs, link_pqcs, scene_pqc)
            scene_col_cost = jnp.maximum(scene_col_cost+0.5, 0).sum()
            # jdb.callback(print_callback, scene_col_cost)

        # joint limit potential function
        # margin = 0.02
        # joint_limit_loss = jnp.sum(10*jnp.maximum(self.q_lower_bound - q + margin, 0.0)**2) + jnp.sum(10*jnp.maximum(q - self.q_upper_bound + margin, 0.0)**2)
        # joint_limit_loss = 0.0
        joint_limit_loss = joint_limit_cost_func(q[...,-self.num_act_joints:], self.q_lower_bound, self.q_upper_bound, 0.1)
        joint_limit_loss = jnp.sum(joint_limit_loss)

        total_loss = pos_loss + 0.02*quat_loss + joint_limit_loss + self_col_loss + 0.5*scene_col_cost
        if damped:
            total_loss += 1e-7*jnp.sum(q**2)

        return total_loss, (pos_loss, quat_loss, scene_col_cost)




    def IK(self, cur_q, goal_pq, self_collision=False,
                            itr_no=25, damped=True, output_cost=False, grasp_center_coordinate=False, fix_base=False, 
                            linesearch=True, scene_oriCORNs=None,
                            grasped_oriCORN_wrt_ee:Optional[loutil.LatentObjects]=None,
                            place_oriCORN:Optional[loutil.LatentObjects]=None):
        '''
        goal_pq is based on grasp target
        '''
        if cur_q is None:
            cur_q = jnp.array([0,0,0,-np.pi*0.5,0,np.pi*0.5,0])
        
        if goal_pq is not None:
            if isinstance(goal_pq, Tuple):
                goal_pq = jnp.concatenate(goal_pq, axis=-1)
            if grasp_center_coordinate:
                goal_pq = self.get_ee_from_gripper_center(goal_pq[...,:3], goal_pq[...,3:])

        grad_func = jax.grad(partial(self.IK_cost_func, goal_pq=goal_pq, 
                                     self_collision=self_collision, damped=damped, 
                                     fix_base=fix_base, 
                                     scene_oriCORNs=scene_oriCORNs,
                                    grasped_oriCORN_wrt_ee=grasped_oriCORN_wrt_ee,
                                    place_oriCORN=place_oriCORN,
                                     ),
                            has_aux=True)
        ik_cost_func = partial(self.IK_cost_func, goal_pq=goal_pq, self_collision=self_collision, 
                               damped=damped, fix_base=fix_base, scene_oriCORNs=scene_oriCORNs, 
                               grasped_oriCORN_wrt_ee=grasped_oriCORN_wrt_ee,
                               place_oriCORN=place_oriCORN)
        q = cur_q

        if linesearch:
            optimizer = optax.lbfgs(linesearch=None)
        else:
            lr = 3e-2
            optimizer = optax.adam(lr)
        opt_state = optimizer.init(q)
        def body_func(carry):
            q, _, opt_state, itr = carry
            grad, cost = grad_func(q)
            updates, opt_state = optimizer.update(grad, opt_state, q)
            if linesearch:
                q = line_search(q, ik_cost_func, updates, intp_no=10)
            else:
                q = optax.apply_updates(q, updates)
            return q, cost, opt_state, itr+1
        
        def cond_func(carry):
            q, cost, _, itr = carry
            invalid_cond = jnp.any(jnp.array([jnp.sqrt(cost[0]) > 1e-5, jnp.sqrt(cost[1]) > 1e-3, cost[2] > 0.5]))
            return jnp.logical_and(itr < itr_no, invalid_cond)

        q, cost, _, _ = jax.lax.while_loop(cond_func, body_func, (q, (1e5, 1e5, 1e5), opt_state, 0))

        # print(jnp.sqrt(cost[0]), jnp.sqrt(cost[1]))
        if output_cost:
            return q, (jnp.sqrt(cost[0]), jnp.sqrt(cost[1]))
        else:
            return q
    
    def show_in_o3d(self, q, visualize=True):
        from util.reconstruction_util import create_scene_mesh_from_oriCORNs
        import open3d as o3d
        oriCORNs_tf, mesh_pqcs = self.FK(q, oriCORN_out=True)

        # mesh frame at ee pqc
        ee_pqc = mesh_pqcs[...,self.ee_idx,:]
        if ee_pqc.ndim == 1:
            ee_pqc = ee_pqc[None]
        vis_list = []
        for j in range(ee_pqc.shape[0]):
            Hmat_ee = tutil.pq2H(ee_pqc[j])
            Hmat_gripper = tutil.pq2H(*self.get_gripper_center_from_ee_pq(ee_pqc[j,...,:3], ee_pqc[j,...,3:]))
            ee_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3)).transform(Hmat_ee)
            gripper_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3)).transform(Hmat_gripper)
            origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3))
            vis_list.append(ee_frame)
            vis_list.append(gripper_frame)
            vis_list.append(origin_frame)
        
        dec = jax.jit(self.models.occ_prediction) 
        rec_mesh = create_scene_mesh_from_oriCORNs(oriCORNs_tf, dec, qp_bound=1.5, visualize=visualize)
        vis_list.append(rec_mesh)
        return vis_list
    

    def show_in_o3d_gt(self, q, color=None, visualize=True):
        import open3d as o3d
        mesh_pqcs = self.FK(q, oriCORN_out=False)

        mesh_list = []
        for i, obj_idx in enumerate(self.canonical_obj_idx):
            mesh_path = self.models.asset_path_util.obj_path_by_idx(obj_idx)
            mesh = o3d.io.read_triangle_mesh(mesh_path)
            mesh.compute_vertex_normals()
            mesh.scale(self.mesh_scale[i], center=np.zeros(3))
            mesh.transform(tutil.pq2H(mesh_pqcs[i]))
            if color is not None:
                mesh.paint_uniform_color(color)
            mesh_list.append(mesh)
        if visualize:
            o3d.visualization.draw_geometries(mesh_list)
        else:
            return mesh_list
            

    def show_fps(self, q, visualize=True, scene_oriCORNs:Optional[loutil.LatentObjects]=None, color=None):
        from util.reconstruction_util import create_fps_fcd_from_oriCORNs
        import open3d as o3d

        oriCORNs_tf, mesh_pqcs = self.FK(q, oriCORN_out=True)

        # mesh frame at ee pqc
        ee_pqc = mesh_pqcs[...,self.ee_idx,:]
        Hmat_ee = tutil.pq2H(ee_pqc)
        Hmat_gripper = tutil.pq2H(*self.get_gripper_center_from_ee_pq(ee_pqc[...,:3], ee_pqc[...,3:]))
        ee_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3)).transform(Hmat_ee)
        gripper_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3)).transform(Hmat_gripper)
        origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3))
        pcd_vis = create_fps_fcd_from_oriCORNs(oriCORNs_tf, visualize=False)
        if color is not None:
            pcd_vis.paint_uniform_color(color)
        vis_list = [pcd_vis, origin_frame, ee_frame, gripper_frame]

        if scene_oriCORNs is not None:
            vis_list.append(create_fps_fcd_from_oriCORNs(scene_oriCORNs, visualize=False))
        if visualize:
            o3d.visualization.draw_geometries(vis_list)
        else:
            return vis_list
    

    def show_fps_sphere(self, q, visualize=True, scene_oriCORNs:Optional[loutil.LatentObjects]=None, color=None, 
                        grasped_oriCORN:Optional[loutil.LatentObjects]=None, ee_to_obj_pqc:jnp.ndarray=None):
        import open3d as o3d

        oriCORNs_tf, mesh_pqcs = self.FK(q, oriCORN_out=True)

        # mesh frame at ee pqc
        ee_pqc = mesh_pqcs[...,self.ee_idx,:]
        Hmat_ee = tutil.pq2H(ee_pqc)
        Hmat_gripper = tutil.pq2H(*self.get_gripper_center_from_ee_pq(ee_pqc[...,:3], ee_pqc[...,3:]))
        ee_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3)).transform(Hmat_ee)
        gripper_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3)).transform(Hmat_gripper)
        origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=np.zeros(3))
        pcd_vis = oriCORNs_tf.get_fps_sphere_o3d()
        if color is not None:
            pcd_vis.paint_uniform_color(color)
        if grasped_oriCORN is not None:
            grasped_oriCORNs_tf = grasped_oriCORN.apply_pq_z(tutil.pq_multi(ee_pqc, ee_to_obj_pqc), self.rot_configs)
            grasped_oriCORNs_tf = grasped_oriCORNs_tf.get_fps_sphere_o3d()
            pcd_vis += grasped_oriCORNs_tf
        vis_list = [*pcd_vis, origin_frame, ee_frame, gripper_frame]

        if scene_oriCORNs is not None:
            vis_list = vis_list+scene_oriCORNs.get_fps_sphere_o3d()
        if visualize:
            o3d.visualization.draw_geometries(vis_list)
        else:
            return vis_list

    def set_q_pb(self, shakey_pb_uid, q, grasped_obj_pbid=None, grasped_obj_pq=None):
        def reset_q_pb(q):
            for i in range(self.num_act_joints):
                p.resetJointState(shakey_pb_uid, self.link_idx_to_pb_idx[i+1], q[i])

        if q.shape[-1] == self.num_act_joints + 3:
            base_pos, base_quat = tutil.SE2h2pq(q[..., :3], self.robot_height)
            q = q[..., 3:]

            p.resetBasePositionAndOrientation(shakey_pb_uid, base_pos, base_quat)


        assert q.shape[-1] == self.num_act_joints

        reset_q_pb(q)
        if grasped_obj_pbid is not None:
            if grasped_obj_pq.ndim == 1:
                grasped_obj_pq = grasped_obj_pq[None]

            gripper_pqc = self.get_ee_pq_pb(shakey_pb_uid, gripper_center=False)

            for oidx, gop in enumerate(grasped_obj_pbid):
                obj_pqc = p.multiplyTransforms(gripper_pqc[oidx,:3], gripper_pqc[oidx,3:], grasped_obj_pq[oidx, :3], grasped_obj_pq[oidx, 3:])
                p.resetBasePositionAndOrientation(gop, *obj_pqc)

    def get_ee_pq_pb(self, shakey_pb_uid, gripper_center=False):
        if np.array(self.ee_idx).ndim == 0:
            ee_indices = np.array(self.ee_idx)[None]
        else:
            ee_indices = np.array(self.ee_idx)
        
        ee_pq_list = []
        for idx in ee_indices:
            ee_pq = p.getLinkState(shakey_pb_uid, idx, computeForwardKinematics=True)[4:6]
            ee_pq = (np.array(ee_pq[0]), np.array(ee_pq[1]))
            if gripper_center:
                ee_pq = self.get_gripper_center_from_ee_pq(*ee_pq)
            ee_pq_list.append(np.concat(ee_pq))
        
        ee_pq_list = np.stack(ee_pq_list, axis=0)
        return ee_pq_list


    def show_in_pb(self, q, base_pos=np.zeros(3,), base_quat=np.array([0,0,0.,1.]),grasped_obj_pbid=None, grasp_pqc=None):
        # visualize shakey in pybullet

        # check pybullet is connected and if not, connect
        if p.getConnectionInfo()['isConnected'] == 0:
            p.connect(p.GUI)
        
        shakey_pb_uid = p.loadURDF(self.urdf_dir, useFixedBase=True, basePosition=base_pos, baseOrientation=base_quat)

        # perform FK
        assert q.shape[-1] == self.num_act_joints
        def reset_q_pb(q):
            for i in range(self.num_act_joints):
                p.resetJointState(shakey_pb_uid, self.link_idx_to_pb_idx[i+1], q[i])
        reset_q_pb(q)

        def reset_grasped_obj(grasped_obj_pbid, grasp_pqc):
            epos, equat = p.getLinkState(shakey_pb_uid, shakey.ee_idx, computeForwardKinematics=True)[4:6]
            new_epos, new_equat = p.multiplyTransforms(epos, equat, grasp_pqc[:3], grasp_pqc[3:])
            p.resetBasePositionAndOrientation(grasped_obj_pbid, new_epos, new_equat)

        return shakey_pb_uid

    def create_pb(self, se2=np.zeros(3,), visualize=False):
        # check pybullet is connected and if not, connect
        if p.getConnectionInfo()['isConnected'] == 0:
            p.connect(p.GUI if visualize else p.DIRECT)
        base_pos_3d = np.zeros(3)
        base_pos_3d[:2] = se2[:2]
        base_pos_3d[2] = self.robot_height
        base_quat = sciR.from_euler('z', se2[2]).as_quat()
        shakey_pb_uid = p.loadURDF(self.urdf_dir, useFixedBase=True, basePosition=base_pos_3d, 
                                   baseOrientation=base_quat, flags=p.URDF_USE_SELF_COLLISION)
        if self.robot_type in ['im2', 'RobotBimanualV4']:
            # remove collision between link 2 - 4 and 8 - 10
            # for i in range(1,13):
            #     p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, i, 0)

            p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 3, 5, 0)
            p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 9, 11, 0)
            p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 2, 0)
            p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 8, 0)
            # p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 3, 0)
            # p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 9, 0)
        # if self.robot_type == 'ur5':
        #     # remove collision between link 2 - 4 and 8 - 10
        #     p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 1, 0)
        #     p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 2, 0)
        #     p.setCollisionFilterPair(shakey_pb_uid, shakey_pb_uid, 0, 3, 0)
        return shakey_pb_uid

    # @staticmethod
    def get_gripper_center_from_ee_pq(self, ee_pos, ee_quat):
        # gripper_tip_offset_from_ee = (np.array([0,0,0.200]), np.array([0,0,0,1.])) # for shakey
        # gripper_tip_offset_from_ee = (np.array([0,0.14,-0.01]), sciR.from_euler('x', -np.pi/2).as_quat()) # for IM2
        return tutil.pq_multi(ee_pos, ee_quat, *self.gripper_tip_offset_from_ee)

    # @staticmethod
    def get_ee_from_gripper_center(self, gripper_center_pos, gripper_center_quat):
        # gripper_tip_offset_from_ee = (np.array([0,0,0.200]), np.array([0,0,0,1.])) # for shakey
        # gripper_tip_offset_from_ee = (np.array([0,0.14,-0.01]), sciR.from_euler('x', -np.pi/2).as_quat()) # for IM2
        return tutil.pq_multi(gripper_center_pos, gripper_center_quat, *tutil.pq_inv(*self.gripper_tip_offset_from_ee))

    def get_gripper_tip_line_pnts(self, ee_pos, ee_quat):
        grasp_len = 0.09
        interpolation_resolution = 30
        finger_end_pos, finger_end_quat = self.get_gripper_center_from_ee_pq(ee_pos, ee_quat)

        finger_end_R = tutil.q2R(finger_end_quat)
        finger_end_y_dir = finger_end_R[..., 0]
        finger_end_z_dir = finger_end_R[..., 2]
        finger_end_1 = finger_end_pos + grasp_len/2*finger_end_y_dir
        finger_end_2 = finger_end_pos - grasp_len/2*finger_end_y_dir
        colregion_check_pnts = finger_end_1[:,None,:] + jnp.linspace(0, 1, interpolation_resolution)[...,None]*(finger_end_2[:,None,:] - finger_end_1[:,None,:])
        return colregion_check_pnts

    def random_q(self, jkey, outer_shape):
        return jax.random.uniform(jkey, shape=outer_shape+(self.num_act_joints,), minval=self.q_lower_bound, maxval=self.q_upper_bound)
    
    def get_IK_jit_func(self, robot_base_pqc, grasp_center_coordinate=False):

        def IK_func(cur_q, goal_pqc):
            if cur_q.shape[-1] == self.num_act_joints:
                robot_se2, robot_height = tutil.pq2SE2h(*robot_base_pqc)
                cur_q = jnp.concat([robot_se2 + jnp.zeros_like(cur_q[...,:1]), cur_q], axis=-1)
                return partial(self.IK, itr_no=25, grasp_center_coordinate=grasp_center_coordinate, fix_base=True, linesearch=True)(cur_q, goal_pqc)
            else:
                return partial(self.IK, itr_no=40, grasp_center_coordinate=grasp_center_coordinate, fix_base=False, linesearch=True)(cur_q, goal_pqc)

        return jax.jit(IK_func)

    @property
    def numjoints(self):
        return len(self.act_joint_mask)
    @property
    def num_act_joints(self):
        return self.q_lower_bound.shape[0]
    @property
    def gripper_oriCORN(self):
        return self.link_canonical_oriCORN[-1]
    
        
def draw_coordinate_frame(pos_in, orn_in, length=0.1):
    if pos_in.ndim==1:
        pos_batch, orn_batch = pos_in[None], orn_in[None]
    else:
        pos_batch, orn_batch = pos_in, orn_in
    
    for i in range(pos_batch.shape[0]):
        pos, orn = pos_batch[i], orn_batch[i]
        # Convert quaternion orientation to rotation matrix
        rotation = sciR.from_quat(orn).as_matrix()

        # X-axis (red), Y-axis (green), Z-axis (blue)
        x_axis = np.array([length, 0, 0])
        y_axis = np.array([0, length, 0])
        z_axis = np.array([0, 0, length])
        
        # Transform axes by the rotation
        x_axis_world = pos + rotation @ x_axis
        y_axis_world = pos + rotation @ y_axis
        z_axis_world = pos + rotation @ z_axis
        
        # Draw the axes in the PyBullet GUI
        p.addUserDebugLine(pos, x_axis_world, [1, 0, 0], 3)  # Red for X
        p.addUserDebugLine(pos, y_axis_world, [0, 1, 0], 3)  # Green for Y
        p.addUserDebugLine(pos, z_axis_world, [0, 0, 1], 3)  # Blue for Z



def load_urdf_kinematics(urdf_dirs=None, robot_height=0.0, models:mutil.Models=None, use_collision_mesh=False, gui=False):

    # models = models.load_self_collision_model()

    if urdf_dirs is None:
        urdf_dirs = "assets/ur5/urdf/shakey_open.urdf"
        # urdf_dirs = "assets/RobotBimanualV4/urdf/RobotBimanualV4.urdf"

    if urdf_dirs.split('/')[-3] == 'ur5':
        robot_name = 'ur5'
    else:
        robot_name = 'RobotBimanualV4' # 'shakey'
    
    if not os.path.exists(urdf_dirs):
        assert False, 'urdf file not found: {}'.format(urdf_dirs)

    if models is None:
        assert canonical_obj_filenames is not None
        assert mesh_aligned_canonical_obj is not None
        assert rot_configs is not None
    else:
        canonical_obj_filenames = models.canonical_latent_obj_filename_list
        mesh_aligned_canonical_obj = models.mesh_aligned_canonical_obj
        rot_configs = models.rot_configs

    dataset_names = [cof.split('/')[-3] for cof in canonical_obj_filenames]
    basenames = [cof.split('/')[-1].split('.')[0] for cof in canonical_obj_filenames]
    id_basenames = [dn+bn for dn, bn in zip(dataset_names, basenames)]

    # import ur5 robot in pybullet

    p.connect(p.GUI if gui else p.DIRECT)
    shakey_pb_uid = p.loadURDF(urdf_dirs, useFixedBase=True)

    numjoints = p.getNumJoints(shakey_pb_uid)

    canonical_obj_idx_list = []
    joint_type_list = []
    link_to_mesh_pq_list = []
    rel_joint_pq_list = []
    cur_link_idx_list = []
    parent_link_idx_list = []
    mesh_scale_list = []
    joint_axis_list = []
    link_name_list = []
    joint_limit_min_list = []
    joint_limit_max_list = []
    pb_link_idx_to_mesh_idx = {}
    link_idx_to_pb_idx = -1*np.ones(len(p.getVisualShapeData(shakey_pb_uid)), dtype=np.int32)
    for i, vis_shape_data in enumerate(p.getVisualShapeData(shakey_pb_uid)):
        link_idx = vis_shape_data[1]
        mesh_scale = vis_shape_data[3]
        mesh_name = vis_shape_data[4]
        local_visual_frame_pos = vis_shape_data[5] # relative to link/joint frame
        local_visual_frame_quat = vis_shape_data[6]
        # id_name = str(mesh_name).split('/')[-1].split('.')[0]
        # ds_name = str(mesh_name).split('/')[-3]
        # if ds_name=='':
        #     ds_name = str(mesh_name).split('/')[-4]
        # if use_collision_mesh:
        #     id_name = 'cvx_'+id_name
        # id_name = ds_name + id_name
        # for j in range(len(canonical_obj_filenames)):
        #     # if dataset_names[j] in (['ur5', 'rg2_gripper'] if robot_name=='ur5' else ['RobotBimanualV4', 'gripper_urdf']):
        #     if id_name == id_basenames[j]:
        #         canonical_obj_idx = j
        #         break
        # assert j < len(canonical_obj_filenames)-1, 'No matching canonical object found for {}'.format(id_name)

        mesh_name_for_obj_select = mesh_name.decode().replace('//','/')
        if use_collision_mesh:
            mesh_name_for_obj_select = os.path.join(os.path.dirname(mesh_name_for_obj_select), 'cvx_'+mesh_name_for_obj_select.split('/')[-1])
        canonical_obj_idx = models.asset_path_util.get_obj_id(mesh_name_for_obj_select)
        assert canonical_obj_idx != -1, 'No matching canonical object found for {}'.format(mesh_name)

        # joint_info = p.getJointInfo(shakey_pb_uid, i)
        if link_idx==-1:
            joint_info = None
            joint_name = 'base'
            joint_type = p.JOINT_FIXED
            link_name = 'base'
            joint_axis = np.array([0,0,0])
            joint_pos_wrt_parent_frame = np.array([0., 0., 0.])
            joint_quat_wrt_parent_frame = np.array([0., 0., 0., 1.])
            parent_idx = -1
            pb_link_idx_to_mesh_idx[-1] = 0
            link_idx_to_pb_idx[0] = -1
            joint_limit_min = 0
            joint_limit_max = 0
        else:
            joint_info = p.getJointInfo(shakey_pb_uid, link_idx)
            joint_name = joint_info[1]
            joint_type = joint_info[2]
            link_name = joint_info[12]
            joint_axis = joint_info[13]
            joint_pos_wrt_parent_frame = joint_info[14]
            joint_quat_wrt_parent_frame = joint_info[15]
            joint_limit_min = joint_info[8]
            joint_limit_max = joint_info[9]
            # parent_idx = joint_info[16]
            if joint_info[16] == -1 and -1 not in pb_link_idx_to_mesh_idx:
                pb_link_idx_to_mesh_idx[-1] = -1
            parent_idx = pb_link_idx_to_mesh_idx[joint_info[16]]
            pb_link_idx_to_mesh_idx[link_idx] = i
            link_idx_to_pb_idx[i] = link_idx
        # p.JOINT_REVOLUTE
        # p.JOINT_FIXED

        joint_limit_min_list.append(joint_limit_min)
        joint_limit_max_list.append(joint_limit_max)
        canonical_obj_idx_list.append(canonical_obj_idx)
        joint_type_list.append(joint_type)
        joint_axis_list.append(np.array(joint_axis))
        link_to_mesh_pq_list.append((np.array(local_visual_frame_pos), np.array(local_visual_frame_quat)))
        rel_joint_pq_list.append((np.array(joint_pos_wrt_parent_frame), np.array(joint_quat_wrt_parent_frame)))
        cur_link_idx_list.append(i)
        parent_link_idx_list.append(parent_idx)
        mesh_scale_list.append(mesh_scale)
        link_name_list.append(link_name)

    # calculate link frame transform
    link_frame_to_com_list = []
    link_pq_list_gt = []
    joint_to_link_pq_list = []
    link_to_joint_pq_list = []
    joint_pq_list = []
    # for i in range(numjoints):
    for cnt, pbidx in enumerate(pb_link_idx_to_mesh_idx):
        # i - pb link idx
        i = pb_link_idx_to_mesh_idx[pbidx]
        if i<0:
            continue
        parent_idx = parent_link_idx_list[i]
        if pbidx==-1:
            link_frame_to_com_list.append((np.zeros(3), np.array([0,0,0,1])))
            link_pq_list_gt.append((np.zeros(3), np.array([0,0,0,1])))
            link_pq = np.zeros(3), np.array([0,0,0,1])
            cur_com_quat = np.array([0,0,0,1])
        else:
            link_state = p.getLinkState(shakey_pb_uid, pbidx)
            link_frame_to_com_list.append((np.array(link_state[2]), np.array(link_state[3])))
            link_pq = np.array(link_state[4]), np.array(link_state[5])
            cur_com_quat = link_state[1]

        if i==0:
            link_to_joint_pq = rel_joint_pq_list[i][0], rel_joint_pq_list[i][1]
            joint_pq = link_to_joint_pq
        else:
            assert parent_idx <= i, 'Parent link index is smaller than current link index'
            if parent_idx == -1:
                parent_link_state = (np.zeros(3), np.array([0,0,0,1]))
            else:
                parent_link_state = p.getLinkState(shakey_pb_uid, parent_idx)
            joint_pq_gt = p.multiplyTransforms(parent_link_state[0], parent_link_state[1], rel_joint_pq_list[i][0], rel_joint_pq_list[i][1])
            link_to_joint_pq = tutil.pq_multi(*link_frame_to_com_list[parent_idx], rel_joint_pq_list[i][0], rel_joint_pq_list[i][1])
            joint_pq = tutil.pq_multi(*link_pq_list_gt[parent_idx], *link_to_joint_pq)
        
        # wiered.. joint axis from pybullet is expressed in 'next' com frame
        joint_axis_tmp = tutil.qaction(np.array(cur_com_quat), joint_axis_list[i])
        joint_axis_list[i] = tutil.qaction(tutil.qinv(joint_pq[1]), joint_axis_tmp)
        
        link_to_joint_pq_list.append(link_to_joint_pq)
        joint_pq_list.append(joint_pq)
        joint_to_link_pq_list.append(tutil.pq_multi(*tutil.pq_inv(*joint_pq), *link_pq))
        link_pq_list_gt.append(link_pq)


    act_joint_mask = np.array(joint_type_list) != p.JOINT_FIXED
    num_act_joints = np.sum(act_joint_mask)

    joint_to_link_pq_arr = jnp.stack([jnp.concat(pq) for pq in joint_to_link_pq_list], 0)
    link_to_joint_pq_arr = jnp.stack([jnp.concat(pq) for pq in link_to_joint_pq_list], 0)

    joint_limit_min_arr = np.array(joint_limit_min_list)[act_joint_mask]
    joint_limit_max_arr = np.array(joint_limit_max_list)[act_joint_mask]

    mesh_scale_arr = jnp.array(mesh_scale_list)
    assert np.all(mesh_scale_arr[...,0] == mesh_scale_arr[...,1])
    assert np.all(mesh_scale_arr[...,1] == mesh_scale_arr[...,2])
    mesh_scale_arr = mesh_scale_arr[...,0]
    joint_axis_arr = jnp.array(joint_axis_list)
    canonical_obj_idx_arr = np.array(canonical_obj_idx_list)
    parent_link_idx_arr = np.array(parent_link_idx_list)
    link_to_mesh_pq_arr = jnp.stack([jnp.concat(pq) for pq in link_to_mesh_pq_list], 0)

    shakey_oriCORNs = mesh_aligned_canonical_obj[canonical_obj_idx_arr]
    shakey_oriCORNs = shakey_oriCORNs.apply_scale(mesh_scale_arr, center=np.zeros(3))

    if urdf_dirs.split('/')[-1]=='shakey_open_rg6.urdf':
        gripper_tip_offset_from_ee = (np.array([0,0,0.265]), np.array([0,0,0,1.])) # for shakey
        ee_idx = 6
    if urdf_dirs.split('/')[-1]=='shakey_robotiq_open.urdf':
        gripper_tip_offset_from_ee = (np.array([0,0,0.150]), np.array([0,0,0,1.])) # for robotiq
        ee_idx = 6
    elif robot_name == 'ur5':
        gripper_tip_offset_from_ee = (np.array([0,0,0.200]), np.array([0,0,0,1.])) # for shakey
        ee_idx = 6
    elif robot_name == 'RobotBimanualV4':
        gripper_tip_offset_from_ee = (np.array([0,0.14,-0.01]), sciR.from_euler('x', -np.pi/2).as_quat()) # for IM2
        ee_idx = np.array([6, 12])
    else:
        raise NotImplementedError

    shakey = Shakey( 
            robot_type=robot_name,
            link_idx_to_pb_idx=link_idx_to_pb_idx,
           link_canonical_oriCORN=shakey_oriCORNs,
           canonical_obj_idx=canonical_obj_idx_arr,
           act_joint_mask=act_joint_mask, 
           collision_check_link_idx=np.arange(len(canonical_obj_idx_arr)),
           arm_joint_idx=np.where(act_joint_mask)[0],
           joint_to_link_pq=joint_to_link_pq_arr, 
           link_to_joint_pq=link_to_joint_pq_arr, 
           mesh_scale=mesh_scale_arr, 
           joint_axis=joint_axis_arr, 
           parent_link_idx=parent_link_idx_arr, 
           link_to_mesh_pq=link_to_mesh_pq_arr,
           q_lower_bound=joint_limit_min_arr,
            q_upper_bound=joint_limit_max_arr,
           robot_height=robot_height,
           models=models,
           gripper_tip_offset_from_ee=gripper_tip_offset_from_ee,
           ee_idx=ee_idx,
            rot_configs=rot_configs,
            urdf_dir=urdf_dirs)

    p.disconnect()

    return shakey



def shakey_to_trimesh(urdf_dir, link_pqc):
    import trimesh
    import trimesh.transformations as tra

    uid = p.loadURDF(urdf_dir)

    # Get visual shape data
    visual_shapes = p.getVisualShapeData(uid)
    meshes = []

    for i, visual_shape in enumerate(visual_shapes):
        (
            objectUniqueId,
            linkIndex,
            visualGeometryType,
            dimensions,
            filename,
            visualFramePosition,
            visualFrameOrientation,
            rgbaColor,
        ) = visual_shape

        # Decode filename if necessary
        if isinstance(filename, bytes):
            filename = filename.decode('utf-8')
        filename = filename.replace('assets/ur5/urdf/', '')

        # always meshes
        mesh_file = filename
        if mesh_file:
            mesh_path = Path(mesh_file)
            try:
                mesh = trimesh.load_mesh(str(mesh_path))
                mesh.apply_scale(dimensions)
            except Exception as e:
                print(f"Failed to load mesh {mesh_path}: {e}")
                continue

        # # # Apply visual frame transform
        link_orn = link_pqc[i, 3:]
        link_pos = link_pqc[i, :3]
        
        mesh.apply_transform(tutil.pq2H(link_pos, link_orn))

        # Collect the mesh
        meshes.append(mesh)

    # Merge meshes
    merged_mesh = trimesh.util.concatenate(meshes)
    return merged_mesh




if __name__ == '__main__':
    models = mutil.Models().load_pretrained_models()
    urdf_dir = "assets/ur5/urdf/shakey_open.urdf"
    # urdf_dir = "assets/ur5/urdf/shakey_open_rg6.urdf"
    # urdf_dir = "assets/RobotBimanualV4/urdf/RobotBimanualV4.urdf"
    # urdf_dir = "assets/RobotBimanualV4/urdf/RobotBimanualV4_onearm.urdf"
    # urdf_dir = "/home/dongwon/object_set//assets/ur5/urdf/shakey_open_rg6.urdf"
    # urdf_dir = "assets/ur5/urdf/shakey_robotiq_open.urdf"
    shakey = load_urdf_kinematics(urdf_dir, models=models, use_collision_mesh=False, gui=True)

    for _ in range(100):
        random_q = np.random.uniform(shakey.q_lower_bound, shakey.q_upper_bound)
        # shakey.show_in_pb(random_q)
        # shakey.show_fps(random_q)
        shakey.show_in_o3d(random_q)
        # shakey.show_in_o3d_gt(random_q)

    random_q = np.random.uniform(shakey.q_lower_bound, shakey.q_upper_bound)
    random_q = np.zeros_like(random_q)
    link_pqcs = shakey.FK(random_q, False)
    shakey.show_in_pb(random_q)

    # ee_pqc = link_pqcs[shakey.ee_idx]
    # POS_EC = np.array([-0.0325, -0.040, 0.13375+0.0213]).astype(np.float32)
    # QUAT_EC = (sciR.from_euler('x', np.pi)).as_quat().astype(np.float32)
    # cam_pq = tutil.pq_multi(ee_pqc[...,:3], ee_pqc[...,3:], POS_EC, QUAT_EC)

    draw_coordinate_frame(*shakey.get_gripper_center_from_ee_pq(link_pqcs[shakey.ee_idx][...,:3], link_pqcs[shakey.ee_idx][...,3:]))
    draw_coordinate_frame(link_pqcs[shakey.ee_idx][...,:3], link_pqcs[shakey.ee_idx][...,3:])
    # draw_coordinate_frame(*cam_pq)
    ik_target = link_pqcs[shakey.ee_idx]
    ik_jit_func = jax.jit(partial(shakey.IK, itr_no=5, output_cost=True, linesearch=True))
    # ik_jit_func = partial(shakey.IK, itr_no=500, output_cost=True, linesearch=False)
    # q, IK_cost = shakey.IK(np.zeros(14), ik_target, itr_no=10, output_cost=True)
    
    print('start plans')
    q, IK_cost = ik_jit_func(random_q, ik_target)
    print('end plans')
    shakey.show_in_pb(q)
    print(IK_cost)

    print(1)