from abc import ABC, abstractmethod


class Entropy(ABC):

    def __init__(self):
        pass

    def __call__(self, logits, normalize=False):
        entropy = self._compute(logits)

        if normalize:
            entropy_max = self.entropy_max(logits)
            return entropy / entropy_max
        else:
            return entropy

    @abstractmethod
    def _compute(self, logits):
        pass

    @abstractmethod
    def entropy_max(self, logits):
        pass
