import torch
import torch.nn as nn

from torch.nn import functional as F

__all__ = ['LinearNN', 'ConvolutionalNN', 'DeconvolutionalNN',
           'LinearGaussian', 'AffineGaussian']

JITTER = 1e-5


class LinearNN(nn.Module):
    """A fully connected neural network.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    """
    def __init__(self, in_dim, out_dim, hidden_dims=(64, 64),
                 nonlinearity=F.relu):
        super().__init__()

        self.nonlinearity = nonlinearity
        self.layers = nn.ModuleList()
        for i in range(len(hidden_dims) + 1):
            if i == 0:
                self.layers.append(nn.Linear(in_dim, hidden_dims[i]))
            elif i == len(hidden_dims):
                self.layers.append(nn.Linear(hidden_dims[i-1], out_dim))
            else:
                self.layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))

    def forward(self, x):
        """Returns output of the network.

        :param x: A Tensor, input of shape [M, in_dim].
        """
        for layer in self.layers[:-1]:
            x = self.nonlinearity(layer(x))

        x = self.layers[-1](x)
        return x


class ConvolutionalNN(nn.Module):
    """ A convolutional neural network.

    :param in_channels: An int, number of input channels.
    :param out_channels: An int, number of output channels.
    :param hidden_channels: A list, number of channels for convolutional
    layers.
    :param kernel_sizes: A list, kernel size for convolutional layers.
    :param strides: A list, strides to apply to convolutional layers.
    :param paddings: A list, paddings to apply to convolutional layers
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    """
    def __init__(self, in_channels, out_channels, hidden_channels,
                 kernel_sizes, strides=None, paddings=None,
                 nonlinearity=F.relu):
        super().__init__()

        if strides is None:
            strides = [1] * (len(hidden_channels) + 1)

        if paddings is None:
            paddings = [0] * (len(hidden_channels) + 1)

        self.nonlinearity = nonlinearity
        self.layers = nn.ModuleList()
        for i in range(len(hidden_channels) + 1):
            if i == 0:
                self.layers.append(nn.Conv2d(in_channels, hidden_channels[i],
                                             kernel_sizes[i],
                                             stride=strides[i],
                                             padding=paddings[i]))
            elif i == len(hidden_channels):
                self.layers.append(nn.Conv2d(hidden_channels[i - 1],
                                             out_channels,
                                             kernel_sizes[i],
                                             stride=strides[i],
                                             padding=paddings[i]))
            else:
                self.layers.append(nn.Conv2d(hidden_channels[i - 1],
                                             hidden_channels[i],
                                             kernel_sizes[i],
                                             stride=strides[i],
                                             padding=paddings[i]))

    def forward(self, x):
        """Returns output of the network.

        :param x: A Tensor, input of shape [M, in_channels, in_height,
        in_width].
        """
        assert len(x.shape) == 4, 'Input should be [M, in_channels, ' \
                                  'in_height, in_width]'

        for i in range(len(self.layers) - 1):
            x = self.nonlinearity(self.layers[i](x))

        x = self.layers[-1](x)
        return x


class DeconvolutionalNN(nn.Module):
    """ A deconvolutional neural network.

    :param in_channels: An int, number of input channels.
    :param out_channels: An int, number of output channels.
    :param hidden_channels: A list, number of channels for deconvolutional
    layers.
    :param kernel_sizes: A list, kernel size for deconvolutional layers.
    :param strides: A list, strides to apply to deconvolutional layers.
    :param paddings: A list, paddings to apply to deconvolutional layers.
    :param output_paddings: A list, output paddings to apply to
    deconvolutional layers.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    """
    def __init__(self, in_channels, out_channels, hidden_channels,
                 kernel_sizes, strides=None, paddings=None,
                 output_paddings=None, nonlinearity=F.relu):
        super().__init__()

        if strides is None:
            strides = [1] * (len(hidden_channels) + 1)

        if paddings is None:
            paddings = [0] * (len(hidden_channels) + 1)

        if output_paddings is None:
            output_paddings = [0] * (len(hidden_channels) + 1)

        self.nonlinearity = nonlinearity
        self.layers = nn.ModuleList()
        for i in range(len(hidden_channels) + 1):
            if i == 0:
                self.layers.append(nn.ConvTranspose2d(in_channels,
                                                      hidden_channels[i],
                                                      kernel_sizes[i],
                                                      stride=strides[i],
                                                      padding=paddings[i],
                                                      output_padding=(
                                                          output_paddings[i])))
            elif i == len(hidden_channels):
                self.layers.append(nn.ConvTranspose2d(hidden_channels[i - 1],
                                                      out_channels,
                                                      kernel_sizes[i],
                                                      stride=strides[i],
                                                      padding=paddings[i],
                                                      output_padding=(
                                                          output_paddings[i])))
            else:
                self.layers.append(nn.ConvTranspose2d(hidden_channels[i - 1],
                                                      hidden_channels[i],
                                                      kernel_sizes[i],
                                                      stride=strides[i],
                                                      padding=paddings[i],
                                                      output_padding=(
                                                          output_paddings[i])))

    def forward(self, x):
        """Returns output of the network.

        :param x: A Tensor, input of shape [M, in_channels, in_height,
        in_width].
        """
        assert len(x.shape) == 4, 'Input should be [M, in_channels, ' \
                                  'in_height, in_width]'

        for i in range(len(self.layers) - 1):
            x = self.nonlinearity(self.layers[i](x))

        x = self.layers[-1](x)
        return x


