import numpy as np
from .strategy import Strategy
import faiss
import random
class KMeansSamplingGPUShuffle(Strategy):
	def __init__(self, X, Y, idxs_lb, net, handler, args):
		super(KMeansSamplingGPUShuffle, self).__init__(X, Y, idxs_lb, net, handler, args)

	def query(self, n):
		idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
		random.shuffle(idxs_unlabeled)
		#print(self.Y[idxs_unlabeled])
		embedding = self.get_embedding(self.X[idxs_unlabeled], self.Y[idxs_unlabeled]).numpy()
		#print(embedding.shape)
		cluster_learner = FaissKmeans(n_clusters = n, gpu = True)
		cluster_learner.fit(embedding)
		
		dis, q_idxs = cluster_learner.predict(embedding)
		q_idxs = q_idxs.T[0]
		#print(self.Y[idxs_unlabeled[q_idxs]])
		return idxs_unlabeled[q_idxs]


class FaissKmeans:
	def __init__(self, n_clusters=8, gpu=True, n_init=10, max_iter=300):
		self.n_clusters = n_clusters
		self.n_init = n_init
		self.max_iter = max_iter
		self.kmeans = None
		self.cluster_centers_ = None
		self.inertia_ = None
		self.gpu = gpu

	def fit(self, X):
		self.kmeans = faiss.Kmeans(d=X.shape[1],
								   k=self.n_clusters,
								   niter=self.max_iter,
								   nredo=self.n_init,
								   gpu = self.gpu)
		self.kmeans.train(X.astype(np.float32))
		self.cluster_centers_ = self.kmeans.centroids
		self.inertia_ = self.kmeans.obj[-1]

	def predict(self, X):
		D, I = self.kmeans.index.search(X.astype(np.float32), 1)
		return D, I
