from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import Constants
from vae.layers.convolution import ConvLayer
from vae.layers.dense import DenseBlock
from vae.layers.distributions import get_distribution, Normal
from vae.layers.misc import check_input_shape, BaseModule
from vae.misc import get_trainable_params


def get_stochastic_layer(input_shape, stoc_type, stochastic_layer):
    """
    :return: stochastic layer instance
    """
    if stoc_type == 'dense':
        block = StocDense(input_shape, stochastic_layer)
    elif stoc_type == 'conv_flat':
        block = StocConvFlat(input_shape=input_shape,
                             stochastic_layer=stochastic_layer)
    elif stoc_type == 'conv_spatial':
        block = StocConvSpatial(input_shape=input_shape,
                                stochastic_layer=stochastic_layer)
    else:
        raise ValueError('Please specify valid type for stochastic layer.')

    return block


class StochasticLayer(BaseModule):
    """ Parameterizes a distributions from a hidden state. """

    def __init__(self):
        super().__init__()
        # Define in subclass.
        self._distribution_shape = None
        self._logit_dim = None

    def forward(self, *args, **kwargs):
        """
        :return:
            - torch distribution object
            - (processed) samples from this distribution
        """
        raise NotImplementedError()

    @property
    def distribution_shape(self):
        """ Shape of samples from the distribution. Note that output-shape
        property might refer to upsampled samples. """
        return self._distribution_shape

    @property
    def logit_dim(self):
        """
        Size of dimension along which to separate logits in case of multiple
        parameters.
        """
        return self._logit_dim

    def __str__(self, *args, **kwargs):
        return f'{super().__str__(additional_info=False)}' \
               f'\nTrainable parameters: {get_trainable_params(self):,}\n' \
               f'Input shape: {self.input_shape}\n' \
               f'Logit Dimension: {self._logit_dim}\n' \
               f'Distribution shape: {self.distribution_shape}\n' \
               f'Output shape: {self.output_shape}\n'

    @staticmethod
    def _get_logit_dim(dist_type, stoc_dim):
        if dist_type == 'normal' or dist_type == 'laplace':
            # Gaussian or Laplace parameterized by mean and scale
            logit_dim = stoc_dim * 2
        elif dist_type in ['categorical', 'bernoulli', 'sigmoid']:
            logit_dim = stoc_dim
        else:
            raise ValueError('Illegal distribution type.')

        return logit_dim


class StocDense(StochasticLayer):
    """ Dense stochastic layer. """

    def __init__(self,
                 input_shape: Tuple[int],
                 stochastic_layer):
        """
        :param input_shape: input features
        :param stochastic_layer:
            - 'dim': latent variable sizes
            - 'dist_type': distribution type
        """
        super().__init__()
        stoc_dim = stochastic_layer['dim']
        dist_type = stochastic_layer['dist_type']
        self.dist = get_distribution(
            dist_type=dist_type,
            scale_regularizer=ScaleRegularizer(stoc_type='flat'))

        self._input_shape = input_shape
        self._distribution_shape = (stoc_dim,)
        self._output_shape = (stoc_dim,)

        self._logit_dim = self._get_logit_dim(dist_type, stoc_dim)
        self.logit_net = DenseBlock(input_shape=input_shape,
                                    specs={'out': self._logit_dim},
                                    activation=False)

    def forward(self, x: torch.Tensor, *args, **kwargs):
        """
        :param x: flat hidden state for each sample in batch
        :return: one-dimensional latent distribution for each sample in batch
        """
        x = self.logit_net(x)
        distribution = self.dist(x, *args, **kwargs)
        return distribution


class StocConv(StochasticLayer):
    """ Generic convolutional stochastic layer. """

    def __init__(self,
                 input_shape,
                 stochastic_layer: dict):
        """
        :param stochastic_layer:
            - 'dim': latent variable size
            - 'dist_type': distribution type
        """
        super().__init__()
        _ = check_input_shape(input_shape)
        stoc_dim = stochastic_layer['dim']
        dist_type = stochastic_layer['dist_type']

        self._input_shape = input_shape
        self.stoc_dim = stoc_dim

        self._logit_dim = self._get_logit_dim(dist_type, stoc_dim)

    def forward(self, *args, **kwargs):
        raise NotImplementedError()


