import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"   
# os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
#                            "intra_op_parallelism_threads=1")

import jax.numpy as jnp
import pickle
import numpy as np
import os
import glob
import open3d as o3d
import jax
import matplotlib.pyplot as plt
import einops
from functools import partial
from tqdm import tqdm
from dataclasses import replace
from pathlib import Path
# import kdtree from scipy

# Setup import path
import sys
BASEDIR = Path(__file__).parent.parent
if str(BASEDIR) not in sys.path:
    sys.path.insert(0, str(BASEDIR))

# from util.model_util import SDFDecoder
import util.model_util as mutil
import util.transform_util as tutil
import util.latent_obj_util as loutil
from modules import shakey_module, ccd
import util.broad_phase as broad_phase
from util.dotenv_util import REP_CKPT

from dataset.generate_swept_volume_dataset import view, get_meshes_as_oricorns
import click
import open3d as o3d
import copy
import time

from scipy.spatial import ConvexHull


@click.group()
def cli():
    pass

def get_results_ours(
    models,
    trials,
    base_seed,
    moving_oriCORNs,
    moving_mesh_q_init,
    moving_mesh_q_goal,
    fk,
    fixed_oriCORNs,
    is_collisions,
    sign_distances,
    is_continuous,
    # debug
    shakey,
    fixed_mesh_paths,
    fixed_mesh_pqcs,
    fixed_mesh_scales,
    optimal_t,
):
    results = []
    broad_phase_cls = broad_phase.BroadPhaseWarp()
    for trial in range(trials):
        jkey = jax.random.PRNGKey(base_seed + trial)
        jkey, subkey1 = jax.random.split(jkey)
        jkey, subkey2 = jax.random.split(jkey)
        hyperparameters = {
            "interpolate_len": jax.random.randint(subkey1, (1,), 2, 100).item(),
            "reduce_k": jax.random.randint(subkey2, (1,), 2, 256).item(),
        }
        print(hyperparameters)
        collision_detector = ccd.OursCCD(
            models,
            collision_threshold=0.1,
            col_coef=1,
            reduce_k=hyperparameters["reduce_k"],
            is_continuous=is_continuous,
            broad_phase_cls=broad_phase_cls,
            broadphase_type="timeoptbf_traj" if is_continuous else None,
            return_collision_loss_pair=True,
        )

        t = jnp.linspace(0, 1, hyperparameters["interpolate_len"])[None, :, None]
        interpolated_qs = moving_mesh_q_init[:, None] * (1 - t) + moving_mesh_q_goal[:, None] * t # [N, T, 6]
        moving_obj_pqs = fk(interpolated_qs) # [N, T, NOB, 7]

        # for i in range(moving_obj_pqs.shape[0]):
        #     col_cost, aux_cost = collision_detector(
        #         moving_oriCORNs,
        #         moving_obj_pqs[i:i+1],
        #         fixed_oriCORNs[i],
        #         jkey,
        #         moving_obj_pqs.shape[1],
        #         visualize=False,
        #     )
        #     pred_label = jnp.any(jnp.where(aux_cost["collision_loss_pair"] > -0.5, 1, 0))
        #     print(f"trial {trial}, data {i}: pred_label={pred_label}, is_collision={is_collisions[i]}, col_cost={col_cost}")

        # jit & warmup
        # FIXME - needs padding for interpolation to jit

        collision_detector_vmap = jax.vmap(
            partial(collision_detector, visualize=False),
            in_axes=(None, 0, 0, None, None),
        )
        collision_detector_partial = lambda pqs, fixed_oriCORNs: collision_detector_vmap(
            moving_oriCORNs,
            pqs,
            fixed_oriCORNs,
            jkey,
            hyperparameters["interpolate_len"] - 1,
        )
        collision_detector_jit = jax.jit(collision_detector_partial)

        batch_size = 35

        for _ in range(3):
            start = time.time()
            # collision_loss_pairs = []
            pred_labels = []
            for i in range(0, len(is_collisions), batch_size):
                _, aux_cost = collision_detector_jit(moving_obj_pqs[i:i+batch_size], fixed_oriCORNs[i:i+batch_size])
                collision_loss_pair = aux_cost["collision_loss_pair"]
                collision_loss_pair = jax.block_until_ready(collision_loss_pair)
                axis = tuple(range(1, collision_loss_pair.ndim))
                pred_labels.append(jnp.any(jnp.where(collision_loss_pair > -0.5, 1, 0),axis=axis))
                # collision_loss_pairs.append(collision_loss_pair)

            pred_labels = jnp.concatenate(pred_labels, axis=0).flatten()
            elapsed_time = time.time() - start
            print(elapsed_time)

        # collision_loss_pairs = jnp.concatenate(collision_loss_pairs, axis=0)
        # collision_loss_pairs.diagonal(axis1=1, axis2=2)
        correct = jnp.equal(pred_labels, is_collisions)
        print(jnp.argwhere(correct == 0).flatten())
        sign_distances = jnp.array(sign_distances)
        accuracy = jnp.mean(correct).item()
        true_positives = jnp.sum(jnp.logical_and(pred_labels == 1, is_collisions == 1)).item()
        predicted_positives = jnp.sum(pred_labels == 1).item()
        precision = true_positives / predicted_positives if predicted_positives > 0 else 0.0

        # Recall: TP / (TP + FN)
        actual_positives = jnp.sum(is_collisions == 1).item()
        recall = true_positives / actual_positives if actual_positives > 0 else 0.0
        print(f"trial {trial}: accuracy={accuracy}, precision={precision}, recall={recall}, elapsed_time={elapsed_time}")

        # canonical_moving_o3d_meshes = [] # [NOB]
        # robot_link_paths = [
        #     models.asset_path_util.obj_paths[i] for i in shakey.canonical_obj_idx
        # ]
        # scales = [
        #     shakey.mesh_scale[i] for i in shakey.canonical_obj_idx
        # ]
        # canonical_moving_o3d_meshes = [
        #     o3d.io.read_triangle_mesh(link_path) for link_path in robot_link_paths
        # ]
        # for mesh, scale in zip(canonical_moving_o3d_meshes, scales):
        #     mesh.scale(scale, center=(0, 0, 0))
        # for i in range(len(correct)):
        #     if correct[i] == 0:
        #         if i <= 53:
        #             continue
        #         print(f"i={i} pred_label {pred_labels[i]} real_label {is_collisions[i]}")
        #         print("fixed mesh path", fixed_mesh_paths[i], "fixed mesh scale", fixed_mesh_scales[i], "fixed mesh pqc", fixed_mesh_pqcs[i])
        #         print("sign_distance", sign_distances[i])

        #         """
        #         (Pdb) moving_obj.shape
        #         (7,)
        #         (Pdb) moving_obj_pqs.shape
        #         (1, 56, 7, 7)
        #         (Pdb) fixed_obj.shape
        #         (2,)
        #         """
        #         collision_detector(
        #             moving_oriCORNs,
        #             moving_obj_pqs[i:i+1],
        #             fixed_oriCORNs[i:i+1],
        #             jkey,
        #             hyperparameters["interpolate_len"],
        #             visualize=True,
        #         )


        #         """
        #         i=1: pred_label False real_label True
        #         i=10: pred_label False real_label True
        #         i=15: pred_label True real_label False
        #         """
        #         fixed_mesh = o3d.io.read_triangle_mesh(fixed_mesh_paths[i])
        #         fixed_mesh.scale(fixed_mesh_scales[i], center=(0, 0, 0))
        #         fixed_mesh.transform(tutil.pq2H(fixed_mesh_pqcs[i]))
        #         # fixed_mesh.translate((-0.2, 0, 0))
        #         view(
        #             moving_mesh_q_init[i],
        #             moving_mesh_q_goal[i],
        #             8,
        #             fk,
        #             canonical_moving_o3d_meshes,
        #             fixed_mesh,
        #             optimal_t[i],
        #             None,
        #             shakey,
        #             fixed_oriCORNs[i],
        #         )
        #         breakpoint()

        results.append({
            "hyperparameters": hyperparameters,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "elapsed_time": elapsed_time,
        })
    return results

