class Objective:
    def __init__(self):
        # Define in subclass
        self.likelihood_estimator = None

    def __call__(self, model, inp, output, beta=1):
        x, _ = inp
        estimator = self.likelihood_estimator(beta=beta)
        likelihood = estimator.get_total_elbo(model, x, output)
        loss = -likelihood
        loss = loss.mean()  # mean over N and K outside log
        return loss