class StocConvSpatial(StocConv):
    """ Spatial stochastic layer
        - Input: 3d tensor
        - Output: 3d distribution
    """

    def __init__(self,
                 input_shape: tuple,
                 stochastic_layer: dict):
        super().__init__(input_shape, stochastic_layer)
        specs = stochastic_layer['specs']
        self.layer = ConvLayer(input_shape,
                               c_out=self._logit_dim,
                               kernel=specs['k'],
                               stride=specs['s'],
                               padding=specs['p'])
        c, d1, d2 = self.layer.output_shape

        self.dist = get_distribution(
            dist_type=stochastic_layer['dist_type'],
            scale_regularizer=ScaleRegularizer(stoc_type='spatial'))

        # compute shapes
        if isinstance(self.dist, Normal):
            # assume distribution parameterized by mean and scale
            output = (c // 2, d1, d2)
        else:
            # assume distribution parameterized by a single tensor
            output = (c, d1, d2)
        self._distribution_shape = output
        self._output_shape = output

    def forward(self, *args, **kwargs):
        return self._forward(*args, **kwargs)

    def _forward(self, x, *args, **kwargs):
        """
        :param x: hidden state
        """
        x = self.layer(x) if self.layer else x  # logits
        distribution = self.dist(x, *args, **kwargs)
        return distribution


class StocConvFlat(StocConv):
    """ Flat stochastic layer
        - Input: 3d tensor
        - Output: 1d distribution
    """

    def __init__(self, input_shape, stochastic_layer):
        super().__init__(input_shape, stochastic_layer)
        specs = stochastic_layer['specs']
        self.downsample = ConvLayer(input_shape,
                                    c_out=self._logit_dim,
                                    kernel=specs['k'],
                                    stride=specs['s'],
                                    padding=specs['p'])

        self.dist = get_distribution(
            dist_type=stochastic_layer['dist_type'],
            scale_regularizer=ScaleRegularizer(stoc_type='flat'))

        self._distribution_shape = (self.downsample.output_shape[-3] // 2,)
        self._output_shape = (self.downsample.output_shape[-3] // 2,)

    def forward(self, x, *args, **kwargs):
        """
        :param x: hidden state
        """
        x = self.downsample(x)  # logits

        # remove spatial dimensions of 1
        msg = 'Please ensure that spatial dimensions are (1, 1) after ' \
              f'downsampling. They are {(x.size(-2), x.size(-1))}.'
        assert x.size(-2) == x.size(-1) == 1, msg
        x = x.flatten(-3)

        distribution = self.dist(x, *args, **kwargs)

        return distribution


class ScaleRegularizer(nn.Module):
    """ Regularizes scale parameter for distribution object. """

    def __init__(self, reg=None, stoc_type: str = 'flat'):
        """
        :param reg:
            - 'sum_d' restrict scale to sum over dimensions
        :param stoc_type: whether stochastic variable is flat or spatial
        """
        super().__init__()
        msg = 'stochastic variable must be flat or spatial.'
        assert stoc_type in ['flat', 'spatial'], msg
        self.reg = reg
        self.stoc_type = stoc_type
        self.layers = self._make_layers()

    @staticmethod
    def _make_layers():
        layers = nn.ModuleList()
        layers += [nn.Softplus()]
        layers = nn.Sequential(*layers)
        return layers

    def forward(self, scale: torch.Tensor):
        scale = self.layers(scale) + Constants.eta
        if self.reg == 'sum_d':
            scale = self._sum_over_dims(scale)
        return scale

    def _sum_over_dims(self, scale):
        """ Regularize scale to be sum over dimensions. """
        if self.stoc_type == 'flat':
            scale = F.softmax(scale, dim=-1) * scale.size(-1)
        elif self.stoc_type == 'spatial':
            s = scale.size()
            flattened = scale.flatten(-3)
            flattened = F.softmax(scale, dim=-1) * flattened.size(-1)
            scale = flattened.view(s)
        else:
            raise ValueError()
        return scale
