"""
Builds bottom-up pass and top-down passes for variational autoencoders.
"""
import copy
import logging
from typing import Tuple, List, Optional

import numpy as np
import torch.nn as nn
from torch import Tensor

from vae.layers.dense import DenseBlock
from .convolution import ConvMergeLayer, ConvBlock
from .dense import DenseMergeLayer
from .load import get_layer
from .misc import BaseModule, reshape
from .stochastic import get_stochastic_layer, StocDense

logger = logging.getLogger('custom')


class BottomUp(BaseModule):
    """ Parameterizes bottom-up hidden states and infers top-level posterior.
    """

    def __init__(self,
                 input_shape: Tuple[int],
                 deterministic_layer: dict,
                 stochastic_layer: dict,
                 generic_layer):
        super().__init__()
        self.num_levels = len(stochastic_layer['dim'])
        self._input_shape = input_shape
        self._hidden_shapes = []

        logger.info(f'\n{"-" * 100}\nBuilding Bottom-Up Stages:')
        self.deterministic_layers = nn.ModuleList()
        self.stochastic_layer = None
        self._build_layers(deterministic_layer, stochastic_layer, generic_layer)

        assert isinstance(self.stochastic_layer, BaseModule)
        self._output_shape = self.stochastic_layer.output_shape

    def _build_layers(self,
                      deterministic_layer,
                      stochastic_layer,
                      generic_layer):
        output_shape = self._build_deterministic_layer(deterministic_layer, generic_layer)
        self._build_stochastic_layer(output_shape, stochastic_layer)

    def _build_deterministic_layer(self, deterministic_layer, generic_layer):
        cur_input_shape = self._input_shape
        for idx in range(self.num_levels):
            specs = deterministic_layer['specs_bu'][idx]
            stage = Block(input_shape=cur_input_shape,
                          specs=specs,
                          generic_layer=generic_layer)
            cur_input_shape = stage.output_shape
            self._hidden_shapes.append(stage.output_shape)
            self.deterministic_layers.append(stage)
            logger.debug(
                f'\n====> Deterministic Bottom-Up Layer '
                f'(Level {idx + 1}/{self.num_levels}):'
                f'\n{stage}')
        self.deterministic_layers = nn.Sequential(*self.deterministic_layers)
        return cur_input_shape

    def _build_stochastic_layer(self, input_shape, stochastic_layer):
        cur_stoc_layer = copy.deepcopy(stochastic_layer)
        cur_stoc_layer['dim'] = cur_stoc_layer['dim'][-1]
        cur_stoc_layer['specs'] = cur_stoc_layer['specs'][-1]
        self.stochastic_layer = get_stochastic_layer(
            input_shape=input_shape,
            stoc_type=cur_stoc_layer['specs']['t'],
            stochastic_layer=cur_stoc_layer)
        logger.debug(f'\n====> Stochastic Top-Level Layer:'
                     f'\n{self.stochastic_layer}')

    def forward(self, x: Tensor, k=None, **kwargs):
        """
        :param x: input data
        :return: hidden states for levels [1..L], posterior
        """
        hidden_states = []
        for idx, stage in enumerate(self.deterministic_layers):
            x = stage(x)
            # only save hidden states for hierarchical levels that will
            # receive information both bottom-up and top-down
            if idx < len(self.deterministic_layers) - 1:
                hidden_states.append(x)
            else:
                hidden_states.append(None)

        if k is None:
            # top-level variable should always have importance sampling
            # dimension
            k = 1
        posterior = self.stochastic_layer(x, k=k, **kwargs)
        return hidden_states, posterior

    @property
    def hidden_shapes(self) -> List[int]:
        """Bottom-up information for stochastic variables. """
        return self._hidden_shapes


