
import numpy as np
import jax.numpy as jnp
import jax

# from warp.jax_experimental import jax_kernel

# import warp as wp
try:
    import warp as wp
    from warp.jax_experimental.ffi import register_ffi_callback, get_jax_device, jax_callable, jax_kernel
    warp_on = True
except:
    warp_on = False
    print('no warp')

import jax
import jax.numpy as jnp
import jax.debug as jdb
import einops
import numpy as np
from functools import partial
from typing import Sequence, Tuple
import os, sys
import time

BASEDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.transform_util as tutil
import util.latent_obj_util as loutil
import util.ev_util.rotm_util as rmutil


def debug_callback(*args):
    print(args)



def SE3_interpolation_coeffs(control_pqc):
    '''
    control_pnts: [..., NT, 7]
    eval_t: [..., NE] - 0~1
    return eval_pqc: [..., NE, 7]
    
    also support linear part only -> end size should be 3
    '''

    NT = control_pqc.shape[-2]
    dt = 1.0/(NT-1)

    if control_pqc.shape[-1] == 7:
        SE3_traj = True
    else:
        SE3_traj = False
    

    twist = tutil.pqc_minus(control_pqc[...,2:,:], control_pqc[...,:-2,:])/(2*dt)
    twist1 = tutil.pqc_minus(control_pqc[...,2:,:], control_pqc[...,1:-1,:])/dt
    twist2 = tutil.pqc_minus(control_pqc[...,:-2,:], control_pqc[...,1:-1,:])/dt
    acc = (twist2 + twist1)/dt

    if SE3_traj:
        # conver to the global frame
        twist = tutil.se3_rot(twist, control_pqc[...,:-2,3:])
        acc = tutil.se3_rot(acc, control_pqc[...,1:-1,3:])

    twist = jnp.concatenate([jnp.zeros_like(twist[...,:1,:]), twist, jnp.zeros_like(twist[...,:1,:])], axis=-2) # velocity in se3 (NT, 6)
    acc = jnp.concatenate([jnp.zeros_like(acc[...,:1,:]), acc, jnp.zeros_like(acc[...,:1,:])], axis=-2) # acceleration in se3 (NT, 6)

    # position, vel, acc constraint for each segment (NT-1)
    # position - 3*NT-3, vel - 3*NT-3, acc - 3*NT-3
    # coefficient matrix A: [3*(NT-1), 6*NT]
    # p(t)=a0+a1t+a2t**2+a4t**4+a5t**5
    # v(t)=a1+2*a2t+4*a4t**3+5*a5t**4
    # a(t)=2*a2+12*a4t+20*a5t**3
    # b=A@x - x = NT X [a0, a1, a2, a3, a4, a5]
    # A = (NT, 6, 6)
    # b = (NT, 6)
    # x = A^-1@b
    # p(0) = a0
    a0 = control_pqc[...,:-1,:]
    a1 = twist[..., :-1, :]
    a2 = acc[..., :-1, :]*0.5

    T_poly = np.array([dt, dt**2, dt**3, dt**4, dt**5]) # (NT, 5)
    A_mat = np.stack([
        T_poly[2:],
        np.array([3, 4, 5]) * T_poly[1:-1],
        np.array([6, 12, 20]) * T_poly[:-2],
    ], axis=-2) # (5, 5)

    A_mat_inv = np.linalg.inv(A_mat)

    if SE3_traj:
        pos_tmp = tutil.se3_rot(tutil.pqc_minus(control_pqc[...,1:,:], a0), a0[...,3:])
    else:
        pos_tmp = control_pqc[...,1:,:] - a0

    b_mat = jnp.stack([
        pos_tmp-a1*dt-a2*dt**2,
        twist[..., 1:, :]-a1-2*a2*dt,
        acc[..., 1:, :]-2*a2,
    ], axis=-1) # (..., NT-1, 6, 5)

    coeffs = jnp.einsum('...ij,...tqj->...tqi', A_mat_inv, b_mat)
    coeffs = jnp.concatenate([a1[...,None], a2[...,None], coeffs], axis=-1)
    return a0, coeffs


def SE3_interpolation_eval(a0, coeffs, eval_t):
    '''
    a0: [..., NT-1, 7]
    coeffs: [..., NT-1, 6, 5]
    eval_t: [..., NE] - 0~1
    return eval_pqc: [..., NE, 7]
    '''

    if a0.shape[-1] == 7:
        SE3_traj = True
    else:
        SE3_traj = False

    NT = a0.shape[-2]+1

    # identify index of segment
    bc_outer_shape = jnp.broadcast_shapes(coeffs.shape[:-3], eval_t.shape[:-1])
    coeffs = jnp.broadcast_to(coeffs, bc_outer_shape + coeffs.shape[-3:])
    eval_t = jnp.broadcast_to(eval_t, bc_outer_shape + eval_t.shape[-1:])

    seg_idx = jnp.floor(eval_t*(NT-1)).astype(jnp.int32) # (..., NE)
    seg_idx = jnp.clip(seg_idx, 0, NT-2)
    coeffs_eval = jnp.take_along_axis(coeffs, seg_idx[...,None,None], axis=-3) # (..., NE, 6, 5)
    a0_eval = jnp.take_along_axis(a0, seg_idx[...,None], axis=-2) # (..., NE, 6)
    eval_t = eval_t - seg_idx.astype(jnp.float32)/(NT-1) # (..., NE)
    eval_T_poly = jnp.stack([jnp.ones_like(eval_t), jnp.ones_like(eval_t), jnp.ones_like(eval_t), eval_t, eval_t**2, eval_t**3, eval_t**4, eval_t**5], axis=-1) # (..., NE, 5)
    eval_p = jnp.sum(coeffs_eval * eval_T_poly[..., None, 3:], axis=-1) # (..., NE, 6)

    if SE3_traj:
        eval_p = tutil.se3_rot(eval_p, tutil.qinv(a0_eval[...,3:]))
        eval_pqc = tutil.pq_multi(a0_eval, tutil.pqc_Exp(eval_p))
    else:
        eval_pqc = a0_eval + eval_p

    vel_eval = jnp.sum(coeffs_eval * jnp.array([1, 2, 3, 4, 5]) * eval_T_poly[...,None,2:-1], axis=-1)
    acc_eval = jnp.sum(coeffs_eval[...,1:] * jnp.array([2, 6, 12, 20]) * eval_T_poly[...,None,2:-2], axis=-1)
    jerk_eval = jnp.sum(coeffs_eval[...,2:] * jnp.array([6, 24, 60]) * eval_T_poly[...,None,2:-3], axis=-1)

    return eval_pqc, vel_eval, acc_eval, jerk_eval



def time_optimization_Newton(interpolation_coeffs, fixed_points, t_prior=None):
    '''
    inputs:
        interpolation_coeffs: [(..., NM, NT-1, 7), (..., NM, NT-1, 6, 5)]
            - defines trajectory in SE3
        fixed_points: (..., NF, 3)
        t_prior: (..., NF, NM) between 0 and 1 - initial time for optimization
    outputs:
        pairwise_dir_AB: (..., NF, NM, 3) - difference between optimized trajectory position and fixed points
        optimal_t: (..., NF, NM) - optimized time
        pqc_final: (..., NF, NM, 7) - optimized SE3
        vel_final: (..., NF, NM, 7) - evaluated velocity at the optimized time
    '''
    # If no initial time is provided, initialize with a constant (e.g. 0.5)
    if t_prior is None:
        NM = interpolation_coeffs[0].shape[-3]
        NF = fixed_points.shape[-2]
        shape = fixed_points.shape[:-1] + (NF, NM)
        t_prior = jnp.full(shape, 0.5)
    
    t_opt = t_prior  # shape: (..., NF, NM)
    num_iters = 50   # maximum number of iterations
    lr = 0.01        # fallback learning rate for gradient descent
    tol = 1e-4       # termination tolerance for max update
    tol_hess = 1e-6  # threshold to decide if Hessian is valid for Newton update
    # tol_hess = 1e-1  # threshold to decide if Hessian is valid for Newton update

    # Expand fixed_points for broadcasting: (..., NF, 3) -> (..., NF, 1, 3)
    fixed_points_expanded = fixed_points[..., :, None, :]

    # Define the state: (iteration count, current t, update_norm)
    state_init = (0, t_opt, jnp.array(jnp.inf))

    def cond_fun(state):
        i, t_current, update_norm = state
        return (i < num_iters) & (update_norm > tol)

    def body_fun(state):
        i, t_current, _ = state
        # Prepare time input for evaluation: shape becomes (..., NF, NM, 1)
        t_eval = t_current[..., None]
        # Evaluate the trajectory and its derivatives at time t_eval.
        pqc_eval, vel, acc, jerk = SE3_interpolation_eval(
            interpolation_coeffs[0][..., None, :, :, :],
            interpolation_coeffs[1][..., None, :, :, :, :],
            t_eval)
        # Remove the extra dimension added for time evaluation.
        pqc_eval, vel, acc, jerk = jax.tree_util.tree_map(lambda x: x.squeeze(-2), 
                                                          (pqc_eval, vel, acc, jerk))
        # Extract the position and linear velocity.
        pos_eval = pqc_eval[..., :3]  # shape: (..., NF, NM, 3)
        lin_vel  = vel[..., :3]       # shape: (..., NF, NM, 3)
        # Compute the difference between evaluated positions and fixed points.
        diff = pos_eval - fixed_points_expanded  # shape: (..., NF, NM, 3)
        # Compute the gradient of the squared distance:
        # f'(t) = 2 * dot(diff, lin_vel)
        grad = 2 * jnp.sum(diff * lin_vel, axis=-1)  # shape: (..., NF, NM)
        # Compute the Hessian (second derivative):
        # f''(t) = 2 * (||lin_vel||^2 + dot(diff, acceleration))
        hess = 2 * (jnp.sum(lin_vel ** 2, axis=-1) +
                    jnp.sum(diff * acc[..., :3], axis=-1))  # shape: (..., NF, NM)
        
        # Compute the update using Newton's method if Hessian is sufficiently positive;
        # otherwise fall back to a gradient descent update.
        new_t = jnp.where(
            hess > tol_hess,
            t_current - grad / hess,   # Newton update
            t_current - lr * grad      # fallback gradient descent update
        )
        # Clip the new time to be within the valid interval [0, 1]
        new_t = jnp.clip(new_t, 0.0, 1.0)
        # Compute maximum absolute change as the update norm.
        update_norm = jnp.max(jnp.abs(new_t - t_current))
        return (i + 1, new_t, update_norm)

    # Run the hybrid Newton/gradient descent loop using a while loop.
    _, t_opt_final, final_update_norm = jax.lax.while_loop(cond_fun, body_fun, jax.lax.stop_gradient(state_init))
    t_opt_final = jax.lax.stop_gradient(t_opt_final)
    final_update_norm = jax.lax.stop_gradient(final_update_norm)
    # _, t_opt_final, final_update_norm = jax.lax.fori_loop(0, num_iters, body_fun, state_init)

    # Evaluate the final SE3 pose and velocity at the optimized time.
    t_final = t_opt_final[..., None]  # shape: (..., NF, NM, 1)
    pqc_final, vel_final, acc_final, _ = SE3_interpolation_eval(
        interpolation_coeffs[0][..., None, :, :, :],
        interpolation_coeffs[1][..., None, :, :, :, :],
        t_final)
    pqc_final, vel_final, acc_final = jax.tree_util.tree_map(lambda x: x.squeeze(-2), (pqc_final, vel_final, acc_final))
    pairwise_dir_AB = pqc_final[..., :3] - fixed_points_expanded

    return pairwise_dir_AB, t_opt_final, pqc_final, vel_final, acc_final


def visualize_in_broad_phase(reduced_A, reduced_B, line_segment_B, time_B, fixed_oriCORN_A, canonical_oriCORNs_B, pqc_path_B):
    import open3d as o3d
    from util.reconstruction_util import create_scene_mesh_from_oriCORNs, create_swept_volume_from_oriCORNs

    fps_tf_A = fixed_oriCORN_A.fps_tf

    fps_path_B = tutil.pq_action(pqc_path_B[...,None,:3], pqc_path_B[...,None,3:], canonical_oriCORNs_B.fps_tf) # (NQ, NAC, NOB, NFPSB, 3)
    fps_path_B_merged = einops.rearrange(fps_path_B, ' ... i j k p -> ... i (j k) p') # (NQ, NAC, NOB*NFPSB, 3)
    fps_path_B_merged_seq = jnp.stack([fps_path_B_merged[...,1:,:,:], fps_path_B_merged[...,:-1,:,:]], axis=-1) # (NQ, NAC, )
    

    for i in range(fps_path_B_merged_seq.shape[0]):
        if reduced_A.ndim == 1:
            rec_mesh_fixed = create_scene_mesh_from_oriCORNs(reduced_A[i], qp_bound=5.0, density=400, ndiv=800, visualize=False)
        else:
            if i==0:
                rec_mesh_fixed = create_scene_mesh_from_oriCORNs(reduced_A, qp_bound=5.0, density=400, ndiv=800, visualize=False)

        link_fps_seq_original_ = fps_path_B_merged_seq[i]
        link_fps_seq_flat = link_fps_seq_original_.reshape(-1, link_fps_seq_original_.shape[-2], link_fps_seq_original_.shape[-1])
        link_fps_seq_flat = jnp.moveaxis(link_fps_seq_flat, -1, -2)
        fixed_idx = np.stack([2*np.arange(link_fps_seq_flat.shape[0]), 2*np.arange(link_fps_seq_flat.shape[0])+1], axis=-1).astype(np.int32)
        line = o3d.geometry.LineSet()
        line.points = o3d.utility.Vector3dVector(link_fps_seq_flat.reshape(-1,3))
        line.lines = o3d.utility.Vector2iVector(fixed_idx)
        line.paint_uniform_color([1, 0, 0])

        pcd_A_entire = o3d.geometry.PointCloud()
        pcd_A_entire.points = o3d.utility.Vector3dVector(fps_tf_A.reshape(-1,3))
        pcd_A_entire.paint_uniform_color([0.01, 0, 0])
        pcd_A = o3d.geometry.PointCloud()
        pcd_A.points = o3d.utility.Vector3dVector(reduced_A.fps_tf[i].reshape(-1,3))
        pcd_A.paint_uniform_color([1, 0, 0])
        pcd_time_opt = o3d.geometry.PointCloud()
        pcd_time_opt.points = o3d.utility.Vector3dVector(reduced_B.fps_tf[i].reshape(-1,3))
        pcd_time_opt.paint_uniform_color([0, 0, 1])
        frames = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6, origin=[0, 0, 0])

        rec_mesh = create_swept_volume_from_oriCORNs(reduced_B[i], line_segment_B[i], time_B[i], qp_bound=1.8, density=200, ndiv=400, visualize=False)

        rec_mesh.paint_uniform_color([0, 1, 1])
        rec_mesh_fixed.paint_uniform_color([1, 1, 0])

        o3d.visualization.draw_geometries([line, pcd_A_entire, pcd_A, pcd_time_opt, rec_mesh, rec_mesh_fixed, frames])
        # o3d.visualization.draw_geometries([line, pcd_A_entire, pcd_A, pcd_time_opt, frames])

        # return [line, pcd_A_entire, pcd_A, pcd_time_opt, rec_mesh, rec_mesh_fixed, frames]
        return [line, pcd_A_entire, pcd_A, pcd_time_opt, frames]


