import numpy as np
from .strategy import Strategy
from .builder import STRATEGIES


@STRATEGIES.register_module()
class MarginSampling(Strategy):
	def __init__(self, dataset, net, args, logger, timestamp, n_drop=1):
		super(MarginSampling, self).__init__(dataset, net, args, logger, timestamp)
		self.n_drop = n_drop

	def query(self, n):
		probs = -self.predict(self.clf, self.get_ulb_list(), 'margin', n_drop=self.n_drop).cpu()
		return probs.sort()[1][:n]
