import numpy as np
import torch
from .strategy import Strategy
from .builder import STRATEGIES


@STRATEGIES.register_module()
class BALDDropout(Strategy):
    def __init__(self, dataset, net, args, logger, timestamp, n_drop=5):
        super(BALDDropout, self).__init__(dataset, net, args, logger, timestamp)
        self.n_drop = n_drop

    def query(self, n):
        
        probs = self.predict(self.clf, split=self.get_ulb_list(), metric='prob', n_drop=self.n_drop, dropout_split=True)
        pb = probs.mean(0)
        entropy1 = (-pb * torch.log(pb)).sum(1)
        entropy2 = (-probs * torch.log(probs)).sum(2).mean(0)
        U = (entropy2 - entropy1).cpu()
        return U.sort()[1][:n]  