def get_results_curobo(
    models,
    trials,
    base_seed,
    moving_oriCORNs,
    moving_mesh_q_init,
    moving_mesh_q_goal,
    fk,
    fixed_mesh_paths,
    fixed_mesh_pqcs,
    fixed_mesh_scales,
    is_collisions,
    is_continuous,
    shakey,
):
    results = []
    for trial in range(trials):
        jkey = jax.random.PRNGKey(base_seed + trial)
        jkey, subkey1 = jax.random.split(jkey)
        jkey, subkey2 = jax.random.split(jkey)
        jkey, subkey3 = jax.random.split(jkey)
        jkey, subkey4 = jax.random.split(jkey)
        hyperparameters = {
            "interpolate_len": jax.random.randint(subkey1, (1,), 2, 64).item(),
            "activation_distance": 0, # jax.random.uniform(subkey2, (1,), minval=-0.2, maxval=0.2).item(),
            "voxel_intervals": jax.random.uniform(subkey3, (7,), minval=jnp.array([0.1, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]), maxval=jnp.array([1, 1, 1, 1, 1, 1, 1])).tolist(),
            "num_surface_samples": jax.random.randint(subkey4, (7,), 0, 100).tolist(),
            "mesh_simplification_voxel_size_factor": jax.random.randint(subkey4, (1,), 2, 32).item(),
            # "interpolate_len": 16, # trial + 2, # jax.random.randint(subkey1, (1,), 2, 100).item(),
            # "activation_distance": jax.random.uniform(subkey2, (1,), minval=-1.0, maxval=1.0).item(),
            # "voxel_intervals": [1, 1, 1, 1, 1, 1, 1],
            # "mesh_simplification_voxel_size_factor": 8,
            # "num_surface_samples": [16] * 7,
        }
        collision_detector = ccd.CuroboCCD(
            col_coef=1,
            rot_configs=models.rot_configs,
            activation_distance=hyperparameters["activation_distance"],
        )
        collision_detector.enroll_robot(
            shakey,
            voxel_intervals=hyperparameters["voxel_intervals"],
            num_surface_samples=hyperparameters["num_surface_samples"],
        )
        t = jnp.linspace(0, 1, hyperparameters["interpolate_len"])[None, :, None]
        interpolated_qs = moving_mesh_q_init[:, None] * (1 - t) + moving_mesh_q_goal[:, None] * t # [N, T, 6]
        moving_obj_pqs = fk(interpolated_qs) # [N, T, NOB, 7]
        pred_labels = []
        elapsed_time = 0
        # collision_detector_batch_jit = jax.jit(jax.vmap(collision_detector_jit))

        batch_size = 35

        pred_labels = []
        elapsed_time = 0
        for i in range(0, len(is_collisions), batch_size):
            collision_detector.enroll_meshes_batch(
                fixed_mesh_paths[i:i+batch_size],
                fixed_mesh_scales[i:i+batch_size],
                fixed_mesh_pqcs[i:i+batch_size],
                True,
                hyperparameters["mesh_simplification_voxel_size_factor"],
            )
            collision_detector_jit = jax.jit(collision_detector.call_batch)

            for _ in range(3):
                start = time.time()
                collision_loss, aux_cost = collision_detector_jit(
                    moving_obj_pqs[i:i+batch_size],
                    collision_detector.batch_mesh_ids,
                )
                collision_binary = jax.block_until_ready(aux_cost["collision_binary"])
                collision_binary = jnp.any(collision_binary, axis=(1, 2))
                end = time.time()
                print(i, end - start)
            
            pred_labels.append(collision_binary)
            elapsed_time += end - start

        pred_labels = jnp.concatenate(pred_labels, axis=0).flatten()

        is_collisions = is_collisions.flatten()

        correct = jnp.equal(pred_labels, is_collisions)

        accuracy = jnp.mean(correct).item()
        true_positives = jnp.sum(jnp.logical_and(pred_labels == 1, is_collisions == 1)).item()
        predicted_positives = jnp.sum(pred_labels == 1).item()
        precision = true_positives / predicted_positives if predicted_positives > 0 else 0.0

        # Recall: TP / (TP + FN)
        actual_positives = jnp.sum(is_collisions == 1).item()
        recall = true_positives / actual_positives if actual_positives > 0 else 0.0
        print(hyperparameters)
        print(f"trial {trial}: accuracy={accuracy}, precision={precision}, recall={recall}, elapsed_time={elapsed_time}")

        results.append({
            "hyperparameters": hyperparameters,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "elapsed_time": elapsed_time,
        })
    return results

