import numpy as np
import torch


class KCentersGreedyAcquisition(object):
    def __init__(self, k_vv, sample_size):
        square_norms = k_vv.diag().unsqueeze(1)
        self._dist_mat = (square_norms - 2 * k_vv + square_norms.t())**0.5
        self._sample_size = sample_size
        self._xi = []

    def next(self):
        self._xi = k_centers_greedy_acquisition(self._dist_mat, self._xi, self._sample_size)
        return self._xi[-1]

    @property
    def xi(self):
        return self._xi


def k_centers_greedy_acquisition(dist_mat, xi, sample_size):
    n = dist_mat.shape[0]

    if len(xi) == 0:
        return torch.randint(n, [1]).tolist()

    pool = list(set(range(n)) - set(xi))
    assert (len(pool) > 0)
    if sample_size < n:
        pool = np.random.choice(pool, sample_size, replace=False).tolist()

    score, _ = dist_mat[pool, :][:, xi].min(dim=1)
    return xi + [pool[torch.argmax(score)]]
