import torch

import acquire


class LcBeamPWeightedCoreset(acquire.LcBeamCoreset):

    NAME = "lc-beam-pweighted-coreset"

    def combine_min_deltas_with_score(self, min_deltas, scores):
        assert min_deltas.size() == scores.size()
        # NOTE: the scores are already negative log prob
        return min_deltas * (1-torch.exp(-scores))  # d * (1-p)
        # d / p => not numerically stable, can explode if -logp is large (i.e. p is small)
        # return min_deltas * torch.exp(scores)