from scipy.spatial import ConvexHull

def get_convex_hull(vertices):
    convex_hull = ConvexHull(vertices)

    vertices = convex_hull.points
    tris = convex_hull.simplices

    # v0, v1, v2 = vertices[tris[:, 0]], vertices[tris[:, 1]], vertices[tris[:, 2]]
    # face_normal = np.cross(v1 - v0, v2 - v0)
    # eq_normal = convex_hull.equations[:,:3]
    # is_face_inward = (face_normal * eq_normal).sum(axis=1) < 0
    # tris[is_face_inward] = tris[is_face_inward][:, [0, 2, 1]]

    faces = np.concatenate((3 * np.ones((len(tris), 1), dtype=np.int64), tris), axis=1).flatten()
    return vertices, tris, faces

def get_results_mesh(
    models,
    trials,
    base_seed,
    moving_oriCORNs,
    moving_mesh_q_init,
    moving_mesh_q_goal,
    fk,
    fixed_mesh_paths,
    fixed_mesh_pqcs,
    fixed_mesh_scales,
    is_collisions,
    is_continuous,
    shakey,
):
    import coacd
    import fcl
    import hashlib
    max_interpolation_length = 24 if is_continuous else 256
    def hash(mesh_name, scale=None, pqc=None):
        key = int(hashlib.sha256(mesh_name.encode('utf-8')).hexdigest(), 16) % 10**8
        if scale is not None:
            key += int(hashlib.sha256(str(scale).encode('utf-8')).hexdigest(), 16) % 10**8
        if pqc is not None:
            key += int(hashlib.sha256(str(pqc).encode('utf-8')).hexdigest(), 16) % 10**8
        return key

    def decompose(original_mesh, key):
        filename = f"temp/coacd_cache/{key}.pkl"
        if os.path.exists(filename):
            parts = pickle.load(open(filename, "rb"))
            return parts

        coacd_mesh = coacd.Mesh(np.asarray(original_mesh.vertices), np.asarray(original_mesh.triangles))
        parts = coacd.run_coacd(coacd_mesh, decimate=True, max_ch_vertex=25) # [[vertices, indices],...]
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        with open(filename, "wb") as f:
            pickle.dump(parts, f)
        return parts

    def aabb_min_distance(aabb1, aabb2):
        A_min = aabb1[0]
        A_max = aabb1[1]
        B_min = aabb2[0]
        B_max = aabb2[1]
        delta1 = A_min - B_max
        delta2 = B_min - A_max
        u = np.max(np.array([np.zeros(len(delta1)), delta1]), axis=0)
        v = np.max(np.array([np.zeros(len(delta2)), delta2]), axis=0)
        dist = np.linalg.norm(np.concatenate([u, v]))
        return dist

    def signed_distance_minkowski(aabb1, aabb2):
        # Compute centers and half extents of the AABBs
        c1 = (aabb1[0] + aabb1[1]) / 2.
        c2 = (aabb2[0] + aabb2[1]) / 2.
        h1 = (aabb1[1] - aabb1[0]) / 2.
        h2 = (aabb2[1] - aabb2[0]) / 2.
        
        # Center difference and combined half extents
        delta_c = np.abs(c2 - c1)
        E = h1 + h2
        
        # For each axis, compute d: negative if overlapping along that axis
        d = delta_c - E
        # The outside distance is the norm of the positive parts of d
        outside = np.maximum(d, 0)
        outside_distance = np.linalg.norm(outside)
        
        # The inside distance is given by the maximum negative amount
        inside_distance = np.minimum(np.max(d), 0)
        
        return outside_distance + inside_distance

    results = []
    canonical_moving_o3d_meshes = [] # [NOB]
    robot_link_paths = [
        models.asset_path_util.obj_paths[i] for i in shakey.canonical_obj_idx
    ]
    scales = shakey.mesh_scale
    canonical_moving_o3d_meshes = [
        o3d.io.read_triangle_mesh(link_path) for link_path in robot_link_paths
    ]
    for mesh, scale in zip(canonical_moving_o3d_meshes, scales):
        mesh.scale(scale, center=(0, 0, 0))

    canonical_parts = [decompose(mesh, hash(link_path, scale)) for mesh, link_path, scale in zip(canonical_moving_o3d_meshes, robot_link_paths, scales)]
    # ray.init()
    # get_convex_hull_remote = ray.remote(get_convex_hull, num_returns=3).remote

    # fcl_canonical_moving_objects = []
    # for parts in canonical_parts:
    #     fcl_parts = []
    #     for vertices, tris in parts:
    #         faces = np.concatenate((3 * np.ones((len(tris), 1), dtype=np.int64), tris), axis=1).flatten()
    #         convex = fcl.Convex(vertices, len(tris), faces)
    #         fcl_parts.append(convex)
    #     fcl_canonical_moving_objects.append(fcl_parts)

    for trial in range(trials):
        jkey = jax.random.PRNGKey(base_seed + trial)
        jkey, subkey1 = jax.random.split(jkey)
        jkey, subkey2 = jax.random.split(jkey)
        hyperparameters = {
            "interpolate_len": jax.random.randint(subkey1, (1,), 2, max_interpolation_length).item(),
            "activation_distance": jax.random.uniform(subkey2, (1,), minval=-0.01, maxval=0.01).item(),
        }
        print(hyperparameters)
        t = jnp.linspace(0, 1, hyperparameters["interpolate_len"])[None, :, None]
        interpolated_qs = moving_mesh_q_init[:, None] * (1 - t) + moving_mesh_q_goal[:, None] * t # [N, T, 6]
        moving_obj_pqs = fk(interpolated_qs) # [N, T, NOB, 7]
        moving_obj_pqs = np.array(moving_obj_pqs)
        pred_labels = []
        elapsed_time = 0

        # [N, T, NOB, 7]
        total_tf_parts = []
        for link_idx, parts_per_link in enumerate(canonical_parts): # [NOB]
            for vertices, tris in parts_per_link:
                tf_vertices = tutil.pq_action(moving_obj_pqs[:, :, link_idx][:, :, None], vertices[None, None]) # [N, T, V, 3]
                total_tf_parts.append((np.array(tf_vertices), tris))

        for i in range(moving_obj_pqs.shape[0]):
            fixed_mesh = o3d.io.read_triangle_mesh(fixed_mesh_paths[i])
            fixed_mesh.scale(fixed_mesh_scales[i], center=(0, 0, 0))
            fixed_mesh.transform(tutil.pq2H(fixed_mesh_pqcs[i]))
            key = hash(fixed_mesh_paths[i], fixed_mesh_scales[i], fixed_mesh_pqcs[i])
            parts = decompose(fixed_mesh, key)
            fixed_fcl_parts = []
            fixed_obj_aabb = np.stack([fixed_mesh.get_min_bound(), fixed_mesh.get_max_bound()])
            for vertices, tris in parts:
                faces = np.concatenate((3 * np.ones((len(tris), 1), dtype=np.int64), tris), axis=1).flatten()
                convex = fcl.Convex(vertices, len(tris), faces)
                fixed_fcl_parts.append(convex)

            fixed_manager = fcl.DynamicAABBTreeCollisionManager()
            moving_manager = fcl.DynamicAABBTreeCollisionManager()
            start = time.time()
            fixed_manager.registerObjects([
                fcl.CollisionObject(obj) for obj in fixed_fcl_parts
            ])
            elapsed_time += time.time() - start
            moving_objects = []
            if is_continuous:
                start = time.time()
                convex_hulls = []
                for t in range(moving_obj_pqs.shape[1] - 1):
                    for total_tf_part in total_tf_parts:
                        # vertices =  # [2*V, 3]
                        vertices, tris = total_tf_part

                        init_vertices = vertices[i, t]
                        goal_vertices = vertices[i, t + 1]
                        
                        vertices = np.concatenate((init_vertices, goal_vertices), axis=0)
                        moving_obj_aabb = np.stack([vertices.min(axis=0), vertices.max(axis=0)])
                        aabb_closest_distance = signed_distance_minkowski(moving_obj_aabb, fixed_obj_aabb)
                        if aabb_closest_distance > hyperparameters["activation_distance"]:
                            continue
                        # init_vertices_num = init_vertices.shape[0]
                        # tris = np.concatenate((tris, tris + init_vertices_num), axis=0)
                        # faces = np.concatenate((3 * np.ones((len(tris), 1), dtype=np.int64), tris), axis=1).flatten()

                        # print(init_vertices.shape)
                        # init_vertices = np.random.rand(50, 3)
                        vertices, tris, faces = get_convex_hull(vertices)
                        convex_hulls.append((vertices, tris, faces))
                        # # start = time.time()
                        # elapsed_time += time.time() - start
                        # # convex_hull = ConvexHull(vertices)

                        # vertices = convex_hull.points
                        # tris = convex_hull.simplices

                        # v0, v1, v2 = vertices[tris[:, 0]], vertices[tris[:, 1]], vertices[tris[:, 2]]
                        # face_normal = np.cross(v1 - v0, v2 - v0)
                        # eq_normal = convex_hull.equations[:,:3]
                        # is_face_inward = (face_normal * eq_normal).sum(axis=1) < 0
                        # tris[is_face_inward] = tris[is_face_inward][:, [0, 2, 1]]

                        # faces = np.concatenate((3 * np.ones((len(tris), 1), dtype=np.int64), tris), axis=1).flatten()
                # waits = ray.get(waits)
                # print(len(waits))
                # print(f"# {i} convex hull: {time.time() - start} / len: {len(results)}")
                elapsed_time += time.time() - start
                # start = time.time()
                moving_objects.extend([
                    fcl.CollisionObject(fcl.Convex(vertices, len(tris), faces)) for vertices, tris, faces in convex_hulls
                ])
                # print(f"# {i} initialization: {time.time() - start}")
            else:
                discrete_meshes = []
                start = time.time()

                for t in range(moving_obj_pqs.shape[1]):
                    for total_tf_part in total_tf_parts:
                        vertices, tris = total_tf_part
                        vertices = vertices[i, t]

                        moving_obj_aabb = np.stack([vertices.min(axis=0), vertices.max(axis=0)])
                        aabb_closest_distance = signed_distance_minkowski(moving_obj_aabb, fixed_obj_aabb)
                        if aabb_closest_distance > hyperparameters["activation_distance"]:
                            continue

                        faces = np.concatenate((3 * np.ones((len(tris), 1), dtype=np.int64), tris), axis=1).flatten()
                        discrete_meshes.append((vertices, tris, faces))

                elapsed_time += time.time() - start

                moving_objects.extend([
                    fcl.CollisionObject(fcl.Convex(vertices, len(tris), faces)) for vertices, tris, faces in discrete_meshes
                ])

            start = time.time()

            moving_manager.registerObjects(moving_objects)
            req = fcl.DistanceRequest(enable_signed_distance=True)
            ddata = fcl.DistanceData(request = req)
            moving_manager.distance(fixed_manager, ddata, fcl.defaultDistanceCallback)
            min_distance = ddata.result.min_distance - hyperparameters["activation_distance"]
            pred_labels.append(min_distance < 0)

            elapsed_time += time.time() - start

        pred_labels = jnp.array(pred_labels).flatten()
        is_collisions = is_collisions.flatten()
        correct = jnp.equal(pred_labels, is_collisions)

        accuracy = jnp.mean(correct).item()
        true_positives = jnp.sum(jnp.logical_and(pred_labels == 1, is_collisions == 1)).item()
        predicted_positives = jnp.sum(pred_labels == 1).item()
        precision = true_positives / predicted_positives if predicted_positives > 0 else 0.0

        # Recall: TP / (TP + FN)
        actual_positives = jnp.sum(is_collisions == 1).item()
        recall = true_positives / actual_positives if actual_positives > 0 else 0.0
        print(f"trial {trial}: accuracy={accuracy}, precision={precision}, recall={recall}, elapsed_time={elapsed_time}")

        results.append({
            "hyperparameters": hyperparameters,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "elapsed_time": elapsed_time,
        })

    return results

