import jax
import jax.numpy as jnp
from flax import struct
from functools import partial
import numpy as np

def depth_calculate(points, min_leaf_size):
    Npnts = points.shape[-2]
    # Automatically determine depth
    depth=int(np.ceil(np.log2(Npnts / min_leaf_size)))
    depth=max(depth, 1)  # Ensure at least depth 1
    return depth

def bruth_force_search(points, query_points):
    pairwise_dif = jnp.sum((points[...,:,None,:] - query_points[...,None,:,:])**2, axis=-1)
    best_idx = jnp.argmin(pairwise_dif, axis=-2)
    minval = jnp.take_along_axis(pairwise_dif, best_idx[...,None,:], axis=-2).squeeze(-2)
    return best_idx, minval

@struct.dataclass
class KDTree:
    """A simple KDTree implementation.
    Dimensions of points always: 3
    Axis should be ordered by x -> y -> z (or 0 -> 1 -> 2)
    """
    points: jnp.ndarray  # reference points [Npnts, 3]
    median_vals: jnp.ndarray  # median values at each node [total_nodes]
    leaf_group_idx: jnp.ndarray  # indices of points in each leaf [num_leaves, max_leaf_size]
    axis_order: jnp.ndarray  # order of axes to split on [depth]

    def visualize(self):
        import open3d as o3d
        nleaf_elements = self.leaf_group_idx.shape[-1]
        num_leaves = np.prod(self.leaf_group_idx.shape[:-1])
        # Generate distinct colors for each leaf
        colors = np.random.rand(num_leaves, 3)
        # Convert JAX arrays to NumPy
        points = np.array(self.points)
        leaf_group_idx = np.array(self.leaf_group_idx).reshape(-1, nleaf_elements)
        point_colors = np.zeros((points.shape[0], 3))
        for i in range(num_leaves):
            idxs = leaf_group_idx[i]
            valid_idxs = idxs[idxs >= 0]
            point_colors[valid_idxs] = colors[i]

        # Create Open3D point cloud
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(point_colors)
        # Visualize
        o3d.visualization.draw_geometries([pcd])


def build_tree(axis_order_type, points, depth):
    axis_orders = jnp.array([[0,1,2],[2,0,1],[1,0,2],[0,2,1],[1,2,0],[2,1,0]])[axis_order_type]
    Npnts = points.shape[-2]

    median_val_by_depth = []
    cur_leaf_groups = jnp.arange(Npnts)
    cur_leaf_groups_pnts = points
    for dep in range(depth):
        axis = axis_orders[dep%3]
        cur_npnts = cur_leaf_groups_pnts.shape[-2]
        sorted_idx = jnp.argsort(cur_leaf_groups_pnts[..., axis])
        cur_leaf_groups_sorted = jnp.take_along_axis(cur_leaf_groups, sorted_idx, axis=-1)
        sorted_val = jnp.take_along_axis(cur_leaf_groups_pnts[..., axis], sorted_idx, axis=-1)
        median_idx = cur_npnts // 2
        if cur_npnts%2 == 0: # even
            left_val = sorted_val[...,median_idx-1]
            right_val = sorted_val[...,median_idx]
            median_value = (left_val + right_val) / 2
            left_split_idx = median_idx
            right_split_idx = median_idx
        else: # odd
            median_value = sorted_val[...,median_idx]
            left_split_idx = median_idx+1
            right_split_idx = median_idx
        # gather left right groups
        left_group = cur_leaf_groups_sorted[..., :left_split_idx]
        right_group = cur_leaf_groups_sorted[..., right_split_idx:]
        left_rel_idx = sorted_idx[..., :left_split_idx]
        right_rel_idx = sorted_idx[..., right_split_idx:]
        assert left_group.shape[-1] == right_group.shape[-1]
        cur_leaf_groups = jnp.stack([left_group, right_group], axis=-2)
        cur_leaf_groups_pnts = jnp.take_along_axis(cur_leaf_groups_pnts[...,None,:,:], jnp.stack([left_rel_idx, right_rel_idx], axis=-2)[...,None], axis=-2)
        median_val_by_depth.append(median_value.reshape(-1))
    
    # visualize devided points
    visualize = False
    if visualize:
        import open3d as o3d
        points_by_group = cur_leaf_groups_pnts.reshape(-1, cur_leaf_groups_pnts.shape[-2], 3)
        ngroup = points_by_group.shape[0]
        colors = np.random.rand(ngroup, 3)
        colors = np.repeat(colors[...,None,:], points_by_group.shape[1], axis=1)
        colors = colors.reshape(-1,3)
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points_by_group.reshape(-1, 3))
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

    median_val_by_depth = jnp.concatenate(median_val_by_depth, axis=-1) # depth
    return KDTree(points=points, median_vals=median_val_by_depth, leaf_group_idx=cur_leaf_groups, axis_order=axis_orders)

def left_right_traj_to_idx(left_right_traj):
    idx = 0
    for i, left in enumerate(left_right_traj):
        idx = 2*idx + left + 1
    idx = idx - 2 ** len(left_right_traj) + 1
    return idx

