import torch

import acquire


class LcBeamPWeightedRelConfCoreset(acquire.LcBeamPWeightedCoreset):

    NAME = "lc-beam-pweighted-relconf-coreset"

    def compute_relative_unlabelled_scores(self, score_unlabelled, score_labelled, min_deltas_indices):
        (U,) = score_unlabelled.size()
        (L,) = score_labelled.size()
        assert min_deltas_indices.size() == (U,)
        counter_scores = score_labelled[min_deltas_indices]

        # NOTE: scores are -logP
        return torch.nn.functional.relu(score_unlabelled - counter_scores, inplace=True)