@cli.command()
@click.option("--urdf_dirs", type=str, default="assets/ur5/urdf/shakey_open.urdf")
@click.option("--data_path", type=str, default="temp/swept_volume_dataset/v1.pkl")
@click.option("--evaluation_type", type=click.Choice([
    'ours_ccd',
    'ours_stamp',
    'ours_global_stamp',
    'curobo_ccd',
    # 'curobo_stamp',
    'trajopt_ccd',
    'trajopt_stamp',
]))
@click.option("--result_dir", type=str, default="temp/swept_volume_evaluation_result/")
@click.option("--trials", type=int, default=100)
@click.option("--seed", type=int, default=0)
def evaluate(
    urdf_dirs,
    data_path,
    evaluation_type,
    result_dir,
    trials,
    seed,
):
    if evaluation_type == "ours_global_stamp":
        assert "global_baseline" in REP_CKPT, f"REP_CKPT must include global_baseline in name if testing global + stamp current: {REP_CKPT}"
    os.makedirs(result_dir, exist_ok=True)
    models = mutil.Models().load_pretrained_models()
    shakey = shakey_module.load_urdf_kinematics(
        urdf_dirs=urdf_dirs,
        models=models,
    )

    with open(data_path, 'rb') as f:
        data = pickle.load(f)

    sign_distances = jnp.array(data["min_distance"])
    is_collisions = jnp.array(data["is_collisions"])
    optimal_t = jnp.array(data["optimal_t"])

    data_num = len(is_collisions)

    moving_oriCORNs = shakey.link_canonical_oriCORN
    moving_mesh_q_init = jnp.array(data["moving_mesh"]["q_init"]) # [N, 6]
    moving_mesh_q_goal = jnp.array(data["moving_mesh"]["q_goal"]) # [N, 6]
    @jax.jit
    def fk(q):
        return shakey.FK(q, oriCORN_out=False) # tutil.pq_multi(robot_base_pqc, shakey.FK(q, oriCORN_out=False))    

    fixed_mesh_paths = data["fixed_mesh"]["paths"]
    fixed_mesh_paths = [
        os.path.join(
            models.asset_path_util.asset_base_dir,
            fixed_mesh_path
                .replace("/mnt/ssd2/hojin/object_dataset/", "")
                .replace("/home/rogga/research/efficient_planning/dataset/", ""),
        )
        for fixed_mesh_path in fixed_mesh_paths
    ]
    fixed_mesh_pqcs = jnp.array(data["fixed_mesh"]["pqcs"])
    fixed_mesh_scales = jnp.array(data["fixed_mesh"]["scales"])

    print(f"loaded {data_num} datas")

    if evaluation_type in ["ours_ccd", "ours_stamp", "ours_global_stamp"]:
        canonical_fixed_oriCORNs = get_meshes_as_oricorns(models, fixed_mesh_paths, model_id=None)

        fixed_oriCORNs = canonical_fixed_oriCORNs.apply_scale(fixed_mesh_scales, center=jnp.array([[0,0,0]])).apply_pq_z(fixed_mesh_pqcs, models.rot_configs)
        results = get_results_ours(
            models,
            trials,
            seed,
            moving_oriCORNs,
            moving_mesh_q_init,
            moving_mesh_q_goal,
            fk,
            fixed_oriCORNs,
            is_collisions,
            sign_distances,
            is_continuous=(evaluation_type == "ours_ccd"),

            shakey=shakey,
            fixed_mesh_paths=fixed_mesh_paths,
            fixed_mesh_pqcs=np.array(fixed_mesh_pqcs),
            fixed_mesh_scales=np.array(fixed_mesh_scales),
            optimal_t=optimal_t,
        )
    elif evaluation_type == "curobo_ccd":
        results = get_results_curobo(
            models,
            trials,
            seed,
            moving_oriCORNs,
            moving_mesh_q_init,
            moving_mesh_q_goal,
            fk,
            fixed_mesh_paths,
            np.array(fixed_mesh_pqcs),
            np.array(fixed_mesh_scales),
            is_collisions,
            is_continuous=False,
            shakey=shakey,
        )
    elif evaluation_type in ["trajopt_ccd", "trajopt_stamp"]:
        results = get_results_mesh(
            models,
            trials,
            seed,
            moving_oriCORNs,
            moving_mesh_q_init,
            moving_mesh_q_goal,
            fk,
            fixed_mesh_paths,
            np.array(fixed_mesh_pqcs),
            np.array(fixed_mesh_scales),
            is_collisions,
            is_continuous=True if evaluation_type == "trajopt_ccd" else False,
            shakey=shakey,
        )
    else:
        raise ValueError(f"Unknown evaluation type: {evaluation_type}")

    if "ours_global" in evaluation_type:
        evaluation_type += f"_{models.pretrain_ckpt_id}"
    filename = f"{result_dir}/{evaluation_type}.pkl"
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            prev_results = pickle.load(f)
        results = prev_results + results
    print(len(results))
    with open(f"{result_dir}/{evaluation_type}.pkl", 'wb') as f:
        pickle.dump(results, f)