def query(kdtree:KDTree, points, depth):
    '''
    kdtree: for one set of kdtree
    points: (Npnts, 3)
    '''
    left_right_traj = []
    npnts = points.shape[0]
    median_idx = jnp.zeros(points.shape[:-2], dtype=jnp.int32)[None]
    leaf_group_idx = kdtree.leaf_group_idx[None].repeat(npnts, 0) # add batch dim
    for dep in range(depth):
        start_node = 2 ** dep - 1
        end_node = 2 ** (dep + 1) - 1
        axis = kdtree.axis_order[dep%3]
        median_value = jnp.take_along_axis(kdtree.median_vals[...,start_node:end_node], median_idx, axis=-1)
        left_right_traj.append(points[..., axis] >= median_value)
        median_idx = left_right_traj_to_idx(left_right_traj)
        original_shape = leaf_group_idx.shape
        leaf_group_idx = jnp.take_along_axis(leaf_group_idx.reshape(leaf_group_idx.shape[0], 2, -1), left_right_traj[-1][...,None,None].astype(jnp.int32), axis=1).reshape(original_shape[0:1]+original_shape[2:])
    leaf_points = jnp.take_along_axis(kdtree.points[None], leaf_group_idx[...,None], axis=-2)

    visualize = False
    if visualize:
        import open3d as o3d
        for qidx in range(points.shape[0]):
            leaf_points_vis = leaf_points[qidx].reshape(-1, 3)
            leaf_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(leaf_points_vis)).paint_uniform_color([1,0,0])
            query_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
            query_sphere.compute_vertex_normals()
            query_sphere.paint_uniform_color([0,1,0])
            query_sphere.translate(points[qidx])
            o3d.visualization.draw_geometries([leaf_pcd, query_sphere])

    
    # return best_idx, minval
    best_idx, min_val = bruth_force_search(leaf_points, points[:,None])
    best_idx = best_idx.squeeze(1)
    min_val = min_val.squeeze(1)
    
    # recover to original index
    best_idx = jnp.take_along_axis(leaf_group_idx, best_idx[...,None], axis=-1).squeeze(-1)
    return best_idx, min_val

def batch_kdtree_query(kdtrees, query_points, depth):
    best_idx, min_dists_sq = jax.vmap(query, (0,None, None))(kdtrees, query_points, depth)
    min_idx = jnp.argmin(min_dists_sq, axis=0)
    best_idx = jnp.take_along_axis(best_idx, min_idx[None], axis=0).squeeze(0)
    min_dists_sq = jnp.take_along_axis(min_dists_sq, min_idx[None], axis=0).squeeze(0)
    return best_idx, min_dists_sq

def batch_nearest_neighbor(sorce_points, query_points, min_leaf_size):
    depth = depth_calculate(sorce_points, min_leaf_size)
    if depth == 1:
        return bruth_force_search(sorce_points, query_points)
    else:
        batch_kdtree_build = partial(jax.vmap(build_tree, (0, None, None)), jnp.arange(3))
        kdtrees = batch_kdtree_build(sorce_points, depth)
        return batch_kdtree_query(kdtrees, query_points, depth)


if __name__ == '__main__':
    # Example usage
    import numpy as np

    np.random.seed(0)

    points = jnp.array(np.random.rand(1000, 3))
    # generate 8 points grid in 3D
    # points = np.meshgrid(np.linspace(0, 1, 4), np.linspace(-1, 0, 2), np.linspace(0, 1, 2))
    # points = np.stack([points[0].ravel(), points[1].ravel(), points[2].ravel()], axis=-1)
    # points = jnp.array(points)
    query_points = jnp.array(np.random.rand(4, 3))
    # query_points = jnp.array(np.random.normal(size=(100, 3))*100)
    # query_points = jnp.array([[0.77815676, 0.87001216, 0.9786183]])
    
    
    min_leaf_size = 10
    # depth = depth_calculate(points, min_leaf_size)
    # depth = 5

    # kdtree = build_tree(0, points, depth)

    # best_idx, min_dists_sq = query(kdtree, query_points, depth)

    # kdtree.visualize()



    # kdtrees = jax.jit(batch_kdtree_build, static_argnums=(1,))(points, depth)
    # res = jax.vmap(query, (0, None, None))(kdtree, query_points, depth)
    # best_idx, min_dists_sq = jax.jit(batch_kdtree_query, static_argnums=(2,))(kdtrees, query_points, depth)

    best_idx, min_dists_sq = jax.jit(batch_nearest_neighbor,static_argnums=(2,))(points, query_points, min_leaf_size)

    # min_leaf_size = 10
    # depth = depth_calculate(points, min_leaf_size)

    # kdgrees = jax.vmap(build_kdtree, in_axes=(None, 0, None))(points, jnp.arange(3), depth)

    # Query with some random points
    # query_points = jnp.array(np.random.rand(4, 3))

    # best_idx, min_dists_sq = batch_kdtree_query(kdgrees, query_points, depth)

    best_idx_bf, min_dists_sq_bf = bruth_force_search(points, query_points)

    print("Nearest neighbor indices:", best_idx, best_idx_bf)
    print("Minimum distances squared:", min_dists_sq, min_dists_sq_bf)

    invalid_mask = jnp.abs(min_dists_sq-min_dists_sq_bf)>0.001
    print("Minimum distances squared:", invalid_mask)

    print(query_points[invalid_mask])