import datetime
import os
import jax.numpy as jnp
import pickle
import numpy as np
import os
import glob
import open3d as o3d
import jax
import matplotlib.pyplot as plt
import einops
from functools import partial
from tqdm import tqdm
from dataclasses import replace
from pathlib import Path
# import kdtree from scipy

# Setup import path
import sys
BASEDIR = Path(__file__).parent.parent
if str(BASEDIR) not in sys.path:
    sys.path.insert(0, str(BASEDIR))

# from util.model_util import SDFDecoder
import util.model_util as mutil
import util.transform_util as tutil
import util.latent_obj_util as loutil
from modules import shakey_module, ccd
from modules.encoder import encode_mesh
import util.broad_phase as broad_phase
from util.reconstruction_util import create_scene_mesh_from_oriCORNs, create_fps_fcd_from_oriCORNs
import click
import open3d as o3d
import copy
import time
import warp as wp
from warp.jax_experimental.custom_call import jax_kernel

def uint64_to_uint32_pair(value):
    # Extract the lower 32 bits
    low = jnp.uint32(value & 0xFFFFFFFF)
    # Extract the upper 32 bits
    high = jnp.uint32(value >> 32)
    # return high, low
    return jnp.stack([high, low], axis=-1).astype(jnp.uint32)

def get_mesh_query_point():
    @wp.kernel
    def mesh_query_point(
        points: wp.array(dtype=wp.vec3),
        mesh_id_32: wp.array(dtype=wp.vec2ui),
        max_distance: wp.array(dtype=wp.float32),
        # outputs
        distance: wp.array(dtype=wp.float32),
        closest_direction: wp.array(dtype=wp.vec3),
    ):
        tid = wp.tid()  # get the thread index

        point = points[tid]
        mesh_id = (wp.uint64(mesh_id_32[0][0]) << wp.uint64(32)) | wp.uint64(mesh_id_32[0][1])

        dist = max_distance[0]
        collide_result = wp.mesh_query_point(mesh_id, point, max_distance[0])
        if collide_result.result:
            sign = collide_result.sign
            # sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.

            closest_point = wp.mesh_eval_position(
                mesh_id, collide_result.face, collide_result.u, collide_result.v
            )
            delta = closest_point - point
            dis_length = wp.length(delta)
            dist = sign * dis_length # - if inside, + if outside

        # Write the resulting signed distance into the output array.
        distance[tid] = dist
        closest_direction[tid] = delta
    return mesh_query_point

