import utils
from disentanglement_vae.models.model import DisentanglementVae


class MoeDisentanglementVae(DisentanglementVae):
    def forward(self, x, eval=False, k=1):
        """
        :param eval:
            True: z_i ~ p(z_i) [evaluation]
            False: z_i ~ q(z_i|x_1) [training]
        """
        output = utils.rec_defaultdict()
        output = self._unimodal_passes(x, output, k)
        if eval:
            output = self._crossmodal_generation_eval(output)
        else:
            output = self._crossmodal_generation_training(output)
        return output

    def _unimodal_passes(self, x, output, k):
        """
        Input: z_i, g
        Output: x_i
        """
        for m, vae in self.vaes.items():
            i = self.modalities.index(m)
            m = self.modalities[i]
            cur_output = vae(x[i], k)
            for cur_type, v in cur_output.items():
                output[cur_type][m][m] = v
        return output
