import torch

import acquire


class ConfidenceBased(acquire.ModelBased):

    def process_labelled_datapool(self, model, labelled_datapool):
        """Can override this to do nothing to save computations if necessary."""
        return None, None

    def compute_z(self, model, X):
        """Extract the necessary tensor z from X using the model."""
        logits = model(X)
        assert len(logits.size()) == 2
        return torch.nn.functional.log_softmax(logits, dim=1)