class TopDown(BaseModule):
    """ Top-down pass from top-level variable to data space. """

    def __init__(self, stochastic_layer, **kwargs):
        super().__init__()
        self.num_levels = len(stochastic_layer['dim'])

        logger.info(f'\n{"-" * 100}\nBuilding Top-Down Stages:')
        self.stages = nn.ModuleList()
        self._build_layers(stochastic_layer=stochastic_layer, **kwargs)

        # information coming from above
        self._input_shape = self.stages[-1].input_shape
        # information leaving below
        self._output_shape = self.stages[0].output_shape

    def _build_layers(self,
                      input_shape: Tuple[int],
                      deterministic_layer: dict,
                      stochastic_layer: dict,
                      reconstruction_layer: dict,
                      generic_layer: dict):
        cur_input_shape = input_shape
        for idx_b in range(self.num_levels)[::-1]:
            is_bottom = idx_b == 0

            sl = copy.deepcopy(stochastic_layer)
            dl = copy.deepcopy(deterministic_layer)
            dl['specs_td'] = deterministic_layer['specs_td'][idx_b]
            if up := sl['upsampling']:
                sl['upsampling'] = up[idx_b]  # upsample latent sample

            if is_bottom:
                logger.debug(
                    f'\n====> Top-Down Stage (Level 1/{self.num_levels}):')
                stage = LowerStage(
                    input_shape=cur_input_shape,
                    deterministic_layer=dl,
                    stochastic_layer=sl,
                    reconstruction_layer=reconstruction_layer,
                    generic_layer=generic_layer)
                self.stages.insert(0, stage)

            else:
                logger.debug(
                    f'\n====> Top-Down Stage '
                    f'(Level {idx_b + 1}/{self.num_levels}):')
                sl['dim'] = sl['dim'][idx_b - 1]
                sl['specs'] = sl['specs'][idx_b - 1]
                stage = IntermediateStage(
                    input_shape=cur_input_shape,
                    deterministic_layer=dl,
                    stochastic_layer=sl,
                    generic_layer=generic_layer)
                cur_input_shape = stage.output_shape
                self.stages.insert(0, stage)

    def forward(self,
                samples: Tensor,
                bu_tensors: Optional[List[Tensor]] = None,
                inference_mode: bool = True,
                mode_layer: Optional[int] = None,
                **kwargs):
        """
        :param samples from upper-level latent distribution
        :param bu_tensors: deterministic states from bottom-up pass
        :param inference_mode: do inference and generation, else only do
        generation
        :param mode_layer: Sample from mode from this level onwards.
        For example:
            mode_level=1: Sample mode for latent space 1
            mode_level=2: Sample mode for latent spaces 1 and 2
        :return: posteriors, priors, reconstruction
        """
        self._check_input_consistency(bu_tensors, inference_mode)

        # iterate through top-down blocks
        posteriors, priors, reconstruction = [], [], None
        for idx_b in np.arange(self.num_levels)[::-1]:
            is_bottom, use_mode = self._prepare_stage_construction(
                idx_b, mode_layer, inference_mode)

            if is_bottom:
                reconstruction = self.stages[idx_b](samples)
            else:
                # inject information from bottom-up stage that maps onto same
                # latent space (if such information is present)
                bottom_up = bu_tensors[idx_b - 1] if bu_tensors is not None \
                    else None
                samples, cur_post, cur_prior = self.stages[idx_b](
                    samples, bottom_up, inference_mode, use_mode=use_mode,
                    **kwargs)
                posteriors.insert(0, cur_post)
                priors.insert(0, cur_prior)

        self._check_output_consistency(posteriors, inference_mode)
        return posteriors, priors, reconstruction

    @staticmethod
    def _prepare_stage_construction(idx_b, mode_layer, inference_mode):
        is_bottom = idx_b == 0

        use_mode = False
        if mode_layer:
            msg = 'Do not sample from mode during inference mode'
            assert inference_mode is False, msg
            if idx_b <= mode_layer and not is_bottom:
                use_mode = True
            else:
                use_mode = False

        return is_bottom, use_mode

    @staticmethod
    def _check_input_consistency(bu_tensors, inference_mode):
        """ For total top-down pass. """
        if inference_mode:
            assert bu_tensors is not None, 'posterior can only be inferred with ' \
                                           'bottom-up info'
        else:
            assert bu_tensors is None, 'generative model cannot have access to ' \
                                       'bottom-up info'

    @staticmethod
    def _check_output_consistency(posterior, inference_mode):
        """ For total top-down pass. """
        if not inference_mode:
            for v in posterior:
                assert v is None, 'posterior can only exist in inference mode.'


