import numba as nb
import numpy as np
from scipy.spatial import distance


@nb.njit(fastmath=True)
def euclidian_distances(vec):
    """Compute all Euclidian distances"""
    res = np.empty(((vec.shape[0] ** 2) // 2 - vec.shape[0] // 2), dtype=vec.dtype)

    ii = 0
    for i in range(vec.shape[0]):
        for j in range(i + 1, vec.shape[0]):
            res[ii] = np.sqrt((vec[i, 0] - vec[j, 0]) ** 2 + (vec[i, 1] - vec[j, 1]) ** 2)
            ii += 1
    return res


def filter_close_points(points, min_dist):
    """Filter too close points given a batch of points and a minimum distance

    Important note: As the associated complexity is O^2, it is way more efficient to call this method on each thread
    separately and then filter again the aggregation among threads rather than gather and filter all points at the
    same time"""

    filtered_points = []
    distances = euclidian_distances(points)
    len_ = len(points)
    start_ind = 0
    end_ind = len_ - 1
    for i in range(len(points)):
        if distances[start_ind:end_ind][distances[start_ind:end_ind] < min_dist].size == 0:
            filtered_points.append(points[i])
        start_ind += len_ - 1
        end_ind += (len_ - 2)
        len_ -= 1
    return np.array(filtered_points)
