from .strategy import Strategy
from scipy import stats
import numpy as np
from sklearn.metrics import pairwise_distances
from .builder import STRATEGIES




def init_centers(X, K):
    ind = np.argmax([np.linalg.norm(s, 2) for s in X])
    mu = [X[ind]]
    indsAll = [ind]
    centInds = [0.] * len(X)
    cent = 0
    print('
    while len(mu) < K:
        if len(mu) == 1:
            D2 = pairwise_distances(X, mu).ravel().astype(float)
        else:
            newD = pairwise_distances(X, [mu[-1]]).ravel().astype(float)
            for i in range(len(X)):
                if D2[i] > newD[i]:
                    centInds[i] = cent
                    D2[i] = newD[i]
        print(str(len(mu)) + '\t' + str(sum(D2)), flush=True)
        
        D2 = D2.ravel().astype(float)
        Ddist = (D2 ** 2) / sum(D2 ** 2)
        customDist = stats.rv_discrete(name='custm', values=(np.arange(len(D2)), Ddist))
        ind = customDist.rvs(size=1)[0]
        while ind in indsAll: ind = customDist.rvs(size=1)[0]
        mu.append(X[ind])
        indsAll.append(ind)
        cent += 1
    return indsAll


@STRATEGIES.register_module()
class BadgeSampling(Strategy):
    def __init__(self, dataset, net, args, logger, timestamp):
        super(BadgeSampling, self).__init__(dataset, net, args, logger, timestamp)

    def query(self, n):
        
        gradEmbedding = self.get_embedding(self.clf, split=self.get_ulb_list(), embed_type='grad').cpu().numpy()
        chosen = init_centers(gradEmbedding, n)
        return chosen  
