from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from vae.layers.misc import BaseModule, Swish, reshape

nonlinearities = {'leaky_relu': nn.LeakyReLU(),
                  'swish': Swish(),
                  'gelu': nn.GELU()}


class DenseBlock(BaseModule):
    """ Wraps dense layer with surrounding modules such as batch norm. """

    def __init__(self,
                 input_shape: tuple,
                 specs: dict,
                 activation=True,
                 nonlin='swish'):
        super().__init__()
        self._input_shape = input_shape
        output_shape = self._makes_layers(input_shape, specs, activation, nonlin)
        self.reshape = self._register_reshape_info(specs)
        self._output_shape = self.reshape if self.reshape else output_shape

    @staticmethod
    def _register_reshape_info(specs):
        if reshape := specs.get('reshape'):
            reshape = reshape
        else:
            reshape = None
        return reshape

    def _makes_layers(self, input_shape, specs, activation, nonlin):
        layers = []
        dense = self._make_linear_layer(input_shape, specs)
        layers.append(dense)
        if activation:
            layers.append(nonlinearities[nonlin])
        self.layers = nn.Sequential(*layers)
        return dense.output_shape

    def _make_linear_layer(self, cur_input_shape, specs):
        return DenseLayer(input_shape=cur_input_shape,
                          output_shape=(specs['out'],))

    def forward(self, x):
        x = self.layers(x)
        if self.reshape:
            x = reshape(x, self.reshape)
        return x


class DenseBlockCategorical(DenseBlock):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _make_linear_layer(self, cur_input_shape, specs):
        layer = DenseLayerCategorical(input_shape=cur_input_shape,
                                      output_shape=(specs['out'],))
        return layer


class DenseLayer(BaseModule):
    def __init__(self,
                 input_shape: Tuple[int],
                 output_shape: Tuple[int]):
        super().__init__()
        self._input_shape = input_shape
        self._output_shape = output_shape
        self.reshape = None
        inp = self._find_shape(input_shape)
        self.linear = nn.Linear(in_features=inp,
                                out_features=self._output_shape[-1])

    def _find_shape(self, input_shape):
        """ Input might come from convolutional model. """
        if len(input_shape) == 1:
            # dense input
            inp = input_shape[-1]
        elif len(input_shape) == 3:
            # convolutional input
            inp = input_shape[-3:]
            inp = np.prod(inp)
            self.reshape = inp
        else:
            raise ValueError
        return inp

    def forward(self, x):
        if self.reshape:
            x = x.flatten(-3)
        x = self.linear(x)
        return x


class DenseLayerCategorical(BaseModule):
    """ Dense Layer that takes categorical variable (e.g., labels) as input.
    """

    def __init__(self,
                 input_shape: Tuple[int],
                 output_shape: Tuple[int]):
        """
        _input_shape attribute = number of categories
        input in forward function = index
        """
        super().__init__()
        self._input_shape = input_shape
        self._output_shape = output_shape
        self.linear = nn.Embedding(num_embeddings=self._input_shape[-1],
                                   embedding_dim=self._output_shape[-1])

    def forward(self, x):
        x = self.linear(x)
        return x


class DenseMergeLayer(BaseModule):
    """ Merges two incoming tensors. """

    def __init__(self, input_shape: Tuple[int],
                 **kwargs):
        """
        :param input_shapes: incoming feature dimensions
        """
        super().__init__()
        output_shape = input_shape
        input_shape = (input_shape[-1] * 2,)
        self.layers = DenseBlock(input_shape,
                                 specs={'out': output_shape[-1]},
                                 **kwargs)

        self._input_shape = input_shape
        self._output_shape = output_shape

    def forward(self, x: Tensor, y: Tensor):
        """ Both tensors are assumed to either have shape (K x N x D) or
        (N x D).
        """
        # one of the tensors may has k samples.
        if len(x.size()) != len(y.size()):
            if len(x.size()) == 3:
                k = x.size(0)
                y = y.unsqueeze(0).repeat(k, 1, 1)
            elif len(y.size()) == 3:
                k = y.size(0)
                x = x.unsqueeze(0).repeat(k, 1, 1)
            else:
                raise ValueError('Please provide known data sizes.')

        x = torch.cat((x, y), dim=-1)
        del y
        x = self.layers(x)
        return x
