import math
from typing import Union, Tuple, Dict, Optional

import torch
import torch.nn as nn
from torch.nn import functional as F

from vae.layers.misc import check_input_shape, forward_with_additional_batch_dimensions, Swish, BaseModule

nonlinearities = {'leaky_relu': nn.LeakyReLU(),
                  'swish': Swish(),
                  'gelu': nn.GELU()}


class ConvLayer(BaseModule):
    """ Convolutional Layer. """

    def __init__(self,
                 input_shape: tuple,
                 c_out: int,
                 kernel: int = 3,
                 stride: int = 1,
                 padding: Union[int, str] = 'same',
                 bias=True):
        """
        :param input_shape: C x H x W
        :param c_out: output channels
        :param kernel: we assume symmetric kernel
        :param stride:
        :param padding: 'same' ensures that spatial dimensions of input and
        output are identical
        """
        super().__init__()
        c_in, d = check_input_shape(input_shape)
        if padding == 'same':
            padding = _get_same_padding(kernel, stride)
        self._input_shape = input_shape
        self._output_shape = self._compute_output_shape(c_out, d, kernel,
                                                        stride, padding)
        self.conv = nn.Conv2d(c_in, c_out,
                              kernel_size=kernel,
                              stride=stride,
                              padding=padding,
                              bias=bias)

    @staticmethod
    def _compute_output_shape(c_out, d, kernel, stride, padding):
        """
        :param d: spatial dimension (identical across H and W)
        """
        d = math.floor((d + 2 * padding - (kernel - 1) - 1) / stride + 1)
        return c_out, d, d

    def forward(self, x):
        x = forward_with_additional_batch_dimensions(layer=self.conv, x=x)
        return x


class ConvTransposedLayer(BaseModule):
    def __init__(self,
                 input_shape: Tuple[int],
                 c_out: int,
                 kernel: int = 3,
                 stride: int = 2,
                 padding: int = 1,
                 output_padding=1):
        """
        :param output_padding: specifies output spatial dimensions d
            - 1: typically d*2
            - 0: typically d*2 - 1
        """
        super().__init__()
        c_in, d = check_input_shape(input_shape)
        self._input_shape = input_shape
        self._output_shape = self._compute_output_shape(
            c_out, d, kernel, stride, padding, output_padding)

        self.conv = nn.ConvTranspose2d(in_channels=c_in,
                                       out_channels=c_out,
                                       kernel_size=kernel,
                                       padding=padding,
                                       stride=stride,
                                       output_padding=output_padding)

    @staticmethod
    def _compute_output_shape(c_out, d, kernel, stride, padding,
                              output_padding):
        """
        :param d: spatial dimension (identical across H and W)
        """
        d = (d - 1) * stride - 2 * padding + (kernel - 1) + output_padding + 1
        return c_out, d, d

    def forward(self, x):
        x = forward_with_additional_batch_dimensions(layer=self.conv, x=x)
        return x


class ConvBlock(BaseModule):
    """ Wraps convolutional layer with surrounding modules such as batch norm. """

    def __init__(self,
                 input_shape: tuple,
                 specs: Dict[str, int],
                 activation=True,
                 nonlin='swish'):
        """
        :param input_shape:
        :param specs: 'kernel', 'stride', and 'padding' for each layer
        """
        super().__init__()
        self._input_shape = input_shape
        cur_input_shape = input_shape

        layers = []
        conv = ConvLayer(input_shape=cur_input_shape,
                         c_out=specs['c'],
                         kernel=specs['k'],
                         stride=specs['s'],
                         padding=specs['p'])
        layers.append(conv)
        cur_input_shape = conv.output_shape
        if activation:
            layers.append(nonlinearities[nonlin])

        self.layers = nn.Sequential(*layers)
        self._output_shape = cur_input_shape

    def forward(self, x):
        x = self.layers(x)
        return x


class ConvTransposedBlock(BaseModule):
    """ Wraps several transposed convolutional layers. """

    def __init__(self,
                 input_shape: tuple,
                 specs: dict,
                 activation=True,
                 nonlin='swish'):
        super().__init__()
        self._input_shape = input_shape
        cur_input_shape = input_shape

        layers = []
        conv = ConvTransposedLayer(input_shape=cur_input_shape,
                                   c_out=specs['c'],
                                   kernel=specs['k'],
                                   stride=specs['s'],
                                   padding=specs['p'],
                                   output_padding=specs['op'])
        layers.append(conv)
        cur_input_shape = conv.output_shape
        if activation:
            layers.append(nonlinearities[nonlin])

        self.layers = nn.Sequential(*layers)
        self._output_shape = cur_input_shape

    def forward(self, x):
        x = self.layers(x)
        return x


