#!/usr/bin/env python3
import torch
import uimnet
from uimnet.measures.measure import Measure
from uimnet import utils


class Entropy(Measure):
  """
    Entropy
    """

  def __init__(self, algorithm):
    super(Entropy, self).__init__(algorithm=algorithm)

  def forward(self, loader):
    logits = self.algorithm.collect_logits_and_targets(loader)['logits']
    utils.message(f'All logits collected')
    sm = logits.softmax(1)
    utils.message('Softmax applied')
    zeros = torch.zeros(()).to(sm.device)
    h = sm * torch.where(sm == 0, zeros, torch.log(sm))
    return h.sum(1).mul(-1)


  # def forward(self, x):
  #   sm = self.algorithm(x).softmax(1)
  #   zeros = torch.zeros(()).to(sm.device)
  #   h = sm * torch.where(sm == 0, zeros, torch.log(sm))
  #   return h.sum(1).mul(-1)


if __name__ == '__main__':
  pass
