import copy
import logging

import torch
import torch.nn as nn

import vae.misc as misc
from vae.layers import BottomUp, TopDown
from vae.layers.misc import BaseModule
from vae.layers.stochastic import ScaleRegularizer

logger = logging.getLogger('custom')


class UniModalVae(BaseModule):
    """ Variational Autoencoder that consists of a bottom-up and a top-down
    pass. Represents one modality.

    Semantics:
        - stage: connects two hierarchical levels
    """

    def __init__(self,
                 shapes: dict,
                 deterministic_layer: dict,
                 stochastic_layer: dict,
                 reconstruction_layer: dict,
                 generic_layer: dict,
                 modality: str = 'x1',
                 **kwargs):
        """
        Find detailed hyperparameter descriptions in hyperparams/README.md

        :param shapes:
            - input: shape of input tensor
        :param deterministic_layer: describes layers before stochastic or
        reconstruction layers
        :param stochastic_layer: parameterizes latent distribution
        :param reconstruction_layer: parameterizes reconstruction distribution
        :param generic_layer: describes single layer
        :param modality: name of modality that the VAE represents
        """
        super().__init__()
        logger.info(
            f'\n====> Building Unimodal VAE For Modality {modality}:\n')
        self.modality = modality
        self.modalities = [modality]  # for compatibility with multimodal models
        self.stoc_dim = stochastic_layer['dim']
        self.num_levels = len(self.stoc_dim)
        # define prior in subclass
        self.pg = None
        self._pg_params = None
        if stochastic_layer['learn_prior']:
            self.prior_scale_regularizer = ScaleRegularizer(reg='sum_d')
        else:
            self.prior_scale_regularizer = None

        self._check_input_consistency(deterministic_layer, stochastic_layer)

        # input noise only supported for input vectors
        if len(shapes['input']) == 1:
            reconstruction_layer['dim'] = shapes['input'][-1]
        else:
            # output channel dimensions are equivalent to input channel
            # dimensions
            reconstruction_layer['dim'] = shapes['input'][-3]

        self._input_shape = shapes['input']
        # initialize prior as None and define in derivation
        self.prior, self._pg_params = None, None

        # build layers
        self.bottom_up = BottomUp(shapes['input'],
                                  deterministic_layer,
                                  stochastic_layer,
                                  generic_layer)
        deterministic_layer['bu_hidden_shapes'] = self.bottom_up.hidden_shapes
        self.top_down = TopDown(input_shape=self.bottom_up.output_shape,
                                deterministic_layer=deterministic_layer,
                                stochastic_layer=stochastic_layer,
                                reconstruction_layer=reconstruction_layer,
                                generic_layer=generic_layer)

        logger.debug(
            f'\nFinished building unimodal VAE for modality {modality}. '
            f'Trainable params: {misc.get_trainable_params(self):,}')

    def _check_input_consistency(self,
                                 deterministic_layer: dict,
                                 stochastic_layer: dict):
        """ Checks consistency of input args. """
        msg = 'All hyperparameters must describe the same number ' \
              'of hierarchical levels.'
        assert self.num_levels == len(deterministic_layer['specs_bu']), msg
        assert self.num_levels == len(deterministic_layer['specs_td']), msg
        assert self.num_levels == len(stochastic_layer['specs']), msg

    @property
    def pg_params(self):
        if self.prior_scale_regularizer:
            var = self.prior_scale_regularizer(self._pg_params[1])
        else:
            var = self._pg_params[1]
        return self._pg_params[0], var

    def forward(self, *args, **kwargs):
        return self._forward(*args, **kwargs)

    def _forward(self, *args, cond_gen=False, **kwargs):
        """
        :param cond_gen: reconstruct data by passing top-level sample
        through generative network
        :return:
            - posterior: dict, {target_mod: cond_mod: list}
            - prior: dict, {target_mod: cond_mod: list}
            - reconstruction: dict,
            {target_mod: cond_mod: distribution/samples}
            - ancestral_samples
        """
        posterior, prior, reconstruction, bu_tensors = self._normal_pass(
            *args, **kwargs)

        output = {'posterior': posterior,
                  'prior': prior,
                  'reconstruction': reconstruction}

        if cond_gen:
            ancestral_samples = self.generate(posterior[-1]['samples'])
            # top-level posterior already defined elsewhere
            output['ancestral_samples'] = ancestral_samples + [None]

        return output, bu_tensors

    def _normal_pass(self,
                     x, k,
                     inference_mode: bool = True,
                     return_bu_tensors: bool = False,
                     **kwargs):
        """ Conventional pass through the model.

        :param return_bu_tensors: Whether to return hidden states from
        botton-up pass. This is useful when using top-down pass from different
        modality
        :return: posterior, priors and reconstruction distributions
        """
        bu_tensors, posterior, prior = self.bottom_up_wrapper(x, k)

        cur_post, cur_prior, reconstruction = self.top_down(
            posterior[-1]['samples'],
            bu_tensors=bu_tensors,
            inference_mode=inference_mode,
            **kwargs)
        posterior = cur_post + posterior
        prior = cur_prior + prior

        if not return_bu_tensors:
            bu_tensors = None

        return posterior, prior, reconstruction, bu_tensors

    def bottom_up_wrapper(self, x, k):
        """
        :param x: data
        :param k: number of importance samples from top-level latent variable
        """
        bu_tensors, cur_post = self.bottom_up(x, k=k)
        posterior = [cur_post]
        prior = [None]  # top-level prior is unconditional

        return bu_tensors, posterior, prior

    def generate(self, samples: torch.Tensor, **kwargs):
        """ Passes top-level sample solely through generative network.

        This is computationally slightly inefficient, because first
        conditional prior is computed twice (instead of amortized). However,
        this practice simplifies the code:
            1. Only top-level stochastic block has to be cut open
            2. All necessary data is at one place (instead of distributed
            across dictionaries)

        :param samples: for example from conditional prior below top-level
        to amortize computation
        :return ancestral_samples: List[distribution, samples]
            - list ranges from reconstruction to sample_level-1
        """
        _, cur_prior, reconstruction = self.top_down(samples,
                                                     inference_mode=False,
                                                     **kwargs)

        ancestral_samples = [reconstruction] + cur_prior

        return ancestral_samples

    @torch.no_grad()
    def sample_from_prior(self, k=100, temp=1.):
        """ Sample from unconditional prior. """
        if temp != 1.:
            mean, scale = copy.deepcopy(self.pg_params)
            scale *= temp
            pg_params = mean, scale
        else:
            pg_params = self.pg_params
        pg = self.pg(*pg_params)
        g = pg.rsample((k,)).squeeze(1)
        unconditional_prior = {'dist': pg, 'samples': g}

        return unconditional_prior

    def ancestral_sampling_from_prior(self, **kwargs):
        """
        Sample from unconditional prior and pass samples through generative network.
        :return: values along the way
        """
        self.eval()
        unconditional_prior = self.sample_from_prior(**kwargs)
        ancestral_samples = self.generate(unconditional_prior['samples'], **kwargs)
        ancestral_samples = ancestral_samples + [unconditional_prior]
        return ancestral_samples


