import numpy as np
from .strategy import Strategy
from sklearn.cluster import KMeans

from torch.utils.data import Subset

class KMeansSampling(Strategy):
	def __init__(self, train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args):
		super(KMeansSampling, self).__init__(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)

	def query(self, n):
		idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
		embedding = self.predict(Subset(self.train_raw_dataset, idxs_unlabeled), return_prob=False, return_embedding=True)
		embedding = embedding.numpy()
		cluster_learner = KMeans(n_clusters=n)
		cluster_learner.fit(embedding)
		
		cluster_idxs = cluster_learner.predict(embedding)
		centers = cluster_learner.cluster_centers_[cluster_idxs]
		dis = (embedding - centers)**2
		dis = dis.sum(axis=1)
		q_idxs = np.array([np.arange(embedding.shape[0])[cluster_idxs==i][dis[cluster_idxs==i].argmin()] for i in range(n)])

		return idxs_unlabeled[q_idxs]
