import torch

import acquire


class Top2EntropyBeamCoreset(acquire.LcBeamCoreset):

    NAME = "top2entropy-beam-coreset"

    def score_logits(self, logits):
        """Return a 1D float tensor of scores given logits.

        Parameters:
        ===========
        logits: float tensor of size (batch, classes).
        """
        assert len(logits.shape) == 2
        top2_logits, _ = logits.topk(2, dim=1)
        top2_prob = torch.nn.functional.softmax(top2_logits, dim=1)
        top2_logp = torch.nn.functional.log_softmax(top2_logits, dim=1)
        return -(top2_prob * top2_logp).sum(dim=1)