class TimeOptimization:
    def __init__(
            self,
            canonical_moving_o3d_meshes,
            fk,
        ):
        self.canonical_moving_o3d_meshes = canonical_moving_o3d_meshes
        self.canonical_moving_mesh_points = []
        for mesh in canonical_moving_o3d_meshes:
            self.canonical_moving_mesh_points.append(
                jnp.concatenate([
                    jnp.array(mesh.vertices).reshape(-1, 3),
                    jnp.array(mesh.sample_points_uniformly(number_of_points=2000).points).reshape(-1, 3),
                ], axis=0)
            )
        self.fk = fk
        self.mesh_query_point = jax_kernel(get_mesh_query_point())

    def enroll_meshes(self, o3d_meshs, refit=False):
        vertices = []
        faces = []
        vtx_idx_offset = 0
        for o3d_mesh in o3d_meshs:
            vertices.append(np.array(o3d_mesh.vertices))
            faces.append(np.array(o3d_mesh.triangles).reshape(-1) + vtx_idx_offset)
            vtx_idx_offset += len(vertices[-1])
        vertices = np.concatenate(vertices, axis=0)
        faces = np.concatenate(faces, axis=0)

        if refit:
            self.vertices.assign(vertices)
            self.faces.assign(faces)
            self.mesh_wp.refit()
        else:
            self.vertices = wp.array(vertices, dtype=wp.vec3)
            self.faces = wp.array(faces, dtype=int)
            self.mesh_wp = wp.Mesh(
                points=self.vertices,
                velocities=None,
                indices=self.faces,
            )

        self.mesh_id = uint64_to_uint32_pair(self.mesh_wp.id)[None]

    def evaluate_sign_distance(self, points):
        nt, npb, _ = points.shape
        points = points.reshape(-1, 3)
        args = [
            points,                                 # wp.array(dtype=wp.vec3)
            self.mesh_id,                           # wp.array(dtype=wp.vec2ui)
            jnp.array([1000.0], dtype=jnp.float32),    # wp.array(dtype=wp.float32)
        ]
        outputs = self.mesh_query_point(*args)
        outputs = outputs
        min_distances = outputs[0]
        closest_directions = outputs[1]

        min_distances = min_distances.reshape(nt, npb)
        closest_directions = closest_directions.reshape(nt, npb, 3)

        idx = jnp.argmin(min_distances, axis=1) # [NT]
        min_distances = min_distances[jnp.arange(nt), idx]
        closest_directions = closest_directions[jnp.arange(nt), idx]

        return min_distances, closest_directions # [NT], [NT, 3]

    def get_jitted_query(self):
        return jax.jit(self.query)

    def query(
        self,
        t, # [NT]
        q_init, # [6]
        q_goal, # [6]
    ):
        nt = t.shape[0]
        q_t = (1 - t)[:,None] * q_init[None,:] + t[:,None] * q_goal # [NT, 6]
        moving_pqs = self.fk(q_t) # [NT, NOB, 7]
        moving_mesh_points = [
            tutil.pq_action(
                moving_pqs[:,i:i+1], self.canonical_moving_mesh_points[i],
            )
            for i in range(len(self.canonical_moving_mesh_points))
        ] # [[NT, N, 3], ...]
        min_distances = jnp.zeros((nt, len(moving_mesh_points)))
        closest_directions = jnp.zeros((nt, len(moving_mesh_points), 3))

        for i in range(len(moving_mesh_points)):
            min_distance, closest_direction = self.evaluate_sign_distance(moving_mesh_points[i])
            min_distances = min_distances.at[:,i].set(min_distance)
            closest_directions = closest_directions.at[:,i].set(closest_direction)

        min_link_idx = jnp.argmin(min_distances, axis=1) # [NT]
        min_distances = min_distances[jnp.arange(nt), min_link_idx]
        closest_directions = closest_directions[jnp.arange(nt), min_link_idx]
        return min_distances, closest_directions, min_link_idx # [NT], [NT, 3]

def view(moving_meshes_q_init, moving_meshes_q_goal, interpolation_num, fk, canonical_meshes, fixed_mesh, min_t=None, closest_direction=None, shakey=None, fixed_oricorn=None):
    t = jnp.linspace(0, 1, interpolation_num)[:, None] # [T, 1]
    t = jnp.concatenate([t, jnp.array([[min_t]])], axis=0) if min_t is not None else t
    t = jnp.sort(t, axis=0)

    moving_meshes_q_interpolated = (1 - t) * moving_meshes_q_init[None] + t * moving_meshes_q_goal[None] # [T, 6]
    moving_meshes_pqcs = fk(moving_meshes_q_interpolated) # [T, NOB, 7]
    mesh_h = tutil.pq2H(moving_meshes_pqcs) # [T, NOB, 4, 4]
    timesteps, nob, *_ = mesh_h.shape

    tf_mesh = []
    for t in range(timesteps):
        for i in range(nob):
            mesh = copy.deepcopy(canonical_meshes[i])
            mesh.compute_vertex_normals()
            mesh.transform(mesh_h[t, i])
            mesh.paint_uniform_color([t / timesteps, 1 - t / timesteps, 0])
            tf_mesh.append(mesh)

    # if shakey is not None:
    #     for q in moving_meshes_q_interpolated:
    #         tf_mesh += shakey.show_in_o3d(q, visualize=False)
    if fixed_oricorn is not None:
        dec = jax.jit(shakey.models.occ_prediction)
        tf_mesh += [create_scene_mesh_from_oriCORNs(fixed_oricorn, dec, visualize=False, qp_bound=0.2, density=400)]

        tf_mesh += [create_fps_fcd_from_oriCORNs(fixed_oricorn, visualize=False)]

    geometries = tf_mesh

    if fixed_mesh is not None:
        fixed_mesh.compute_vertex_normals()
        fixed_mesh.paint_uniform_color([1, 0.706, 0])
        geometries.append(fixed_mesh)

        if closest_direction is not None:
            closest_direction = np.array(closest_direction).reshape(3)
            line = o3d.geometry.LineSet()
            mesh_center = fixed_mesh.get_center()
            closest_direction = mesh_center - closest_direction
            line.points = o3d.utility.Vector3dVector([mesh_center, closest_direction])
            line.lines = o3d.utility.Vector2iVector([[0, 1]])
            line.colors = o3d.utility.Vector3dVector([[1, 0, 0]])
            geometries.append(line)

    o3d.visualization.draw_geometries(geometries)

