import os
import sys

BASEDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import jax
import jax.numpy as jnp
import open3d as o3d
import trimesh
import coacd
import numpy as np
from typing import Tuple
import hashlib
import pickle

import util.transform_util as tutil
from jax_libccd import _gjk_epa_module
for name, target in _gjk_epa_module.registrations().items():
  jax.ffi.register_ffi_target(name, target)

from .base import ContinuousCollisionCostBase

def compute_penetration(mesh1, mesh2):
    shape = mesh1.shape[:-2]
    penetration_depth, penetration_dir, contact_points = jax.ffi.ffi_call(
        "compute_penetration",
        (
            jax.ShapeDtypeStruct(shape, jnp.float32), # depth
            jax.ShapeDtypeStruct((*shape, 3,), jnp.float32), # penetration_dir
            jax.ShapeDtypeStruct((*shape, 3,), jnp.float32), # contact_point
        ),
        vmap_method="broadcast_all",
    )(mesh1, mesh2)
    return penetration_depth, penetration_dir, contact_points

def hash(mesh_name, scale, pqc, link_idx):
    mesh_name_key = int(hashlib.sha256(mesh_name.encode('utf-8')).hexdigest(), 16) % 10**8
    scale_key = int(hashlib.sha256(str(scale).encode('utf-8')).hexdigest(), 16) % 10**8
    pqc_key = int(hashlib.sha256(str(pqc).encode('utf-8')).hexdigest(), 16) % 10**8
    link_idx_key = int(hashlib.sha256(str(link_idx).encode('utf-8')).hexdigest(), 16) % 10**8
    return (mesh_name_key + scale_key + pqc_key + link_idx_key) % 10**8


