import torch
import numpy as np
from .strategy import Strategy
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

class LossPredictionLoss(Strategy):
	def __init__(self, X, Y, idxs_lb, net, handler, args, net_lpl):
		super(LossPredictionLoss, self).__init__(X, Y, idxs_lb, net, handler, args, net_lpl)

	def query(self, n):
		idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
		U = self.unc_lpl(self.X[idxs_unlabeled], self.Y[idxs_unlabeled])
		#print(U)
		#print(idxs_unlabeled[U.sort(descending=False)[1][:n]])
		return idxs_unlabeled[U.sort(descending=True)[1][:n]]

	def unc_lpl(self, X, Y):
		loader_te = DataLoader(self.handler(X, Y, transform=self.args['transform']),
							shuffle=False, **self.args['loader_te_args'])
		self.clf.eval()
		self.clf_lpl.eval()
		uncertainty = torch.tensor([]).to(self.device)
		with torch.no_grad():
			for x, y, idxs in loader_te:
				x, y = x.to(self.device), y.to(self.device)
				out, feature = self.clf(x)
				pred_loss = self.clf_lpl(feature)
				pred_loss = pred_loss.view(pred_loss.size(0))
				uncertainty = torch.cat((uncertainty, pred_loss), 0)

		uncertainty = uncertainty.cpu()
		return uncertainty