
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""CUDA kernels implemented in warp-lang for computing signed distance to meshes."""
# Third Party
import os, sys
import jax
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".50"
import copy
import jax.numpy as jnp
import warp as wp
from typing import Tuple
import jax.debug as jdb
from warp.jax import get_jax_device
# if not hasattr(jnp, 'float8_e3m4'):
#     # Define a placeholder or appropriate fallback.
#     jnp.float8_e3m4 = jnp.float32  # or create a custom type if needed
#     jnp.float8_e4m3 = jnp.float32  # or create a custom type if needed
# from warp.jax_experimental.ffi import register_ffi_callback
from warp.jax_experimental import jax_kernel
# from warp.jax_experimental.ffi import jax_kernel
import yaml as ym
import numpy as np
import open3d as o3d
import einops
import time

BASEDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.transform_util as tutil
from modules.ccd.base import ContinuousCollisionCostBase
from modules.ccd.sphere_decomposition import decompose_mesh_to_spheres
import modules.shakey_module as shakey_module

def debug_print(*args):
    args = jax.block_until_ready(args)
    # print(args)

@wp.func
def mesh_query_point_fn(
    idx: wp.uint64,
    point: wp.vec3,
    max_distance: float,
):
    """Query point on mesh."""
    collide_result = wp.mesh_query_point(idx, point, max_distance)
    return collide_result


# def get_ccd_func(horizon, nspheres, batch_size, max_nmesh,
                #  sweep_steps = 4, enable_speed_metric = 0, write_grad = 1):
def get_ccd_func(max_nmesh, sweep_steps = 4, enable_speed_metric = 0, write_grad = 1):

    @wp.kernel
    def get_swept_closest_pt_batch_env(
        # inputs
        pt: wp.array(dtype=wp.vec4),
        activation_distance: wp.array(dtype=wp.float32),  # eta threshold
        mesh_id_32: wp.array(dtype=wp.vec2ui),
        n_env_mesh: wp.array(dtype=wp.int32),
        horizon_arr: wp.array(dtype=wp.int32),
        nspheres_arr: wp.array(dtype=wp.int32),
        batch_size_arr: wp.array(dtype=wp.int32),
        # outputs
        distance: wp.array(dtype=wp.float32),  # this stores the output cost
        closest_pt: wp.array(dtype=wp.vec4),  # this stores the gradient
        sparsity_idx: wp.array(dtype=wp.uint8),
        signed_distance: wp.array(dtype=wp.float32),
    ):
        """Compute signed distance between a trajectory of a sphere and world meshes."""
        # we launch nspheres kernels
        # compute gradient here and return
        # distance is negative outside and positive inside
        # settings
        # sweep_steps = int(4)
        # enable_speed_metric = int(0)
        # write_grad = int(1)
        # use_batch_env = int(1)

        tid = int(0)
        tid = wp.tid()

        signed_distance[tid] = -wp.inf

        b_idx = int(0)
        h_idx = int(0)
        sph_idx = int(0)
        # read horizon

        mesh = wp.uint64(0)

        
        horizon = horizon_arr[0]
        nspheres = nspheres_arr[0]
        batch_size = batch_size_arr[0]
        # max_nmesh = int(1)

        b_idx = tid / (horizon * nspheres)

        h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
        sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
        if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
            return
        uint_zero = wp.uint8(0)
        uint_one = wp.uint8(1)
        env_idx = int(0)
        n_mesh = int(0)
        # $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
        # read sphere
        sphere_0_distance = float(0.0)
        sphere_2_distance = float(0.0)
        

        sphere_0 = wp.vec3(0.0)
        sphere_2 = wp.vec3(0.0)
        sphere_int = wp.vec3(0.0)
        sphere_temp = wp.vec3(0.0)
        grad_vec = wp.vec3(0.0)
        eta = float(0.0)
        dt = float(0.0)
        k0 = float(0.0)
        sign = float(0.0)
        dist = float(0.0)
        dist_metric = float(0.0)
        euclidean_distance = float(0.0)
        cl_pt = wp.vec3(0.0)
        local_pt = wp.vec3(0.0)
        in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
        in_rad = in_sphere[3]
        if in_rad < 0.0:
            distance[tid] = 0.0
            sparsity_idx[tid] == uint_one
            if write_grad == 1 and sparsity_idx[tid] == uint_one:
                sparsity_idx[tid] = uint_zero
                closest_pt[tid][0] = 0.0
                closest_pt[tid][1] = 0.0
                closest_pt[tid][2] = 0.0
                # closest_pt[tid * 4] = 0.0
                # closest_pt[tid * 4 + 1] = 0.0
                # closest_pt[tid * 4 + 2] = 0.0

            return
        dt = 0.01
        eta = activation_distance[0]
        in_rad += eta
        max_dist_buffer = float(1.0)
        # max_dist_buffer = max_dist[0]
        if (in_rad) > max_dist_buffer:
            max_dist_buffer += in_rad

        in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
        # read in sphere 0:
        if h_idx > 0:
            in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
            sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
            sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
        if h_idx < horizon - 1:
            in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
            sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
            sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0

        # read in sphere 2:
        closest_distance = float(0.0)
        closest_point = wp.vec3(0.0)
        i = int(0)
        dis_length = float(0.0)
        jump_distance = float(0.0)
        mid_distance = float(0.0)
        # if use_batch_env:
        #     env_idx = env_query_idx[b_idx]
        i = max_nmesh * env_idx
        n_mesh = i + n_env_mesh[env_idx]
        obj_position = wp.vec3()

        while i < n_mesh:
            # if mesh_enable[i] == uint_one:
            if True:
                # transform point to mesh frame:
                # mesh_pt = T_inverse @ w_pt
                # obj_position[0] = mesh_pose[i * 8 + 0]
                # obj_position[1] = mesh_pose[i * 8 + 1]
                # obj_position[2] = mesh_pose[i * 8 + 2]
                # obj_quat = wp.quaternion(
                #     mesh_pose[i * 8 + 4],
                #     mesh_pose[i * 8 + 5],
                #     mesh_pose[i * 8 + 6],
                #     mesh_pose[i * 8 + 3],
                # )
                obj_position[0] = 0.0
                obj_position[1] = 0.0
                obj_position[2] = 0.0
                obj_quat = wp.quaternion(0.0, 0.0, 0.0, 1.0)

                # mesh = (wp.uint64(mesh_id_32[b_idx][0]) << wp.uint64(32)) | wp.uint64(mesh_id_32[b_idx][1])
                mesh = (wp.uint64(mesh_id_32[0][0]) << wp.uint64(32)) | wp.uint64(mesh_id_32[0][1])

                obj_w_pose = wp.transform(obj_position, obj_quat)
                obj_w_pose_t = wp.transform_inverse(obj_w_pose)
                local_pt = wp.transform_point(obj_w_pose, in_pt)
                collide_result = mesh_query_point_fn(mesh, local_pt, max_dist_buffer)
                if collide_result.result:
                    sign = collide_result.sign
                    cl_pt = wp.mesh_eval_position(
                        mesh, collide_result.face, collide_result.u, collide_result.v
                    )
                    delta = cl_pt - local_pt
                    dis_length = wp.length(delta)
                    dist = (-1.0 * dis_length * sign) + in_rad
                    if signed_distance[tid] < (-1.0 * dis_length * sign) + in_sphere[3]:
                        signed_distance[tid] = (-1.0 * dis_length * sign) + in_sphere[3]
                    if dist > 0:
                        if dist == in_rad:
                            cl_pt = sign * (delta) / (dist)
                        else:
                            cl_pt = sign * (delta) / dis_length
                        euclidean_distance = dist
                        if dist > eta:
                            dist_metric = dist - 0.5 * eta
                        elif dist <= eta:
                            dist_metric = (0.5 / eta) * (dist) * dist
                            cl_pt = (1.0 / eta) * dist * cl_pt
                        grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)

                        closest_distance += dist_metric
                        closest_point += grad_vec
                    else:
                        dist = -1.0 * dist
                        euclidean_distance = dist
                else:
                    dist = max_dist_buffer
                    euclidean_distance = dist
                dist = max(euclidean_distance - in_rad, in_rad)

                mid_distance = euclidean_distance
                # transform sphere -1
                if h_idx > 0 and mid_distance < sphere_0_distance:
                    jump_distance = mid_distance
                    j = int(0)
                    sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
                    while j < sweep_steps:
                        k0 = (
                            1.0 - 0.5 * jump_distance / sphere_0_distance
                        )  # dist could be greater than sphere_0_distance here?
                        sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
                        collide_result = mesh_query_point_fn(mesh, sphere_int, max_dist_buffer)
                        if collide_result.result:
                            sign = collide_result.sign
                            cl_pt = wp.mesh_eval_position(
                                mesh, collide_result.face, collide_result.u, collide_result.v
                            )
                            delta = cl_pt - sphere_int
                            dis_length = wp.length(delta)
                            dist = (-1.0 * dis_length * sign) + in_rad
                            if signed_distance[tid] < (-1.0 * dis_length * sign) + in_sphere[3]:
                                signed_distance[tid] = (-1.0 * dis_length * sign) + in_sphere[3]
                            if dist > 0:
                                if dist == in_rad:
                                    cl_pt = sign * (delta) / (dist)
                                else:
                                    cl_pt = sign * (delta) / dis_length
                                euclidean_distance = dist
                                if dist > eta:
                                    dist_metric = dist - 0.5 * eta
                                elif dist <= eta:
                                    dist_metric = (0.5 / eta) * (dist) * dist
                                    cl_pt = (1.0 / eta) * dist * cl_pt

                                closest_distance += dist_metric
                                grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)

                                closest_point += grad_vec
                                dist = max(euclidean_distance - in_rad, in_rad)
                                jump_distance += euclidean_distance
                            else:
                                dist = max(-dist - in_rad, in_rad)
                                euclidean_distance = dist
                                jump_distance += euclidean_distance
                        else:
                            jump_distance += max_dist_buffer
                        j += 1
                        if jump_distance >= sphere_0_distance:
                            j = int(sweep_steps)
                # transform sphere -1
                if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
                    jump_distance = mid_distance
                    j = int(0)
                    sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
                    while j < sweep_steps:
                        k0 = (
                            1.0 - 0.5 * jump_distance / sphere_2_distance
                        )  # dist could be greater than sphere_0_distance here?
                        sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
                        collide_result = mesh_query_point_fn(mesh, sphere_int, max_dist_buffer)
                        if collide_result.result:
                            sign = collide_result.sign
                            cl_pt = wp.mesh_eval_position(
                                mesh, collide_result.face, collide_result.u, collide_result.v
                            )
                            delta = cl_pt - sphere_int
                            dis_length = wp.length(delta)
                            dist = (-1.0 * dis_length * sign) + in_rad
                            if dist > 0:
                                euclidean_distance = dist
                                if dist == in_rad:
                                    cl_pt = sign * (delta) / (dist)
                                else:
                                    cl_pt = sign * (delta) / dis_length
                                # cl_pt = sign * (delta) / dis_length
                                if dist > eta:
                                    dist_metric = dist - 0.5 * eta
                                elif dist <= eta:
                                    dist_metric = (0.5 / eta) * (dist) * dist
                                    cl_pt = (1.0 / eta) * dist * cl_pt
                                closest_distance += dist_metric
                                grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)

                                closest_point += grad_vec
                                dist = max(euclidean_distance - in_rad, in_rad)
                                jump_distance += dist

                            else:
                                dist = max(-dist - in_rad, in_rad)
                                jump_distance += dist
                        else:
                            jump_distance += max_dist_buffer

                        j += 1
                        if jump_distance >= sphere_2_distance:
                            j = int(sweep_steps)
            i += 1

        # return
        if closest_distance <= 0.0:
            if sparsity_idx[tid] == uint_zero:
                return
            sparsity_idx[tid] = uint_zero
            distance[tid] = 0.0
            if write_grad == 1:
                closest_pt[tid][0] = 0.0
                closest_pt[tid][1] = 0.0
                closest_pt[tid][2] = 0.0
                # closest_pt[tid * 4 + 0] = 0.0
                # closest_pt[tid * 4 + 1] = 0.0
                # closest_pt[tid * 4 + 2] = 0.0

            return
        if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
            # calculate sphere velocity and acceleration:
            norm_vel_vec = wp.vec3(0.0)
            sph_acc_vec = wp.vec3(0.0)
            sph_vel = wp.float(0.0)

            # use central difference
            norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
            sph_vel = wp.length(norm_vel_vec)
            if sph_vel > 1e-3:
                sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)

                norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)

                curvature_vec = sph_acc_vec / (sph_vel * sph_vel)

                orth_proj = wp.mat33(0.0)
                for i in range(3):
                    for j in range(3):
                        orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]

                orth_proj[0, 0] = orth_proj[0, 0] + 1.0
                orth_proj[1, 1] = orth_proj[1, 1] + 1.0
                orth_proj[2, 2] = orth_proj[2, 2] + 1.0

                orth_curv = wp.vec3(
                    0.0, 0.0, 0.0
                )  # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
                orth_pt = wp.vec3(0.0, 0.0, 0.0)  # orth_proj @ closest_point
                for i in range(3):
                    orth_pt[i] = (
                        orth_proj[i, 0] * closest_point[0]
                        + orth_proj[i, 1] * closest_point[1]
                        + orth_proj[i, 2] * closest_point[2]
                    )

                    orth_curv[i] = closest_distance * (
                        orth_proj[i, 0] * curvature_vec[0]
                        + orth_proj[i, 1] * curvature_vec[1]
                        + orth_proj[i, 2] * curvature_vec[2]
                    )

                closest_point = sph_vel * (orth_pt - orth_curv)

                closest_distance = sph_vel * closest_distance

        # distance[tid] = weight[0] * closest_distance
        distance[tid] = closest_distance
        sparsity_idx[tid] = uint_one
        if write_grad == 1:
            # compute gradient:
            # closest_distance = weight[0]
            closest_distance = 1.0
            closest_pt[tid][0] = closest_distance * closest_point[0]
            closest_pt[tid][1] = closest_distance * closest_point[1]
            closest_pt[tid][2] = closest_distance * closest_point[2]
            # closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
            # closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
            # closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]

    return get_swept_closest_pt_batch_env


def warp_func(inputs, outputs, attrs, ctx):

    device = wp.device_from_jax(get_jax_device())
    stream = wp.Stream(device, cuda_stream=ctx.stream)

    with wp.ScopedStream(stream):
        # launch with arrays of scalars
        wp.launch(get_ccd_func(horizon=int(attrs['horizon']), nspheres=int(attrs['nspheres']), batch_size=int(attrs['batch_size']), max_nmesh=int(1)), dim=inputs[0].shape[0], inputs=inputs, outputs=outputs)

# register_ffi_callback("warp_func", warp_func) # for new version


def uint64_to_uint32_pair(value):
    # Extract the lower 32 bits
    low = jnp.uint32(value & 0xFFFFFFFF)
    # Extract the upper 32 bits
    high = jnp.uint32(value >> 32)
    # return high, low
    return jnp.stack([high, low], axis=-1).astype(jnp.uint32)

class CuroboCCD(ContinuousCollisionCostBase):
    def __init__(self, col_coef, activation_distance, rot_configs):

        self.activation_distance = activation_distance
        self.ccd_func = jax_kernel(get_ccd_func(max_nmesh=1))

        self.mesh_wp = None
        self.rot_configs = rot_configs
        self.col_coef = col_coef
        self.col_threshold = -0.005

        # yml_name = 'assets/ur5/ur5.yml'

        # # Load YAML files
        # with open(yml_name, 'r') as f:
        #     ur5_cfg = ym.load(f, Loader=ym.FullLoader)

        # self.sphere_list = []
        # for k in ur5_cfg['robot_cfg']['kinematics']['collision_spheres']:
        #     v = ur5_cfg['robot_cfg']['kinematics']['collision_spheres'][k]
        #     sphere_list_link = []
        #     for cr in v:
        #         sphere_list_link.append(np.concatenate([np.array(cr['center']), np.array(cr['radius'])[None]], axis=-1))
        #     self.sphere_list.append(np.stack(sphere_list_link, 0))
        
    def place_spheres(self, moving_obj_pqs, sphere_list=None):
        '''
        link_pqc (... NLINK, PQC)
        output: (... NSPHERE, 4)
        '''
        if sphere_list is None:
            sphere_list = self.sphere_list
        
        assert len(sphere_list) == moving_obj_pqs.shape[-2]

        entire_sphere = []
        for i in range(len(sphere_list)):
            sphere_tf = tutil.pq_action(moving_obj_pqs[...,i:i+1,:], sphere_list[i][...,:3])
            sphere_tf = jnp.concatenate([sphere_tf, jnp.zeros_like(sphere_tf[...,:1]) + sphere_list[i][...,-1:]], axis=-1)
            entire_sphere.append(sphere_tf)
        entire_sphere = jnp.concatenate(entire_sphere, axis=-2)

        return entire_sphere

    def enroll_robot(
            self,
            shakey: shakey_module.Shakey,
            voxel_intervals=[0.1],
            num_surface_samples=[10],
            visualize=False,
            grasping_objects=[],
        ):
        models = shakey.models
        canonical_moving_o3d_meshes = [] # [NOB]
        robot_link_paths = [
            models.asset_path_util.obj_paths[i] for i in shakey.canonical_obj_idx
        ]
        scales = shakey.mesh_scale
        assert len(robot_link_paths) == len(scales)
        canonical_moving_o3d_meshes = [
            o3d.io.read_triangle_mesh(link_path) for link_path in robot_link_paths
        ]
        for mesh, scale in zip(canonical_moving_o3d_meshes, scales):
            mesh.scale(scale, center=(0, 0, 0))

        for goidx, obj in enumerate(grasping_objects):
            mesh = o3d.io.read_triangle_mesh(obj.mesh_path)
            scale = obj.scale[0] if isinstance(obj.scale, Tuple) else obj.scale
            mesh.scale(scale, center=(0, 0, 0))
            pqc = np.concatenate(obj.base_pose, axis=-1)
            H_mat = tutil.pq2H(pqc)
            mesh.transform(H_mat)
            if visualize:
                decompose_mesh_to_spheres(mesh, voxel_intervals[len(canonical_moving_o3d_meshes)], num_surface_samples[len(canonical_moving_o3d_meshes)], visualize=True)
            canonical_moving_o3d_meshes.append(mesh)
        self.sphere_list = [
            decompose_mesh_to_spheres(mesh, voxel_interval, num_surface_sample)
            for i, (mesh, voxel_interval, num_surface_sample) in enumerate(zip(canonical_moving_o3d_meshes, voxel_intervals, num_surface_samples))
        ]

        if visualize:
            jkey = jax.random.PRNGKey(1)
            for j in range(2):
                _, jkey = jax.random.split(jkey)
                # test_q = jnp.zeros((shakey.num_act_joints,))
                test_q = shakey.random_q(jkey, ())
                shakey_mesh_list = shakey.show_in_o3d_gt(test_q, visualize=False)
                sphere_list = self.visualize_spheres(test_q, shakey, color=(0,0.5,0), visualize=False)

                test_q_se2 = np.concat([np.array([0, 1, 0]), test_q], axis=0)
                sphere_list2 = self.visualize_spheres(test_q_se2, shakey, color=(0,0.5,0), visualize=False)
                o3d.visualization.draw_geometries(shakey_mesh_list + sphere_list + sphere_list2)

        num_spheres = np.sum([len(sphere) for sphere in self.sphere_list])
        print(f"Enroll robot done, num spheres: {[len(sphere) for sphere in self.sphere_list]}")

    def visualize_spheres_and_meshes(self, spheres_tf, grads, vertices, faces, signed_distance=None, visualize=False):

        # import pickle
        # with open('sphere_log.pkl', 'wb') as f:
        #     pickle.dump((spheres_tf, grads, vertices, faces, signed_distance), f)

        geometries = []

        traj_lines = o3d.geometry.LineSet()
        traj_lines.points = o3d.utility.Vector3dVector(spheres_tf[..., :3].reshape(-1, 3))
        
        ntime = spheres_tf.shape[0]
        nsphere = spheres_tf.shape[1]
        line_indices = []
        for li in range(nsphere):
            for ti in range(ntime-1):
                line_idx = [li+nsphere*(ti), li+nsphere*(ti+1)]
                # if col_cost is not None:
                #     if col_cost[ti, li] >= self.activation_distance*0.5 or col_cost[ti+1, li] >= self.activation_distance*0.5:
                        # line_indices.append(line_idx)
                line_indices.append(line_idx)
        
        traj_lines.lines = o3d.utility.Vector2iVector(
            np.array(line_indices).astype(np.int32)
        )
        geometries.append(traj_lines)
        entire_spheres = []
        hit_spheres = []

        for t in range(spheres_tf.shape[0]):
            color = (1 - t / spheres_tf.shape[0], 0, t / spheres_tf.shape[0])
            # use actual spheres
            for j in range(spheres_tf.shape[1]):
                sphere = o3d.geometry.TriangleMesh.create_sphere(radius=np.maximum(spheres_tf[t, j, -1], 0.003), resolution=10)
                sphere.compute_vertex_normals()
                sphere.translate(spheres_tf[t, j, :3])
                sphere.paint_uniform_color(color)
                entire_spheres.append(sphere)
                if signed_distance is not None and signed_distance[t, j] >= self.col_threshold:
                    hit_spheres.append(sphere)
                # geometries.append(sphere)
            geometries += hit_spheres
            # o3d_points = o3d.geometry.PointCloud()
            # o3d_points.points = o3d.utility.Vector3dVector(spheres_tf[t, :, :3])
            # o3d_points.paint_uniform_color(color)
            # geometries.append(o3d_points)
            lineset = o3d.geometry.LineSet()

            lineset.points = o3d.utility.Vector3dVector(
                np.concatenate([
                    spheres_tf[t, :, :3],
                    spheres_tf[t, :, :3] - grads[t, :, :3] * 0.05
                ], axis=0)
            )
            lineset.lines = o3d.utility.Vector2iVector(
                np.array([[i, i + spheres_tf.shape[1]] for i in range(spheres_tf.shape[1])])
            )
            lineset.paint_uniform_color(color)
            geometries.append(lineset)

        
        mesh_o3d = o3d.geometry.TriangleMesh()
        mesh_o3d.vertices = o3d.utility.Vector3dVector(np.array(vertices))
        mesh_o3d.triangles = o3d.utility.Vector3iVector(np.array(faces).reshape(-1,3))
        mesh_o3d.compute_vertex_normals()
        mesh_o3d.paint_uniform_color((0.5, 0.5, 0.5))
        geometries.append(mesh_o3d)

        if visualize:
            o3d.visualization.draw_geometries(geometries)
        else:
            return geometries


    def visualize_spheres(self, q, shakey: shakey_module.Shakey, color=None, visualize=False):
        # if spheres_tf is None:
        moving_obj_pqs = shakey.FK(q, oriCORN_out=False)
        spheres_tf = self.place_spheres(moving_obj_pqs, self.sphere_list[:moving_obj_pqs.shape[-2]]) # (NSPHERE, 4)

        sphere_o3d_list = []
        for sp_idx in range(spheres_tf.shape[0]):
            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=np.maximum(spheres_tf[sp_idx,-1], 0.005), resolution=10)
            sphere.compute_vertex_normals()
            sphere.translate(spheres_tf[sp_idx,:3])
            if color is not None:
                sphere.paint_uniform_color(color)
            sphere_o3d_list.append(sphere)

        if visualize:
            o3d.visualization.draw_geometries(sphere_o3d_list)
        else:
            return sphere_o3d_list


    def enroll_meshes(self, mesh_names, scales, pqcs=None, vtx_fcs=None):

        enroll_start_time = time.time()
        if vtx_fcs is None:
            if len(mesh_names) == 0:
                print("Warn: meshes to enroll")
                return

            vertices = []
            faces = []
            vtx_idx_offset = 0
            for mesh_name, scale, pqc in zip(mesh_names, scales, pqcs):

                mesh_o3d = o3d.io.read_triangle_mesh(mesh_name)
                if isinstance(scale, Tuple):
                    scale = scale[0]
                mesh_o3d.scale(scale, center=(0, 0, 0))
                if pqc is not None:
                    H_mat = tutil.pq2H(pqc)
                else:
                    H_mat = jnp.eye(4)

                mesh_o3d.transform(H_mat)
                mesh_o3d.compute_vertex_normals()

                vertices.append(np.array(mesh_o3d.vertices))
                faces.append(np.array(mesh_o3d.triangles).reshape(-1) + vtx_idx_offset)
                vtx_idx_offset += len(vertices[-1])
            vertices = np.concatenate(vertices, axis=0)
            faces = np.concatenate(faces, axis=0)
        else:
            vertices, faces = vtx_fcs

        if self.mesh_wp is None:
            self.vertices = wp.array(vertices, dtype=wp.vec3)
            self.faces = wp.array(faces, dtype=int)
            self.mesh_wp = wp.Mesh(
                points=self.vertices,
                velocities=None,
                indices=self.faces,
            )
        else:
            self.vertices = wp.array(vertices, dtype=wp.vec3)
            self.faces = wp.array(faces, dtype=int)
            
            del self.mesh_wp
            self.mesh_wp = wp.Mesh(
                points=self.vertices,
                velocities=None,
                indices=self.faces,
            )
            # self.vertices.assign(vertices)
            # self.faces.assign(faces)
            # self.mesh_wp.refit()

        self.mesh_ids = uint64_to_uint32_pair(self.mesh_wp.id)[None]

        enroll_end_time = time.time()
        print(f"Enroll time: {enroll_end_time - enroll_start_time}, num vertices: {len(vertices)}, num faces: {len(faces)}")
        return self.mesh_ids

    def enroll_meshes_batch(self, mesh_names, scales, pqcs=None, simplify=False, voxel_size_factor=16):

        enroll_start_time = time.time()
        self.batch_mesh_ids = []
        self.batch_mesh_wps = []
        self.batch_vertices = []
        self.batch_faces = []

        for mesh_name, scale, pqc in zip(mesh_names, scales, pqcs):
            mesh_o3d = o3d.io.read_triangle_mesh(mesh_name)
            if isinstance(scale, Tuple):
                scale = scale[0]
            mesh_o3d.scale(scale, center=(0, 0, 0))
            if pqc is not None:
                H_mat = tutil.pq2H(pqc)
            else:
                H_mat = jnp.eye(4)

            mesh_o3d.transform(H_mat)
            mesh_o3d.compute_vertex_normals()
            if simplify:
                voxel_size = max(mesh_o3d.get_max_bound() - mesh_o3d.get_min_bound()) / voxel_size_factor
                mesh_smp = mesh_o3d.simplify_vertex_clustering(
                    voxel_size=voxel_size,
                    contraction=o3d.geometry.SimplificationContraction.Average)
                mesh_smp.compute_vertex_normals()
                print(len(np.array(mesh_o3d.vertices)), "->", len(np.array(mesh_smp.vertices)))
                mesh_o3d = mesh_smp

            vertices = wp.array(np.array(mesh_o3d.vertices), dtype=wp.vec3)
            faces = wp.array(np.array(mesh_o3d.triangles).reshape(-1), dtype=int)
            mesh_wp = wp.Mesh(
                points=vertices,
                velocities=None,
                indices=faces,
            )
            self.batch_mesh_ids.append(uint64_to_uint32_pair(mesh_wp.id))
            self.batch_mesh_wps.append(mesh_wp)
            self.batch_vertices.append(vertices)
            self.batch_faces.append(faces)

        self.batch_mesh_ids = jnp.stack(self.batch_mesh_ids, axis=0)
        enroll_end_time = time.time()
        print(f"Enroll time: {enroll_end_time - enroll_start_time}")

        return self.batch_mesh_ids

    def ccd_curobo(self, fixed_obj_ids, spheres_tf, visualize=False)->Tuple[jnp.ndarray, jnp.ndarray]:
        '''
        fixed_obj_ids: list of object ids
        spheres_tf: (NSEG, NSPHERE, 4)
        '''
        fixed_obj_ids = jnp.array(fixed_obj_ids)
        spheres_tf = jnp.array(spheres_tf)
        fixed_obj_ids = jax.lax.stop_gradient(fixed_obj_ids)
        spheres_tf = jax.lax.stop_gradient(spheres_tf)


        # spheres_tf = self.place_spheres(moving_obj_pqs) # (NSEG, NSPHERE, 4)
        spheres_tf = spheres_tf.astype(jnp.float32)

        original_outer_shape = spheres_tf.shape[:-3]
        # original_shape = spheres_tf.shape
        spheres_tf = spheres_tf.reshape(-1, *spheres_tf.shape[-3:])
        batch_size, horizon, nspheres, _ = spheres_tf.shape

        pt = spheres_tf.reshape(-1, 4)
        activation_distance = jnp.array([self.activation_distance], dtype=jnp.float32)  # eta threshold

        # Mesh pose: use an identity transform.
        # The kernel expects 8 values: translation (3), a scalar, then quaternion (3) where the quaternion
        # is built as: wp.quaternion(mesh_pose[4], mesh_pose[5], mesh_pose[6], mesh_pose[3])
        n_env_mesh = jnp.array([fixed_obj_ids.shape[-2]], dtype=jnp.int32)

        # -----------------------------------------------------------------------------
        # Build positional argument list (in the order expected by get_swept_closest_pt_batch_env)
        # -----------------------------------------------------------------------------
        args = [
            pt.astype(jnp.float32),                 # wp.array(dtype=wp.vec4)
            activation_distance.astype(jnp.float32),# wp.array(dtype=wp.float32)
            fixed_obj_ids.astype(jnp.uint32),         # wp.array(dtype=wp.uint64)
            n_env_mesh.astype(jnp.int32),         # wp.array(dtype=wp.int32)
            jnp.array([horizon]).astype(jnp.int32), # wp.array(dtype=wp.int32)
            jnp.array([nspheres]).astype(jnp.int32), # wp.array(dtype=wp.int32)
            jnp.array([batch_size]).astype(jnp.int32), # wp.array(dtype=wp.int32)
        ]

        # num_threads = batch_size * horizon * nspheres
        # ccd_func = jax_kernel(get_ccd_func(horizon=horizon, nspheres=nspheres, batch_size=batch_size, max_nmesh=fixed_obj_ids.shape[-2]))
        if visualize: 
            outputs = jax.jit(self.ccd_func)(*args)
        else:
            outputs = self.ccd_func(*args)

        # ccd_func = jax_kernel(get_ccd_func(horizon=horizon, nspheres=nspheres, batch_size=batch_size, max_nmesh=fixed_obj_ids.shape[-2]), num_outputs=3)
        # outputs = ccd_func(*args, vmap_method='sequential')
        
        # jdb.callback(debug_print, outputs)
        outputs = jax.block_until_ready(outputs)
        outputs = jax.lax.stop_gradient(outputs)

        col_cost = outputs[0].reshape(batch_size, horizon, nspheres)
        gradient_dir = outputs[1].reshape(batch_size, horizon, nspheres, 4)
        sparse_idx = outputs[2].reshape(batch_size, horizon, nspheres)
        signed_dist = outputs[3].reshape(batch_size, horizon, nspheres)

        col_cost = jnp.where(sparse_idx==1, col_cost, 0)
        signed_dist = jnp.where(sparse_idx==1, signed_dist, -10)
        gradient_dir = jnp.where(sparse_idx[...,None]==1, gradient_dir, 0)
        col_cost = col_cost.reshape(original_outer_shape + col_cost.shape[-2:])
        signed_dist = signed_dist.reshape(original_outer_shape + signed_dist.shape[-2:])
        gradient_dir = gradient_dir.reshape(original_outer_shape + gradient_dir.shape[-3:])

        if visualize:
            self.visualize_spheres_and_meshes(
                spheres_tf.reshape(-1, spheres_tf.shape[-2], spheres_tf.shape[-1]),
                gradient_dir.reshape(-1, spheres_tf.shape[-2], spheres_tf.shape[-1]),
                self.vertices.numpy(),
                self.faces.numpy(),
                signed_dist.reshape(-1, col_cost.shape[-1]),
                True,
            )

        return col_cost, gradient_dir, signed_dist


    def __call__(self, moving_obj, moving_obj_pqs, fixed_obj, moving_mesh_idxs, fixed_mesh_idx, moving_spheres=None, mesh_ids=None, visualize=False):
        sphere_list = copy.deepcopy(self.sphere_list)
        if moving_spheres is not None:
            sphere_list += moving_spheres
            # assert len(sphere_list) == moving_obj.shape[-1]
        if fixed_mesh_idx is None:
            if mesh_ids is None:
                mesh_ids = self.mesh_ids
        else:
            if mesh_ids is not None:
                mesh_ids = mesh_ids[fixed_mesh_idx][None]
            else:
                mesh_ids = self.batch_mesh_ids[fixed_mesh_idx][None]
            sphere_list = [sphere_list[moving_mesh_idx] for moving_mesh_idx in moving_mesh_idxs]

        @jax.custom_gradient
        def curobo_collision_loss(spheres_tf):
            collision_loss_batch, grad_val, signed_dist = self.ccd_curobo(mesh_ids, spheres_tf, visualize=visualize)
            # grad_val = jnp.where(collision_loss_batch[...,None]>0.5*self.activation_distance, grad_val, 0.0)
            grad_val_original = grad_val
            # collision_loss, grad_val, sparse_idx = self.curobo_cls.ccd_curobo_ffi(self.mesh_ids, spheres_tf)
            collision_loss = jnp.sum(collision_loss_batch, axis=(-1,-2))
            # jdb.callback(
            #     self.visualize_spheres_and_meshes,
            #     jnp.concatenate([
            #         spheres_tf[0, -4:,:],
            #         # spheres_tf[0, -4:,-129:-50]
            #     ], axis=-2),
            #     jnp.concatenate([
            #         grad_val[0, -4:,:],
            #         # grad_val[0, -4:,200:550],
            #         # grad_val[0, -4:,-129:-50]
            #     ], axis=-2),
            #     jnp.array(self.batch_vertices[fixed_mesh_idx.item()]) if fixed_mesh_idx is not None else jnp.array(self.vertices),
            #     jnp.array(self.batch_faces[fixed_mesh_idx.item()]) if fixed_mesh_idx is not None else jnp.array(self.faces),
            #     True,
            # )
            grad_val *= 500.0
            collision_loss *= 500.0
            def grad_fn(upstream_grad):
                return (upstream_grad[0][...,None,None,None]*grad_val,)
            return (collision_loss, collision_loss_batch, grad_val_original, signed_dist), grad_fn

        spheres_tf = self.place_spheres(moving_obj_pqs, sphere_list) # (NSEG, NSPHERE, 4)
        # radius = moving_obj.mean_fps_dist[...,None].repeat(moving_obj.nfps, axis=-1)[...,None]
        # objB_fps_tf = tutil.pq_action(moving_obj_pqs[...,None,:], moving_obj.fps_tf)
        # # objB_query = moving_obj.apply_pq_z(moving_obj_pqs[...,:3], moving_obj_pqs[...,3:], self.rot_configs)
        # radius = jnp.broadcast_to(radius, objB_fps_tf[...,:1].shape)
        # radius = jnp.ones_like(radius)*0.01
        # spheres_tf = einops.rearrange(jnp.concat([objB_fps_tf, radius], axis=-1), '... i j k -> ... (i j) k')

        collision_loss, collision_loss_pair, collision_grad, signed_dist = curobo_collision_loss(spheres_tf)
        # collision_binary = jnp.where(collision_loss_pair > 0.5*self.activation_distance, 1, 0)
        # collision_binary = signed_dist >= -0.005
        collision_binary = signed_dist >= self.col_threshold
        collision_logits = collision_loss
        collision_loss = self.col_coef*collision_loss + 100*jnp.sum(collision_binary, axis=(-1,-2))
        collision_binary = jnp.any(collision_binary, axis=(-1,-2))
        # collision_loss = self.col_coef*collision_loss
        # jdb.callback(
        #     print,
        #     moving_mesh_idxs,
        #     fixed_mesh_idx,
        #     collision_binary[0, -2:],
        # )
        # jdb.callback(
        #     self.visualize_spheres_and_meshes,
        #     spheres_tf[0, -2:].reshape(-1, 4),
        #     fixed_mesh_idx.item(),
        #     True,
        # )

        if visualize:
            vis_args = (
                spheres_tf.reshape(-1, spheres_tf.shape[-2], spheres_tf.shape[-1]),
                collision_grad.reshape(-1, spheres_tf.shape[-2], spheres_tf.shape[-1]),
                self.batch_vertices[fixed_mesh_idx.item()].numpy() if fixed_mesh_idx is not None else self.vertices.numpy(),
                self.batch_faces[fixed_mesh_idx.item()].numpy() if fixed_mesh_idx is not None else self.faces.numpy(),
                signed_dist.reshape(-1, signed_dist.shape[-1]),
            )

            import pickle
            with open('sphere_log.pkl', 'wb') as f:
                pickle.dump(vis_args, f)

            self.visualize_spheres_and_meshes(*vis_args, True)

        return collision_loss, {
            "collision_logits": collision_logits,
            "collision_binary": collision_binary,
        }

    def call_batch(self, moving_obj_pqs, mesh_ids):
        @jax.custom_gradient
        def curobo_collision_loss(spheres_tf):
            collision_loss_batch, grad_val, sparse_idx = self.ccd_curobo(mesh_ids, spheres_tf)
            grad_val = jnp.where(collision_loss_batch[...,None]>0.5*self.activation_distance, grad_val, 0.0)
            # collision_loss, grad_val, sparse_idx = self.curobo_cls.ccd_curobo_ffi(self.mesh_ids, spheres_tf)
            collision_loss = jnp.sum(collision_loss_batch, axis=(-1,-2))
            grad_val *= 500.0
            collision_loss *= 500.0
            def grad_fn(upstream_grad):
                return (upstream_grad[0][...,None,None,None]*grad_val,)
            return (collision_loss, collision_loss_batch), grad_fn

        spheres_tf = self.place_spheres(moving_obj_pqs) # (NSEG, NSPHERE, 4)

        # radius = moving_obj.mean_fps_dist[...,None].repeat(moving_obj.nfps, axis=-1)[...,None]
        # objB_query = moving_obj.apply_pq_z(moving_obj_pqs[...,:3], moving_obj_pqs[...,3:], self.rot_configs)
        # radius = jnp.broadcast_to(radius, objB_query.rel_fps[...,:1].shape)
        # radius = jnp.ones_like(radius)*0.01
        # spheres_tf_deprecated = einops.rearrange(jnp.concat([objB_query.fps_tf, radius], axis=-1), '... i j k -> ... (i j) k')

        collision_loss, collision_loss_pair = curobo_collision_loss(spheres_tf)
        collision_binary = jnp.where(collision_loss_pair > 0.5*self.activation_distance, 1, 0)
        collision_logits = collision_loss
        collision_loss = self.col_coef*collision_loss + 500*jnp.max(collision_binary, axis=(-1,-2))
        return collision_loss, {
            "collision_logits": collision_logits,
            "collision_binary": collision_binary,
        }

if __name__ == '__main__':
    import time
    from tqdm import tqdm
    import pickle

    import util.model_util as mutil
    import pybullet as pb

    with open('sphere_log.pkl', 'rb') as f:
        sphere_traj, grads, vertices, faces, signed_distanc = pickle.load(f)

    curobo_ccd = CuroboCCD(col_coef=5.0, activation_distance=0.0, rot_configs=None)
    curobo_ccd.enroll_meshes(None, None, None, (vertices, faces))

    # dishholder_obj_filename = '/home/dongwon/research/object_set/GoogleScannedObjects/cvx/coacd/Poppin_File_Sorter_White.obj'
    # curobo_ccd.enroll_meshes([dishholder_obj_filename], [1.0], [np.array([0,0,0,0,0,0,1.])])

    # xgrid, ygrid = np.meshgrid(np.linspace(-0.1, 0.1, 70), np.linspace(-0.1, 0.1, 70))
    # zgrid = -0.16*np.ones_like(xgrid)
    # sphere_traj = np.stack([zgrid, xgrid, ygrid], axis=-1).reshape(-1, 3)
    # sphere_traj = np.stack([sphere_traj, np.array([0.07,0,-0.030])+sphere_traj], axis=0)
    # sphere_traj = jnp.concat([sphere_traj, np.ones_like(sphere_traj[..., :1]) * 0.001], axis=-1)

    ccd_out = curobo_ccd.ccd_curobo(curobo_ccd.mesh_ids, sphere_traj, visualize=True)

    print(1)