class TrajOptCCD(ContinuousCollisionCostBase):
    def __init__(self, col_coef):
        self.col_coef = col_coef
        self.fixed_meshes = []
        self.fixed_mesh_point_nums = []
        self.moving_mehses = []
        self.moving_mesh_point_nums = []
        self.use_cache = True
        self.cache_dir = os.path.join(BASEDIR, "cache", "convex_meshes")
        os.makedirs(self.cache_dir, exist_ok=True)

    def pad_meshes(self, mesh_sets):
        # return: [[NOB_CONVEX, NP_MAX, 3], ...], NOB_CONVEX can be different for each element
        # pad the meshes to have the same length
        max_len = max([max([len(mesh) for mesh in mesh_set]) for mesh_set in mesh_sets])
        padded_meshes = []
        for mesh_set in mesh_sets:
            padded_mesh = []
            for i in range(len(mesh_set)):
                mesh = mesh_set[i]
                if len(mesh) < max_len:
                    pad_len = max_len - len(mesh)
                    pad = mesh[-1, :] * jnp.ones((pad_len, 3))
                    mesh = jnp.concatenate([mesh, pad], axis=0)
                padded_mesh.append(mesh)
            padded_mesh = jnp.stack(padded_mesh, axis=0)
            padded_meshes.append(padded_mesh)
        return padded_meshes

    def enroll_mesh(self, mesh_name, scale, pqc, link_idx):
        key = hash(mesh_name, scale, pqc, link_idx)
        file_name = os.path.join(self.cache_dir, f"{key}.pkl")
        if os.path.exists(file_name) and self.use_cache:
            with open(file_name, "rb") as f:
                (meshes, point_nums, is_fixed_mesh) = pickle.load(f)
        else:
            mesh_o3d = o3d.io.read_triangle_mesh(mesh_name)
            is_fixed_mesh = (link_idx == -1)
            if isinstance(scale, Tuple):
                scale = scale[0]
            mesh_o3d.scale(scale, center=(0, 0, 0))
            if pqc is not None:
                H_mat = tutil.pq2H(pqc)
            else:
                H_mat = jnp.eye(4)

            # mesh_trimesh.apply_transform(np.array(H_mat))
            mesh_o3d.transform(H_mat)
            mesh_o3d.compute_vertex_normals()

            vertices = np.array(mesh_o3d.vertices)
            faces = np.array(mesh_o3d.triangles)
            decomposed_mesh = coacd.run_coacd(coacd.Mesh(vertices, faces), decimate=True, max_ch_vertex=25)
            # if not trimesh.convex.is_convex(mesh):
            meshes = []
            point_nums = []
            for (vertices, _) in decomposed_mesh:
                points = np.unique(np.asarray(vertices), axis=0)
                meshes.append(jnp.array(points))
                point_num = len(points)
                point_nums.append(point_num)
            with open(file_name, "wb") as f:
                pickle.dump((meshes, point_nums, is_fixed_mesh), f)

        if is_fixed_mesh:
            self.fixed_meshes.append(meshes)
            self.fixed_mesh_point_nums.extend(point_nums)
        else:
            self.moving_mehses.append(meshes)
            self.moving_mesh_point_nums.extend(point_nums)

    def enroll_meshes(self, mesh_names, scales, pqcs=None, link_idxs=None):
        # 1. get decomposed convex mesh, assume the mesh is convex decomposed
        for mesh_name, scale, pqc, link_idx in zip(mesh_names, scales, pqcs, link_idxs):
            self.enroll_mesh(mesh_name, scale, pqc, link_idx)

        self.fixed_meshes = self.pad_meshes(self.fixed_meshes)
        self.moving_mehses = self.pad_meshes(self.moving_mehses)
        self.fixed_mesh_point_nums = jnp.array(self.fixed_mesh_point_nums, dtype=jnp.int32)
        self.moving_mesh_point_nums = jnp.array(self.moving_mesh_point_nums, dtype=jnp.int32)
        self.fixed_point_sets = jnp.concatenate(self.fixed_meshes, axis=0) # [NOA_CONVEX, NP_MAX, 3]

    def place_meshes(self, link_pqc):
        # link_pqc: [B, T, NOB, 7]
        _, _, num_moving_obj, _ = link_pqc.shape
        assert num_moving_obj == len(self.moving_mehses), f"NOB: {num_moving_obj} != len(self.moving_mehses): {len(self.moving_mehses)}"
        moving_point_sets = []
        for i in range(num_moving_obj):
            point_set = self.moving_mehses[i] # [NOB_CONVEX, NP_MAX, 3]
            link_pqc[:, :, i] # [B, T, 7]
            tf_point_set = tutil.pq_action(link_pqc[:, :, None, None, i], point_set[None, None])
            moving_point_sets.append(tf_point_set)
        moving_point_sets = jnp.concatenate(moving_point_sets, axis=2) # [B, T, NOB_CONVEX, NP_MAX, 3]

        return moving_point_sets # [B, T, NOB_CONVEX, NP_MAX, 3]

    def ccd(self, moving_point_sets):
        fixed_point_sets = self.fixed_point_sets
        # fixed_point_sets: [NOA_CONVEX, NP_MAX, 3]
        # moving_point_sets: [B, T, NOB_CONVEX, NP_MAX, 3]

        # compute convex hull(just give set of points)
        moving_point_convex_hull = jnp.concatenate([
            moving_point_sets[:, :-1],
            moving_point_sets[:, 1:],
        ], axis=-2) # [B, T-1, NOB_CONVEX, NP_MAX * 2, 3]
        # generate pairs for computing penetration depth
        # number of pairs: B * T-1 * NOB_CONVEX * NOA_CONVEX
        pair_shape = moving_point_convex_hull.shape[:-2] + fixed_point_sets.shape[:-2] # [B, T-1, NOB_CONVEX, NOA_CONVEX]
        fixed_point_sets = jnp.broadcast_to(fixed_point_sets, pair_shape + fixed_point_sets.shape[-2:]) # [B, T-1, NOB_CONVEX, NOA_CONVEX, NP_MAX, 3]
        moving_point_convex_hull = jnp.broadcast_to(moving_point_convex_hull[..., None, :, :], pair_shape + moving_point_convex_hull.shape[-2:]) # [B, T-1, NOB_CONVEX, NOA_CONVEX, NP_MAX * 2, 3]

        # compute penetration depth

        # fixed_point_sets = jax.device_put(fixed_point_sets, cpu_device)
        # moving_point_convex_hull = jax.device_put(moving_point_convex_hull, cpu_device)

        outter_shape = fixed_point_sets.shape[:-2]
        fixed_point_sets = fixed_point_sets.reshape((-1, fixed_point_sets.shape[-2], fixed_point_sets.shape[-1]))
        moving_point_convex_hull = moving_point_convex_hull.reshape((-1, moving_point_convex_hull.shape[-2], moving_point_convex_hull.shape[-1]))
        # penetration_depth, penetration_dir, _ = compute_penetration(fixed_point_sets, moving_point_convex_hull)
        penetration_depth, penetration_dir, _ = jax.pure_callback(
            compute_penetration,
            (
                jax.ShapeDtypeStruct((moving_point_convex_hull.shape[0], ), jnp.float32), # depth
                jax.ShapeDtypeStruct((moving_point_convex_hull.shape[0], 3,), jnp.float32), # penetration_dir
                jax.ShapeDtypeStruct((moving_point_convex_hull.shape[0], 3,), jnp.float32), # contact_point
            ),
            fixed_point_sets,
            moving_point_convex_hull,
            vmap_method="broadcast_all",
        )
        penetration_depth = penetration_depth.reshape(outter_shape)
        penetration_dir = penetration_dir.reshape((*outter_shape, 3))

        return penetration_depth, penetration_dir # [B, T-1, NOB_CONVEX, NOA_CONVEX], [B, T-1, NOB_CONVEX, NOA_CONVEX, 3]

    def __call__(self, moving_obj, moving_obj_pqs, fixed_obj):
        @jax.custom_gradient
        def convex_mesh_collision_loss(moving_meshes):
            penetration_depth, penetration_dir = self.ccd(moving_meshes)
            penetration_dir = -penetration_dir
            # [B, T-1, NOB_CONVEX, NOA_CONVEX], [B, T-1, NOB_CONVEX, NOA_CONVEX, 3]
            penetration_dir_pad = jnp.pad(penetration_dir, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0))) # [B, T + 1, NOB_CONVEX, NOA_CONVEX, 3]
            grad_val = (penetration_dir_pad[:, :-1] + penetration_dir_pad[:, 1:]) / 2.0 # [B, T, NOB_CONVEX, NOA_CONVEX, 3]
            grad_val = grad_val.sum(axis=-2, keepdims=True) # [B, T, NOB_CONVEX, 1, 3]
            grad_val = jnp.broadcast_to(grad_val, moving_meshes.shape) # [B, T, NOB_CONVEX, MAX_POINTS, 3]
            collision_loss = jnp.sum(penetration_depth, axis=(-1,-2,-3))
            def grad_fn(upstream_grad):
                return (upstream_grad[0][...,None,None,None,None]*grad_val,)
            return (collision_loss, penetration_depth), grad_fn

        moving_meshes = self.place_meshes(moving_obj_pqs)
        collision_loss, collision_loss_pair = convex_mesh_collision_loss(moving_meshes)
        collision_logits = collision_loss
        collision_loss = self.col_coef*collision_loss
        return collision_loss, {
            "collision_logits": collision_logits,
        }

