import torch
import heapq
from scipy.spatial import cKDTree

def build_kdtree(points):
    return cKDTree(points.cpu().numpy())

def kdtree_batch_knn(points, k):
    tree = build_kdtree(points)
    dists, indices = tree.query(points.cpu().numpy(), k=k+1, workers=-1)
    neighbors = indices[:, 1:]
    return torch.from_numpy(neighbors).to(points.device)