def compute_pareto_front(xs, ys):
    # Combine x and y into points.
    points = list(zip(xs, ys))
    # Sort points by x (ascending) and then by y (descending) so that points with lower x and higher y come first.
    points.sort(key=lambda p: (p[0], -p[1]))
    pareto = []
    max_y = -float('inf')
    for x, y in points:
        if y > max_y:
            pareto.append((x, y))
            max_y = y
    return pareto

@cli.command()
@click.option("--result_dir", type=str, default="temp/swept_volume_evaluation_result/")
@click.option("--xlabel", type=str, default="time")
@click.option("--ylabels", type=str, multiple=True)
@click.option("--draw_pareto", is_flag=True, show_default=True, default=False)
def plot(
    result_dir,
    xlabel,
    ylabels,
    draw_pareto,
):
    pickle_files = glob.glob(f"{result_dir}/*.pkl")

    excludes = []
    fig, axs = plt.subplots(1, len(ylabels), figsize=(5 * len(ylabels), 5))
    if len(ylabels) > 1:
        axs = axs.flatten()
    else:
        axs = [axs]

    for pickle_file in pickle_files:
        if any([exclude in pickle_file for exclude in excludes]):
            continue

        label = os.path.basename(pickle_file).split(".")[0]
        with open(pickle_file, 'rb') as f:
            results = pickle.load(f)
        print(label, len(results))
        data = {
            "interpolate length": [result["hyperparameters"]["interpolate_len"] for result in results],
            "time": [result["elapsed_time"] for result in results],
            "accuracy": [result["accuracy"] for result in results],
            "precision": [result["precision"] for result in results],
            "recall": [result["recall"] for result in results],
        }
        for i, ylabel in enumerate(ylabels):
            xs = data[xlabel]
            ys = data[ylabel]

            # Now that we have all points for this subplot, compute Pareto frontier
            # We'll assume time is minimized (smaller is better) and the metric is maximized
            if draw_pareto:
                pf = compute_pareto_front(xs, ys)
                pareto_x, pareto_y = zip(*pf)
                pareto_x, pareto_y = list(pareto_x), list(pareto_y)
                pareto_x.append(max(xs))
                pareto_y.append(pareto_y[-1])
                axs[i].plot(pareto_x, pareto_y, '-o', label=label)
            else:
                axs[i].scatter(xs, ys, label=label)

    for i, ylabel in enumerate(ylabels):
        axs[i].set_title(ylabel)
        axs[i].set_xscale('log')
        axs[i].set_xlabel(xlabel)
        axs[i].set_ylabel(ylabel)
        axs[i].legend()

    fig.suptitle("Swept Volume Evaluation", fontsize=16)
    fig.tight_layout()
    image_path = os.path.join(result_dir, f"result_{xlabel}_{'_'.join(ylabels)}{'_pareto' if draw_pareto else ''}.png")
    plt.savefig(image_path)