if __name__ == "__main__":

    from modules import shakey_module
    import util.model_util as mutil
    import util.scene_util as scene_util
    import pybullet as pb

    # pb.connect(pb.GUI)

    trajopt_cls = TrajOptCCD()
    models = mutil.Models().load_pretrained_models()
    urdf_dirs = "assets/ur5/urdf/shakey_open_rg6.urdf"
    urdf_dirs = "assets/ur5/urdf/shakey_open.urdf"
    models = models.load_self_collision_model('shakey')
    shakey = shakey_module.load_urdf_kinematics(
        urdf_dirs=urdf_dirs,
        models=models,
    )
    base_se2 = np.array([0, -0.15, -np.pi/2.0])
    robot_pb_uid = shakey.create_pb(se2=base_se2, visualize=True)
    base_pqc = tutil.SE2h2pq(base_se2, np.array(shakey.robot_height))
    robot_base_pqc = jnp.concat(base_pqc, axis=-1)

    q_zero = np.zeros(6)


    pb.resetDebugVisualizerCamera(
        cameraDistance=1.79, 
        cameraYaw=-443.20,
        cameraPitch=0.26,
        cameraTargetPosition=[-0.05, 0.07, 0.30]
    )

    @jax.jit
    def fk(q):
        return tutil.pq_multi(robot_base_pqc, shakey.FK(q, oriCORN_out=False)[shakey.ee_idx])
    ik = shakey.get_IK_jit_func((robot_base_pqc[:3], robot_base_pqc[3:]), grasp_center_coordinate=False)

    def sample_valid_q(z_range, jkey):
        # turn on self collision
        lower_bound = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
        upper_bound = np.array([2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])
        while True:
            jkey, _ = jax.random.split(jkey)
            q_random = jax.random.uniform(jkey, (6,), minval=lower_bound, maxval=upper_bound)
            shakey.set_q_pb(robot_pb_uid, q_random)
            pb.performCollisionDetection()
            col_res = pb.getContactPoints(robot_pb_uid)
            if len(col_res) != 0:
                continue
            pqc = fk(q_random)
            if pqc[2] > z_range[0] and pqc[2] < z_range[1]:
                if pqc[1] > 0.2:
                    break
        return q_random

    jkey = jax.random.PRNGKey(0)
    random_obj, environment_obj, pybullet_scene = scene_util.create_table_sampled_scene(models=models, num_objects=0, seed=0)
    # init, goal sampling
    init_q = sample_valid_q([0.4, 1.0], jkey)
    goal_q = sample_valid_q([0, 0.4], jkey)
    if random_obj is None:
        fixed_obj = environment_obj
    else:
        fixed_obj = environment_obj.concat(random_obj, axis=0)
    
    path_list = []
    scale_list = []
    pqc_list = []
    shakey.link_canonical_oriCORN
    for obj in pybullet_scene.fixed_objects:
        path_list.append(obj.mesh_path)
        scale_list.append(obj.scale)
        pqc_list.append(np.concatenate(obj.base_pose, axis=-1))

    # FIXME - link pqc is wrong
    fixed_obj_idxs = [-1 for _ in range(len(path_list))]
    robot_link_paths = [
        models.asset_path_util.obj_paths[i] for i in shakey.canonical_obj_idx
    ]
    scales = [
        shakey.mesh_scale[i] for i in shakey.canonical_obj_idx
    ]
    # FIXME - pqc is not correct, must consider base pose
    pqc_lists = jnp.concatenate([
        jnp.zeros((len(shakey.canonical_obj_idx),6)),
        jnp.ones((len(shakey.canonical_obj_idx),1)),
    ], axis=1)
    
    link_idxs = [i for i in range(len(robot_link_paths))]

    path_list.extend(robot_link_paths)
    scale_list.extend(scales)
    pqc_list.extend(pqc_lists)

    assert len(pybullet_scene.movable_objects) <= 1, 'more than two movable object is not supported'
    if len(pybullet_scene.movable_objects) == 1:
        obj = pybullet_scene.movable_objects[0]
        path_list.append(obj.mesh_path)
        scale_list.append(obj.scale)
        pqc_list.append(np.concatenate(obj.base_pose, axis=-1))
        link_idxs.append(len(link_idxs))

    trajopt_cls.enroll_meshes(path_list, scale_list, pqc_list, fixed_obj_idxs + link_idxs)

    num_trajectory_points_particle = 3
    num_seed = 1
    init_q = init_q[None].repeat(num_seed, 0)
    x = (goal_q - init_q)[...,None,:]*jnp.linspace(0, 1, num_trajectory_points_particle)[...,None] + init_q[...,None,:]

    x = jnp.concatenate([
        jnp.broadcast_to(base_se2, (x.shape[:-1] + (3,))),
        x,
    ], axis=-1)
    moving_obj_pqs = shakey.FK(x, oriCORN_out=False)
    moving_obj = shakey.link_canonical_oriCORN

    # moving_obj_pqs = tutil.pq_multi(moving_obj_pqs, shakey.link_to_mesh_pq[None, None])
    # shakey.link_to_mesh_pq

    # open3d visualize mesh
    import open3d as o3d
    q_idx = 2
    shakey.set_q_pb(robot_pb_uid, x[0, q_idx, 3:])

    visual_shapes = pb.getVisualShapeData(robot_pb_uid)
    for shape in visual_shapes:
        # Extract link index and mesh center (localVisualFramePosition)
        link_index = shape[1]
        # local mesh center in the link's coordinate frame
        local_mesh_center = shape[5]
        local_mesh_orientation = shape[6]
        print(f"Link {link_index} has local mesh center: {local_mesh_center} / orientation: {local_mesh_orientation}")

        link_state = pb.getLinkState(robot_pb_uid, link_index)
        world_link_pos = link_state[4]      # World position of the link frame
        world_link_orient = link_state[5]   # World orientation of the link frame (quaternion)

        global_mesh_center, global_mesh_orientation = pb.multiplyTransforms(world_link_pos, world_link_orient,
                                                    local_mesh_center, local_mesh_orientation)
        print(f"Link {link_index} center: {world_link_pos} / orientation: {world_link_orient}")
        print(f"Link {link_index} mesh center: {global_mesh_center} / orientation: {global_mesh_orientation}")

    moving_meshes = trajopt_cls.place_meshes(moving_obj_pqs)

    pcds = []
    for mesh_set in [
        moving_meshes[0, q_idx],
        trajopt_cls.fixed_point_sets,
    ]:
        for mesh in mesh_set:
            color = np.random.rand(3)
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(mesh)
            pcd.paint_uniform_color(color)
            pcds.append(pcd)
    o3d.visualization.draw_geometries(pcds)

    """
    Link 0 has local mesh center: (0.0, 0.0, 0.0) / orientation: (0.0, 0.0, 0.0, 1.0)
    Link 1 has local mesh center: (0.0, 0.0, 0.0) / orientation: (0.0, 0.0, 1.0, 6.123233995736766e-17)
    Link 2 has local mesh center: (0.0, 0.0, 0.13585) / orientation: (0.5, -0.4999999999999999, -0.5, 0.5000000000000001)
    Link 3 has local mesh center: (0.0, 0.0, 0.0165) / orientation: (0.5, -0.4999999999999999, -0.5, 0.5000000000000001)
    Link 4 has local mesh center: (0.0, 0.0, -0.093) / orientation: (0.7071067811865475, 0.0, 0.0, 0.7071067811865476)
    Link 5 has local mesh center: (0.0, 0.0, -0.095) / orientation: (0.0, 0.0, 0.0, 1.0)
    Link 6 has local mesh center: (0.0, 0.0, 0.0) / orientation: (0.0, 0.0, 0.0, 1.0)
    
    Link 0 center: (0.0, -0.15000000596046448, 0.0) / orientation: (0.0, 0.0, 0.7071067690849304, 0.7071067690849304)
    Link 1 center: (0.0, -0.15000000596046448, 0.08915899693965912) / orientation: (0.0, 0.0, 0.9114270806312561, -0.4114616811275482)
    Link 2 center: (0.0, -0.15000000596046448, 0.08915895223617554) / orientation: (0.664384126663208, -0.24206148087978363, -0.6584102511405945, -0.2578682601451874)
    Link 3 center: (-0.0067169442772865295, -0.15761710703372955, 0.5140376687049866) / orientation: (-0.43311598896980286, 0.5589370131492615, 0.6945320963859558, -0.13275983929634094)
    Link 4 center: (0.14267592132091522, 0.17682544887065887, 0.6918120384216309) / orientation: (-0.66881263256073, -0.22954228520393372, 0.34981364011764526, 0.6145163774490356)
    Link 5 center: (0.1543075442314148, 0.19001585245132446, 0.7848138809204102) / orientation: (-0.08666413277387619, 0.034582775086164474, -0.33484676480293274, 0.9376412034034729)
    Link 6 center: (0.20549306273460388, 0.2526242434978485, 0.7695324420928955) / orientation: (0.532841145992279, 0.5558058619499207, 0.6361031532287598, -0.05032851919531822)

    Link 0 mesh center: (0.0, -0.15000000596046448, 0.0) / orientation: (0.0, 0.0, 0.7071067690849304, 0.7071067690849304)
    Link 1 mesh center: (0.0, -0.15000000596046448, 0.08915899693965912) / orientation: (0.0, 0.0, 0.4114616811275482, 0.9114270806312561)
    Link 2 mesh center: (-0.10189219564199448, -0.060148999094963074, 0.08915895968675613) / orientation: (0.004916451871395111, -0.010890326462686062, 0.41143229603767395, 0.9113619923591614)
    Link 3 mesh center: (-0.019092515110969543, -0.1467040330171585, 0.5140376687049866) / orientation: (-0.21514035761356354, 0.4765564501285553, 0.3507353961467743, 0.7769125699996948)
    Link 4 mesh center: (0.21242913603782654, 0.11531537771224976, 0.6918120384216309) / orientation: (-0.03839322552084923, 0.08504468947649002, 0.409666508436203, 0.9074506163597107)
    Link 5 mesh center: (0.14263291656970978, 0.17677666246891022, 0.6914681196212769) / orientation: (-0.08666413277387619, 0.034582775086164474, -0.33484676480293274, 0.9376412034034729)
    Link 6 mesh center: (0.20549306273460388, 0.2526242434978485, 0.7695324420928955) / orientation: (0.532841145992279, 0.5558058619499207, 0.6361031532287598, -0.05032851919531822)
    """