import torch
import torch.nn as nn

from .base_networks import LinearGaussian
from .partial_networks import FactorNet

from torch.nn import functional as F

__all__ = ['GlobalAuxiliaryNet', 'GlobalAuxiliaryNet2']

JITTER = 1e-5


class GlobalAuxiliaryNet(nn.Module):
    """Parameterises a multivariate Gaussian distribution. The input
    includes a shared auxiliary variable.

    :param in_dim: An int, dimension of the input variable.
    :param s_dim: An int, dimension of the auxiliary varaible.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param initial_sigma: A float, sets the initial output sigma.
    :param initial_mu: A float, sets the initial output mean.
    :param sigma: A float, sets the homoscedastic sigma.
    :param train_sigma: A bool, whether to train the homoscedastic sigma.
    :param min_sigma: A float, sets the minimum output sigma.
    :param contains_nan: A bool, whether the inputs potentially contain nans.
    :param nonlinearity: A function , the non-linearity to apply in between
    layers.
    """
    def __init__(self, in_dim, s_dim, out_dim, hidden_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., sigma=None,
                 train_sigma=False, min_sigma=0., contains_nan=False,
                 nonlinearity=F.relu):
        super().__init__()

        self.in_dim = in_dim
        self.s_dim = s_dim
        self.out_dim = out_dim

        if not contains_nan:
            self.network = LinearGaussian(
                (in_dim + s_dim), out_dim, hidden_dims, initial_sigma,
                initial_mu, sigma, train_sigma, min_sigma, nonlinearity)
        else:
            self.network = FactorNet(
                (in_dim + s_dim), out_dim, hidden_dims, initial_sigma,
                initial_mu, sigma, train_sigma, min_sigma, nonlinearity)

    def forward(self, inputs, masks=None):
        # Unpack inputs.
        s, y = inputs

        # Ensure auxiliary variable is one-dimensional.
        s = s.squeeze()

        # Append auxiliary variable to each input.
        y_ = torch.cat([y, s.unsqueeze(0).repeat(y.shape[0], 1)], 1)

        if masks is not None:
            _, mask = masks
            mask_ = torch.cat([mask, torch.ones(y.shape[0], s.shape[0])], 1)
        else:
            mask_ = None

        # Pass through network.
        mu, sigma = self.network(y_, mask_)

        return mu, sigma


class GlobalAuxiliaryNet2(nn.Module):
    """Parameterises a multivariate Gaussian distribution. The input
    includes a shared auxiliary variable.

    :param in_dim: An int, dimension of the input variable.
    :param s_dim: An int, dimension of the auxiliary varaible.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param initial_sigma: A float, sets the initial output sigma.
    :param initial_mu: A float, sets the initial output mean.
    :param sigma: A float, sets the homoscedastic sigma.
    :param train_sigma: A bool, whether to train the homoscedastic sigma.
    :param min_sigma: A float, sets the minimum output sigma.
    :param contains_nan: A bool, whether the inputs potentially contain nans.
    :param nonlinearity: A function , the non-linearity to apply in between
    layers.
    """

    def __init__(self, in_dim, s_dim, out_dim, hidden_dims=(64, 64),
                 initial_sigma=1., initial_mu=0., sigma=None,
                 train_sigma=False, min_sigma=0., contains_nan=False,
                 nonlinearity=F.relu):
        super().__init__()

        self.in_dim = in_dim
        self.s_dim = s_dim
        self.out_dim = out_dim

        # Rescale for multiple outputs.
        initial_sigma = initial_sigma * 2 ** 0.5
        min_sigma = min_sigma * 2 ** 0.5

        if sigma is not None:
            sigma = sigma * 2 ** 0.5

        self.networks = nn.ModuleList()

        # Auxiliary network.
        self.networks.append(LinearGaussian(
            s_dim, out_dim, hidden_dims, initial_sigma, initial_mu, sigma,
            train_sigma, min_sigma, nonlinearity))

        # Observation network.
        if not contains_nan:
            self.networks.append(LinearGaussian(
                in_dim, out_dim, hidden_dims, initial_sigma, initial_mu,
                sigma, train_sigma, min_sigma, nonlinearity))
        else:
            self.networks.append(FactorNet(
                in_dim, out_dim, hidden_dims, initial_sigma, initial_mu,
                sigma, train_sigma, min_sigma, nonlinearity))

    def forward(self, inputs, masks=None):
        # Unpack inputs.
        s, y = inputs

        if masks is not None:
            _, mask = masks
        else:
            mask = None

        # Ensure auxiliary variable is one-dimensional.
        if len(s.shape) == 2:
            s = s.squeeze(1)

        # Pass auxiliary variable through the auxiliary network.
        mu_s, sigma_s = self.networks[0].forward(s.unsqueeze(0))
        np_1_s = mu_s / sigma_s ** 2
        np_2_s = -1. / (2. * sigma_s ** 2)

        # Pass observations through observation network.
        mu_y, sigma_y = self.networks[1].forward(y, mask)
        np_1_y = mu_y / sigma_y ** 2
        np_2_y = -1. / (2. * sigma_y ** 2)

        # Sum natural parameters
        np_1 = np_1_s + np_1_y
        np_2 = np_2_s + np_2_y
        sigma = (-1. / (2. * np_2)) ** 0.5
        mu = np_1 * sigma ** 2.

        return mu, sigma