class LinearGaussian(nn.Module):
    """A fully connected neural network for parameterising a multivariate
    Gaussian distribution.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) sets the initial output
    variance.
    """
    def __init__(self, in_dim, out_dim, hidden_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., sigma=None,
                 train_sigma=False, min_sigma=0., nonlinearity=F.relu):
        super().__init__()

        self.out_dim = out_dim
        self.sigma = sigma
        self.min_sigma = min_sigma

        if self.sigma is not None:
            self.network = LinearNN(in_dim, out_dim, hidden_dims, nonlinearity)
            if train_sigma:
                self.sigma = nn.Parameter(torch.tensor(self.sigma),
                                          requires_grad=True)
            else:
                self.sigma = nn.Parameter(torch.tensor(self.sigma),
                                          requires_grad=False)
        else:
            self.network = LinearNN(in_dim, 2*out_dim, hidden_dims,
                                    nonlinearity)

            # Initial output sigma and mu.
            initial_sigma = (torch.tensor(initial_sigma)
                             + JITTER * torch.randn(out_dim))
            initial_mu = (torch.tensor(initial_mu)
                          + JITTER * torch.randn(out_dim))

            self.network.layers[-1].bias.data = torch.cat(
                [initial_mu, torch.log(torch.exp(initial_sigma) - 1)])

    def forward(self, x, *args, **kwargs):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_dim].
        """
        x = self.network(x)
        mu = x[..., :self.out_dim]
        if self.sigma is not None:
            sigma = (self.min_sigma +
                     (1 - self.min_sigma) * self.sigma * torch.ones_like(mu))
        else:
            sigma = (self.min_sigma +
                     (1 - self.min_sigma) * F.softplus(x[..., self.out_dim:]))

        return mu, sigma


class AffineGaussian(nn.Module):
    """The mean of the output Gaussian is an affine transformation of the
    input.

    :param in_dim: An int, the dimension of the input variable.
    :param out_dim: An int, the dimension of the output variable.
    :param sigma: A float, sets the initial output sigma (shared across all
    output dimensions).
    :param initial_weight: A float, sets the initial weight.
    :param initial_bias: A float, sets the initial bias.
    """

    def __init__(self, in_dim, out_dim, sigma=1., initial_weight=None,
                 initial_bias=None):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        if initial_weight is None:
            initial_weight = torch.ones(out_dim, in_dim) / in_dim
        else:
            initial_weight = torch.tensor(initial_weight)

        if initial_bias is None:
            initial_bias = torch.zeros(out_dim)
        else:
            initial_bias = torch.tensor(initial_bias)

        # Initial weight and bias of the affine transformation.
        self.weight = nn.Parameter(initial_weight + JITTER * torch.randn(
            out_dim, in_dim), requires_grad=True)
        self.bias = nn.Parameter(initial_bias + JITTER * torch.randn(out_dim),
                                 requires_grad=True)

        self.raw_sigma = nn.Parameter(torch.tensor(sigma).log(),
                                      requires_grad=True)

    def forward(self, x):
        """Returns parameters of a Gaussian distribution.

        :param x: A Tensor, input of shape [M, 1].
        """
        mu = self.weight.matmul(x.unsqueeze(2)).squeeze(2) + self.bias
        sigma = torch.ones_like(mu) * self.raw_sigma.exp()

        return mu, sigma
