from disentanglement_vae.likelihoods import MoeLikelihoodEstimator


class MoeObjective:
    def __init__(self):
        self.likelihood_estimator = MoeLikelihoodEstimator

    def __call__(self, model, input, output, beta=1):
        x, _ = input
        estimator = self.likelihood_estimator(beta=beta)
        likelihood = estimator.get_joint_likelihood(
            x=x, model=model, output=output)
        loss = -likelihood
        return loss