class IntermediateStage(BaseModule):
    """ Maps between two latent spaces in top-down pass. """

    def __init__(self, **kwargs):
        super().__init__()
        # define in method
        self.upsample = None
        self.deterministic_layer = None
        self.stochastic_gen = None
        self.merger = None
        self.stochastic_inf = None
        self.upsample = None
        self.upsampling_specs = {}

        self._build_layers(**kwargs)

        assert isinstance(self.deterministic_layer, BaseModule)
        assert isinstance(self.stochastic_inf, BaseModule)
        self._input_shape = self.deterministic_layer.input_shape
        self._output_shape = self.stochastic_inf.output_shape

    def _build_layers(self,
                      input_shape,
                      deterministic_layer,
                      stochastic_layer,
                      generic_layer):
        output_shape = self._build_upsample_layer(
            input_shape, deterministic_layer, stochastic_layer, generic_layer)
        self._build_shared_backbone(output_shape, deterministic_layer, generic_layer)
        self._build_posterior_inference(stochastic_layer, generic_layer)
        self._build_prior_generation(stochastic_layer, generic_layer)

    def _build_upsample_layer(self, input_shape,
                              deterministic_layer,
                              stochastic_layer,
                              generic_layer):
        if up_specs := stochastic_layer.get('upsampling'):
            if up_specs['t'] == 'dense':
                # transform to flattened previous shape
                self.upsampling_specs['type'] = 'dense'
                self.upsampling_specs['shape'] = tuple(up_specs['reshape'])
                self.upsample = DenseBlock(input_shape,
                                           specs={'out': np.prod(self.upsampling_specs['shape'])},
                                           **generic_layer)
                output_shape = self.upsampling_specs['shape']
            elif 'conv' in up_specs['t']:
                self.upsampling_specs['type'] = 'conv'
                specs = dict(k=1, s=1, p=0)
                # Find upcoming channel dimension
                # (note that spatial upsampling layers do not carry this information)
                dl = deterministic_layer['specs_td']
                for i in range(1, len(dl) + 1):
                    if c := dl[-i].get('c'):
                        specs['c'] = c
                self.upsample = ConvBlock(
                    input_shape=input_shape,
                    specs=specs,
                    **generic_layer)
                output_shape = self.upsample.output_shape
            else:
                raise Exception
            logger.debug(f'\nUpsampling Layer:\n{self.upsample}')
        else:
            self.upsample = None
            output_shape = input_shape
        return output_shape

    def _build_shared_backbone(self,
                               input_shape,
                               deterministic_layer,
                               generic_layer):
        self.deterministic_layer = Block(input_shape=input_shape,
                                         specs=deterministic_layer['specs_td'][::-1],
                                         generic_layer=generic_layer)
        logger.debug(f'\nDeterministic layer:\n{self.deterministic_layer}')

    def _build_posterior_inference(self, stochastic_layer, generic_layer):
        # posterior: merge bottom-up and top-down info
        if not stochastic_layer['merge_layer']:
            raise Exception('Provide merge layer for hierarchical methods.')
        if stochastic_layer['merge_layer']['t'] == 'conv':
            c, d1, d2 = self.deterministic_layer.output_shape
            self.merger = ConvMergeLayer(
                input_shape=(c * 2, d1, d2),
                **generic_layer)
        elif stochastic_layer['merge_layer']['t'] == 'dense':
            self.merger = DenseMergeLayer(
                input_shape=self.deterministic_layer.output_shape,
                **generic_layer)
        else:
            raise ValueError
        logger.debug(f'\nMerge Layer (Inference Network):'
                     f'\n{self.merger}')

        # posterior: stochastic layer
        if stochastic_layer['specs']['t'] == 'conv':
            stochastic_layer['specs'].update(dict(k=1, s=1, p=0))
            self.stochastic_inf = get_stochastic_layer(
                input_shape=self.merger.output_shape,
                stoc_type='conv_spatial',
                stochastic_layer=stochastic_layer)
        elif stochastic_layer['specs']['t'] == 'dense':
            self.stochastic_inf = StocDense(
                input_shape=self.merger.output_shape,
                stochastic_layer=stochastic_layer)
        else:
            raise ValueError
        logger.debug(f'\nStochastic Layer (Inference Model):'
                     f'\n{self.stochastic_inf}')

    def _build_prior_generation(self, stochastic_layer, generic_layer):
        # --- Mimic merge-layer on the inference side ---

        # Inside this class, the following holds for the deterministic parts:
        #   - Convolutional blocks have same output channel size
        #   - Dense blocks have same output dimension
        # We extract this information from the layer above:
        basedim = self.deterministic_layer.output_shape[0]
        if stochastic_layer['specs']['t'] == 'dense':
            specs = dict(t='dense', out=basedim)
        elif stochastic_layer['specs']['t'] == 'conv':
            specs = dict(t='dconv', c=basedim)
        else:
            raise ValueError

        self.det_gen = Block(
            input_shape=self.deterministic_layer.output_shape,
            specs=[specs],
            generic_layer=generic_layer)
        logger.debug(f'\nDeterministic Layer (Generative Model):'
                     f'\n{self.det_gen}')

        # --- Stochastic layer for parameterizing prior ---
        if stochastic_layer['specs']['t'] == 'conv':
            self.stochastic_gen = get_stochastic_layer(
                input_shape=self.deterministic_layer.output_shape,
                stoc_type='conv_spatial',
                stochastic_layer=stochastic_layer)
        elif stochastic_layer['specs']['t'] == 'dense':
            self.stochastic_gen = StocDense(
                input_shape=self.deterministic_layer.output_shape,
                stochastic_layer=stochastic_layer)
        else:
            raise ValueError
        logger.debug(f'\nStochastic Layer (Generative Model):'
                     f'\n{self.stochastic_gen}')

    def forward(self,
                x: Tensor,
                bu_tensor: Optional[Tensor] = None,
                inference_mode: bool = False,
                **kwargs):
        """ Infer lower latent distribution from upper latent distribution.
        :param x: samples from upper latent distribution
        :param bu_tensor: information from bottom-up pass
        """
        x = upsampling_wrapper(x, self.upsample, self.upsampling_specs)
        x = self.deterministic_layer(x)

        h_gen = self.det_gen(x)
        prior = self.stochastic_gen(h_gen, **kwargs)
        samples = prior['samples']
        posterior = None

        if inference_mode:
            x = self.merger(x, bu_tensor)
            posterior = self.stochastic_inf(x)
            samples = posterior['samples']
        else:
            assert bu_tensor is None

        return samples, posterior, prior


