import torch
import torch.nn as nn

from .base_networks import LinearGaussian
from .partial_networks import IndexNet

from torch.nn import functional as F

__all__ = ['MultiInputInferenceNet', 'GlobalGaussian']

JITTER = 1e-5


class MultiInputInferenceNet(nn.Module):
    def __init__(self, in_dims, out_dim, hidden_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., sigma=None,
                 train_sigma=False, min_sigma=0., contains_nans=None,
                 nonlinearity=F.relu):
        super().__init__()

        self.in_dims = in_dims
        self.out_dim = out_dim

        # Rescale sigmas for multiple outputs.
        initial_sigma = initial_sigma * len(in_dims) ** 0.5
        min_sigma = min_sigma * len(in_dims) ** 0.5

        if sigma is not None:
            sigma = sigma * len(in_dims) ** 0.5

        # A network for each input.
        self.networks = nn.ModuleList()
        if contains_nans is None:
            for in_dim in in_dims:
                self.networks.append(LinearGaussian(
                    in_dim, out_dim, hidden_dims, initial_sigma, initial_mu,
                    sigma, train_sigma, min_sigma, nonlinearity))
        else:
            # Split hidden_dims for defining a IndexNet class.
            middle_idx = len(hidden_dims) // 2
            dims_1 = hidden_dims[:middle_idx]
            dims_2 = hidden_dims[middle_idx:]
            middle_dim = hidden_dims[middle_idx]

            for in_dim, contains_nan in zip(in_dims, contains_nans):
                if contains_nan:
                    # If the input potentially contains nans use a
                    # IndexNet imputation class.
                    self.networks.append(IndexNet(
                        in_dim, out_dim, middle_dim, dims_1, dims_2,
                        initial_sigma, initial_mu, sigma, train_sigma,
                        min_sigma, nonlinearity))
                    # Try FactorNet imputation class.
                    # self.networks.append(FactorNet(
                    #     in_dim, out_dim, hidden_dims, initial_sigma,
                    #     initial_mu, sigma, train_sigma, min_sigma,
                    #     nonlinearity))
                else:
                    self.networks.append(LinearGaussian(
                        in_dim, out_dim, hidden_dims, initial_sigma,
                        initial_mu, sigma, train_sigma, min_sigma,
                        nonlinearity))

    def forward(self, y, mask=None):
        # Extract inputs from y.
        inputs = []
        i = 0
        for j in range(len(self.in_dims)):
            inputs.append(y[:, i:i+self.in_dims[j]])
            i += self.in_dims[j]

        if mask is not None:
            # Extract masks from mask.
            masks = []
            i = 0
            for j in range(len(self.in_dims)):
                masks.append(mask[:, i:i+self.in_dims[j]])
                i += self.in_dims[j]
        else:
            masks = None

        np_1 = torch.zeros(len(inputs), y.shape[0], self.out_dim)
        np_2 = torch.zeros_like(np_1)

        # Pass each input through it's corresponding network.
        if masks is None:
            for i, x in enumerate(inputs):
                mu, sigma = self.networks[i](x)
                np_1[i, :, :] = mu / sigma ** 2
                np_2[i, :, :] = -1. / (2. * sigma ** 2)
        elif isinstance(masks, list) or isinstance(masks, tuple):
            for i, (x, mask) in enumerate(zip(inputs, masks)):
                if mask is not None:
                    mu, sigma = self.networks[i](x, mask)
                else:
                    mu, sigma = self.networks[i](x)

                np_1[i, :, :] = mu / sigma ** 2
                np_2[i, :, :] = -1. / (2. * sigma ** 2)
        else:
            print('masks should either be None or a list/tuple.')
            raise TypeError

        # Sum natural parameters.
        np_1 = torch.sum(np_1, 0)
        np_2 = torch.sum(np_2, 0)
        sigma = -1. / (2. * np_2 ** 2)
        mu = np_1 * sigma ** 2.

        return mu, sigma


class GlobalGaussian(nn.Module):
    """Holds the parameters of a global Gaussian distribution.

    :param out_dim: An int, the dimension of the global variable.
    :param initial_sigma: A Tensor or float, sets the initial sigma.
    :param initial_mu: A Tensor or float, sets the initial mean.
    :param min_sigma: A float, the minimum sigma.
    """
    def __init__(self, out_dim, initial_sigma=1., initial_mu=1.,
                 min_sigma=0.):
        super().__init__()

        self.out_dim = out_dim
        self.min_sigma = min_sigma

        # Initialise the mean and sigma of the output distribution.
        self.mu = nn.Parameter(
            torch.tensor(initial_mu) + JITTER * torch.randn(self.out_dim, 1),
            requires_grad=True)

        self.raw_sigma = nn.Parameter(
            (torch.tensor(initial_sigma).exp() - 1).log()
            + JITTER * torch.randn(self.out_dim, 1), requires_grad=True)

    def forward(self, *args, **kwargs):
        mu = self.mu
        sigma = F.softplus(self.raw_sigma)
        sigma = self.min_sigma + (1 - self.min_sigma) * sigma

        return mu, sigma
