import torch

from mhvae_vasco.likelihood.likelihood import LikelihoodEstimator


class LikelihoodEstimatorImages(LikelihoodEstimator):
    def reconstruction(self, x, output):
        """
        Vasco et al. solely optimize unimodal reconstructions.
        p(x_i|z_i) with x_i ~ q(z_i|x_i)
        """
        rec = []
        for m_i, m in enumerate(['x1', 'x2']):
            target = x[m_i]
            if m == 'x1':
                cur_rec = self._compute_rec_sigmoid(
                    target,
                    rec_samples=output[m][m]['rec']['samples'])
            elif m == 'x2':
                cur_rec = self._compute_rec_gaussian(
                    target,
                    rec_dist=output[m][m]['rec']['dist'])
            else:
                raise ValueError('Modality must be "x1" or "x2".')
            rec.append(cur_rec)
        rec = torch.stack(rec).sum(0)
        return rec
