import logging

import torch
from torch import nn

import utils
from methods.misc import parse_unimodal_vae_params
from vae import MultiModalVae, UniModalVae
from vae.misc import get_trainable_params

logger = logging.getLogger('custom')
lme = utils.log_mean_exp


class MultiModalMethod(MultiModalVae):
    def __init__(self, args):
        super().__init__(args)
        for m in self.modalities:
            vae = UniModalVae(**parse_unimodal_vae_params(args, m))
            vae = vae
            self.vaes.append(vae)

        # define prior
        self.pg = utils.get_dist(args.stoc_dist)
        # scale activation:
        #   if requires_grad=True: scale is activated by softplus
        #   if requires_grad=False: scale if fixed
        var = torch.zeros(1, self.stoc_dim) if args.learn_prior else torch.ones(1, self.stoc_dim)
        self._pg_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.stoc_dim),
                         requires_grad=False),  # loc
            nn.Parameter(var, requires_grad=args.learn_prior)
        ])

        self.print_number_of_parameters()

    def print_number_of_parameters(self):
        params = {'VAE (x1)': get_trainable_params(self.vaes[0]),
                  'VAE (x2)': get_trainable_params(self.vaes[1]),
                  'Total': get_trainable_params(self)}

        logger.info('\nParameter overview:')
        for k, v in params.items():
            logger.info(f'- {k}: {v / 10 ** 6:.1f}M')

    @torch.no_grad()
    def ancestral_sampling_from_prior(self, k, **kwargs):
        """
        Sample from unconditional prior and pass samples through generative network.
        :return: values along the way
        """
        self.eval()
        unconditional_prior = self.sample_from_prior(k, **kwargs)
        g = unconditional_prior['samples']

        # sample in every modality-direction
        ancestral_samples = utils.rec_defaultdict()
        for i, vae in enumerate(self.vaes):
            cur_as = vae.generate(g, **kwargs)
            m = self.modalities[i]
            # unconditional prior implicitly defined elsewhere
            ancestral_samples[m]['g'] = cur_as + [None]

        return ancestral_samples

    def _unimodal_passes(self, x, **kwargs):
        output = utils.rec_defaultdict()
        bu_tensors = []

        for i, vae in enumerate(self.vaes):
            cur_output, cur_bu_tensors = vae(
                x[i], return_bu_tensors=True, **kwargs)
            m = self.modalities[i]
            for k, v in cur_output.items():
                output[k][m][m] = v
            bu_tensors.append(cur_bu_tensors)

        return output, bu_tensors

    def _crossmodal_passes(self, output, **kwargs):
        """ Wrapper """
        for c_i in range(self.n_modalities):
            for t_i in range(self.n_modalities):
                if c_i != t_i:
                    output = self._crossmodal_pass(
                        output=output,
                        modalities={'c': self.modalities[c_i],
                                    't': self.modalities[t_i]},
                        **kwargs)
        return output

    def _crossmodal_pass(self, output, modalities, bu_tensors):
        """ Pass sample from top-level posterior of one modality through generative
        network of other modality.
        :param modalities
            keys from 'c' for conditioning and 't' for target
            values from ['x1', 'x2']
        """
        c = modalities['c']
        t = modalities['t']
        t_i = self.modalities.index(t)

        # Conditioning samples
        samples = output['posterior'][c][c][-1]['samples']

        # Pass through generative network of other modality
        vae = self.vaes[t_i]
        posterior, prior, reconstruction = vae.top_down(
            samples,
            bu_tensors=bu_tensors[t_i])

        # Top-level latent space already defined
        output['posterior'][t][c] = posterior + [None]
        output['prior'][t][c] = prior + [None]
        output['reconstruction'][t][c] = reconstruction

        return output

    def _crossmodal_generations(self, output):
        """ Wrapper """
        for i in range(self.n_modalities):
            for j in range(self.n_modalities):
                if i != j:
                    output = self._crossmodal_generation(
                        output, modalities={'c': self.modalities[i],
                                            't': self.modalities[j]})
        return output

    def _crossmodal_generation(self, output, modalities):
        """ Pass samples from x_c to generative net of x_t."""
        c = modalities['c']
        t = modalities['t']

        # conditioning samples
        samples = output['posterior'][c][c][-1]['samples']

        # pass through generative network of other modality
        idx = self.modalities.index(t)
        ancestral_samples = self.vaes[idx].generate(samples)
        # top level already defined elsewhere
        ancestral_samples = ancestral_samples + [None]
        output['ancestral_samples'][t][c] = ancestral_samples

        return output
