from __future__ import print_function

import itertools

from pixyz.losses import KullbackLeibler, Expectation as E, LogProb
from pixyz.losses import Parameter

from .base import Base
from pixyz.distributions.poe import ProductOfNormal
from pixyz.distributions.moe import MixtureOfNormal


class MoPoE(Base):
    def get_aggregated_inference(self, q_z, name="q_poe"):
        return ProductOfNormal(q_z, name=name)

    def get_loss(self):
        beta = Parameter("beta")
        coef = 1. / (2. ** len(self.modality_id))

        q_comb = []
        kld_all = []
        for j in range(1, len(self.modality_id)+1):
            for conb in itertools.combinations(self.modality_id, j):
                q_part = [self.dist_dict["q_z_%d" %i] for i in conb]
                q_part_poe = self.get_aggregated_inference(q_part, name="q_part_poe")

                q_comb += [q_part_poe]
                kld_all += [coef * KullbackLeibler(q_part_poe, self.dist_dict["prior_z"])]

        q_comb_moe = MixtureOfNormal(q_comb, name="q_moe")        

        kld_all = sum(kld_all)

        log_probs = 0
        for k in self.modality_id:
            log_probs += self.rec_weight_all[k] * LogProb(self.dist_dict["p_x_%d" %k])

        rec_error = - E(q_comb_moe, log_probs)

        loss = rec_error + kld_all * beta

        return loss