class LowerStage(BaseModule):
    """ Maps from latent to data space in top-down pass. """

    def __init__(self, **kwargs):
        super().__init__()
        self.deterministic_layer = None
        self.stochastic_gen = None
        self.merger = None
        self.stochastic_inf = None
        self.upsample = None
        self.upsampling_specs = {}

        self._build_layers(**kwargs)

        self._input_shape = self.deterministic_layer.input_shape
        self._output_shape = self.reconstruction.output_shape

    def _build_layers(self,
                      input_shape: Tuple[int],
                      deterministic_layer: dict,
                      stochastic_layer: dict,
                      reconstruction_layer: dict,
                      generic_layer: dict):
        output_shape = self._build_upsample_layer(
            input_shape, deterministic_layer, stochastic_layer, generic_layer)
        self._build_deterministic_block(
            output_shape, deterministic_layer, generic_layer)
        self._build_reconstruction_layer(reconstruction_layer)

    def _build_upsample_layer(self,
                              input_shape,
                              deterministic_layer,
                              stochastic_layer,
                              generic_layer):
        """ Upsample latent sample """
        if up_specs := stochastic_layer.get('upsampling'):
            if up_specs['t'] == 'dense':
                # transform to flattened previous shape
                self.upsampling_specs['type'] = 'dense'
                self.upsampling_specs['shape'] = tuple(up_specs['reshape'])
                self.upsample = DenseBlock(input_shape,
                                           specs={'out': np.prod(self.upsampling_specs['shape'])},
                                           **generic_layer)
                output_shape = self.upsampling_specs['shape']
            elif 'conv' in up_specs['t']:
                # Upsample channel-dimension
                self.upsampling_specs['type'] = 'conv'
                # transform to upcoming channel size and previous spatial dimensions
                if len(input_shape) == 1:
                    # flat hierarchical VAE, i.e., input to lower stage comes from
                    # top-level latent space
                    input_shape = (input_shape[0], 1, 1)
                specs = dict(k=1, s=1, p=0)
                # Find upcoming channel dimension
                # (note that spatial upsampling layers do not carry this information)
                dl = deterministic_layer['specs_td']
                for i in range(1, len(dl) + 1):
                    if c := dl[-i].get('c'):
                        specs['c'] = c
                        break
                self.upsample = ConvBlock(
                    input_shape=input_shape,
                    specs=specs,
                    **generic_layer)
                output_shape = self.upsample.output_shape
            else:
                raise ValueError
            logger.debug(f'\nUpsampling Layer:\n{self.upsample}')
        else:
            self.upsample = None
            output_shape = input_shape
        return output_shape

    def _build_deterministic_block(self,
                                   input_shape,
                                   deterministic_layer,
                                   generic_layer):
        # main deterministic block
        self.deterministic_layer = Block(
            input_shape=input_shape,
            specs=deterministic_layer['specs_td'][::-1],
            generic_layer=generic_layer)
        logger.debug(f'\nDeterministic layer:\n{self.deterministic_layer}')

    def _build_reconstruction_layer(self, reconstruction_layer):
        # reconstruction layer requires separate convolutional layer which is
        # not activated
        if 'conv' in reconstruction_layer['specs']['t']:
            k = reconstruction_layer['specs']['k']
            reconstruction_layer['specs'].update({'s': 1, 'p': k // 2})
            self.reconstruction = get_stochastic_layer(
                input_shape=self.deterministic_layer.output_shape,
                stoc_type='conv_spatial',
                stochastic_layer=reconstruction_layer)
        elif reconstruction_layer['specs']['t'] == 'dense':
            self.reconstruction = StocDense(
                input_shape=self.deterministic_layer.output_shape,
                stochastic_layer={
                    'dim': reconstruction_layer['dim'],
                    'dist_type': reconstruction_layer['dist_type']})
        else:
            raise ValueError
        logger.debug(f'\nReconstruction Layer:\n{self.reconstruction}')

    def forward(self, x) -> dict:
        """
        :param x: samples
        :return: reconstruction distribution
        """
        x = upsampling_wrapper(x, self.upsample, self.upsampling_specs)
        x = self.deterministic_layer(x)
        x = self.reconstruction(x)
        return x


class Block(BaseModule):
    """ Wraps several layers. """

    def __init__(self,
                 input_shape: tuple,
                 specs: List[dict],
                 generic_layer):
        super().__init__()
        self._input_shape = input_shape
        cur_input_shape = input_shape

        layers = []
        for cur_specs in specs:
            layers.append(get_layer(cur_input_shape, cur_specs, generic_layer))
            cur_input_shape = layers[-1].output_shape
        self.layers = nn.Sequential(*layers)
        self._output_shape = cur_input_shape

    def forward(self, x):
        x = self.layers(x)
        return x


def upsampling_wrapper(x, layer, specs):
    """
    :param layer: upsampling layer
    :param specs: upsampling specs
    """
    if layer:
        if all([specs['type'] == 'conv',
                len(x.size()) == 2 or len(x.size()) == 3]):
            # flat hierarchical VAE, i.e., input to lower stage comes from
            # top-level latent space
            x = x[..., None, None]
        x = layer(x)
        if 'shape' in specs.keys():
            x = reshape(x, specs['shape'])
    return x