def simple_search(query_func, num_samples, k):
    progress_bar = tqdm(total=k + 1)
    t = jnp.linspace(0, 1, num_samples)
    distance, direction, link_index = query_func(t)
    top_k_idx = jax.lax.top_k(-distance, k=k)[1]

    min_distance = distance[top_k_idx[0]]
    closest_direction = direction[top_k_idx[0]]
    min_t = t[top_k_idx[0]]
    min_link_index = link_index[top_k_idx[0]]
    progress_bar.update(1)

    for idx in top_k_idx:
        t_init = t[max(idx - 1, 0)]
        t_end = t[min(idx + 1, num_samples - 1)]
        t_range = jnp.linspace(t_init, t_end, num_samples)
        distance, direction, link_index = query_func(t_range)
        idx = jnp.argmin(distance)
        if distance[idx] < min_distance:
            min_distance = distance[idx]
            closest_direction = direction[idx]
            min_t = t_range[idx]
            min_link_index = link_index[idx]
        progress_bar.update(1)

    return min_distance, closest_direction, min_t, min_link_index


def generate_row(
        time_optimization: TimeOptimization,
        canonical_fixed_o3d_mesh,
        init_fixed_mesh_h,
        moving_meshes_q_init,
        moving_meshes_q_goal,
        jkey,

        # for debug
        canonical_fixed_oricorn: loutil.LatentObjects,
        fixed_mesh_scale,
        call_collision_detector,
        shakey,
        visualize=False,
    ):
        fixed_o3d_mesh = copy.deepcopy(canonical_fixed_o3d_mesh)
        fixed_o3d_mesh.transform(init_fixed_mesh_h)

        time_optimization.enroll_meshes([fixed_o3d_mesh], refit=False)
        query = time_optimization.get_jitted_query()
        def query_func(t_batch):
            distance, direction, link_index = query(t_batch, moving_meshes_q_init, moving_meshes_q_goal)
            distance = jax.block_until_ready(distance)
            direction = jax.block_until_ready(direction)
            link_index = jax.block_until_ready(link_index)
            return distance, direction, link_index

        batch_size = 1024
        t = jnp.linspace(0, 1, batch_size)
        min_distance = jnp.inf
        closest_direction = jnp.zeros((3,))
        min_t = None
        min_link_index = None
        for i in tqdm(range(0, len(t), batch_size)):
            t_batch = t[i:i+batch_size]
            distance, direction, link_index = query_func(t_batch)
            idx = jnp.argmin(distance)
            if distance[idx] < min_distance:
                min_distance = distance[idx]
                closest_direction = direction[idx]
                min_t = t_batch[idx]
                min_link_index = link_index[idx]

        # view(
        #     moving_meshes_q_init,
        #     moving_meshes_q_goal,
        #     8,
        #     time_optimization.fk,
        #     time_optimization.canonical_moving_o3d_meshes,
        #     fixed_o3d_mesh,
        #     min_t,
        #     None,
        #     None,
        #     None,
        # )


        # put_in = jax.random.bernoulli(jkey, shape=()) * 2 - 1
        # norm = jnp.linalg.norm(closest_direction)
        noise = jax.random.normal(jkey, shape=(3,), dtype=jnp.float32) * 0.03
        perturbed_closest_direction = closest_direction + noise

        init_fixed_mesh_pq = tutil.H2pq(init_fixed_mesh_h, concat=False)
        fixed_mesh_h = tutil.pq2H(init_fixed_mesh_pq[0] - perturbed_closest_direction, init_fixed_mesh_pq[1])

        fixed_o3d_mesh = copy.deepcopy(canonical_fixed_o3d_mesh)
        fixed_o3d_mesh.transform(fixed_mesh_h)
        time_optimization.enroll_meshes([fixed_o3d_mesh], refit=True)

        k = 3
        elapsed_time = 0
        start = time.time()
        min_distance, min_direction, min_t, min_link_index = simple_search(query_func, batch_size, k)
        elapsed_time += time.time() - start
        is_collision = (min_distance < 0)

        # ours_interpolate_len = 16
        # t = jnp.linspace(0, 1, ours_interpolate_len)[:, None] # [T, 1]
        # moving_meshes_q_interpolated = (1 - t) * moving_meshes_q_init[None] + t * moving_meshes_q_goal[None] # [T, 6]
        # moving_meshes_pqcs = time_optimization.fk(moving_meshes_q_interpolated) # [T, NOB, 7]
        # _, aux_info_init = collision_detector(moving_meshes_pqcs, tutil.H2pq(init_fixed_mesh_h, concat=True), visualize=False)
        # tf_fixed_oricorn = canonical_fixed_oricorn.apply_pq_z(tutil.H2pq(fixed_mesh_h, concat=True), shakey.models.rot_configs)
        # (_, aux_info) = call_collision_detector(
        #     moving_meshes_pqcs,
        #     tf_fixed_oricorn,
        #     visualize=False
        # )

        # pred_is_collision = jnp.any(jnp.where(aux_info["collision_loss_pair"] > -0.5, 1, 0))
        # print(put_in)
        print("gt:", is_collision, "min_distance:", min_distance, "min_t:", min_t, "min_link_index:", min_link_index)
        # view(
        #     moving_meshes_q_init,
        #     moving_meshes_q_goal,
        #     8,
        #     time_optimization.fk,
        #     time_optimization.canonical_moving_o3d_meshes,
        #     fixed_o3d_mesh,
        #     min_t,
        #     None,
        #     None,
        #     None,
        # )
        # print("gt:", is_collision, "pred:", pred_is_collision, "min_distance:", min_distance)

        # if is_collision != pred_is_collision:

        if visualize:
            # tf_fixed_oricorn = canonical_fixed_oricorn.apply_scale(fixed_mesh_scale, center=jnp.array([0,0,0])).apply_pq_z(tutil.H2pq(fixed_mesh_h, concat=True), shakey.models.rot_configs)
            view(
                moving_meshes_q_init,
                moving_meshes_q_goal,
                8,
                time_optimization.fk,
                time_optimization.canonical_moving_o3d_meshes,
                fixed_o3d_mesh,
                min_t,
                closest_direction,
                shakey,
                None,
            )

        # collision_detector(
        #     moving_meshes_pqcs,
        #     tf_fixed_oricorn,
        #     visualize=True
        # )
        return is_collision, min_distance, fixed_mesh_h, min_t, min_link_index, elapsed_time

