import torch
import torch.nn as nn

from .base_networks import LinearNN, LinearGaussian

from torch.nn import functional as F

__all__ = ['IndexNet', 'FactorNet', 'PointNet', 'ResNet']

JITTER = 1e-5


class FactorNet(nn.Module):
    """A fully connected neural network for each dimension for
    parameterising a multivariate Gaussian distribution.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) sets the initial output
    sigma.
    """

    def __init__(self, in_dim, out_dim, hidden_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., sigma=None,
                 train_sigma=False,  min_sigma=0., nonlinearity=F.relu):
        super().__init__()

        self.out_dim = out_dim

        # Rescale sigmas for multiple outputs.
        initial_sigma = initial_sigma * in_dim ** 0.5
        min_sigma = min_sigma * in_dim ** 0.5

        if sigma is not None:
            sigma = sigma * in_dim ** 0.5

        # A network for each dimension.
        self.networks = nn.ModuleList()
        if sigma is None:
            for _ in range(in_dim):
                self.networks.append(LinearGaussian(1, out_dim, hidden_dims,
                                                    initial_sigma,
                                                    initial_mu,
                                                    min_sigma=min_sigma,
                                                    nonlinearity=nonlinearity))
        else:
            for _ in range(in_dim):
                self.networks.append(LinearGaussian(1, out_dim, hidden_dims,
                                                    initial_sigma,
                                                    initial_mu, sigma,
                                                    train_sigma, min_sigma,
                                                    nonlinearity))

    def forward(self, x, mask=None):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_dim].
        :param mask: A Tensor, True/False indicating which inputs to ignore.
        """
        np_1 = torch.zeros(x.shape[0], x.shape[1], self.out_dim)
        np_2 = torch.zeros_like(np_1)

        # Pass through individual networks.
        for dim, x_dim in enumerate(x.transpose(0, 1)):
            if mask is not None:
                idx = torch.where(mask[:, dim])[0]
                x_in = x_dim[idx].unsqueeze(1)

                # Don't pass through if no inputs.
                if len(x_in) != 0:
                    mu, sigma = self.networks[dim](x_in)
                    np_1[idx, dim, :] = mu / sigma ** 2
                    np_2[idx, dim, :] = - 1. / (2. * sigma ** 2)
            else:
                x_in = x_dim.unsqueeze(1)
                mu, sigma = self.networks[dim](x_in)
                np_1[:, dim, :] = mu / sigma ** 2
                np_2[:, dim, :] = -1. / (2. * sigma ** 2)

        # Sum natural parameters.
        np_1 = torch.sum(np_1, 1)
        np_2 = torch.sum(np_2, 1)
        sigma = (- 1. / (2. * np_2)) ** 0.5
        mu = np_1 * sigma ** 2.

        return mu, sigma


class IndexNet(nn.Module):
    """A fully connected neural network for each dimension for
    parameterising a multivariate Gaussian distribution using Jonny's
    approach. Mask is applied before passing through network.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param middle_dim: An int, dimension of the aggregation layer.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param shared_hidden_dims: A list, dimensions of hidden layes of shared
    network.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) sets the initial output
    variance.
    """

    def __init__(self, in_dim, out_dim, middle_dim,
                 hidden_dims=(64, 64), shared_hidden_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., sigma=None,
                 train_sigma=False, min_sigma=0., nonlinearity=F.relu):
        super().__init__()

        self.out_dim = out_dim
        self.middle_dim = middle_dim

        # A network for each dimension.
        self.networks = nn.ModuleList()
        for _ in range(in_dim):
            self.networks.append(LinearNN(1, middle_dim, hidden_dims,
                                          nonlinearity))

        # Takes the aggregation of the outputs from self.networks.
        self.final_network = LinearGaussian(
            middle_dim, out_dim, shared_hidden_dims, initial_sigma,
            initial_mu, sigma, train_sigma, min_sigma, nonlinearity)

    def forward(self, x, mask=None):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_dim].
        :param mask: A Tensor, True/False indicating which inputs to ignore.
        """
        out = torch.zeros(x.shape[0], x.shape[1], self.middle_dim)

        # Pass through individual networks.
        for dim, x_dim in enumerate(x.transpose(0, 1)):
            if mask is not None:
                idx = torch.where(mask[:, dim])[0]
                x_in = x_dim[idx].unsqueeze(1)

                # Don't pass through if no inputs.
                if len(x_in) != 0:
                    x_out = self.networks[dim](x_in)
                    out[idx, dim, :] = x_out

            else:
                x_in = x_dim.unsqueeze(1)
                x_out = self.networks[dim](x_in)
                out[:, dim, :] = x_out

        # Aggregation layer.
        out = torch.sum(out, 1)

        # Pass through shared network.
        mu, sigma = self.final_network(out)
        return mu, sigma


