import torch

def farthest_point_sampling(data: torch.Tensor, num_samples: int):
    """ Picks k points in the data based on the kmeans++ method.

    Parameters
    ----------
    data : torch.Tensor
        Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
        data, rank 2 multidimensional data, in which case one
        row is one observation.
    num_samples : int
        Number of samples to generate.

    Returns
    -------
    init : ndarray
        A 'num_samples' by 'N' containing the initial centroids.

    References
    ----------
    .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
       careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
       on Discrete Algorithms, 2007.
    .. [2] scipy/cluster/vq.py: _kpp
    """

    dims = data.shape[1] if len(data.shape) > 1 else 1
    init = torch.zeros((num_samples, dims)).to(data.device)

    r = torch.distributions.uniform.Uniform(0, 1)
    for i in range(num_samples):
        if i == 0:
            init[i, :] = data[torch.randint(0,data.shape[0], (1,))]
        else:
            D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0)
            probs = D2 / torch.sum(D2)
            cumprobs = torch.cumsum(probs, dim=0)
            init[i, :] = data[torch.searchsorted(
                cumprobs, r.sample([1]).to(data.device))]
    return init


init_methods = {
    "fps": farthest_point_sampling
}