def get_meshes_as_oricorns(models, mesh_paths, train_set = False, load_cached_objects_only = False, model_id=None):
    if train_set:
        return models.mesh_aligned_canonical_obj
    oriCORNs: loutil.LatentObjects = None
    unique_mesh_paths = list(set(mesh_paths))
    if model_id is None:
        model_id = models.pretrain_ckpt_id

    for mesh_path in tqdm(unique_mesh_paths):
        mesh_idx = models.asset_path_util.get_obj_id(mesh_path)
        if mesh_idx == -1:
            dataset_name = os.path.basename(os.path.dirname(os.path.dirname(mesh_path)))
            output_filename = f'assets_oriCORNs/{model_id}/{dataset_name}/{os.path.basename(mesh_path).split(".")[0]}.pkl'
            if os.path.exists(output_filename):
                with open(output_filename, 'rb') as f:
                    oriCORN = pickle.load(f)
            elif load_cached_objects_only:
                continue
            else:
                oriCORN = encode_mesh(models, mesh_path, nfps_multiplier=1, niter=10000)
        else:
            oriCORN = models.mesh_aligned_canonical_obj[mesh_idx]

        if oriCORNs is None:
            oriCORNs = oriCORN[None]
        else:
            oriCORNs = oriCORNs.concat(oriCORN[None], axis=0)
    mesh_path_ids = [unique_mesh_paths.index(mesh_path) for mesh_path in mesh_paths]
    oriCORNs = oriCORNs[jnp.array(mesh_path_ids, dtype=jnp.int32)]
    return oriCORNs

