import logging

import torch
from torch import nn
from torch.distributions.normal import Normal

import utils
from vae.misc import get_trainable_params

logger = logging.getLogger('custom')


class MHVAE(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        print(f'\n====> Building MHVAE (Vasco et al.):')
        self.device = device

        self.encoder = self.build_encoder(args)
        self.decoder = self.build_decoder(args)

        # Build prior
        self.pg = Normal
        self._pg_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, args.stoc_dim['g']),
                         requires_grad=False),
            nn.Parameter(torch.ones(1, args.stoc_dim['g']),
                         requires_grad=False)
        ])

        logger.info(self)
        self.print_number_of_parameters()

    def print_number_of_parameters(self):
        p1 = [
            get_trainable_params(self.encoder['backbone_x1']),
            get_trainable_params(self.encoder['encoder_g']) / 2,
            get_trainable_params(self.encoder['encoder_z1']),
            get_trainable_params(self.decoder['z1_to_x1']),
            get_trainable_params(self.decoder['g_to_z1'])
        ]
        p2 = [
            get_trainable_params(self.encoder['backbone_x2']),
            get_trainable_params(self.encoder['encoder_g']) / 2,
            get_trainable_params(self.encoder['encoder_z2']),
            get_trainable_params(self.decoder['z2_to_x2']),
            get_trainable_params(self.decoder['g_to_z2'])
        ]

        params = {'VAE (x1)': sum(p1),
                  'VAE (x2)': sum(p2),
                  'Total': get_trainable_params(self)}

        logger.info('\nParameter overview:')
        for k, v in params.items():
            logger.info(f'- {k}: {v / 10 ** 6:.1f}M')

    @staticmethod
    def build_encoder(args):
        raise NotImplementedError

    @staticmethod
    def build_decoder(args):
        raise NotImplementedError

    @property
    def pg_params(self):
        return self._pg_params

    def forward(self, x, eval=False):
        # Output: [target_mode][cond_mod][latent_var][dist_kind][dist/samples]
        output = utils.rec_defaultdict()

        # Infer h
        h1 = self.encoder['backbone_x1'](x[0])
        h2 = self.encoder['backbone_x2'](x[1])

        # Domain dropout
        # We employ a uniform distribution over the different combinations
        h_zeros = torch.zeros(h1.size()).to(self.device)

        # Infer posterior over g
        output['x1']['x1']['g'], g1 = self.encoder['encoder_g'](
            h=torch.cat([h1, h_zeros], dim=-1))
        output['x2']['x2']['g'], g2 = self.encoder['encoder_g'](
            h=torch.cat([h_zeros, h2], dim=-1))
        output['joint']['joint']['g'], g_joint = self.encoder['encoder_g'](
            h=torch.cat([h1, h2], dim=-1))

        # Infer posteriors over z
        # (Unimodal posteriors do not depend on g)
        output['x1']['x1']['z'], z1 = self.encoder['encoder_z1'](h1)
        output['x2']['x2']['z'], z2 = self.encoder['encoder_z2'](h2)

        # Generate conditional priors over z
        #   Generate unimodal and crossmodal conditional priors
        g = [g1, g2]
        for t_i, t in enumerate(['x1', 'x2']):
            for c_i, c in enumerate(['x1', 'x2']):
                prior, _ = self.decoder[f'g_to_z{t_i + 1}'](g[c_i])
                output[t][c]['z'].update(prior)
        # Generate multimodal conditional priors
        output['x1']['joint']['z'], _ = self.decoder['g_to_z1'](g_joint)
        output['x2']['joint']['z'], _ = self.decoder['g_to_z2'](g_joint)

        # Generate modalities
        # (Modality generation does not depend on g)
        rec, _ = self.decoder['z1_to_x1'](z1)
        output['x1']['x1'].update(rec)
        rec, _ = self.decoder['z2_to_x2'](z2)
        output['x2']['x2'].update(rec)

        if eval:
            m = ['x1', 'x2']
            for c, t in zip(m, m[::-1]):
                output = self.crossmodal_pass(output, x, c, t)

        return output

    def crossmodal_pass(self, output, x, c: str, t: str):
        """ Evaluation.
        :param c: conditioning modality, from ['x1', 'x2']
        :param t: target modality, from ['x1', 'x2']
        """
        # Infer backbone hidden state
        cur_x = x[0] if c == 'x1' else x[1]
        h = self.encoder[f'backbone_{c}'](cur_x)
        h_zeros = torch.zeros(h.size()).to(self.device)
        if c == 'x1':
            h = torch.cat([h, h_zeros], dim=-1)
        elif c == 'x2':
            h = torch.cat([h_zeros, h], dim=-1)
        else:
            raise ValueError('Modalities must be "x1" or "x2".')

        # Generate target modality
        _, g = self.encoder['encoder_g'](h)
        i = 0 if t == 'x1' else 1
        _, z = self.decoder[f'g_to_z{i + 1}'](g)
        rec, _ = self.decoder[f'z{i + 1}_to_x{i + 1}'](z)
        output[t][c].update(rec)

        return output
