from __future__ import print_function

import itertools
import sys

from pixyz.losses import KullbackLeibler, Expectation as E, LogProb
from pixyz.losses import Parameter

from .base import Base
from pixyz.distributions import ProductOfNormal

sys.path.append('../')


class MVAE(Base):
    def __init__(self, params, device="cpu"):
        super().__init__(params, device)

        # sub sampling
        self.sub_list = range(2**len(self.modality_id) - len(self.modality_id) - 2)
        self.sub_dict = {"sub_%d" %j: 0 for j in self.sub_list}        

    def get_aggregated_inference(self, q_z, name="q_poe"):
        return ProductOfNormal(q_z, name=name)

    def get_loss(self):
        q_z = [self.dist_dict["q_z_%d" %i] for i in self.modality_id]

        q_z_poe = self.get_aggregated_inference(q_z, name="q_poe")
        q_z_poe_each_modality = {i: self.get_aggregated_inference(self.dist_dict["q_z_%d" %i], name="q_poe") for i in self.modality_id}

        beta = Parameter("beta")

        # ELBO (all modality)
        kld_all = KullbackLeibler(q_z_poe, self.dist_dict["prior_z"])
        log_probs = sum([self.rec_weight_all[i] * LogProb(self.dist_dict["p_x_%d" %i])
                        for i in self.modality_id])

        rec_error = - E(q_z_poe, log_probs)
        loss = rec_error + beta * kld_all
        print("all modality:")
        print(loss)

        # ELBO (each modality)
        loss_modality = []

        for i in self.modality_id:
            kld_all = KullbackLeibler(q_z_poe_each_modality[i], self.dist_dict["prior_z"])
            log_probs = LogProb(self.dist_dict["p_x_%d" %i])

            rec_error = - E(q_z_poe_each_modality[i], self.rec_weight_all[i] * log_probs)
            loss_modality += [rec_error + beta * kld_all]

        print("each modality:")
        print(sum(loss_modality))
        loss += sum(loss_modality)

        return loss