@click.command()
@click.option("--urdf_dirs", type=str, default="assets/ur5/ur5/urdf/shakey_open.urdf")
@click.option("--data_num", type=int, default=1000)
@click.option("--mesh_dir", type=str, default="EGAD/modified")
@click.option("--save_data_dir", type=str, default="temp/swept_volume_dataset")
@click.option("--seed", type=int, default=42)
@click.option("--in_domain", is_flag=True, show_default=True, default=False)
def main(
    urdf_dirs,
    data_num,
    mesh_dir,
    save_data_dir,
    seed,
    in_domain,
):
    models = mutil.Models().load_pretrained_models()
    mesh_dir = os.path.join(models.asset_path_util.asset_base_dir, mesh_dir)
    if in_domain:
        print("In domain, using fixed meshes from models, mesh_dir is ignored")
        fixed_mesh_paths = []
        for fixed_mesh_path in models.asset_path_util.obj_paths:
            if "GoogleScannedObjects" not in fixed_mesh_path:
                continue
            mesh = o3d.io.read_triangle_mesh(fixed_mesh_path)
            # mesh.get_non_manifold_edges()
            # breakpoint()
            # print(fixed_mesh_path, mesh.is_watertight())
            if len(mesh.vertices) < 3000:
                fixed_mesh_paths.append(fixed_mesh_path)
                print(fixed_mesh_path)
            if len(fixed_mesh_paths) >= 100:
                break
    else:
        print(f"Out of domain, using fixed meshes from mesh_dir {mesh_dir}")        
        fixed_mesh_paths = []
        for fixed_mesh_path in glob.glob(os.path.join(mesh_dir, "*.obj")):
            if o3d.io.read_triangle_mesh(fixed_mesh_path).is_watertight():
                fixed_mesh_paths.append(fixed_mesh_path)
            if len(fixed_mesh_paths) >= 25:
                break

    mesh_path_max_len = len(fixed_mesh_paths)
    print(f"using {len(fixed_mesh_paths)} fixed meshes")

    shakey = shakey_module.load_urdf_kinematics(
        urdf_dirs=urdf_dirs,
        models=models,
    )

    @jax.jit
    def fk(q):
        return shakey.FK(q, oriCORN_out=False) # tutil.pq_multi(robot_base_pqc, )

    lower_bound = np.array([-5, -5, 0, -2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
    upper_bound = np.array([5, 5, np.pi * 2, 2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])

    jkey = jax.random.PRNGKey(seed)

    canonical_moving_o3d_meshes = [] # [NOB]
    robot_link_paths = [
        models.asset_path_util.obj_paths[i] for i in shakey.canonical_obj_idx
    ]
    scales = [
        shakey.mesh_scale[i] for i in shakey.canonical_obj_idx
    ]
    canonical_moving_o3d_meshes = [
        o3d.io.read_triangle_mesh(link_path) for link_path in robot_link_paths
    ]
    for mesh, scale in zip(canonical_moving_o3d_meshes, scales):
        mesh.scale(scale, center=(0, 0, 0))

    moving_meshes_q_init = [] # [N, 6]
    moving_meshes_q_goal = [] # [N, 6]
    fixed_mesh_pqcs = []
    fixed_mesh_scales = []
    fixed_mesh_idxs = []
    is_collisions = []
    min_distances = []
    optimal_t = []
    min_link_indexs = []
    elapsed_times = []

    time_optimization = TimeOptimization(
        canonical_moving_o3d_meshes,
        fk,
    )
    collision_detector = ccd.OursCCD(
        models,
        collision_threshold=0.1,
        col_coef=1,
        reduce_k=32,
        is_continuous=True,
        broad_phase_cls=broad_phase.BroadPhaseWarp(),
        broadphase_type="timeoptbf_traj",
        return_collision_loss_pair=True,
    )
    moving_oriCORNs = shakey.link_canonical_oriCORN
    fixed_oriCORNs = get_meshes_as_oricorns(models, fixed_mesh_paths, load_cached_objects_only=True)

    # # visualize
    # vis_path = "/home/rogga/research/efficient_planning/dataset/EGAD/modified/C05_1.obj"
    # vis_idx = fixed_mesh_paths.index(vis_path)
    # breakpoint()
    # fixed_mesh = o3d.io.read_triangle_mesh(fixed_mesh_paths[vis_idx])
    # fixed_mesh.compute_vertex_normals()
    # dec = jax.jit(shakey.models.occ_prediction)
    # fixed_mesh.translate([-0.2, 0, 0])
    # o3d.visualization.draw_geometries(
    #     [
    #         fixed_mesh,
    #         create_scene_mesh_from_oriCORNs(fixed_oriCORNs[vis_idx], dec, visualize=False, qp_bound=0.2, density=400),
    #         create_fps_fcd_from_oriCORNs(fixed_oriCORNs[vis_idx], False),
    #     ]
    # )
    # breakpoint()

    cnt = 0
    data_counts = [[0 for _ in range(len(canonical_moving_o3d_meshes))] for _ in range(2)]
    limit_per_data = data_num // (len(canonical_moving_o3d_meshes) * 2)
    print("limit per data: ", limit_per_data)
    while True:
        cnt += 1
        print(cnt)
        jkey = jax.random.PRNGKey(seed + cnt)
        q_init = jax.random.uniform(jkey, (upper_bound.shape), minval=lower_bound, maxval=upper_bound)
        jkey, _ = jax.random.split(jkey)
        q_goal = jax.random.uniform(jkey, (upper_bound.shape), minval=lower_bound, maxval=upper_bound)
        q_distance = jnp.linalg.norm(q_goal[3:] - q_init[3:])
        base_distance = jnp.linalg.norm(q_goal[:2] - q_init[:2]) * 5 + jnp.abs(q_goal[2] - q_init[2])
        print("q distance:", q_distance, "base distance:", base_distance)
        if base_distance < 30 or q_distance < 3:
            continue

        jkey, _ = jax.random.split(jkey)
        fixed_mesh_index = jax.random.randint(jkey, (), minval=0, maxval=mesh_path_max_len).item()
        jkey, _ = jax.random.split(jkey)

        if in_domain:
            fixed_mesh_scale = jax.random.uniform(jkey, (), minval=0.2, maxval=1.0)
        else:
            fixed_mesh_scale = jax.random.uniform(jkey, (), minval=1, maxval=2)

        jkey, _ = jax.random.split(jkey)
        init_t = jax.random.uniform(jkey, (), minval=0.1, maxval=0.9)
        jkey, _ = jax.random.split(jkey)
        center = jnp.concatenate([(q_goal[:2] * init_t + q_init[:2] * (1-init_t)), jnp.array([0])])

        fixed_mesh_pos = jax.random.uniform(jkey, (3,), minval=center+jnp.array([-1.5, -1.5, 0]), maxval=center+jnp.array([1.5, 1.5, 1]))
        jkey, _ = jax.random.split(jkey)
        fixed_mesh_quat = tutil.qrand(outer_shape=(), jkey=jkey)

        fixed_mesh_path = fixed_mesh_paths[fixed_mesh_index]
        jkey, _ = jax.random.split(jkey)
        init_fixed_mesh_h = tutil.pq2H(fixed_mesh_pos, fixed_mesh_quat)
        fixed_o3d_mesh: o3d.geometry = o3d.io.read_triangle_mesh(fixed_mesh_path)
        fixed_o3d_mesh.scale(fixed_mesh_scale, center=(0, 0, 0))

        # dec = jax.jit(shakey.models.occ_prediction)
        # fixed_o3d_mesh.translate([-0.2, 0, 0])
        # fixed_o3d_mesh.compute_vertex_normals()
        # o3d.visualization.draw_geometries([
        #     create_scene_mesh_from_oriCORNs(
        #         fixed_oriCORNs[fixed_mesh_index].apply_scale(fixed_mesh_scale, center=jnp.array([0,0,0])),
        #         dec,
        #         visualize=False,
        #         qp_bound=0.2
        #     ),
        #     fixed_o3d_mesh,
        # ])

        def call_collision_detector(moving_obj_pqs, fixed_oriCORN, visualize):

            return collision_detector(# [7], [T, NOB, 7]
                moving_oriCORNs,
                moving_obj_pqs[None],
                fixed_oriCORN,
                jkey,
                moving_obj_pqs.shape[0],
                visualize=visualize,
            )
        is_collision, min_distance, fixed_mesh_h, min_t, min_link_index, elapsed_time = generate_row(
            time_optimization,
            fixed_o3d_mesh,
            init_fixed_mesh_h,
            q_init,
            q_goal,
            jkey,

            fixed_oriCORNs[fixed_mesh_index],
            fixed_mesh_scale,
            call_collision_detector,
            shakey,
            visualize = False,
        )
        min_t = min_t.item()
        print(min_t)

        if min_t < 0.1 or min_t > 0.9:
            print("min_t is too close to init or goal, skipping")
            continue
        
        collision_idx = 1 if is_collision else 0
        print(data_counts)
        if min_link_index == 0:
            if data_counts[collision_idx][0] >= limit_per_data:
                print(f"data num reached at {'collision' if is_collision else 'non-collision'} {min_link_index} {data_counts[collision_idx][min_link_index]}")
                continue
        else:
            if sum(data_counts[collision_idx][1:]) >= limit_per_data * (len(canonical_moving_o3d_meshes) - 1):
                print(f"data num reached at {'collision' if is_collision else 'non-collision'} {min_link_index} {data_counts[collision_idx][min_link_index]}")
                continue

        data_counts[collision_idx][min_link_index] += 1

        moving_meshes_q_init.append(q_init)
        moving_meshes_q_goal.append(q_goal)
        fixed_mesh_pqcs.append(tutil.H2pq(fixed_mesh_h, concat=True))
        fixed_mesh_scales.append(fixed_mesh_scale)
        fixed_mesh_idxs.append(fixed_mesh_index)
        is_collisions.append(is_collision)
        min_distances.append(min_distance)
        optimal_t.append(min_t)
        min_link_indexs.append(min_link_index)
        elapsed_times.append(elapsed_time)

        print("data num: ", len(fixed_mesh_idxs), "data_counts: ", data_counts, "elapsed_time: ", elapsed_time)
        if len(fixed_mesh_idxs) >= data_num:
            break

    print("mean elapsed time: ", sum(elapsed_times) / len(elapsed_times))
    moving_meshes_q_init = np.array(jnp.stack(moving_meshes_q_init))
    moving_meshes_q_goal = np.array(jnp.stack(moving_meshes_q_goal))
    fixed_mesh_pqcs = np.array(jnp.stack(fixed_mesh_pqcs))
    fixed_mesh_scales = np.array(jnp.stack(fixed_mesh_scales))
    fixed_mesh_idxs = np.array(fixed_mesh_idxs, dtype=jnp.int32)
    is_collisions = np.array(is_collisions, dtype=bool)
    min_distances = np.array(min_distances)
    optimal_t = np.array(optimal_t)
    min_link_indexs = np.array(min_link_indexs)

    print('collision ratio: ', sum(is_collisions) / len(is_collisions))
    datas = {
        "urdf_dirs": urdf_dirs,
        "moving_mesh": {
            "q_init": moving_meshes_q_init,
            "q_goal": moving_meshes_q_goal,
        },
        "fixed_mesh": {
            "paths": [fixed_mesh_paths[i].replace(f"{models.asset_path_util.asset_base_dir}/", "") for i in fixed_mesh_idxs],
            "pqcs": fixed_mesh_pqcs,
            "scales": fixed_mesh_scales,
        },
        "is_collisions": is_collisions,
        "min_distance": min_distances,
        "optimal_t": optimal_t,
        "min_link_indexs": min_link_indexs,
    }

    save_filename = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S.pkl')
    save_filepath = os.path.join(save_data_dir, save_filename)
    os.makedirs(save_data_dir, exist_ok=True)
    with open(save_filepath, "wb") as f:
        pickle.dump(datas, f)

    # for mesh_path in mesh_paths:
    #     print(mesh_path)
    #     generate_data(
    #         mesh_path,
    #         data_num,
    #         lower_bound,
    #         upper_bound,
    #         interpolate_len,
    #         canonical_meshes,
    #         robot_base_pqc,
    #         fk,
    #         urdf_dirs
    #     )

if __name__ == "__main__":
    main()
