import numpy as np
import torch
from .strategy import Strategy

from torch.utils.data import Subset

class EntropySamplingDropout(Strategy):
	def __init__(self, train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args, n_drop=10):
		super(EntropySamplingDropout, self).__init__(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
		self.n_drop = n_drop

	def query(self, n):
		idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
		probs = self.predict_prob_dropout(Subset(self.train_raw_dataset, idxs_unlabeled), self.n_drop)
		log_probs = torch.log(probs)
		U = (probs*log_probs).sum(1)
		return idxs_unlabeled[U.sort()[1][:n]]