class MultiModalVae(BaseModule):
    """ Variational Autoencoder that maximizes likelihood of multiple
    modalities x_{1:M}. """

    def __init__(self, args):
        super().__init__()
        self.vaes = nn.ModuleList()
        self.pg = None
        self._pg_params = None
        if args.learn_prior:
            self.prior_scale_regularizer = ScaleRegularizer(reg='sum_d')
        else:
            self.prior_scale_regularizer = None
        self.modalities = [f'x{i + 1}' for i in range(args.n_modalities)]
        self.n_modalities = len(self.modalities)
        # top-level latent space sizes are identical across all modalities
        self.stoc_dim = args.stoc_dim['x1'][-1]

        self._check_input_consistency(args)

    def _check_input_consistency(self, args):
        if len(self.modalities) > 2:
            raise NotImplementedError(
                'Currently, only one or two modalities are supported')

        msg = 'Top-level latent space must have identical size across' \
              'unimodal VAEs'
        assert args.stoc_dim['x1'][-1] == args.stoc_dim['x2'][-1], msg

    @property
    def pg_params(self):
        if self.prior_scale_regularizer:
            var = self.prior_scale_regularizer(self._pg_params[1])
        else:
            var = self._pg_params[1]
        return self._pg_params[0], var

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def _unimodal_passes(self, *args, **kwargs):
        """ Passes for unimodal VAEs. """
        raise NotImplementedError

    def _crossmodal_generation(self, *args, **kwargs):
        """ Pass samples from x_c to generative net of x_t."""
        raise NotImplementedError

    @torch.no_grad()
    def sample_from_prior(self, k=1, temp=1.):
        """ Sample from unconditional prior. """
        if temp != 1.:
            mean, scale = copy.deepcopy(self.pg_params)
            scale *= temp
            pg_params = mean, scale
        else:
            pg_params = self.pg_params
        pg = self.pg(*pg_params)
        g = pg.rsample((k,)).squeeze(1)
        unconditional_prior = {'dist': pg, 'samples': g}

        return unconditional_prior

    def ancestral_sampling_from_prior(self, *args, **kwargs):
        """
        Sample from unconditional prior and pass samples through generative network.
        :return: values along the way
        """
        raise NotImplementedError