class PointNet(nn.Module):
    """A fully connected neural network for parameterising a multivariate
    Gaussian distribution. The input data potentially contains NaNs,
    so we use the deep set approach as in EDDI. The first network maps to a
    middle layer. The outputs at the middle layer are aggregated in a
    permutation invariant manner (summation), then passed through the second
    network which parameterises a Gaussian.

    :param out_dim: An int, dimension of the output variable.
    :param first_hidden_dims: A list, dimensions of hidden layers for first
    network.
    :param second_hidden_dims: A list, dimensions of hidden layers for
    second network.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) set the initial output
    variance.
    """

    def __init__(self, out_dim, middle_dim, first_hidden_dims=(64, 64),
                 second_hidden_dims=(64, 64), nonlinearity=F.relu,
                 initial_sigma=1., initial_mu=0., min_sigma=0.):
        super().__init__()

        self.out_dim = out_dim
        self.middle_dim = middle_dim

        # Takes the index of the observation dimension and it's value.
        self.first_network = LinearNN(
            2, middle_dim, first_hidden_dims, nonlinearity)

        # Takes the aggregation of the outputs from self.first_network.
        self.second_network = LinearGaussian(
            middle_dim, out_dim, second_hidden_dims, initial_sigma,
            initial_mu, min_sigma=min_sigma, nonlinearity=nonlinearity)

    def forward(self, x, mask=None):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_dim].
        :param mask: A Tensor, 1s and 0s indicating which inputs to
        ignore.
        """
        out = torch.zeros(x.shape[0], x.shape[1], self.middle_dim)

        # Pass through first network.
        for dim, x_dim in enumerate(x.transpose(0, 1)):
            if mask is not None:
                idx = torch.where(mask[:, dim])[0]
                x_in = x_dim[idx].unsqueeze(1)
                x_in = torch.cat([x_in, torch.ones_like(x_in) * dim], 1)
                out[idx, dim, :] = self.first_network(x_in)
            else:
                x_in = x_dim.unsqueeze(1)
                torch.cat([x_in, torch.ones_like(x_in) * dim], 1)
                out[:, dim, :] = self.first_network(x_in)

        # Aggregation layer.
        out = torch.sum(out, 1)

        # Pass through second network.
        mu, sigma = self.second_network(out)

        return mu, sigma


class ResNet(nn.Module):
    """Analogous to a proper ResNet (init), the ResNet is a FactorNet with
    an IndexNet residual.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param nonlinearity: A function, the non-linearity to apply in between
    layers.
    :param initial_sigma: A float, (approximately) sets the initial output
    sigma.
    """

    def __init__(self, in_dim, out_dim, h_dims=(64, 64), rho_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., min_sigma=0.,
                 nonlinearity=F.relu):
        super().__init__()

        self.out_dim = out_dim

        # Rescale sigmas for multiple outputs.
        initial_sigma = initial_sigma * in_dim ** 0.5
        min_sigma = min_sigma * in_dim ** 0.5

        # A network for each dimension.
        self.hs = nn.ModuleList()
        for _ in range(in_dim):
            self.hs.append(LinearGaussian(
                1, out_dim, h_dims, initial_sigma=initial_sigma,
                initial_mu=initial_mu, min_sigma=min_sigma,
                nonlinearity=nonlinearity))

        # The residual network.
        self.rho = LinearGaussian(
            2*out_dim, out_dim, rho_dims, min_sigma=min_sigma,
            nonlinearity=nonlinearity)

    def forward(self, x, mask=None):
        """Returns parameters of a multivariate Gaussian distribution.

        :param x: A Tensor, input of shape [M, in_dim].
        :param mask: A Tensor, True/False indicating which inputs to ignore.
        """
        np_1 = torch.zeros(x.shape[0], x.shape[1], self.out_dim)
        np_2 = torch.zeros_like(np_1)

        # Pass through individual networks.
        for dim, x_dim in enumerate(x.transpose(0, 1)):
            if mask is not None:
                idx = torch.where(mask[:, dim])[0]
                x_in = x_dim[idx].unsqueeze(1)

                # Don't pass through if no inputs.
                if len(x_in) != 0:
                    mu, sigma = self.hs[dim](x_in)
                    np_1[idx, dim, :] = mu / sigma ** 2
                    np_2[idx, dim, :] = - 1. / (2. * sigma ** 2)
            else:
                x_in = x_dim.unsqueeze(1)
                mu, sigma = self.hs[dim](x_in)
                np_1[:, dim, :] = mu / sigma ** 2
                np_2[:, dim, :] = -1. / (2. * sigma ** 2)

        # Sum natural parameters.
        np_1 = torch.sum(np_1, 1)
        np_2 = torch.sum(np_2, 1)

        # Pass through to get residuals.
        np = torch.cat([np_1, np_2], 1)
        res_mu, res_sigma = self.rho(np)

        # Get residual natural parameters.
        res_np_1 = res_mu / res_sigma ** 2
        res_np_2 = -1. / (2. * res_sigma ** 2)

        # Combine all natural parameters.
        np_1 += res_np_1
        np_2 += res_np_2

        sigma = (- 1. / (2. * np_2)) ** 0.5
        mu = np_1 * sigma ** 2.

        return mu, sigma
