from __future__ import print_function

from pixyz.losses import KullbackLeibler, Expectation as E, LogProb
from pixyz.losses import Parameter

from .base import Base
from pixyz.distributions import MixtureOfNormal


class MMVAE(Base):
    def get_aggregated_inference(self, q_z, name="q_moe"):
        return MixtureOfNormal(q_z, name=name)

    def get_loss(self):
        beta = Parameter("beta")
        coef = 1. / len(self.modality_id)

        kld_all = sum([coef * KullbackLeibler(self.dist_dict["q_z_%d" %i], self.dist_dict["prior_z"])
                        for i in self.modality_id])

        q_moe = self.get_aggregated_inference([self.dist_dict["q_z_%d" %i] for i in self.modality_id], name="q_part_moe")

        log_probs = 0
        for k in self.modality_id:
            log_probs += self.rec_weight_all[k] * LogProb(self.dist_dict["p_x_%d" %k])

        rec_error = - E(q_moe, log_probs)

        loss = rec_error + kld_all * beta

        return loss
