from __future__ import print_function

import itertools
import sys

from pixyz.losses import KullbackLeibler, Expectation as E, LogProb
from pixyz.losses import Parameter

from .base import Base
from pixyz.distributions.poe import ProductOfNormal

sys.path.append('../')


class CRMVAE(Base):
    def __init__(self, params, device="cpu"):
        super().__init__(params, device)

        # sub sampling
        self.sub_list = range(2**len(self.modality_id) - len(self.modality_id) - 2)
        self.sub_dict = {"sub_%d" %j: 0 for j in self.sub_list}        

    def get_aggregated_inference(self, q_z, name="q_poe"):
        return ProductOfNormal(q_z, name=name)

    def get_loss(self):
        beta = Parameter("beta")
        coef = 1. / (len(self.modality_id) + 1)

        kld_all = []

        q_z = [self.dist_dict["q_z_%d" %i] for i in self.modality_id]
        q_all = self.get_aggregated_inference(q_z, name="q_all")

        if self.forward_kl:
            kld_prior = KullbackLeibler(q_all, self.dist_dict["prior_z"])
        else:
            kld_prior = KullbackLeibler(self.dist_dict["prior_z"], q_all)

        if self.kl_set:
            lists = range(1, len(self.modality_id)+1)
        else:
            lists = [1]
        for j in lists:
            for conb in itertools.combinations(self.modality_id, j):
                q_part = [self.dist_dict["q_z_%d" %i] for i in conb]
                q_part_poe = self.get_aggregated_inference(q_part, name="q_part_poe")
                if self.forward_kl:
                    kld_all += [KullbackLeibler(q_all, q_part_poe)]
                else:
                    kld_all += [KullbackLeibler(q_part_poe, q_all)]

        kld_all += [kld_prior]
        kld_all = sum(kld_all) / len(kld_all)

        rec_error = []
        if self.reconst_unimodal:
            lists = [1, len(self.modality_id)]            
        else:
            lists = [len(self.modality_id)]
        for j in lists:
            for conb in itertools.combinations(self.modality_id, j):
                q_part = [self.dist_dict["q_z_%d" %i] for i in conb]
                q_part_poe = self.get_aggregated_inference(q_part, name="q_part_poe")
                log_probs = sum([self.rec_weight_all[i] * LogProb(self.dist_dict["p_x_%d" %i]) for i in conb])
                rec_error += [- E(q_part_poe, log_probs)]

        div_coef = int((len(self.modality_id) + 1))
        if not self.reconst_unimodal:
            div_coef = 1
        rec_error = sum(rec_error) / div_coef

        loss = rec_error + beta * kld_all
        print(loss)

        return loss