from methods.mm_method import MultiModalMethod
from .estimator import MoeLikelihoodEstimator


class Model(MultiModalMethod):
    def __init__(self, args):
        super().__init__(args)
        self.multimodal = True

    def forward(self, x, eval=False, **kwargs):
        output, bu_tensors = self._unimodal_passes(x, **kwargs)
        output = self._crossmodal_passes(output, bu_tensors=bu_tensors)
        if eval:
            output = self._crossmodal_generations(output)
        return output

    def evaluate_likelihood(self, inp, output, **kwargs):
        x, _ = inp
        estimator = MoeLikelihoodEstimator()
        likelihood = estimator.get_evaluation_likelihood(
            model=self, x=x, output=output, **kwargs)
        return likelihood
