import numpy as np
import torch
import torch.nn as nn

from .base_networks import *

from torch.nn import functional as F

__all__ = ['ImageToLatentGaussian', 'LatentToImageGaussian']

JITTER = 1e-5


class ImageToLatentGaussian(nn.Module):
    """ A convolutional neural network for parameterising a multivariate
    Gaussian distribution. The input is an image, the output is a
    one-dimensional latent space.

    :param in_channels: An int, number of input channels.
    :param in_height: An int, the height of the input image.
    :param in_width: An int, the width of the input image.
    :param final_conv_dim: An int, the flattened dimension of output of the
    convolutional neural network.
    :param out_dim: An int, the output dimension.
    :param hidden_dims: A list, dimensions of the hidden layers of the fully
    connected neural network.
    :param hidden_channels: A list, number of channels for convolutional
    layers.
    :param kernel_sizes: A list, kernel size for convolutional layers.
    :param strides: A list, strides to apply to convolutional layers.
    :param paddings: A list, paddings to apply to convolutional layers
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) sets the initial output
    sigma.
    """
    def __init__(self, in_channels, in_height, in_width, final_conv_dim,
                 out_dim, hidden_dims, hidden_channels, kernel_sizes,
                 strides=None, paddings=None, nonlinearity=F.relu,
                 initial_sigma=None, sigma=None, train_sigma=False,
                 min_sigma=0.):
        super().__init__()

        self.in_channels = in_channels
        self.in_height = in_height
        self.in_width = in_width
        self.out_dim = out_dim
        self.conv = ConvolutionalNN(in_channels,
                                    hidden_channels[-1],
                                    hidden_channels[:-1],
                                    kernel_sizes,
                                    strides,
                                    paddings,
                                    nonlinearity)
        self.sigma = sigma
        self.min_sigma = min_sigma
        if self.sigma is not None:
            self.linear = LinearNN(final_conv_dim, out_dim, hidden_dims,
                                   nonlinearity)
            if train_sigma:
                self.sigma = nn.Parameter(torch.tensor(self.sigma),
                                          requires_grad=True)
        else:
            self.linear = LinearNN(final_conv_dim, 2 * out_dim, hidden_dims,
                                   nonlinearity)
            if initial_sigma is not None:
                self.linear.layers[-1].bias.data = torch.cat(
                    [JITTER * torch.randn(out_dim),
                     np.log(np.exp(initial_sigma) - 1) +
                     JITTER * torch.randn(out_dim)])

    def forward(self, x):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_channels * in_height *
        in_width].
        """
        x = x.view(-1, self.in_channels, self.in_height, self.in_width)
        x = self.conv(x)
        x = x.flatten(1, -1)
        x = self.linear(x)
        mu = x[:, :self.out_dim]
        if self.sigma is not None:
            sigma = (self.min_sigma +
                     (1 - self.min_sigma) * self.sigma * torch.ones_like(mu))
        else:
            sigma = (self.min_sigma +
                     (1 - self.min_sigma) * F.softplus(x[..., self.out_dim:]))

        return mu, sigma


class LatentToImageGaussian(nn.Module):
    """ A deconvolutional neural network for parameterising a multivariate
    Gaussian distribution. The input is a one-dimensional latent space,
    the output is an image.

    :param in_dim: An int, input dimension.
    :param final_linear_dim: An int, flattened dimension of input to the
    deconvolutional neural network.
    :param out_channels: An int, number of output channels.
    :param out_height: An int, height of output image.
    :param out_width: An int, width of output image.
    :param hidden_dims: A list, dimensions of the hidden layers of the fully
    connected neural network.
    :param hidden_channels: A list, number of channels for deconvolutional
    layers.
    :param kernel_sizes: A list, kernel size for deconvolutional layers.
    :param strides: A list, strides to apply to deconvolutional layers.
    :param paddings: A list, paddings to apply to deconvolutional layers
    :param output_paddings: A list, output paddings to apply to
    deconvolutional layers.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) sets the initial output
    sigma.
    """
    def __init__(self, in_dim, final_linear_dim, out_channels, out_height,
                 out_width, hidden_dims, hidden_channels, kernel_sizes,
                 strides=None, paddings=None, output_paddings=None,
                 nonlinearity=F.relu, initial_sigma=None, sigma=None,
                 train_sigma=False, min_sigma=0.):
        super().__init__()

        initial_conv_dim = (final_linear_dim / hidden_channels[0]) ** 0.5
        assert initial_conv_dim.is_integer(), 'Incompatable parameters.'

        self.initial_conv_dim = int(initial_conv_dim)
        self.initial_conv_channels = hidden_channels[0]
        self.out_channels = out_channels
        self.out_height = out_height
        self.out_width = out_width
        self.linear = LinearNN(in_dim, final_linear_dim, hidden_dims)

        self.sigma = sigma
        self.min_sigma = min_sigma
        if self.sigma is not None:
            self.deconv = DeconvolutionalNN(hidden_channels[0],
                                            out_channels,
                                            hidden_channels[1:],
                                            kernel_sizes,
                                            strides,
                                            paddings,
                                            output_paddings,
                                            nonlinearity)
            if train_sigma:
                self.sigma = nn.Parameter(torch.tensor(self.sigma),
                                          requires_grad=True)
        else:
            self.deconv = DeconvolutionalNN(hidden_channels[0],
                                            2 * out_channels,
                                            hidden_channels[1:],
                                            kernel_sizes,
                                            strides,
                                            paddings,
                                            output_paddings,
                                            nonlinearity)
            if initial_sigma is not None:
                self.deconv.layers[-1].bias.data = torch.cat(
                    [JITTER * torch.randn(out_channels),
                     np.log(np.exp(initial_sigma) - 1) +
                     JITTER * torch.randn(out_channels)])

    def forward(self, x):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_dim].
        """
        x = self.linear(x)
        x = x.view(x.shape[0], self.initial_conv_channels,
                   self.initial_conv_dim, self.initial_conv_dim)
        x = self.deconv(x)

        assert x.shape[2] == self.out_height, 'Incorrect output height'
        assert x.shape[3] == self.out_width, 'Incorrect output width'

        mu = x[:, :self.out_channels, ...].flatten(1, -1)
        if self.sigma is not None:
            sigma = (self.min_sigma +
                     (1 - self.min_sigma) * self.sigma * torch.ones_like(mu))
        else:
            sigma = (self.min_sigma +
                     (1 - self.min_sigma) *
                     F.softplus(x[:, self.out_channels:, ...].flatten(1, -1)))

        return mu, sigma
