import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from .strategy import Strategy
from copy import deepcopy
import random

class TwoStage(Strategy):
	def __init__(self, X, Y, idxs_lb, net, handler, args):
		super(TwoStage, self).__init__(X, Y, idxs_lb, net, handler, args)

	def query(self, n):
		idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
		X = self.X[idxs_unlabeled]
		Y = self.Y[idxs_unlabeled]
		U = self.entropy(X,Y)
		maha_score = self.maha(X,Y)
		#_, select_idx = select_POSS_intersection(U.tolist(), maha_score.tolist(), None, n)
		#top = [i for i in range(len(select_idx)) if select_idx[i] == 1.]
		pre_select = []
		for i in range(0,U.shape[0]):
			if maha_score[i] > (np.average(maha_score)+np.min(maha_score))/2.0:
				pre_select.append(1.0)
			else:
				pre_select.append(0.0)
		combine = torch.mul(torch.tensor(pre_select), U)
		return idxs_unlabeled[combine.sort(descending = True)[1][:n]]

	def entropy(self, X, Y):
		probs = self.predict_prob(X, Y)
		log_probs = torch.log(probs)
		# the larger the better
		U = (-1.0 * probs * log_probs).sum(1)
		return U
	
	def maha(self, X, Y):

		# get id train loader
		idxs_train = np.arange(self.n_pool)[self.idxs_lb]
		X_train_full = self.X[idxs_train]
		Y_train_full = self.Y[idxs_train]
		a = list(range(Y_train_full.shape[0]))
		b = torch.where(Y_train_full<0)[0].numpy()
		d = sorted(list(set(a).difference(set(b))))
		Y_train = torch.index_select(Y_train_full, 0, torch.tensor(d))
		if type(X_train_full) is np.ndarray:
			tmp = deepcopy(X_train_full)
			tmp = torch.from_numpy(tmp)
			X_train = torch.index_select(tmp, 0, torch.tensor(d))
			X_train = X_train.numpy().astype(X_train_full.dtype)
		else:
			X_train = torch.index_select(X_train_full, 0, torch.tensor(d))
		
		train_loader = DataLoader(self.handler(X_train, Y_train, transform=self.args['transform_train']),
							**self.args['loader_tr_args'])

		# set feature_list shape
		model = self.get_model()
		temp_x = torch.rand(2, X_train[0].shape[2], X_train[0].shape[0], X_train[0].shape[1]).to(self.device)
		temp_x = Variable(temp_x)
		temp_list = model.feature_list(temp_x)[1]
		num_output = len(temp_list)
		feature_list = np.empty(num_output)
		count = 0
		for out in temp_list:
			feature_list[count] = out.size(1)
			count += 1

		sample_mean, sample_cov = self.sample_estimator(model, self.args['num_class'], feature_list, train_loader)
		# get mahalanobis score
		test_loader = DataLoader(self.handler(X, Y, transform=self.args['transform']),
							shuffle=False, **self.args['loader_te_args'])

		for i in range(num_output):
			M_score = self.get_Mahalanobis_score(model, test_loader, self.args['num_class'], sample_mean, sample_cov, i)
			if i == 0:
				Maha_score = M_score.reshape(M_score.shape[0], -1)
			else:
				Maha_score = np.concatenate((Maha_score, M_score.reshape((M_score.shape[0], -1))), axis=1)
		Maha_score = np.asarray(Maha_score, dtype = np.float32)

		# the smaller the better
		Maha_avg_score = np.mean(Maha_score, axis = 1)

		inv_Maha_avg_score = np.max(Maha_avg_score) - Maha_avg_score
		return inv_Maha_avg_score
		
	def get_Mahalanobis_score(self, model, test_loader, num_classes, sample_mean, precision, layer_index, magnitude = 0.01):
		'''
		Compute the proposed Mahalanobis confidence score on input dataset
		return: Mahalanobis score from layer_index
		'''
		model.eval()
		Mahalanobis = []
		
		
		for data, target, idx in test_loader:
			
			data, target = data.to(self.device), target.to(self.device)
			data, target = Variable(data, requires_grad = True), Variable(target)
			
			out_features = model.intermediate_forward(data, layer_index)
			out_features = out_features.view(out_features.size(0), out_features.size(1), -1)
			out_features = torch.mean(out_features, 2)
			
			# compute Mahalanobis score
			gaussian_score = 0
			for i in range(num_classes):
				batch_sample_mean = sample_mean[layer_index][i]
				zero_f = out_features.data - batch_sample_mean
				term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
				if i == 0:
					gaussian_score = term_gau.view(-1,1)
				else:
					gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1)
			
			# Input_processing
			sample_pred = gaussian_score.max(1)[1]
			batch_sample_mean = sample_mean[layer_index].index_select(0, sample_pred)
			zero_f = out_features - Variable(batch_sample_mean)
			pure_gau = -0.5*torch.mm(torch.mm(zero_f, Variable(precision[layer_index])), zero_f.t()).diag()
			loss = torch.mean(-pure_gau)
			loss.backward()
			
			gradient =  torch.ge(data.grad.data, 0)
			gradient = (gradient.float() - 0.5) * 2
			gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).to(self.device)) / (0.2023))
			gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).to(self.device)) / (0.1994))
			gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).to(self.device)) / (0.2010))
			tempInputs = torch.add(data.data, -magnitude, gradient)
	
			noise_out_features = model.intermediate_forward(Variable(tempInputs, volatile=True), layer_index)
			noise_out_features = noise_out_features.view(noise_out_features.size(0), noise_out_features.size(1), -1)
			noise_out_features = torch.mean(noise_out_features, 2)
			noise_gaussian_score = 0
			for i in range(num_classes):
				batch_sample_mean = sample_mean[layer_index][i]
				zero_f = noise_out_features.data - batch_sample_mean
				term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
				if i == 0:
					noise_gaussian_score = term_gau.view(-1,1)
				else:
					noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1,1)), 1)	  

			noise_gaussian_score, _ = torch.max(noise_gaussian_score, dim=1)
			Mahalanobis.extend(noise_gaussian_score.cpu().numpy())

		return np.array(Mahalanobis)

	def sample_estimator(self,model, num_classes, feature_list, train_loader):
		"""
		compute sample mean and precision (inverse of covariance)
		return: sample_class_mean: list of class mean
				precision: list of precisions
		"""
		import sklearn.covariance
		
		model.eval()
		group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False)

		num_output = len(feature_list)
		num_sample_per_class = np.empty(num_classes)
		num_sample_per_class.fill(0)
		list_features = []
		for i in range(num_output):
			temp_list = []
			for j in range(num_classes):
				temp_list.append(0)
			list_features.append(temp_list)
		
		for data, target,idx in train_loader:
			data = data.to(self.device)
			data = Variable(data, volatile=True)
			output, out_features = model.feature_list(data)
			
			# get hidden features
			for i in range(num_output):
				out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1)
				out_features[i] = torch.mean(out_features[i].data, 2)
				
			
			# construct the sample matrix
			for i in range(data.size(0)):
				label = target[i]
				if num_sample_per_class[label] == 0:
					out_count = 0
					for out in out_features:
						list_features[out_count][label] = out[i].view(1, -1)
						out_count += 1
				else:
					out_count = 0
					for out in out_features:
						list_features[out_count][label] \
						= torch.cat((list_features[out_count][label], out[i].view(1, -1)), 0)
						out_count += 1				
				num_sample_per_class[label] += 1
				
		sample_class_mean = []
		out_count = 0
		for num_feature in feature_list:
			temp_list = torch.Tensor(num_classes, int(num_feature)).cuda()
			for j in range(num_classes):
				temp_list[j] = torch.mean(list_features[out_count][j], 0)
			sample_class_mean.append(temp_list)
			out_count += 1
			
		precision = []
		for k in range(num_output):
			X = 0
			for i in range(num_classes):
				if i == 0:
					X = list_features[k][i] - sample_class_mean[k][i]
				else:
					X = torch.cat((X, list_features[k][i] - sample_class_mean[k][i]), 0)
					
			# find inverse			
			group_lasso.fit(X.cpu().numpy())
			temp_precision = group_lasso.precision_
			temp_precision = torch.from_numpy(temp_precision).float().cuda()
			precision.append(temp_precision)
			
		return sample_class_mean, precision
