import torch

import acquire
import utils


class DistanceLogitsBased(acquire.ModelBased):

    def compute_z(self, model, X):
        """Extract the necessary tensor z from X using the model."""
        hidden = model.extract_features(X)
        N, H = hidden.size()
        logits = model.classify_hidden_features(hidden)
        scores = self.score_logits(logits)
        assert len(scores.shape) == 1 and len(scores) == len(hidden)
        out = torch.cat([scores.unsqueeze(-1), hidden], dim=1)
        assert out.size() == (N, H+1)
        return out

    @utils.abstract
    def score_logits(self, logits):
        """Return a 1D float tensor of scores given logits.

        Parameters:
        ===========
        logits: float tensor of size (batch, classes).
        """

    def extract_hidden_features_and_scores(self, z):
        scores = z[:, 0]
        hidden = z[:, 1:]
        return scores, hidden