# Inspired from: https://github.com/addtt/ladder-vae-pytorch
class ConvMergeLayer(BaseModule):
    """
    Assumes incoming tensors to have convolutional dimension of identical size.
    """

    def __init__(self,
                 input_shape: tuple,
                 nonlin: str):
        super().__init__()
        c, d = check_input_shape(input_shape)
        c_half = int(c / 2)

        self._input_shape = input_shape
        self.layer = DeepConvBlock(input_shape, specs=dict(c=c_half))
        self._output_shape = self.layer.output_shape
        self.nonlin = nonlinearities[nonlin]

    def forward(self, x, y):
        """ One of the inputs may have an additional batch dimensions from
        importance sampling. Otherwise, shapes are assumed to be identical. """
        # one of the tensors may has k samples.
        if len(x.size()) != len(y.size()):
            if len(x.size()) == 5:
                k = x.size(0)
                y = torch.stack([y] * k)
            elif len(y.size()) == 5:
                k = y.size(0)
                x = torch.stack([x] * k)
            else:
                raise ValueError('Please provide known data sizes.')

        # Assumes that x is top-down information
        x = torch.cat((x, y), dim=-3)

        x = self.layer(x)
        x = self.nonlin(x)
        return x


class DeepConvBlock(BaseModule):
    def __init__(self,
                 input_shape: tuple,
                 specs: Optional[dict] = None):
        """ ConvBlock with multiple layers
        :param input_shape: C x H x W
        :param specs:
            * downsample: whether to downsample spatial-dim by factor of two
            * residual: whether to build residual connection
            * c: output channel dimension
        """
        super().__init__()
        self._input_shape = input_shape
        self.residual = False
        self.downsample = False
        c_red, c_out = self._parse_input(input_shape, specs)
        self.c1 = ConvLayer(input_shape, c_red, 1, 1, 0, bias=False)
        self.c2 = ConvLayer(self.c1.output_shape, c_red, 3, 1, 1, bias=False)
        self.c3 = ConvLayer(self.c2.output_shape, c_out, 1, 1, 0, bias=False)
        if self.downsample:
            c, d1, d2 = self.c3.output_shape
            assert d1 == d2
            d = int(d1 / 2)
            self._output_shape = (c, d, d)
        else:
            self._output_shape = self.c3.output_shape

    def _parse_input(self, input_shape, specs):
        if not specs:
            specs = {}
        self.residual = specs.get('residual', False)
        self.downsample = specs.get('downsample', False)
        c_out, _ = check_input_shape(input_shape)
        if not self.residual:
            # Input and output channel dim can be different
            c_out = specs.get('c', c_out)
        # Downscale channel-dimension inside block
        c_red = int(specs.get('f', 0.25) * c_out)
        return c_red, c_out

    def forward(self, x):
        # Assume that input is already activated
        xhat = self.c1(x)
        xhat = self.c2(F.gelu(xhat))
        xhat = self.c3(F.gelu(xhat))
        x = x + xhat if self.residual else xhat
        del xhat
        if self.downsample:
            pool = lambda x: F.avg_pool2d(x, kernel_size=2, stride=2)
            x = forward_with_additional_batch_dimensions(pool, x)
        x = F.gelu(x)
        return x

    def __str__(self, *args, **kwargs):
        s = super().__str__(*args, **kwargs)
        s += f'Residual connection: {self.residual}\n'
        return s


def _get_same_padding(kernel: int, stride: int) -> int:
    """ Get padding that preserves input shape.

    A guide to convolution arithmetic for deep learning (2.2.1)
    https://arxiv.org/abs/1603.07285

    Feature request for PyTorch:
    https://github.com/pytorch/pytorch/issues/3867

    :param kernel: assumes symmetric two-dimensional kernel
    :return: padding
    """
    msg = 'Please pass height or width of an symmetric kernel.'
    assert isinstance(kernel, int), msg
    assert kernel % 2 != 0, 'Please provide uneven kernel.'
    msg = '(1) Stride must be one for formula to hold, ' \
          '(2) We avoid higher strides as this can result in kernel that ' \
          'sometimes solely slides over padded zeros.'
    assert stride == 1, msg

    pad = kernel // 2
    return pad
