from __future__ import print_function

from pixyz.losses import KullbackLeibler, Expectation as E, LogProb
from pixyz.losses import Parameter

from .base import Base
from pixyz.distributions import MixtureOfNormal
from pixyz.distributions import ProductOfNormal


class MMJSD(Base):
    def get_aggregated_inference(self, q_z, name="q_moe"):
        return MixtureOfNormal(q_z, name=name)

    def get_loss(self):
        beta = Parameter("beta")
        coef = 1. / (len(self.modality_id) + 1)

        q_z = [self.dist_dict["q_z_%d" %i] for i in self.modality_id]
        q_all = ProductOfNormal(q_z, name="q_all")        

        kld_all = sum([coef * KullbackLeibler(self.dist_dict["q_z_%d" %i], q_all)
                        for i in self.modality_id])
        kld_all += coef * KullbackLeibler(self.dist_dict["prior_z"], q_all)

        q_moe = self.get_aggregated_inference([self.dist_dict["q_z_%d" %i] for i in self.modality_id], name="q_part_moe")

        log_probs = 0
        for k in self.modality_id:
            log_probs += self.rec_weight_all[k] * LogProb(self.dist_dict["p_x_%d" %k])

        rec_error = - E(q_moe, log_probs)

        loss = rec_error + kld_all * beta

        return loss