@cli.command()
@click.option("--result_dirs", type=str, multiple=True)
@click.option("--subtitles", type=str, multiple=True)
@click.option("--xlabel", type=str, default="time")
@click.option("--ylabel", type=str, default="accuracy")
@click.option("--draw_pareto", is_flag=True, show_default=True, default=False)
def plot_multiple(
    result_dirs,
    subtitles,
    xlabel,
    ylabel,
    draw_pareto,
):

    fig, axs = plt.subplots(1, len(result_dirs), figsize=(5 * len(result_dirs), 5), )
    if len(result_dirs) > 1:
        axs = axs.flatten()
    else:
        axs = [axs]

    for i, (subtitle, result_dir) in enumerate(zip(subtitles, result_dirs)):
        # pickle_files = glob.glob(f"{result_dir}/*.pkl")
        ax = axs[i]
        excludes = ["ours_global_stamp_global_baseline_v2"]
        label_alias = {
            "ours_ccd": "OURS-CONTINUOUS",
            "ours_stamp": "OURS-DISCRETE",
            "ours_global_stamp": "OURS(GLOBAL)-DISCRETE",
            "curobo_ccd": "SPHERE-CONTINUOUS",
            "trajopt_ccd": "CONVEXHULL-CONTINUOUS",
            "trajopt_stamp": "CONVEXHULL-DISCRETE",
        }
        label_orders = list(label_alias.keys())
        print(label_orders)
        pickle_files = [f"{result_dir}/{label}.pkl" for label in label_orders]
        pickle_files = [pickle_file for pickle_file in pickle_files if os.path.exists(pickle_file)]
        print(pickle_files)

        for pickle_file in pickle_files:
            if any([exclude in pickle_file for exclude in excludes]):
                continue

            label = os.path.basename(pickle_file).split(".")[0]
            with open(pickle_file, 'rb') as f:
                results = pickle.load(f)
            print(label, len(results))
            data = {
                "interpolate length": [result["hyperparameters"]["interpolate_len"] for result in results],
                "time": [result["elapsed_time"] / 700 for result in results],
                "accuracy": [result["accuracy"] * 100 for result in results],
                "precision": [result["precision"] for result in results],
                "recall": [result["recall"] for result in results],
            }
            xs = data[xlabel]
            ys = data[ylabel]

            # Now that we have all points for this subplot, compute Pareto frontier
            # We'll assume time is minimized (smaller is better) and the metric is maximized
            if draw_pareto:
                pf = compute_pareto_front(xs, ys)
                pareto_x, pareto_y = zip(*pf)
                pareto_x, pareto_y = list(pareto_x), list(pareto_y)
                pareto_x.append(max(xs))
                pareto_y.append(pareto_y[-1])

                target = 90  # The horizontal line y = 90
                if label in ["ours_ccd", "ours_stamp"]:
                    print(pareto_x)
                    print(pareto_y)
                for idx in range(len(pareto_x) - 1):
                    y0, y1 = pareto_y[idx], pareto_y[idx+1]
                    if (y0 - target) * (y1 - target) <= 0:
                        x0, x1 = pareto_x[idx], pareto_x[idx+1]
                        if y1 != y0:
                            xi = x0 + (target - y0) * (x1 - x0) / (y1 - y0)
                        else:
                            xi = x0
                        if label == "ours_ccd":
                            ours_x_i = xi
                        print(f"{label} y={target} x = {xi}, faster {xi / ours_x_i} times")
                        break
                ax.plot(pareto_x, pareto_y, '-o', label=label_alias[label])
            else:
                ax.scatter(xs, ys, label=label_alias[label])
        if subtitle == "ID":
            ax.scatter([2.3348825366156443], [100], label="EXACT", color="black")
        elif subtitle == "OOD":
            ax.scatter([2.0292290265219552], [100], label="EXACT", color="black")

        ax.axhline(y=100, color='grey', linestyle='--')
        ax.set_title(subtitle)
        ax.set_xscale('log')
        ax.set_ylabel(ylabel.title() + " (%)")
        ax.set_xlabel("Time per SVCD query (s)")
        # ax.set_xlabel(xlabel.title() + " (s)")
        ax.set_ylim(49, 101)
        ax.legend()

    fig.tight_layout()
    image_path = os.path.join(result_dir, f"result_multiple_{xlabel}_{ylabel}_{'_'.join(subtitles)}{'_pareto' if draw_pareto else ''}.png")
    print("saving", image_path)
    plt.savefig(image_path, dpi=400)

commands = click.CommandCollection(sources=[cli])

if __name__ == "__main__":
    commands()