def sample_new_indices(jkey, Aidx, nfps, n_new):
    """
    Sample `n_new` distinct indices from [0, nfps_A)
    which are *not* already present in Aidx.
    """
    # Split the key
    jkey, subkey = jax.random.split(jkey)

    # Convert the existing Aidx into a mask so we know which entries are still "valid" to pick
    # We assume Aidx is 1D here. If Aidx has some batch shape, see the note below on batching.
    valid_mask = jnp.ones((nfps,), dtype=bool).at[Aidx].set(False)

    # Get the actual list of valid indices
    # valid_indices = jnp.where(valid_mask)[0]
    choice_p = valid_mask.astype(jnp.float32)
    choice_p = choice_p/jnp.sum(choice_p)

    # Sample without replacement
    # shape=(n_new,) ensures we get n_new distinct indices from valid_indices
    new_idx = jax.random.choice(subkey, jnp.arange(nfps), shape=(n_new,), p=choice_p, replace=False)

    # Concatenate them to Aidx
    Aidx_updated = jnp.concatenate([Aidx, new_idx], axis=0)
    return jkey, Aidx_updated

def reduce_fps(latent_obj_A:loutil.LatentObjects, latent_obj_B:loutil.LatentObjects, 
               line_segment_B:jnp.ndarray, pq_transform_B:jnp.ndarray, jkey, 
               reduce_k, rot_configs=None, merge=False, debug=False, train=False)->Tuple[loutil.LatentObjects, loutil.LatentObjects, jnp.ndarray, jnp.ndarray]:

    nfps_A_original = latent_obj_A.nfps
    nfps_B_original = latent_obj_B.nfps
    # bradcast latent_obj_A and latent_obj_B
    if merge:
        lo_outer_shape = jnp.broadcast_shapes(latent_obj_A.outer_shape[:-1], latent_obj_B.outer_shape[:-1])
        if pq_transform_B is not None:
            lo_outer_shape = jnp.broadcast_shapes(lo_outer_shape, pq_transform_B.shape[:-2])
        nouter_dim = len(lo_outer_shape) + 1
    else:
        lo_outer_shape = jnp.broadcast_shapes(latent_obj_A.outer_shape, latent_obj_B.outer_shape)
        if pq_transform_B is not None:
            lo_outer_shape = jnp.broadcast_shapes(lo_outer_shape, pq_transform_B.shape[:-1])
        nouter_dim = len(lo_outer_shape)

    for _ in range(nouter_dim - latent_obj_A.ndim):
        latent_obj_A = latent_obj_A[None]
    for _ in range(nouter_dim - latent_obj_B.ndim):
        latent_obj_B = latent_obj_B[None]

    fps_dist_A = latent_obj_A.mean_fps_dist[...,None].repeat(latent_obj_A.nfps, -1) # (NOA, nfpsA)
    fps_dist_B = latent_obj_B.mean_fps_dist[...,None].repeat(latent_obj_B.nfps, -1) # (NOB, nfpsB)
    fps_dist_A = jax.lax.stop_gradient(fps_dist_A)
    fps_dist_B = jax.lax.stop_gradient(fps_dist_B)

    fps_tf_A = latent_obj_A.fps_tf
    if pq_transform_B is not None:
        fps_tf_B = tutil.pq_action(pq_transform_B[...,None,:], latent_obj_B.fps_tf)
    else:
        fps_tf_B = latent_obj_B.fps_tf
    if merge:
        fps_tf_A = einops.rearrange(fps_tf_A, '... n d i -> ... (n d) i')
        fps_tf_B = einops.rearrange(fps_tf_B, '... n d i -> ... (n d) i')
        fps_dist_A = einops.rearrange(fps_dist_A, '... n i -> ... (n i)')
        fps_dist_B = einops.rearrange(fps_dist_B, '... n i -> ... (n i)')
    
    # pruning pairs
    if line_segment_B is not None:
        if line_segment_B.ndim != fps_tf_B.ndim:
            line_segment_B = line_segment_B[...,None,:]
        line_segment_B = jnp.broadcast_to(line_segment_B, fps_tf_B.shape)
        pairwise_dist_sq, pairwise_pnt_on_line_B, t_clamped, pairwise_line_AB = segment_point_min_distance_and_points(fps_tf_A, fps_tf_B, line_segment_B)
    else:
        pairwise_line_AB = fps_tf_A[...,:,None,:] - fps_tf_B[...,None,:,:]
        pairwise_dist_sq = jnp.linalg.norm(pairwise_line_AB, axis=-1)
        # pairwise_dist_sq = jnp.sum((pairwise_line_AB)**2, axis=-1)
        t_clamped = None
    
    pairwise_dist_sq = pairwise_dist_sq - fps_dist_A[...,None] - fps_dist_B[...,None,:]

    _, Aidx = jax.lax.top_k(-jnp.min(pairwise_dist_sq, axis=-1), np.minimum(reduce_k//2, latent_obj_A.nfps))
    if line_segment_B is not None:
        argmin_on_A = jnp.argmin(pairwise_dist_sq, axis=-2)
        min_dist_A = jnp.take_along_axis(pairwise_dist_sq, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
        _, Bidx = jax.lax.top_k(-min_dist_A, np.minimum(reduce_k//2, latent_obj_B.nfps))
    else:
        _, Bidx = jax.lax.top_k(-jnp.min(pairwise_dist_sq, axis=-2), np.minimum(reduce_k//2, latent_obj_B.nfps))
    
    if train:
        # random augmentation index, which is not included in Aidx and Bidx
        nfps_A = fps_tf_A.shape[-2]
        nfps_B = fps_tf_B.shape[-2]
        
        if nfps_A != 1:
            original_outer_shape = Aidx.shape[:-1]
            _, Aidx = jax.vmap(partial(sample_new_indices, nfps=nfps_A, n_new=nfps_A//8))(jax.random.split(jkey, np.prod(original_outer_shape)), Aidx.reshape(-1, Aidx.shape[-1]))
            Aidx = Aidx.reshape(*original_outer_shape, Aidx.shape[-1])
            jkey, subkey = jax.random.split(jkey)
        if nfps_B != 1:
            original_outer_shape = Bidx.shape[:-1]
            _, Bidx = jax.vmap(partial(sample_new_indices, nfps=nfps_B, n_new=nfps_B//8))(jax.random.split(jkey, np.prod(original_outer_shape)), Bidx.reshape(-1, Bidx.shape[-1]))
            Bidx = Bidx.reshape(*original_outer_shape, Bidx.shape[-1])
            jkey, subkey = jax.random.split(jkey)

    # lazy transformation for efficiency
    fps_tf_reducedA = jnp.take_along_axis(fps_tf_A, Aidx[...,None], axis=-2)
    if line_segment_B is not None:

        t_clamped_pairwise = jnp.take_along_axis(t_clamped, Aidx[...,None], axis=-2)
        t_clamped_pairwise = jnp.take_along_axis(t_clamped_pairwise, Bidx[...,None,:], axis=-1)

        # reduce_B_Aidx = jnp.take_along_axis(argmin_on_A, Aidx, axis=-1)
        fps_tf_reducedB = jnp.take_along_axis(pairwise_pnt_on_line_B, argmin_on_A[...,None,:,None], axis=-3).squeeze(-3)
        fps_tf_reducedB = jnp.take_along_axis(fps_tf_reducedB, Bidx[...,None], axis=-2)

        line_segment_B = jnp.take_along_axis(line_segment_B, Bidx[...,None], axis=-2)
    else:
        t_clamped_pairwise = None
        fps_tf_reducedB = jnp.take_along_axis(fps_tf_B, Bidx[...,None], axis=-2)
    # fps_tf_reducedB = jnp.take_along_axis(fps_tf_B, Bidx[...,None], axis=-2)

    pairwise_line_AB_reduced = jnp.take_along_axis(pairwise_line_AB, Aidx[...,None,None], axis=-3)
    pairwise_line_AB_reduced = jnp.take_along_axis(pairwise_line_AB_reduced, Bidx[...,None,:,None], axis=-2)

    # (4, 11, 64) -> (4*11*64) / (11, 64)
    # nfpsA = latent_obj_A.z.shape[-3]
    # nfpsB = latent_obj_B.z.shape[-3]
    if merge:
        z_A = einops.rearrange(latent_obj_A.z, '... n d i f -> ... (n d) i f')
        z_B = einops.rearrange(latent_obj_B.z, '... n d i f -> ... (n d) i f')
    else:
        z_A = latent_obj_A.z
        z_B = latent_obj_B.z

    latent_obj_A_reduced:loutil.LatentObjects = latent_obj_A.replace(z=jnp.take_along_axis(z_A, Aidx[...,None,None], axis=-3),
                                                                     pos=jnp.mean(fps_tf_reducedA, axis=-2)).set_fps_tf(fps_tf_reducedA)
    z_B = jnp.take_along_axis(z_B, Bidx[...,None,None], axis=-3)
    if pq_transform_B is not None:
        if merge:
            pq_transform_B = jnp.take_along_axis(pq_transform_B, Bidx[...,None]//nfps_B_original, axis=-2)
        z_B = rmutil.apply_rot(z_B, tutil.q2R(pq_transform_B[...,3:]), rot_configs)
    latent_obj_B_reduced:loutil.LatentObjects = latent_obj_B.replace(z=z_B, pos=jnp.mean(fps_tf_reducedB, axis=-2)).set_fps_tf(fps_tf_reducedB)
    
    if debug:
        fps_tf_A_vis = fps_tf_A
        # fps_tf_A_vis = (fps_tf_A + translation[...,None,:])
        # fps_tf_A_vis = fps_tf_A_vis/scales_global[...,None,None]

        fps_tf_B_vis = fps_tf_B
        # fps_tf_B_vis = (fps_tf_B + translation[...,None,:])
        # fps_tf_B_vis = fps_tf_B_vis/scales_global[...,None,None]

        if len(fps_tf_A_vis.shape) >= 3:
            fps_tf_A_vis = fps_tf_A_vis[0].reshape(-1,3)
            fps_tf_B_vis = fps_tf_B_vis[0].reshape(-1,3)
        # fps_tf_A_vis = fps_tf_A.reshape(-1,3)
        
        # fps_tf_B_vis = fps_tf_B.reshape(-1,3)

        from util.reconstruction_util import create_scene_mesh_from_oriCORNs
        import open3d as o3d

        obj1_fps_o3d_full = o3d.geometry.PointCloud()
        obj1_fps_o3d_full.points = o3d.utility.Vector3dVector(np.array(fps_tf_A_vis))
        obj1_fps_o3d_full.paint_uniform_color([1, 0, 0])
        obj2_fps_o3d_full = o3d.geometry.PointCloud()
        obj2_fps_o3d_full.points = o3d.utility.Vector3dVector(np.array(fps_tf_B_vis))
        obj2_fps_o3d_full.paint_uniform_color([0, 1, 0])

        fps_AB_tf = latent_obj_AB_reduced.fps_tf
        if len(fps_AB_tf.shape) == 4:
            fps_AB_tf = fps_AB_tf[0]
        obj1_fps_o3d = o3d.geometry.PointCloud()
        obj1_fps_o3d.points = o3d.utility.Vector3dVector(fps_AB_tf[0])
        obj1_fps_o3d.paint_uniform_color([1, 0, 1])
        obj2_fps_o3d = o3d.geometry.PointCloud()
        obj2_fps_o3d.points = o3d.utility.Vector3dVector(fps_AB_tf[1])
        obj2_fps_o3d.paint_uniform_color([0, 1, 1])

        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6, origin=[0, 0, 0])
        o3d.visualization.draw_geometries([obj1_fps_o3d_full, obj2_fps_o3d_full, obj1_fps_o3d, obj2_fps_o3d, mesh_frame])

    return latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B

def segment_point_min_distance_and_points(query_points,
                                          line_starts,
                                          line_segments):
    """
    Compute the minimum distance from each query point to each line segment.

    Args:
      query_points:  jnp.ndarray of shape (Q1, 3)
      line_starts:   jnp.ndarray of shape (1, Q2, 3)
                     Starting point of each segment
      line_segments: jnp.ndarray of shape (1, Q2, 3)
                     Vector from start to end of each segment

    Returns:
      min_distances: jnp.ndarray of shape (Q1, Q2)
                     The distance of each query point to each segment
      min_points:    jnp.ndarray of shape (Q1, Q2, 3)
                     The closest point on each segment to each query point
        t_clamped:     jnp.ndarray of shape (Q1, Q2)
                        The value of t that minimizes the distance between [0, 1]
    """

    # Expand dims for broadcasting:
    # query_points -> (Q1, 1, 3)
    # line_starts  -> (1, Q2, 3)
    # line_segments-> (1, Q2, 3)
    Q = query_points[...,:, None, :]      # (Q1, 1, 3)
    LS = line_starts[...,None,:,:]                  # (1, Q2, 3)
    LV = line_segments[...,None,:,:]                # (1, Q2, 3)

    # Vector from segment start to the query point
    QL = Q - LS                       # (Q1, Q2, 3)

    # Compute unconstrained t in [0,1]
    # t = ((Q - LS) dot LV) / (LV dot LV)
    numer = jnp.sum(QL * LV, axis=-1)   # (Q1, Q2)
    denom = jnp.sum(LV * LV, axis=-1)   # (1,  Q2)

    is_zero = jnp.isclose(denom,0.)
    # temporarily swap x with ones if is_zero, then swap back
    denom = jnp.where(is_zero, jnp.ones_like(denom), denom)
    t = numer / denom                   # (Q1, Q2)

    # Clamp t to [0, 1] to stay within the segment
    t_clamped = jnp.clip(t, 0.0, 1.0)   # (Q1, Q2)
    # t_clamped = jnp.where(jnp.isnan(t_clamped), 0.0, t_clamped)
    t_clamped = jnp.where(is_zero, 0.0, t_clamped)

    # Expand t to shape (Q1, Q2, 1) for broadcasting
    t_expanded = t_clamped[..., None]   # (Q1, Q2, 1)

    # Closest point on the segment to each query point
    min_points = LS + t_expanded * LV   # (Q1, Q2, 3)

    # Distance from query point to this closest point
    closest_dif = Q - min_points
    min_distances = jnp.linalg.norm(closest_dif, axis=-1)  # (Q1, Q2)

    return min_distances, min_points, t_clamped, closest_dif

def pairwise_closest_pnts_path(fps_pnts_A, fps_anchor_pnts_B):
    '''
    fps_pnts_A: (nfpsA, 3)
    fps_anchor_pnts_B: (nac, nfpsB, 3)


    '''
    nfpsA = fps_pnts_A.shape[-2]
    nfpsB = fps_anchor_pnts_B.shape[-2]
    nac = fps_anchor_pnts_B.shape[-3]

    segment_direction = fps_anchor_pnts_B[...,1:,:,:] - fps_anchor_pnts_B[...,:-1,:,:]
    start_fps_B = fps_anchor_pnts_B[...,:-1,:,:]

    min_distances, min_points, t_clamped, closest_dif = segment_point_min_distance_and_points(fps_pnts_A,
                                          einops.rearrange(start_fps_B, '... i j k -> ... (i j) k'),
                                          einops.rearrange(segment_direction, '... i j k -> ... (i j) k'))
    # (nfpsA, (nac-1)*nfpsB)

    min_distances_anchor = einops.rearrange(min_distances, '... (i j) -> ... i j', j=nfpsB) # (nfpsA, nac-1, nfpsB)
    pairwise_min_anchor_idx = jnp.argmin(min_distances_anchor, axis=-2) # (nfpsA, nfpsB)

    pairwise_min_dist = jnp.take_along_axis(min_distances_anchor, pairwise_min_anchor_idx[...,None,:], axis=-2).squeeze(-2) # (nfpsA, nfpsB)

    pairwise_t_clamped = einops.rearrange(t_clamped, '... (i j) -> ... i j', j=nfpsB)
    pairwise_t_clamped = jnp.take_along_axis(pairwise_t_clamped, pairwise_min_anchor_idx[...,None,:], axis=-2).squeeze(-2) # (nfpsA, nfpsB)

    pairwise_min_points = einops.rearrange(min_points, '... (i j) k -> ... i j k', j=nfpsB)
    pairwise_min_points = jnp.take_along_axis(pairwise_min_points, pairwise_min_anchor_idx[...,None,:,None], axis=-3).squeeze(-3) # (nfpsA, nfpsB, 3)

    pairwise_closest_dif = einops.rearrange(closest_dif, '... (i j) k -> ... i j k', j=nfpsB)
    pairwise_closest_dif = jnp.take_along_axis(pairwise_closest_dif, pairwise_min_anchor_idx[...,None,:,None], axis=-3).squeeze(-3) # (nfpsA, nfpsB, 3)

    return pairwise_min_dist, pairwise_min_points, pairwise_t_clamped, pairwise_closest_dif, pairwise_min_anchor_idx

def reduce_fps_path(fixed_oriCORNs_A:loutil.LatentObjects, canonical_oriCORNs_B:loutil.LatentObjects, pqc_path_B, reduce_k, rot_config, visualize=False):
    '''
    fixed_oriCORNs_A: (NOA, )
    canonical_oriCORNs_B: (NOB, )
    pqc_path_B: (NQ, NAC, NOB, 7) pos and quat
    '''

    nob = canonical_oriCORNs_B.nobj
    nfps_B = canonical_oriCORNs_B.nfps

    # fps_dst
    # fps_dist_A = fixed_oriCORNs_A.fps_dist # (NOA, nfpsA)
    # fps_dist_B = canonical_oriCORNs_B.fps_dist # (NOB, nfpsB)
    fps_dist_A = fixed_oriCORNs_A.mean_fps_dist[...,None].repeat(fixed_oriCORNs_A.nfps, -1) # (NOA, nfpsA)
    fps_dist_B = canonical_oriCORNs_B.mean_fps_dist[...,None].repeat(canonical_oriCORNs_B.nfps, -1) # (NOB, nfpsB)
    fps_dist_A = jax.lax.stop_gradient(fps_dist_A)
    fps_dist_B = jax.lax.stop_gradient(fps_dist_B)

    fixed_oriCORNs_A_merged = fixed_oriCORNs_A.merge() # (,)
    if fps_dist_A.ndim >= 2:
        fps_dist_A_merged = einops.rearrange(fps_dist_A, '... o f -> ... (o f)') # (NOA*nfpsA,)
    else:
        fps_dist_A_merged = fps_dist_A

    outer_dim = np.maximum(pqc_path_B.ndim - 3, fixed_oriCORNs_A_merged.ndim)
    outer_dim = np.maximum(outer_dim, canonical_oriCORNs_B.ndim-1)

    for _ in range(outer_dim - fixed_oriCORNs_A_merged.ndim):
        fixed_oriCORNs_A_merged = fixed_oriCORNs_A_merged[None]
        fps_dist_A_merged = fps_dist_A_merged[None]
    for _ in range(outer_dim - canonical_oriCORNs_B.ndim+1):
        canonical_oriCORNs_B = canonical_oriCORNs_B[None]
        fps_dist_B = fps_dist_B[None]
    for _ in range(outer_dim - pqc_path_B.ndim - 3):
        pqc_path_B = pqc_path_B[None]

    fps_tf_A = fixed_oriCORNs_A_merged.fps_tf

    # calculate pairwise informations
    fps_path_B = tutil.pq_action(pqc_path_B[...,None,:3], pqc_path_B[...,None,3:], canonical_oriCORNs_B.fps_tf) # (NQ, NAC, NOB, NFPSB, 3)
    fps_path_B_merged = einops.rearrange(fps_path_B, ' ... i j k p -> ... i (j k) p') # (NQ, NAC, NOB*NFPSB, 3)

    pairwise_dist_sq, pairwise_pnt_on_line_B, t_clamped, pairwise_line_AB, pairwise_min_anchor_idx \
          = pairwise_closest_pnts_path(fps_tf_A, fps_path_B_merged)
    # (NQ, NFPSA, NOB*NFPSB)
    pairwise_dist_sq = pairwise_dist_sq - fps_dist_A_merged[...,None]
    pairwise_dist_sq = pairwise_dist_sq - einops.rearrange(fps_dist_B[...,None,:,:], '... i j -> ... (i j)')

    nfpsA, nofpsB = pairwise_dist_sq.shape[-2:]

    _, Aidx = jax.lax.top_k(-jnp.min(pairwise_dist_sq, axis=-1), np.minimum(reduce_k//2, nfpsA))
    argmin_on_A = jnp.argmin(pairwise_dist_sq, axis=-2)
    min_dist_A = jnp.take_along_axis(pairwise_dist_sq, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
    _, Bidx = jax.lax.top_k(-min_dist_A, np.minimum(reduce_k//2, nofpsB))

    pairwise_min_anchor_idx_reduced = jnp.take_along_axis(pairwise_min_anchor_idx, Aidx[...,None], axis=-2)
    pairwise_min_anchor_idx_reduced = jnp.take_along_axis(pairwise_min_anchor_idx_reduced, Bidx[...,None,:], axis=-1)  # (NQ, NFPSA_reduced, NFPSB_reduced)

    pairwise_min_anchor_idx_reduced_B = jnp.take_along_axis(pairwise_min_anchor_idx, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
    pairwise_min_anchor_idx_reduced_B = jnp.take_along_axis(pairwise_min_anchor_idx_reduced_B, Bidx, axis=-1) # (NQ, NFPSB_reduced)

    # lazy transformation for efficiency
    fps_tf_reducedA = jnp.take_along_axis(fps_tf_A, Aidx[...,None], axis=-2)

    t_clamped_pairwise = jnp.take_along_axis(t_clamped, Aidx[...,None], axis=-2)
    t_clamped_pairwise = jnp.take_along_axis(t_clamped_pairwise, Bidx[...,None,:], axis=-1)

    t_clamped_reduce_B = jnp.take_along_axis(t_clamped, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
    t_clamped_reduce_B = jnp.take_along_axis(t_clamped_reduce_B, Bidx, axis=-1) # (NQ, NFPSB_reduced)

    fps_tf_reducedB = jnp.take_along_axis(pairwise_pnt_on_line_B, argmin_on_A[...,None,:,None], axis=-3).squeeze(-3)
    fps_tf_reducedB = jnp.take_along_axis(fps_tf_reducedB, Bidx[...,None], axis=-2)

    pairwise_line_AB_reduced = jnp.take_along_axis(pairwise_line_AB, Aidx[...,None,None], axis=-3)
    pairwise_line_AB_reduced = jnp.take_along_axis(pairwise_line_AB_reduced, Bidx[...,None,:,None], axis=-2)

    # pqc_select_B = jnp.take_along_axis(pqc_path_B[...,None,None,:], pairwise_min_anchor_idx[...,None,:,:,:,None], axis=-4).squeeze(-4) # (NQ, NOB, NFPSA, NFPSB, 7)
    # pqc_select_B = einops.rearrange(pqc_select_B, '... i j k l -> ... j (i k) l') # (NQ, NFPSA, NOB*NFPSB, 7)
    # pqc_select_B = jnp.take_along_axis(pqc_select_B, argmin_on_A[...,None,:,None], axis=-3).squeeze(-3)
    # pqc_select_B = jnp.take_along_axis(pqc_select_B, Bidx[...,None], axis=-2) # (NQ, NFPSB_reduced, 7)
    # pqc_select_B = jnp.take_along_axis(pqc_path_B[...,None,:,:,:], pairwise_min_anchor_idx_reduced[...,None,:,:,None], axis=-3).squeeze(-3)

    pqc_path_B_reduced = jnp.take_along_axis(pqc_path_B, Bidx[...,None,:,None]//nfps_B, axis=-2) # (NQ, NAC, NFPSB_reduced, 7)
    pqc_path_B_reduced_end = jnp.take_along_axis(pqc_path_B_reduced, pairwise_min_anchor_idx_reduced_B[...,None,:,None]+1, axis=-3).squeeze(-3) # (NQ, NFPSB_reduced, 7)
    pqc_path_B_reduced = jnp.take_along_axis(pqc_path_B_reduced, pairwise_min_anchor_idx_reduced_B[...,None,:,None], axis=-3).squeeze(-3) # (NQ, NFPSB_reduced, 7)

    qdir = tutil.qmulti(tutil.qinv(pqc_path_B_reduced[...,3:]), pqc_path_B_reduced_end[...,3:])
    qdir = tutil.qExp(t_clamped_reduce_B[...,None]*tutil.qLog(qdir))
    quat_path_B_reduced_mid = tutil.qmulti(pqc_path_B_reduced[...,3:], qdir)

    # (NOB*NFPSB, 3) / (NQ, NFPSB_reduced)
    fps_start_pnts_on_B = jnp.take_along_axis(einops.rearrange(canonical_oriCORNs_B.fps_tf, '... i j k -> ... (i j) k'), Bidx[...,None], axis=-2)
    fps_end_pnts_on_B = fps_start_pnts_on_B
    fps_end_pnts_on_B = tutil.pq_action(pqc_path_B_reduced_end[...,:3], pqc_path_B_reduced_end[...,3:], fps_end_pnts_on_B) # (NQ, NFPSB_reduced, 3)
    fps_start_pnts_on_B = tutil.pq_action(pqc_path_B_reduced[...,:3], pqc_path_B_reduced[...,3:], fps_start_pnts_on_B) # (NQ, NFPSB_reduced, 3)
    line_segment_B = fps_end_pnts_on_B - fps_start_pnts_on_B

    selected_z_B = einops.rearrange(canonical_oriCORNs_B.z, '... i k j p -> ... (i k) j p') # (NOB*NFPSB, NZ1 NZ2)
    selected_z_B = jnp.take_along_axis(selected_z_B, Bidx[...,None,None], axis=-3) # (NQ, NFPSB_reduced, NZ1 NZ2)
    selected_z_B = rmutil.apply_rot(selected_z_B, tutil.q2R(quat_path_B_reduced_mid), rot_config, feature_axis=-2)

    # (4, 11, 64) -> (4*11*64) / (11, 64)
    latent_obj_A_reduced:loutil.LatentObjects = fixed_oriCORNs_A_merged.replace(z=jnp.take_along_axis(fixed_oriCORNs_A_merged.z, Aidx[...,None,None], axis=-3),
                                                rel_fps=fps_tf_reducedA-fixed_oriCORNs_A_merged.pos[...,None,:])
    latent_obj_B_reduced:loutil.LatentObjects = loutil.LatentObjects(z=selected_z_B, pos=None, rel_fps=fps_tf_reducedB).init_pos_zero()

    if visualize:
        pass

    return latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B




def compute_curvature(p_prime, p_double_prime):
    """
    Compute curvature given first and second derivatives.

    p_prime: velocity vector at t (NB, D)
    p_double_prime: acceleration vector at t (NB, D)

    Returns:
        curvature (NB,)
    """
    cross_norm = jnp.linalg.norm(jnp.cross(p_prime, p_double_prime), axis=-1)
    speed_norm_cubed = jnp.linalg.norm(p_prime, axis=-1)**3 + 1e-8  # stability
    curvature = cross_norm / speed_norm_cubed
    return curvature

def adaptive_segment_length(p_prime, p_double_prime, L_max=0.1, alpha=10.0):
    """
    Adapt linear approximation length based on curvature.

    p_prime: velocity at t (NB, D)
    p_double_prime: acceleration at t (NB, D)

    alpha: controls sensitivity:
        Larger alpha -> quickly shinks lengh when curvature increases

    Returns:
        adaptive length L(t) (NB,)
    """
    kappa = compute_curvature(p_prime, p_double_prime)
    L_t = L_max / (1 + alpha * kappa)
    return L_t


def reduce_fps_path_zoom(fixed_oriCORNs_A:loutil.LatentObjects, canonical_oriCORNs_B:loutil.LatentObjects,
                         pqc_path_B, reduce_k, rot_config, jkey, broadphase_func=None, visualize=False):
    '''
    fixed_oriCORNs_A: (NOA, )
    canonical_oriCORNs_B: (NOB, )
    pqc_path_B: (NQ, NAC, NOB, 7) pos and quat
    '''

    nob = canonical_oriCORNs_B.nobj
    nfps_B = canonical_oriCORNs_B.nfps
    NAC = pqc_path_B.shape[-3]

    # fps_dst
    # fps_dist_A = fixed_oriCORNs_A.mean_fps_dist[...,None].repeat(fixed_oriCORNs_A.nfps, -1) # (NOA, nfpsA)
    # fps_dist_B = canonical_oriCORNs_B.mean_fps_dist[...,None].repeat(canonical_oriCORNs_B.nfps, -1) # (NOB, nfpsB)

    fps_dist_A = 0.2*fixed_oriCORNs_A.mean_fps_dist[...,None].repeat(fixed_oriCORNs_A.nfps, -1) # (NOA, nfpsA)
    fps_dist_B = 0.2*canonical_oriCORNs_B.mean_fps_dist[...,None].repeat(canonical_oriCORNs_B.nfps, -1) # (NOB, nfpsB)

    fixed_oriCORNs_A_merged = fixed_oriCORNs_A.merge() # (,)
    if fps_dist_A.ndim >= 2:
        fps_dist_A_merged = einops.rearrange(fps_dist_A, '... o f -> ... (o f)') # (NOA*nfpsA,)
    else:
        fps_dist_A_merged = fps_dist_A

    outer_dim = np.maximum(pqc_path_B.ndim - 3, fixed_oriCORNs_A_merged.ndim)
    outer_dim = np.maximum(outer_dim, canonical_oriCORNs_B.ndim-1)

    for _ in range(outer_dim - fixed_oriCORNs_A_merged.ndim):
        fixed_oriCORNs_A_merged = fixed_oriCORNs_A_merged[None]
        fps_dist_A_merged = fps_dist_A_merged[None]
    for _ in range(outer_dim - canonical_oriCORNs_B.ndim+1):
        canonical_oriCORNs_B = canonical_oriCORNs_B[None]
        fps_dist_B = fps_dist_B[None]
    for _ in range(outer_dim - pqc_path_B.ndim - 3):
        pqc_path_B = pqc_path_B[None]

    # calculate pairwise informations
    # broadphase_func = None
    if broadphase_func is None:
        fps_tf_A = fixed_oriCORNs_A_merged.fps_tf
        fps_path_B = tutil.pq_action(pqc_path_B[...,None,:3], pqc_path_B[...,None,3:], canonical_oriCORNs_B.fps_tf) # (NQ, NAC, NOB, NFPSB, 3)
        fps_path_B_merged = einops.rearrange(fps_path_B, ' ... i j k p -> ... i (j k) p') # (NQ, NAC, NOB*NFPSB, 3)

        pairwise_dist_sq, pairwise_pnt_on_line_B, t_clamped, pairwise_line_AB, pairwise_min_anchor_idx \
            = pairwise_closest_pnts_path(fps_tf_A, fps_path_B_merged)
        
        pairwise_dist_sq = pairwise_dist_sq - fps_dist_A_merged[...,None]
        pairwise_dist_sq = pairwise_dist_sq - einops.rearrange(fps_dist_B[...,None,:,:], '... i j -> ... (i j)')
        pairwise_dist_sq = jax.lax.stop_gradient(pairwise_dist_sq)
        t_clamped = jax.lax.stop_gradient(t_clamped)
        pairwise_min_anchor_idx = jax.lax.stop_gradient(pairwise_min_anchor_idx)

        nfpsA, nofpsB = pairwise_dist_sq.shape[-2:]

        min_A = jnp.min(pairwise_dist_sq, axis=-1)
        # min_A = jnp.where(min_A < 0, min_A-10+jax.random.normal(jkey, shape=min_A.shape)*0.040, min_A)
        _, Aidx = jax.lax.top_k(-min_A, np.minimum(reduce_k//2, nfpsA))
        argmin_on_A = jnp.argmin(pairwise_dist_sq, axis=-2)
        min_dist_A = jnp.take_along_axis(pairwise_dist_sq, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
        jkey, _ = jax.random.split(jkey)
        # min_dist_A = jnp.where(min_dist_A < 0, min_dist_A-10+jax.random.normal(jkey, shape=min_dist_A.shape)*0.040, min_dist_A)
        _, Bidx = jax.lax.top_k(-min_dist_A, np.minimum(reduce_k//2, nofpsB)) # (NQ, NFPSB_reduced)
        Aidx = jax.lax.stop_gradient(Aidx)
        Bidx = jax.lax.stop_gradient(Bidx)

        pairwise_min_anchor_idx_reduced = jnp.take_along_axis(pairwise_min_anchor_idx, Aidx[...,None], axis=-2)
        pairwise_min_anchor_idx_reduced = jnp.take_along_axis(pairwise_min_anchor_idx_reduced, Bidx[...,None,:], axis=-1)  # (NQ, NFPSA_reduced, NFPSB_reduced)

        # lazy transformation for efficiency
        fps_tf_reducedA = jnp.take_along_axis(fps_tf_A, Aidx[...,None], axis=-2)

        t_clamped_pairwise = jnp.take_along_axis(t_clamped, Aidx[...,None], axis=-2)
        t_clamped_pairwise = jnp.take_along_axis(t_clamped_pairwise, Bidx[...,None,:], axis=-1) # (NQ, NFPSA_reduced, NFPSB_reduced)

        pqc_path_B_reduced = jnp.take_along_axis(pqc_path_B, Bidx[...,None,:,None]//nfps_B, axis=-2) # (NQ, NAC, NFPSB_reduced, 7)
        pqc_path_B_quat = pqc_path_B_reduced[...,3:]

        # compensation with interpolation
        # fps_path_B_merged : (NQ, NAC, NOB*NFPSB, 3)
        # Bidx : (NQ, NFPSB_reduced)
        fps_path_B_reduced = jnp.take_along_axis(fps_path_B_merged, Bidx[...,None,:,None], axis=-2) # (NQ, NAC, NFPSB_reduced, 3)
        fps_path_B_pqc = jnp.concat([fps_path_B_reduced, pqc_path_B_quat], axis=-1)
        t_eval = pairwise_min_anchor_idx_reduced.astype(jnp.float32)/(NAC-1) + t_clamped_pairwise.astype(jnp.float32)/(NAC-1) # (NQ, NFPSA_reduced, NFPSB_reduced)
    else:
        fps_path_B_pqc, fps_tf_reducedA, Aidx, Bidx, t_eval = broadphase_func(pqc_path_B, canonical_oriCORNs_B, reduce_k//2, fixed_oriCORNs_A, visualize=visualize) # (NQ, NFPSA, NOB*NFPSB)

    fps_path_B_control = jnp.moveaxis(fps_path_B_pqc, -3, -2) # (NQ, NFPSB_reduced, NAC, 7)
    pqc_path_B_coeffs = SE3_interpolation_coeffs(fps_path_B_control) # (NQ, NFPSB_reduced, NAC-1, ...)
    pairwise_line_AB_reduced, t_clamped_pairwise_normalized, fps_pqc_compensated, twist, acc = \
        time_optimization_Newton(pqc_path_B_coeffs, fps_tf_reducedA, t_eval) # (NQ, NFPSA_reduced, NFPSB_reduced, ...)

    # # add mid node for gradient compensation
    # NT = fps_path_B_pqc.shape[-3]
    # seg_idx = jnp.floor(t_clamped_pairwise_normalized*(NT-1)).astype(jnp.int32) # (... NFPSA_reduced, NFPSB_reduced)
    # seg_idx = jnp.clip(seg_idx, 0, NT-2)
    # seg_idx = jnp.moveaxis(seg_idx, -1, -2) # (... NFPSB_reduced, NFPSA_reduced)
    # fps_node_start = jnp.take_along_axis(fps_path_B_control, seg_idx[...,None], axis=-2)
    # fps_node_end = jnp.take_along_axis(fps_path_B_control, seg_idx[...,None]+1, axis=-2)
    # t_in_dt = (t_clamped_pairwise_normalized - seg_idx)/(1.0/(NT-1))
    # t_in_dt = jnp.clip(t_in_dt, 0, 1)
    # fps_node_mid = fps_node_end*(t_in_dt[...,None]) + fps_node_start*(1-t_in_dt[...,None])
    # fps_node_mid = jnp.moveaxis(fps_node_mid, -2, -3)[...,:3] # (... NFPSA_reduced, NFPSB_reduced, 7)
    # pairwise_line_AB_reduced = jax.lax.stop_gradient(pairwise_line_AB_reduced) + fps_node_mid - jax.lax.stop_gradient(fps_node_mid)
    # fps_pqc_compensated = fps_pqc_compensated.at[...,:3].set(jax.lax.stop_gradient(fps_pqc_compensated[...,:3]) + fps_node_mid - jax.lax.stop_gradient(fps_node_mid))

    pairwise_line_AB_reduced = -pairwise_line_AB_reduced # this should be actually B to A

    fps_tf_reducedB = fps_pqc_compensated[...,:3] # (NQ, NFPSA_reduced, NFPSB_reduced, 3)

    # generate line segment and times
    finite_len = adaptive_segment_length(twist[...,:3], acc[...,:3], L_max=0.08, alpha=2.0)[...,None]
    projection_t_len = 0.1
    # finite_len = 0.03
    # projection_t_len = 0.07
    # finite_len = 0.05
    # projection_t_len = 0.1
    start_end_pnts = SE3_interpolation_eval(*pqc_path_B_coeffs, jnp.array([0.,1.]))[0][...,:3] # (NQ, NFPSB_reduced, 2, 3)
    line_segment_B = twist[...,:3]*finite_len
    # line_segment_B = tutil.normalize(twist[...,:3])*finite_len
    line_segment_B_norm = jnp.linalg.norm(line_segment_B, axis=-1)
    end_pnts = fps_tf_reducedB+line_segment_B*0.5
    cur_to_end_norm = jnp.linalg.norm(fps_tf_reducedB - start_end_pnts[...,None,:,1,:], axis=-1)
    end_replace_mask = jnp.logical_and(cur_to_end_norm < line_segment_B_norm*0.5, t_clamped_pairwise_normalized > 1-projection_t_len)
    end_pnts = jnp.where(end_replace_mask[...,None], start_end_pnts[...,None,:,1,:], end_pnts)

    start_pnts = fps_tf_reducedB-line_segment_B*0.5
    cur_to_start_norm = jnp.linalg.norm(fps_tf_reducedB - start_end_pnts[...,None,:,0,:], axis=-1)
    start_replace_mask = jnp.logical_and(cur_to_start_norm < line_segment_B_norm*0.5, t_clamped_pairwise_normalized < projection_t_len)
    start_pnts = jnp.where(start_replace_mask[...,None], start_end_pnts[...,None,:,0,:], start_pnts)

    t_clamped_pairwise = jnp.linalg.norm(fps_tf_reducedB - start_pnts, axis=-1)/jnp.linalg.norm(end_pnts - start_pnts, axis=-1).clip(1e-6)
    t_clamped_pairwise = jnp.where(jnp.isfinite(t_clamped_pairwise), t_clamped_pairwise, 0.5)
    t_clamped_pairwise = jnp.clip(t_clamped_pairwise, 0, 1)
    line_segment_B = end_pnts - start_pnts

    # t_clamped_pairwise = 0.5*jnp.ones_like(t_clamped_pairwise_normalized)
    # line_segment_B = tutil.normalize(twist[...,:3])
    # t_clamped_pairwise = jax.lax.stop_gradient(t_clamped_pairwise)
    # line_segment_B = jax.lax.stop_gradient(line_segment_B)

    selected_z_B = einops.rearrange(canonical_oriCORNs_B.z, '... i k j p -> ... (i k) j p') # (NOB*NFPSB, NZ1 NZ2)
    selected_z_B = jnp.take_along_axis(selected_z_B, Bidx[...,None,None], axis=-3) # (NQ, NFPSB_reduced, NZ1 NZ2)
    selected_z_B = rmutil.apply_rot(selected_z_B[...,None,:,:,:], tutil.q2R(fps_pqc_compensated[...,-4:]), rot_config, feature_axis=-2)

    # (4, 11, 64) -> (4*11*64) / (11, 64)
    latent_obj_A_reduced:loutil.LatentObjects = fixed_oriCORNs_A_merged.replace(z=jnp.take_along_axis(fixed_oriCORNs_A_merged.z, Aidx[...,None,None], axis=-3),
                                                rel_fps=fps_tf_reducedA-fixed_oriCORNs_A_merged.pos[...,None,:])
    latent_obj_A_reduced = latent_obj_A_reduced.broadcast_outershape(fps_tf_reducedA.shape[:-2])
    latent_obj_B_reduced:loutil.LatentObjects = loutil.LatentObjects(z=selected_z_B, rel_fps=fps_tf_reducedB).init_pos_zero()

    if visualize:
        res = visualize_in_broad_phase(latent_obj_A_reduced, latent_obj_B_reduced, line_segment_B, t_clamped_pairwise, fixed_oriCORNs_A_merged, canonical_oriCORNs_B, pqc_path_B)
        return latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, t_clamped_pairwise, pairwise_line_AB_reduced, res

    return latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B

def reduce_fps_path_pair(fixed_oriCORNs_A:loutil.LatentObjects, canonical_oriCORNs_B:loutil.LatentObjects, 
                         pqc_path_B, reduce_k, rot_config, broadphase_func=None, visualize=False):
    '''
    fixed_oriCORNs_A: (NOA, )
    canonical_oriCORNs_B: (NOB, )
    pqc_path_B: (NQ, NAC, NOB, 7) pos and quat
    '''

    nob = canonical_oriCORNs_B.nobj
    nfps_B = canonical_oriCORNs_B.nfps

    # fps_dst
    # fps_dist_A = fixed_oriCORNs_A.fps_dist # (NOA, nfpsA)
    # fps_dist_B = canonical_oriCORNs_B.fps_dist # (NOB, nfpsB)
    fps_dist_A = fixed_oriCORNs_A.mean_fps_dist[...,None].repeat(fixed_oriCORNs_A.nfps, -1) # (NOA, nfpsA)
    fps_dist_B = canonical_oriCORNs_B.mean_fps_dist[...,None].repeat(canonical_oriCORNs_B.nfps, -1) # (NOB, nfpsB)

    fixed_oriCORNs_A_merged = fixed_oriCORNs_A.merge() # (,)
    if fps_dist_A.ndim >= 2:
        fps_dist_A_merged = einops.rearrange(fps_dist_A, '... o f -> ... (o f)') # (NOA*nfpsA,)
    else:
        fps_dist_A_merged = fps_dist_A

    outer_dim = np.maximum(pqc_path_B.ndim - 3, fixed_oriCORNs_A_merged.ndim)
    outer_dim = np.maximum(outer_dim, canonical_oriCORNs_B.ndim-1)

    for _ in range(outer_dim - fixed_oriCORNs_A_merged.ndim):
        fixed_oriCORNs_A_merged = fixed_oriCORNs_A_merged[None]
        fps_dist_A_merged = fps_dist_A_merged[None]
    for _ in range(outer_dim - canonical_oriCORNs_B.ndim+1):
        canonical_oriCORNs_B = canonical_oriCORNs_B[None]
        fps_dist_B = fps_dist_B[None]
    for _ in range(outer_dim - pqc_path_B.ndim - 3):
        pqc_path_B = pqc_path_B[None]

    fps_tf_A = fixed_oriCORNs_A_merged.fps_tf

    if broadphase_func is not None:
        fps_path_B_pqc, fps_tf_reducedA, Aidx, Bidx, t_eval \
            = broadphase_func(pqc_path_B, canonical_oriCORNs_B, reduce_k//2, fixed_oriCORNs_A, visualize=visualize) # (NQ, NFPSA, NOB*NFPSB)
        
        pairwise_dist_sq, fps_tf_reducedB, t_clamped_pairwise, pairwise_line_AB_reduced, pairwise_min_anchor_idx_reduced \
          = pairwise_closest_pnts_path(fps_tf_reducedA, fps_path_B_pqc[...,:3])
    else:

        # calculate pairwise informations
        fps_path_B = tutil.pq_action(pqc_path_B[...,None,:3], pqc_path_B[...,None,3:], canonical_oriCORNs_B.fps_tf) # (NQ, NAC, NOB, NFPSB, 3)
        fps_path_B_merged = einops.rearrange(fps_path_B, ' ... i j k p -> ... i (j k) p') # (NQ, NAC, NOB*NFPSB, 3)

        pairwise_dist_sq, pairwise_pnt_on_line_B, t_clamped, pairwise_line_AB, pairwise_min_anchor_idx \
            = pairwise_closest_pnts_path(fps_tf_A, fps_path_B_merged)
        # (NQ, NFPSA, NOB*NFPSB)
        # t_clamped = jax.lax.stop_gradient(t_clamped)
        pairwise_min_anchor_idx = jax.lax.stop_gradient(pairwise_min_anchor_idx)
        pairwise_dist_sq = pairwise_dist_sq - fps_dist_A_merged[...,None]
        pairwise_dist_sq = pairwise_dist_sq - einops.rearrange(fps_dist_B[...,None,:,:], '... i j -> ... (i j)')
        pairwise_dist_sq = jax.lax.stop_gradient(pairwise_dist_sq)

        nfpsA, nofpsB = pairwise_dist_sq.shape[-2:]

        _, Aidx = jax.lax.top_k(-jnp.min(pairwise_dist_sq, axis=-1), np.minimum(reduce_k//2, nfpsA))
        argmin_on_A = jnp.argmin(pairwise_dist_sq, axis=-2)
        min_dist_A = jnp.take_along_axis(pairwise_dist_sq, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
        _, Bidx = jax.lax.top_k(-min_dist_A, np.minimum(reduce_k//2, nofpsB))
        Aidx = jax.lax.stop_gradient(Aidx)
        Bidx = jax.lax.stop_gradient(Bidx)

        pairwise_min_anchor_idx_reduced = jnp.take_along_axis(pairwise_min_anchor_idx, Aidx[...,None], axis=-2)
        pairwise_min_anchor_idx_reduced = jnp.take_along_axis(pairwise_min_anchor_idx_reduced, Bidx[...,None,:], axis=-1)  # (NQ, NFPSA_reduced, NFPSB_reduced)

        pairwise_min_anchor_idx_reduced_B = jnp.take_along_axis(pairwise_min_anchor_idx, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
        pairwise_min_anchor_idx_reduced_B = jnp.take_along_axis(pairwise_min_anchor_idx_reduced_B, Bidx, axis=-1) # (NQ, NFPSB_reduced)

        # lazy transformation for efficiency
        fps_tf_reducedA = jnp.take_along_axis(fps_tf_A, Aidx[...,None], axis=-2)

        t_clamped_pairwise = jnp.take_along_axis(t_clamped, Aidx[...,None], axis=-2)
        t_clamped_pairwise = jnp.take_along_axis(t_clamped_pairwise, Bidx[...,None,:], axis=-1)

        fps_tf_reducedB = jnp.take_along_axis(pairwise_pnt_on_line_B, Aidx[...,None,None], axis=-3)
        fps_tf_reducedB = jnp.take_along_axis(fps_tf_reducedB, Bidx[...,None,:,None], axis=-2)

        pairwise_line_AB_reduced = jnp.take_along_axis(pairwise_line_AB, Aidx[...,None,None], axis=-3)
        pairwise_line_AB_reduced = jnp.take_along_axis(pairwise_line_AB_reduced, Bidx[...,None,:,None], axis=-2)

    pqc_path_B_reduced = jnp.take_along_axis(pqc_path_B, Bidx[...,None,:,None]//nfps_B, axis=-2) # (NQ, NAC, NFPSB_reduced, 7)
    pqc_path_B_reduced_end = jnp.take_along_axis(pqc_path_B_reduced, pairwise_min_anchor_idx_reduced[...,None]+1, axis=-3) # (NQ, NFPSA_reduced, NFPSB_reduced, 7)
    pqc_path_B_reduced = jnp.take_along_axis(pqc_path_B_reduced, pairwise_min_anchor_idx_reduced[...,None], axis=-3) # (NQ, NFPSA_reduced, NFPSB_reduced, 7)

    qdir = tutil.qmulti(tutil.qinv(pqc_path_B_reduced[...,3:]), pqc_path_B_reduced_end[...,3:])
    qdir = tutil.qExp(t_clamped_pairwise[...,None]*tutil.qLog(qdir))
    quat_path_B_reduced_mid = tutil.qmulti(pqc_path_B_reduced[...,3:], qdir)

    # (NOB*NFPSB, 3) / (NQ, NFPSB_reduced)
    fps_start_pnts_on_B = jnp.take_along_axis(einops.rearrange(canonical_oriCORNs_B.fps_tf, '... i j k -> ... (i j) k'), Bidx[...,None], axis=-2)
    fps_end_pnts_on_B = fps_start_pnts_on_B
    fps_end_pnts_on_B = tutil.pq_action(pqc_path_B_reduced_end[...,:3], pqc_path_B_reduced_end[...,3:], fps_end_pnts_on_B[...,None,:,:]) # (NQ, NFPSA_reduced, NFPSB_reduced, 3)
    fps_start_pnts_on_B = tutil.pq_action(pqc_path_B_reduced[...,:3], pqc_path_B_reduced[...,3:], fps_start_pnts_on_B[...,None,:,:]) # (NQ, NFPSA_reduced, NFPSB_reduced, 3)
    line_segment_B = fps_end_pnts_on_B - fps_start_pnts_on_B

    selected_z_B = einops.rearrange(canonical_oriCORNs_B.z, '... i k j p -> ... (i k) j p') # (NOB*NFPSB, NZ1 NZ2)
    selected_z_B = jnp.take_along_axis(selected_z_B, Bidx[...,None,None], axis=-3) # (NQ, NFPSB_reduced, NZ1 NZ2)
    selected_z_B = rmutil.apply_rot(selected_z_B[...,None,:,:,:], tutil.q2R(quat_path_B_reduced_mid), rot_config, feature_axis=-2)
    
    # (4, 11, 64) -> (4*11*64) / (11, 64)
    latent_obj_A_reduced:loutil.LatentObjects = fixed_oriCORNs_A_merged.replace(z=jnp.take_along_axis(fixed_oriCORNs_A_merged.z, Aidx[...,None,None], axis=-3),
                                                rel_fps=fps_tf_reducedA-fixed_oriCORNs_A_merged.pos[...,None,:])
    latent_obj_B_reduced:loutil.LatentObjects = loutil.LatentObjects(z=selected_z_B, pos=None, rel_fps=fps_tf_reducedB).init_pos_zero()

    if visualize:
        visualize_in_broad_phase(latent_obj_A_reduced, latent_obj_B_reduced, line_segment_B, t_clamped_pairwise, fixed_oriCORNs_A_merged, canonical_oriCORNs_B, pqc_path_B)

    return latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B



def reduce_fps_path_aabb(fixed_oriCORNs_A:loutil.LatentObjects, canonical_oriCORNs_B:loutil.LatentObjects, pqc_path_B, reduce_k, rot_config, visualize=False):
    '''
    fixed_oriCORNs_A: (NOA, )
    canonical_oriCORNs_B: (NOB, )
    pqc_path_B: (NQ, NOB, 7) pos and quat
    '''
    # nob = canonical_oriCORNs_B.nobj
    # nfps_B = canonical_oriCORNs_B.nfps
    # nfpsA = fixed_oriCORNs_A.nfps
    # nq = pqc_path_B.shape[-3]

    # outer_dim = np.maximum(pqc_path_B.ndim - 3, fixed_oriCORNs_A.ndim)
    # outer_dim = np.maximum(outer_dim, canonical_oriCORNs_B.ndim-1)

    # for _ in range(outer_dim - fixed_oriCORNs_A.ndim+1):
    #     fixed_oriCORNs_A = fixed_oriCORNs_A[None]
    # for _ in range(outer_dim - canonical_oriCORNs_B.ndim+1):
    #     canonical_oriCORNs_B = canonical_oriCORNs_B[None]
    # for _ in range(outer_dim - pqc_path_B.ndim - 3):
    #     pqc_path_B = pqc_path_B[None]
    
    fps_tf_A = fixed_oriCORNs_A.fps_tf # (NOA, NFPSA, 3)

    # calculate pairwise informations
    fps_path_B = tutil.pq_action(pqc_path_B[...,None,:3], pqc_path_B[...,None,3:], canonical_oriCORNs_B.fps_tf) # (NQ, NOB, NFPSB, 3)

    fps_path_B_max = jnp.max(fps_path_B, axis=-2) # [NQ, NOB, 3]
    fps_path_B_min = jnp.min(fps_path_B, axis=-2) # [NQ, NOB, 3]
    fps_path_B_aabb_min = jnp.min(jnp.stack([fps_path_B_min[:-1], fps_path_B_min[1:]], axis=-2), axis=-2) # [NQ - 1, NOB, 3]
    fps_path_B_aabb_max = jnp.max(jnp.stack([fps_path_B_max[:-1], fps_path_B_max[1:]], axis=-2), axis=-2) # [NQ - 1, NOB, 3]
    fps_tf_A_aabb_min = jnp.min(fps_tf_A, axis=-2) # [NOA, 3]
    fps_tf_A_aabb_max = jnp.max(fps_tf_A, axis=-2) # [NOA, 3]

    overlap_min = jnp.maximum(fps_path_B_aabb_min[..., None, :], fps_tf_A_aabb_min[None, None, :, :])
    overlap_max = jnp.minimum(fps_path_B_aabb_max[..., None, :], fps_tf_A_aabb_max[None, None, :, :])
    overlap_dims = jnp.maximum(0.0, overlap_max - overlap_min)
    intersection = jnp.prod(overlap_dims, axis=-1) # [NQ - 1, NOB, NOA]

    vol1 = jnp.prod(fps_path_B_aabb_max - fps_path_B_aabb_min, axis=-1)  # [NQ - 1, NOB]
    vol2 = jnp.prod(fps_tf_A_aabb_max - fps_tf_A_aabb_min, axis=-1)  # [NOA]
    union = vol1[..., None] + vol2[None, None, :] - intersection # [NQ - 1, NOB, NOA]
    iou = jnp.where(union > 0, intersection / union, 0.0)

    targets = union # iou? volume of intersection?
    flat_targets = targets.ravel()
    
    _, flat_indices = jax.lax.top_k(flat_targets, len(flat_targets) // 4)
    index_q, index_nob, index_noa = jnp.unravel_index(flat_indices, targets.shape)
    # multi_dim_indices[0]: indicates q index
    # multi_dim_indices[1]: indicates nob index
    # multi_dim_indices[2]: indicates noa index
    line_start = fps_path_B[index_q, index_nob] # [K, NFPSB, 3]
    line_end = fps_path_B[index_q+1, index_nob] # [K, NFPSB, 3]
    line_segment = line_end - line_start # [K, NFPSB, 3]
    points = fps_tf_A[index_noa] # [K, NFPSA, 3]
    
    pairwise_dist_point_line, min_pts, t_clamped, closest_dif = segment_point_min_distance_and_points(points, line_start, line_segment)
    # min_dists: [K, NFPSA, NFPSB]
    # min_pts: [K, NFPSA, NFPSB, 3] : closest point on line on global frame
    # t_clamped: [K, NFPSA, NFPSB]
    # closest_dif: [K, NFPSA, NFPSB, 3] : vector from query point to closest point on the line

    group_num, num_fps_a, num_fps_b = pairwise_dist_point_line.shape
    pairwise_dist_point_line = einops.rearrange(pairwise_dist_point_line, 'k a b -> (k a) b')
    _, index_fps_a = jax.lax.top_k(-jnp.min(pairwise_dist_point_line, axis=-1), np.minimum(reduce_k//2, group_num * num_fps_a))
    argmin_on_A = jnp.argmin(pairwise_dist_point_line, axis=-2)
    index_group, index_fps_a = jnp.unravel_index(index_fps_a, (group_num, num_fps_a))
    # Gidx: indicates group index
    # Aidx: indicates nfpsa index
    min_dist_A = jnp.take_along_axis(pairwise_dist_point_line, argmin_on_A[...,None,:], axis=-2).squeeze(-2)
    _, index_fps_b = jax.lax.top_k(-min_dist_A, np.minimum(reduce_k//2, num_fps_b))

    fps_tf_reducedA = fixed_oriCORNs_A.fps_tf[index_noa[index_group], index_fps_a] # [reduced_k, 3]

    t_clamped_pairwise = t_clamped[index_group, index_fps_a, index_fps_b, None, None] # [reduced_k, 1, 1]
    pairwise_line_AB_reduced = closest_dif[index_group, index_fps_a, index_fps_b, None, None, :] # [reduced_k, 1, 1, 3]

    fps_tf_reducedB = min_pts[index_group, None, index_fps_a, index_fps_b] # [reduced_k, 1, 3]
    line_segment_B = line_segment[index_group, index_fps_b, None, :] # [reduced_k, 3]

    pqc_path_B_reduced = pqc_path_B[index_q[index_group], index_nob[index_group]] # [reduced_k, 7]
    selected_z_B = canonical_oriCORNs_B.z[index_nob[index_group], index_fps_b] # [reduced_k, NZ1, NZ2]
    selected_z_B = rmutil.apply_rot(selected_z_B, tutil.q2R(pqc_path_B_reduced[...,3:]), rot_config, feature_axis=-2)[:, None] # [reduced_k, 1, NZ1, NZ2]
    # center_B = jnp.mean(fps_tf_reducedB, axis=-2) # [3]

    fixed_oriCORNs_A_merged = fixed_oriCORNs_A.merge() # (,)
    # (4, 11, 64) -> (4*11*64) / (11, 64)

    latent_obj_A_reduced:loutil.LatentObjects = loutil.LatentObjects(
        z=fixed_oriCORNs_A.z[index_noa[index_group], None, index_fps_a], # [reduced_k, 1, NZ1, NZ2]
        rel_fps=(fps_tf_reducedA-fixed_oriCORNs_A_merged.pos[...,None,:])[:, None], # [reduced_k, 1, NFPSA, 3]
        pos=einops.repeat(fixed_oriCORNs_A_merged.pos, "n -> k n", k=reduce_k) # [reduced_k, 3]
    )
    latent_obj_B_reduced:loutil.LatentObjects = loutil.LatentObjects(
        z=selected_z_B,
        pos=fps_tf_reducedB,
        rel_fps=jnp.zeros_like(fps_tf_reducedB)
    )
    if visualize:
        pass

    return latent_obj_A_reduced, latent_obj_B_reduced, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B



def get_segment_broad_phase_func(bvh_size):

    @wp.kernel
    def bvh_query_segment(segment_start: wp.array(dtype=wp.vec3),
                        segment_end: wp.array(dtype=wp.vec3),
                        radius: wp.array(dtype=wp.vec(1, wp.float32)),
                        bvh_center: wp.array(dtype=wp.vec3),
                        bvh_radius: wp.array(dtype=wp.vec(1, wp.float32)),
                        bvh_id_32: wp.array(dtype=wp.vec2ui),
                        batch_size: wp.array(dtype=wp.int32),
                        # outputs
                        cloest_bvh_idx_per_segment: wp.array(dtype=wp.int32),
                        closest_distance_per_segment: wp.array(dtype=wp.float32),
                        closest_distance_per_bvh: wp.array(dtype=wp.float32),
                        pairwise_dist: wp.array(dtype=wp.vec(bvh_size, wp.float32)),
                        ):
        """
        Given a line segment and a set of BVH points, find the closest intersection.
        Supports batch dimensions.
        Only outputs the closest BVH candidate per segment.
        
        inputs:
            segment_start: (B*N, 3)
            segment_end: (B*N, 3)
            radius: (B*N, 1)
            bvh_center: (M, 3)
            bvh_radius: (M, 1)
            bvh_id_32: (1, 2)  -- packed BVH identifier (common for all segments)
            batch_size: (1, )
        outputs:
            cloest_bvh_idx_per_segment: (B*N, ) -- closest BVH index per segment (or -1 if none found)
            closest_distance_per_segment: (B*N, ) -- corresponding candidate distance for each segment
            closest_distance_per_bvh: (B*M, ) -- per-batch, per-BVH minimum distance (atomically updated)
        """
        # Unpack the BVH id (common for all segments)
        bvh_id = (wp.uint64(bvh_id_32[0][0]) << wp.uint64(32)) | wp.uint64(bvh_id_32[0][1])
        
        # Each thread corresponds to one segment.
        seg_idx = wp.tid()
        total_segments = segment_start.shape[0]  # should equal B * N
        n_segments_per_batch = total_segments // batch_size[0]
        batch_id = seg_idx // n_segments_per_batch

        # Initialize best candidate for this segment.
        best_distance = wp.float32(1000.0)
        best_bvh_idx = int(-1)

        cloest_bvh_idx_per_segment[seg_idx] = int(-1)
        closest_distance_per_segment[seg_idx] = wp.float32(1000.0)
        closest_distance_per_bvh[seg_idx] = wp.float32(1000.0)
        for i in range(bvh_size):
            pairwise_dist[seg_idx][i] = wp.float32(1000.0)

        # Build the AABB for the segment (expanded by its radius).
        seg_min = wp.min(segment_start[seg_idx], segment_end[seg_idx]) - wp.vec3(radius[seg_idx][0],
                                                                                radius[seg_idx][0],
                                                                                radius[seg_idx][0])
        seg_max = wp.max(segment_start[seg_idx], segment_end[seg_idx]) + wp.vec3(radius[seg_idx][0],
                                                                                radius[seg_idx][0],
                                                                                radius[seg_idx][0])
        query = wp.bvh_query_aabb(bvh_id, seg_min, seg_max)

        candidate = int(0)  # variable to hold candidate index from the BVH query
        while wp.bvh_query_next(query, candidate):
            # Compute the segment vector and its length.
            seg_dir = segment_end[seg_idx] - segment_start[seg_idx]
            seg_len_sq = wp.dot(seg_dir, seg_dir)

            # Let P0 = segment_start, P1 = segment_end, Q = bvh_center[candidate].
            diff = bvh_center[candidate] - segment_start[seg_idx]
            t = wp.dot(diff, seg_dir) / seg_len_sq  # normalized projection parameter
            t = wp.max(t, wp.float32(0.0))
            t = wp.min(t, wp.float32(1.0))
            proj_point = t * seg_dir  # Projection displacement from segment_start.
            closest_dir = diff - proj_point
            distance_to_segment = wp.sqrt(wp.dot(closest_dir, closest_dir))
            penetration_depth = distance_to_segment - bvh_radius[candidate][0] - radius[seg_idx][0]

            pairwise_dist[seg_idx][candidate] = penetration_depth

            # Reject if candidate is too far (i.e. beyond the sum of BVH radius and segment radius)
            if penetration_depth > 0:
                continue

            # Update the best candidate for this segment.
            if penetration_depth < best_distance:
                best_distance = penetration_depth
                best_bvh_idx = candidate

            # Update the global best for this candidate BVH element in the current batch.
            M = bvh_center.shape[0]  # total number of BVH elements
            global_bvh_idx = batch_id * M + candidate
            wp.atomic_min(closest_distance_per_bvh, global_bvh_idx, penetration_depth)
        
        # Write out the per-segment best candidate.
        cloest_bvh_idx_per_segment[seg_idx] = best_bvh_idx
        closest_distance_per_segment[seg_idx] = best_distance

    return bvh_query_segment

def get_aabb_broad_phase_func(bvh_dim):
    @wp.kernel
    def bvh_query_aabb(lower: wp.array(dtype=wp.vec3), upper: wp.array(dtype=wp.vec3), 
                        bvh_id_32: wp.array(dtype=wp.vec2ui), bounds_intersected: wp.array(dtype=wp.vec(bvh_dim, wp.uint8))):
    # def bvh_query_aabb(lower: wp.array(dtype=wp.vec3), upper: wp.array(dtype=wp.vec3), bvh_id_32: wp.array(dtype=wp.vec2ui), bounds_intersected: wp.array(dtype=wp.vec(bvh_dim, wp.uint8))):
    # def bvh_query_aabb(bvh_id: wp.uint64, lower: wp.array(dtype=wp.vec3), upper: wp.array(dtype=wp.vec3), bounds_intersected: wp.array(dtype=wp.vec(bvh_dim, wp.uint8))):

        bvh_id = (wp.uint64(bvh_id_32[0][0]) << wp.uint64(32)) | wp.uint64(bvh_id_32[0][1])

        idx = wp.tid()
        query = wp.bvh_query_aabb(bvh_id, lower[idx], upper[idx])
        bounds_nr = int(0)

        # initialize the bounds_intersected to 0
        for i in range(bvh_dim):
            bounds_intersected[idx][i] = wp.uint8(0)

        while wp.bvh_query_next(query, bounds_nr):
            # narrow phase
            bounds_intersected[idx][bounds_nr] = wp.uint8(1)

    num_candidates_per_query = 10
    @wp.kernel
    def bvh_query_segment(segment_start: wp.array(dtype=wp.vec3), segment_end: wp.array(dtype=wp.vec3), radius:wp.array(dtype=wp.vec(1, wp.float32)),
                        bvh_center: wp.array(dtype=wp.vec3), bvh_radius: wp.array(dtype=wp.vec(1, wp.float32)),
                        bvh_id_32: wp.array(dtype=wp.vec2ui), 
                        col_idx: wp.array(dtype=wp.vec(num_candidates_per_query, wp.int32)), 
                        closest_dist_out: wp.array(dtype=wp.vec(num_candidates_per_query, wp.float32)),
                        closest_distance_per_bvh: wp.array(dtype=wp.vec(1, wp.float32)),
                        ):
        '''
        given line segment and set of points built by bvh, find the closest points on the line segment to the points in the bvh
        inputs:
            segment_start: (N, 3)
            segment_end: (N, 3)
            radius: (N, 1)
            bvh_center: (M, 3)
            bvh_radius: (M, 1)
            bvh_id_32: (N, 2)
        outputs:
            col_idx: (N, num_candidates_per_query)
            closest_dist_out: (N, num_candidates_per_query)
            closest_distance_per_bvh: (M, 1)
        '''
        bvh_id = (wp.uint64(bvh_id_32[0][0]) << wp.uint64(32)) | wp.uint64(bvh_id_32[0][1])

        idx = wp.tid()

        query = wp.bvh_query_aabb(bvh_id, wp.min(segment_start[idx], segment_end[idx]) - wp.vec3(radius[idx][0],radius[idx][0],radius[idx][0]), 
                                  wp.max(segment_start[idx], segment_end[idx]) + wp.vec3(radius[idx][0],radius[idx][0],radius[idx][0]))
        bounds_nr = int(0)

        # initialize the bounds_intersected to 0
        for i in range(num_candidates_per_query):
            col_idx[idx][i] = wp.int32(-1)
            closest_dist_out[idx][i] = wp.float32(1000.0)

        idx_cnt = int(0)
        while wp.bvh_query_next(query, bounds_nr):
            # narrow phase
            # distance between segment and bvh center
            dist_squared = wp.dot(bvh_center[bounds_nr] - segment_start[idx], bvh_center[bounds_nr] - segment_start[idx])
            dist_dir = wp.sqrt(wp.dot(segment_end[idx] - segment_start[idx], segment_end[idx] - segment_start[idx]))
            dist_dir = wp.max(dist_dir, wp.float32(1e-6))
            dist_to_start = wp.dot(bvh_center[bounds_nr] - segment_start[idx], segment_end[idx] - segment_start[idx])/dist_dir
            closest_dist = wp.sqrt(dist_squared - dist_to_start*dist_to_start)

            if closest_dist > bvh_radius[bounds_nr][0] + radius[idx][0]:
                continue

            col_idx[idx][idx_cnt] = wp.int32(bounds_nr)
            closest_dist_out[idx][idx_cnt] = closest_dist
            idx_cnt += 1
            if idx_cnt >= num_candidates_per_query:
                break


    def callable_func(
            # inputs
            query_lower: wp.array(dtype=wp.vec3), 
            query_upper: wp.array(dtype=wp.vec3), 
            lower: wp.array(dtype=wp.vec3), 
            upper: wp.array(dtype=wp.vec3), 
            # outputs
            bounds_intersected: wp.array(dtype=wp.vec(bvh_dim, wp.uint8))
        ):
        bvh = wp.Bvh(lower, upper)
        wp.launch(bvh_query_aabb, dim=int(query_lower.shape[0]), inputs=[bvh.id, query_lower, query_upper], outputs=[bounds_intersected])
        wp.synchronize()
        wp.synchronize_device(wp.device_from_jax(get_jax_device()))
        

    return bvh_query_segment, bvh_query_aabb, callable_func

def warp_func(inputs, outputs, attrs, ctx):

    device = wp.device_from_jax(get_jax_device())
    stream = wp.Stream(device, cuda_stream=ctx.stream)

    with wp.ScopedStream(stream):
        # lowers = wp.array(inputs[0], dtype=wp.vec3)
        # uppers = wp.array(inputs[1], dtype=wp.vec3)
        # query_lower = wp.array(inputs[2], dtype=wp.vec3)
        # query_upper = wp.array(inputs[3], dtype=wp.vec3)

        # Build the BVH with the given bounds
        bvh = wp.Bvh(inputs[0], inputs[1])
        inputs = [bvh.id, inputs[2], inputs[3]]

        wp.launch(get_aabb_broad_phase_func(query_dim=int(attrs['query_dim'])), dim=int(attrs['query_dim']), inputs=inputs, outputs=outputs)

# register_ffi_callback("warp_func", warp_func) # for new version


def uint64_to_uint32_pair(value):
    # Extract the lower 32 bits
    low = jnp.uint32(value & 0xFFFFFFFF)
    # Extract the upper 32 bits
    high = jnp.uint32(value >> 32)
    # return high, low
    return jnp.stack([high, low], axis=-1).astype(jnp.uint32)


def get_broad_phase_jax(bvh_dim):
    # jax_func = jax_callable(get_aabb_broad_phase_func(query_dim)[1], num_outputs=1, vmap_method="broadcast_all")
    jax_func = jax_callable(get_aabb_broad_phase_func(bvh_dim)[1], num_outputs=1, vmap_method='sequential')
    def broad_phase_jax(query_lower:jnp.ndarray, query_upper:jnp.ndarray, lower:jnp.ndarray, upper:jnp.ndarray):
        original_shape =query_lower.shape[:-1]
        flat_args = jax.tree_util.tree_map(lambda x: x.reshape(-1, 3), [query_lower, query_upper, lower, upper])
        flat_args = jax.lax.stop_gradient(flat_args)
        idx_mask = jax_func(*flat_args)[0]
        idx_mask = idx_mask.reshape(original_shape+idx_mask.shape[-1:])
        idx_mask = jax.lax.stop_gradient(idx_mask)
        return idx_mask
    return broad_phase_jax

class BroadPhaseWarp:
    def __init__(self):
        self.bvh = None
        # self.bvh_dim = bvh_dim
        # self.jax_func = get_broad_phase_jax(bvh_dim)
        # self.jax_func = jax.jit(self.jax_func)
    
    def enroll_bvh(self, fixed_obj:loutil.LatentObjects):

        enrol_start_time = time.time()

        # margin = 0.1
        radius = fixed_obj.mean_fps_dist
        radius = radius[...,None].repeat(fixed_obj.nfps, axis=-1)
        radius = radius.reshape(-1, 1).astype(jnp.float32)
        fixed_obj_merged = fixed_obj.merge()
        lower_bound_bvh = (fixed_obj_merged.fps_tf - radius).astype(jnp.float32)
        upper_bound_bvh = lower_bound_bvh + 2*radius
        # self.bvh_center = fixed_obj_merged.fps_tf.astype(jnp.float32)
        # self.bvh_radius = radius
        # self.fixed_obj_merged = fixed_obj_merged

        self.bvh_dim = lower_bound_bvh.shape[0]


        if self.bvh is None:
            self.lowers = wp.from_numpy(np.array(lower_bound_bvh), dtype=wp.vec3)
            self.uppers = wp.from_numpy(np.array(upper_bound_bvh), dtype=wp.vec3)
            # lowers = wp.array(np.array(lower_bound_bvh), dtype=wp.vec3)
            # uppers = wp.array(np.array(upper_bound_bvh), dtype=wp.vec3)
            self.bvh = wp.Bvh(self.lowers, self.uppers)
            # self.bvh = wp.Bvh(wp.from_numpy(np.array(lower_bound_bvh), dtype=wp.vec3),
            #               wp.from_numpy(np.array(upper_bound_bvh), dtype=wp.vec3))

            self.bvh_id = uint64_to_uint32_pair(self.bvh.id)[None].astype(jnp.uint32)
            # self.jax_kernel = jax_kernel(get_aabb_broad_phase_func(self.bvh_dim)[0], num_outputs=2)
            # self.jax_kernel = jax_kernel(get_aabb_broad_phase_func(self.bvh_dim)[0])
            self.jax_kernel = jax_kernel(get_segment_broad_phase_func(self.bvh_dim), num_outputs=4)
        else:
            self.lowers.assign(np.array(lower_bound_bvh))
            self.uppers.assign(np.array(upper_bound_bvh))
            self.bvh.refit()

        print(f"BVH enroll time: {time.time()-enrol_start_time}")

    def __call__(self, *args):
        original_shape =args[0].shape[:-1]

        flat_args = jax.tree_util.tree_map(lambda x: x.reshape(-1, x.shape[-1]), args)
        flat_args = jax.lax.stop_gradient(flat_args)

        kernel_out = self.jax_kernel(*flat_args, self.bvh_id)

        kernel_out = jax.tree_util.tree_map(lambda x: x.reshape(original_shape+x.shape[-1:]), kernel_out)
        kernel_out = jax.lax.stop_gradient(kernel_out)
        return kernel_out
    
    def segment(self, moving_obj_pqs, moving_obj:loutil.LatentObjects, fixed_oriCORNs, models, select_num=100, visualize=False):
        '''
        moving_obj_pqs: (..., NSEQ, NOB, 7)
        moving_obj: (NOB, )
        '''

        bvh_center = fixed_oriCORNs.fps_tf.reshape(-1,3)


        moving_obj_radius = moving_obj.mean_fps_dist[...,None].repeat(moving_obj.nfps, axis=-1)
        link_fps_seq = tutil.pq_action(moving_obj_pqs[...,None,:], moving_obj.fps_tf) # (..., NSEQ, NOB, NFPSB, 3)
        link_fps_seq = jnp.stack([link_fps_seq[...,:-1,:,:,:], link_fps_seq[...,1:,:,:,:]], axis=-1).astype(jnp.float32)
        moving_obj_radius = jnp.broadcast_to(moving_obj_radius, link_fps_seq.shape[:-2]).astype(jnp.float32)[...,None]

        bp_pairs, closest_dist = self.__call__(link_fps_seq[...,0], link_fps_seq[...,1], moving_obj_radius, self.bvh_center, self.bvh_radius)
        # (..., NSEQ, NOB, NFPS, 10)

        # select best N pair
        NC = bp_pairs.shape[-1]
        outer_shape = bp_pairs.shape[:-4]
        bp_pairs = bp_pairs.reshape((int(np.prod(outer_shape)), -1))
        bp_pairs = jnp.where(bp_pairs!=-1, bp_pairs, jax.random.permutation(jax.random.PRNGKey(0), bp_pairs.shape[-1]))
        closest_dist = closest_dist.reshape((int(np.prod(outer_shape)), -1))
        topk_val, moving_idx = jax.lax.top_k(-closest_dist, np.minimum(select_num, closest_dist.shape[-1]))
        fixed_idx = jnp.take_along_axis(bp_pairs, moving_idx, axis=-1)
        moving_idx = moving_idx//NC

        link_fps_seq_flat = link_fps_seq.reshape(int(np.prod(outer_shape)), -1, link_fps_seq.shape[-2], link_fps_seq.shape[-1])
        moving_fps_seq_filtered = jnp.take_along_axis(link_fps_seq_flat, moving_idx[...,None,None], axis=-3)
        moving_z = jnp.take_along_axis(moving_obj.z.reshape(-1, *moving_obj.latent_shape[1:])[None], 
                                       moving_idx[...,None,None]%(moving_obj.nfps*moving_obj.nobj), axis=-3)
        moving_obj_pqs_flat = moving_obj_pqs[...,:-1,:,:].reshape(int(np.prod(outer_shape)), -1, moving_obj_pqs.shape[-1]) # (..., NSEQ, NOB, 7)
        moving_obj_pqs_filtered = jnp.take_along_axis(moving_obj_pqs_flat, moving_idx[...,None]//moving_obj.nfps, axis=-2)
        moving_z = rmutil.apply_rot(moving_z, tutil.q2R(moving_obj_pqs_filtered[...,3:]), models.rot_configs, feature_axis=-2)

        moving_oriCORN = loutil.LatentObjects(z=moving_z[:,:,None], rel_fps=moving_fps_seq_filtered[:,:,None,...,0])
        moving_oriCORN = moving_oriCORN.init_pos_zero()
        moving_segment_line = moving_fps_seq_filtered[...,1] - moving_fps_seq_filtered[...,0]
        moving_segment_line = moving_segment_line[:,:,None]
        fixed_fps_filtered = jnp.take_along_axis(self.fixed_obj_merged.fps_tf[None], fixed_idx[...,None], axis=-2)
        fixed_z = jnp.take_along_axis(self.fixed_obj_merged.z[None], fixed_idx[...,None,None], axis=-3)
        fixed_oriCORN = loutil.LatentObjects(z=fixed_z[:,:,None], rel_fps=fixed_fps_filtered[:,:,None])
        fixed_oriCORN = fixed_oriCORN.init_pos_zero()

        moving_oriCORN = moving_oriCORN.reshape_outer_shape(outer_shape + (select_num,))
        fixed_oriCORN = fixed_oriCORN.reshape_outer_shape(outer_shape + (select_num,))
        moving_segment_line = moving_segment_line.reshape(outer_shape + (select_num, 1, 3))


        collision_loss_pair = models.apply('col_decoder', fixed_oriCORN, moving_oriCORN, 
                                            line_segment_B=moving_segment_line,
                                            reduce_k=16, jkey=jax.random.PRNGKey(0), pairwise_out=True)
        collision_loss_pair = collision_loss_pair.squeeze(-2)

        if visualize:
            import open3d as o3d

            if len(outer_shape) >= 1:
                batch_idx = 1
                link_fps_seq = link_fps_seq[batch_idx]
                moving_oriCORN = moving_oriCORN[batch_idx]
                fixed_oriCORN = fixed_oriCORN[batch_idx]
                moving_segment_line = moving_segment_line[batch_idx]

            # visualize lines and AABBs in open3d
            # create AABB linesets
            lower_bound_bvh = self.bvh_center - self.bvh_radius
            upper_bound_bvh = self.bvh_center + self.bvh_radius
            aabb_linesets = []
            for i in range(self.bvh_center.shape[0]):
                aabb = o3d.geometry.AxisAlignedBoundingBox(lower_bound_bvh[i], upper_bound_bvh[i])
                # if i not in bp_pairs.reshape(-1).tolist():
                if np.any(np.linalg.norm(fixed_oriCORN.fps_tf.squeeze(-2) - self.bvh_center[i], axis=-1) < 1e-3):
                    aabb.color = (1,0,0)
                else:
                    aabb.color = (0,1,0)
                aabb_linesets.append(aabb)
            

            # draw lines from link_fps_seq # (NSEQ, NLINK, NFPSB, 3, 2)
                

            link_fps_seq_flat = link_fps_seq.reshape(-1, link_fps_seq.shape[-2], link_fps_seq.shape[-1])
            link_fps_seq_flat = jnp.moveaxis(link_fps_seq_flat, -1, -2)
            bp_pairs_flat = bp_pairs.reshape(-1, bp_pairs.shape[-1])
            line_sets = []
            for i in range(link_fps_seq_flat.shape[0]):
                line = o3d.geometry.LineSet()
                line.points = o3d.utility.Vector3dVector(link_fps_seq_flat[i])
                line.lines = o3d.utility.Vector2iVector([[0,1]])
                line.colors = o3d.utility.Vector3dVector(np.array([[0,1,0]]))
                # if not (bp_pairs_flat[i]!=-1).any():
                #     line.colors = o3d.utility.Vector3dVector(np.array([[0,1,0]]))
                # else:
                #     line.colors = o3d.utility.Vector3dVector(np.array([[1,0,0]]))
                line_sets.append(line)

            for i in range(moving_oriCORN.shape[0]):
                line = o3d.geometry.LineSet()
                line_pnts = [moving_oriCORN.fps_tf.squeeze(-2)[i], moving_oriCORN.fps_tf.squeeze(-2)[i] + moving_segment_line[i].squeeze(-2)]
                line_pnts = np.array(line_pnts)
                line.points = o3d.utility.Vector3dVector(line_pnts)
                line.lines = o3d.utility.Vector2iVector([[0,1]])
                line.colors = o3d.utility.Vector3dVector(np.array([[1,0,0]]))
                line_sets.append(line)

            frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
            o3d.visualization.draw_geometries(aabb_linesets+line_sets + [frame])

        return bp_pairs, closest_dist, collision_loss_pair

    def path_broad_phase(self, moving_obj_pqs, moving_obj, K, fixed_obj, visualize=False):
        '''
        inputs:
            moving_obj_pqs: (nac, nlink, 3)
            moving_obj_radius: (nfpsB, 1)
        outputs:
            
        '''

        bvh_radius = fixed_obj.mean_fps_dist
        bvh_radius = bvh_radius[...,None].repeat(fixed_obj.nfps, axis=-1)
        bvh_radius = bvh_radius.reshape(-1, 1).astype(jnp.float32)
        bvh_center = fixed_obj.fps_tf.reshape(-1,3)

        NSEQ = moving_obj_pqs.shape[-3]
        original_batch_shape = moving_obj_pqs.shape[:-3]

        moving_obj_radius = moving_obj.mean_fps_dist[...,None].repeat(moving_obj.nfps, axis=-1)*1.8
        link_fps_seq = tutil.pq_action(moving_obj_pqs[...,None,:], moving_obj.fps_tf) # (..., NSEQ, NOB, NFPSB, 3)
        link_fps_traj = link_fps_seq # (..., NSEQ, NOB, NFPSB, 3)
        NOB, NFPSB = link_fps_seq.shape[-3:-1]
        link_fps_seq = jnp.stack([link_fps_seq[...,:-1,:,:,:], link_fps_seq[...,1:,:,:,:]], axis=-1).astype(jnp.float32)
        link_fps_seq_original = link_fps_seq # (..., NSEQ, NOB, NFPSB, 3 2)
        moving_obj_radius = jnp.broadcast_to(moving_obj_radius, link_fps_seq.shape[:-2]).astype(jnp.float32)[...,None]

        # moving_obj_radius = moving_obj.mean_fps_dist[...,None].repeat(moving_obj.nfps, axis=-1)
        # fps_seq = jnp.stack([fps_anchor_pnts_B[...,1:,:,:], fps_anchor_pnts_B[...,:-1,:,:]], axis=-1).astype(jnp.float32) # (... nac-1 nfpsB 3 2)
        # moving_obj_radius = jnp.broadcast_to(moving_obj_radius, fps_seq.shape[:-2] + (1,)).astype(jnp.float32)
        original_moving_shape = link_fps_seq.shape[:-2]
        original_fixed_shape = original_batch_shape + (self.bvh_dim,)
        link_fps_seq = link_fps_seq.reshape(-1, 3, 2)
        moving_obj_radius = moving_obj_radius.reshape(-1, 1)

        if visualize:
            bp_jax_kernel = jax.jit(self.jax_kernel)
        else:
            bp_jax_kernel = self.jax_kernel

        batch_size = int(np.prod(original_batch_shape))
        kernel_args = (link_fps_seq[...,0], link_fps_seq[...,1], moving_obj_radius, bvh_center, bvh_radius, self.bvh_id, jnp.array([batch_size]).astype(jnp.int32))
        kernel_args = jax.lax.stop_gradient(kernel_args)
        res = bp_jax_kernel(*kernel_args)
        res = jax.lax.stop_gradient(res)
        closest_idx_moving = res[0].reshape(original_moving_shape) # (..., NSEQ, NOB, NFPSB)
        closest_dist_moving = res[1].reshape(original_moving_shape) # (..., NSEQ, NOB, NFPSB)
        closest_dist_fixed = res[2][:np.prod(original_fixed_shape)].reshape(original_fixed_shape) # (..., NOA*NFPSA)
        pairwise_signed_dist = res[3].reshape(original_moving_shape + (self.bvh_dim,)) # (..., NSEQ, NOB, NFPSB NOA*NFPSA) - negative: penetration
        pairwise_signed_dist = einops.rearrange(pairwise_signed_dist, '... nseq nob nfpsb nfpsa -> ... nseq (nob nfpsb) nfpsa')
        pairwise_signed_dist = jnp.moveaxis(pairwise_signed_dist, -1, -2) # (..., NSEQ, NOA*NFPSA, NOB*NFPSB)

        # gather closesest top K pairs
        # closest_idx_moving = einops.rearrange(closest_idx_moving, '... nseq nob nfpsb -> ... (nob nfpsb) nseq')
        closest_dist_moving_traj = einops.rearrange(closest_dist_moving, '... nseq nob nfpsb -> ... (nob nfpsb) nseq')
        closest_dist_moving_traj_idx = jnp.argmin(closest_dist_moving_traj, axis=-1) # (..., NOB*NFPSB)
        closest_dist_moving_traj = jnp.take_along_axis(closest_dist_moving_traj, closest_dist_moving_traj_idx[...,None], axis=-1).squeeze(-1) # (..., NOB*NFPSB)
        closest_dist_moving_traj = jnp.where(closest_dist_moving_traj>999, 
                                             closest_dist_moving_traj+jax.random.uniform(jax.random.PRNGKey(0), shape=closest_dist_moving_traj.shape)*1e-3,
                                             closest_dist_moving_traj)
        topk_val_B, topk_idx_B = jax.lax.top_k(-closest_dist_moving_traj, np.minimum(K, closest_dist_moving_traj.shape[-1])) # (..., K)
        linke_fps_traj_filtered = einops.rearrange(link_fps_traj, '... nseq nob nfpsb i  -> ... (nob nfpsb) nseq i')
        # linke_fps_traj_filtered = jnp.take_along_axis(linke_fps_traj_filtered, closest_dist_moving_traj_idx[...,None, None], axis=-3) # (..., K NSEQ 3)
        linke_fps_traj_filtered = jnp.take_along_axis(linke_fps_traj_filtered, topk_idx_B[..., None, None], axis=-3) # (..., K NSEQ 3)
        # min_t = jnp.take_along_axis(closest_dist_moving_traj_idx, topk_idx_B, axis=-1) # (..., K)
        closest_dist_fixed = jnp.where(closest_dist_fixed>999,
                                             closest_dist_fixed+jax.random.uniform(jax.random.PRNGKey(1), shape=closest_dist_fixed.shape)*1e-3,
                                             closest_dist_fixed)
        topk_val_A, topk_idx_A = jax.lax.top_k(-closest_dist_fixed, np.minimum(K, closest_dist_fixed.shape[-1])) # (..., K)

        fixed_fps_filtered = bvh_center
        for i in range(len(original_batch_shape)):
            fixed_fps_filtered = fixed_fps_filtered[None]
        fixed_fps_filtered = jnp.take_along_axis(fixed_fps_filtered, topk_idx_A[...,None], axis=-2)

        pairwise_t = jnp.argmin(pairwise_signed_dist, axis=-3) # (..., NOA*NFPSA, NOB*NFPSB)
        pairwise_t = jnp.take_along_axis(pairwise_t, topk_idx_A[...,None], axis=-2)
        pairwise_t = jnp.take_along_axis(pairwise_t, topk_idx_B[...,None,:], axis=-1) # (..., K, K)
        pairwise_t = (pairwise_t.astype(jnp.float32)+0.5)/(NSEQ-1) # (..., K, K)
        pairwise_t = pairwise_t.clip(0,1)

        linke_fps_traj_filtered = jnp.moveaxis(linke_fps_traj_filtered, -2, -3) # (..., NSEQ, K, 3)
        pqc_path_B_reduced = jnp.take_along_axis(moving_obj_pqs, topk_idx_B[...,None,:,None]//moving_obj.nfps, axis=-2) # (NAC, K, 7)
        pqc_path_B_quat = pqc_path_B_reduced[...,3:]
        linke_pq_traj_filtered = jnp.concat([linke_fps_traj_filtered, pqc_path_B_quat], axis=-1) # (..., K, 7)

        # pairwise_signed_dist_reduced = jnp.take_along_axis(pairwise_signed_dist, topk_idx_A[...,None,:,None], axis=-2) # (..., K, NOB*NFPSB)
        # pairwise_signed_dist_reduced = jnp.take_along_axis(pairwise_signed_dist_reduced, topk_idx_B[...,None,None,:], axis=-1) # (..., K, K)

        if visualize:
            import open3d as o3d

            for batch_idx in range(linke_pq_traj_filtered.shape[0]):
            # if len(original_batch_shape) >= 1:

                link_fps_seq_original_ = link_fps_seq_original[batch_idx]
                closest_idx_moving_ = closest_idx_moving[batch_idx]
                closest_dist_moving_ = closest_dist_moving[batch_idx]
                closest_dist_fixed_ = closest_dist_fixed[batch_idx]
                topk_idx_A_ = topk_idx_A[batch_idx]
                topk_idx_B_ = topk_idx_B[batch_idx]
                linke_pq_traj_filtered_ = linke_pq_traj_filtered[batch_idx]

                # visualize lines and AABBs in open3d
                # create AABB linesets
                lower_bound_bvh = bvh_center - bvh_radius
                upper_bound_bvh = bvh_center + bvh_radius
                aabb_linesets = []
                for i in range(bvh_center.shape[0]):
                    aabb = o3d.geometry.AxisAlignedBoundingBox(lower_bound_bvh[i], upper_bound_bvh[i])
                    if i in topk_idx_A_.tolist():
                        aabb.color = (0,0,1)
                    elif closest_dist_fixed_[i] < 100:
                        aabb.color = (1,0,0)
                    else:
                        aabb.color = (0,1,0)
                    aabb_linesets.append(aabb)

                link_fps_seq_flat = link_fps_seq_original_.reshape(-1, link_fps_seq_original_.shape[-2], link_fps_seq_original_.shape[-1])
                link_fps_seq_flat = jnp.moveaxis(link_fps_seq_flat, -1, -2)
                closest_idx_moving_flat = closest_idx_moving_.reshape(-1)
                # closest_idx_moving_flat = jnp.where(closest_dist_moving<100, -1, 0).reshape(-1)
                line_sets = []
                for i in range(link_fps_seq_flat.shape[0]):
                    line = o3d.geometry.LineSet()
                    line.points = o3d.utility.Vector3dVector(link_fps_seq_flat[i])
                    line.lines = o3d.utility.Vector2iVector([[0,1]])

                    if i%(NOB*NFPSB) in topk_idx_B_.tolist():
                        line.colors = o3d.utility.Vector3dVector(np.array([[0,0,1]]))
                    elif not (closest_idx_moving_flat[i]!=-1).any():
                        line.colors = o3d.utility.Vector3dVector(np.array([[0,1,0]]))
                    else:
                        line.colors = o3d.utility.Vector3dVector(np.array([[1,0,0]]))
                    line_sets.append(line)
                
                pq_pnts_flat = jnp.stack([linke_pq_traj_filtered_[...,1:, :, :], linke_pq_traj_filtered_[...,:-1,:,:]], axis=-1).reshape(-1, 7, 2)
                pq_pnts_flat = einops.rearrange(pq_pnts_flat, '... i j k -> ... i k j')[...,:3]

                for i in range(pq_pnts_flat.shape[0]):
                    line = o3d.geometry.LineSet()
                    line.points = o3d.utility.Vector3dVector(pq_pnts_flat[i])
                    line.lines = o3d.utility.Vector2iVector([[0,1]])
                    line.colors = o3d.utility.Vector3dVector(np.array([[1,1,0]]))
                    line_sets.append(line)

                frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
                o3d.visualization.draw_geometries(aabb_linesets+line_sets + [frame])
                break
        

        return linke_pq_traj_filtered, fixed_fps_filtered, topk_idx_A, topk_idx_B, pairwise_t


if __name__ == '__main__':
    import open3d as o3d
    import matplotlib.pyplot as plt

    for _ in range(10):
        pos_random = np.random.uniform(-1, 1, (7, 3,))
        quat_random = tutil.qrand((7,))

        # pos_random = np.linspace(-1,1,4)
        # pos_random = jnp.stack([pos_random, jnp.zeros_like(pos_random), jnp.zeros_like(pos_random)], axis=-1)
        # quat_random = np.random.normal(size=(4, 3))*0.0
        # quat_random = tutil.qExp(quat_random)

        control_pnts = np.concatenate([pos_random, quat_random], axis=-1)[None]
        # control_pnts = pos_random[None]

        twist, acc = derivatives(control_pnts)

        eval_t = np.linspace(0, 1, 200)

        coeffs = SE3_interpolation_coeffs(control_pnts)
        eval_pqc = SE3_interpolation_eval(*coeffs, eval_t)

        coeffs2 = SE3_interpolation_coeffs2(control_pnts)
        eval_pqc2 = SE3_interpolation_eval2(*coeffs2, eval_t)
        
        for cp, eval1, eval2 in zip((control_pnts, twist, acc, None), eval_pqc, eval_pqc2):

            plt.figure()
            for i in range(3):
                plt.subplot(3,1,i+1)
                # plt.plot(np.linspace(0,1,cp.shape[1]), cp[0,:,i])
                plt.plot(np.linspace(0,1,eval1.shape[1]), eval1[0,:,i])
                plt.plot(np.linspace(0,1,eval2.shape[1]), eval2[0,:,i])
                plt.legend(['control', 'eval1', 'eval2'])
            plt.show()

            # dif = jnp.abs(eval1-eval2)
            # assert jnp.max(dif) < 1e-6

        eval_pqc = eval_pqc.squeeze(0)
        control_pnts = control_pnts.squeeze(0)

        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(eval_pqc[...,0])
        plt.show()

        # draw coordinates
        intp_frames = []
        for i in range(eval_pqc.shape[-2]):
            intp_frames.append(o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1).transform(tutil.pq2H(eval_pqc[i])))
        for i in range(control_pnts.shape[-2]):
            intp_frames.append(o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3).transform(tutil.pq2H(control_pnts[i])))
        o3d.visualization.draw_geometries(intp_frames)




    import time
    from tqdm import tqdm
    import pybullet as pb
    import util.model_util as mutil
    import modules.shakey_module as shakey_module
    import util.scene_util as scene_util

    seed = 0

    # load oriCORN model
    models = mutil.Models().load_pretrained_models()
    
    # robot
    models = models.load_self_collision_model('ur5')
    shakey = shakey_module.load_urdf_kinematics(
        urdf_dirs="assets/ur5/urdf/ur5.urdf",
        models=models,
    )

    # robot
    # models = models.load_self_collision_model('shakey')
    # shakey = shakey_module.load_urdf_kinematics(
    #     urdf_dirs="assets/ur5/urdf/shakey_open_rg6.urdf",
    #     models=models,
    # )

    pb.connect(pb.DIRECT)

    # base_se2 = np.array([0, -0.15, np.pi/2.0])
    base_se2 = np.array([0, -0.15, -np.pi/2.0])
    robot_pb_uid = shakey.create_pb(se2=base_se2)
    
    base_pqc = tutil.SE2h2pq(base_se2, np.array(shakey.robot_height))
    robot_base_pqc = jnp.concat(base_pqc, axis=-1)

    q_zero = np.zeros(6)

    broad_phase_cls = BroadPhaseWarp()

    @jax.jit
    def FK_jit(q):
        return tutil.pq_multi(robot_base_pqc, shakey.FK(q, oriCORN_out=False)[shakey.ee_idx])
    

    def sample_valid_q(z_range, jkey):
        # turn on self collision
        # lower_bound = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
        # upper_bound = np.array([2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])
        lower_bound = np.array([-2.8973+np.pi/2, -1.7628, -2.8973, -3.0718, -2.8973, -3.0718])
        upper_bound = np.array([2.8973+np.pi/2, 1.7628, 2.8973, 0.0698, 2.8973, 0.0698])
        while True:
            jkey, _ = jax.random.split(jkey)
            q_random = jax.random.uniform(jkey, (6,), minval=lower_bound, maxval=upper_bound)
            shakey.set_q_pb(robot_pb_uid, q_random)
            pb.performCollisionDetection()
            col_res = pb.getContactPoints(robot_pb_uid)
            if len(col_res) != 0:
                continue
            pqc = FK_jit(q_random)
            if pqc[2] > z_range[0] and pqc[2] < z_range[1]:
                if pqc[1] > 0.2:
                    break
        return q_random


    pb.resetDebugVisualizerCamera(
        cameraDistance=1.79, 
        cameraYaw=-443.20,
        cameraPitch=0.26,
        cameraTargetPosition=[-0.05, 0.07, 0.30]
    )

    # hyperparameter
    success_list = []
    elapsed_time_list = []
    for itr, seed in enumerate(range(100)):
        jkey = jax.random.PRNGKey(seed)
        random_obj, environment_obj, pybullet_scene = scene_util.create_table_sampled_scene(models=models, num_objects=0, seed=seed)
        fixed_obj = environment_obj

        if itr==0:
            broad_phase_cls.enroll_bvh(fixed_obj)

        moving_obj = shakey.link_canonical_oriCORN

        # init, goal sampling
        init_q = sample_valid_q([0.4, 1.0], jkey)
        goal_q = sample_valid_q([0, 0.4], jkey)

        moving_obj_radius = moving_obj.mean_fps_dist[...,None].repeat(moving_obj.nfps, axis=-1)[...,None]
        interpolated_trajectory = jnp.linspace(0, 1, 5)[...,None]*(goal_q - init_q)[...,None,:] + init_q[...,None,:]
        interpolated_trajectory = np.random.normal(size=(10,)+interpolated_trajectory.shape)*0.1 + interpolated_trajectory
        interpolated_trajectory = jnp.concatenate([
                jnp.broadcast_to(base_se2, (interpolated_trajectory.shape[:-1] + (3,))),
                interpolated_trajectory
            ], axis=-1)
        moving_obj_pqs = shakey.FK(interpolated_trajectory, oriCORN_out=False)

        out = broad_phase_cls.path_broad_phase(moving_obj_pqs, moving_obj, 2, visualize=True)
        # out = broad_phase_cls.segment(moving_obj_pqs, moving_obj, models, 1000, visualize=True)









    # rng = np.random.default_rng(123)
    # num_bounds = 200
    # lowers = rng.random(size=(num_bounds, 3)) * 5.0
    # uppers = lowers + rng.random(size=(num_bounds, 3)) * 5.0

    # query_dim = 64*7*50*50
    # query_lower = rng.random(size=(query_dim, 3)) * 5.0
    # query_upper = query_lower + rng.random(size=(query_dim, 3)) * 5.0

    # broad_phase_jax = get_broad_phase_jax(num_bounds)
    # broad_phase_jax = jax.jit(broad_phase_jax)

    # args = [jnp.array(query_lower).astype(jnp.float32), jnp.array(query_upper).astype(jnp.float32), jnp.array(lowers).astype(jnp.float32), jnp.array(uppers).astype(jnp.float32)]
    # args = [args[0][None].repeat(10, axis=0), args[1][None].repeat(10, axis=0), args[2], args[3]]

    # res = broad_phase_jax(*args)
    # # time.sleep(0.2)
    # print('start run')
    # start_t = time.time()
    # for _ in tqdm(range(100)):
    #     res = broad_phase_jax(*args)
    #     # time.sleep(0.2)
    # print('time', (time.time()-start_t)/100)
    # # inputs
    

    # print(1)