import torch
import torch.nn as nn

from torch.nn import functional as F

__all__ = ['MeanFieldNet', 'MeanFieldSparseNet']

JITTER = 1e-5


class MeanFieldNet(nn.Module):
    """Hold the variational parameters for performing mean-field variational
    inference.

    :param n: An int, the number of training data.
    :param out_dim: An int, dimension of the output variable.
    :param initial_mu: A Tensor, sets the initial output mean.
    :param initial_sigma: A Tensor, sets the initial output variance.
    :param min_sigma: A float, the minimum output sigma.
    """
    def __init__(self, n, out_dim, initial_mu=0., initial_sigma=1.,
                 min_sigma=0.):
        super().__init__()

        self.n = n
        self.out_dim = out_dim
        self.min_sigma = min_sigma

        # Initialise the mean and sigmas of the variational distribution.
        self.mu = nn.Parameter(
            torch.tensor(initial_mu)
            + JITTER * torch.randn(self.n, self.out_dim), requires_grad=True)

        self.raw_sigma = nn.Parameter(
            (torch.tensor(initial_sigma).exp() - 1).log()
            + JITTER * torch.randn(self.n, self.out_dim),
            requires_grad=True)

    def forward(self, idx):
        """Returns parameters of a multivariate Gaussian distribution.

        :param idx: A Tensor, the indeces of the inputs.
        """
        mu = self.mu[idx, :]
        sigma = F.softplus(self.raw_sigma[idx, :])
        sigma = self.min_sigma + (1 - self.min_sigma) * sigma

        return mu, sigma


class MeanFieldSparseNet(MeanFieldNet):
    """Hold the variational parameters for performing VFE mean-field
    variational inference.

    :param z: A Tensor, the initial inducing point locations.
    :param out_dim: An int, dimension of the output variable.
    :param initial_mu: A Tensor, sets the initial output mean.
    :param initial_sigma: A Tensor, sets the initial output variance.
    :param min_sigma: A float, the minimum output sigma.
    """
    def __init__(self, z, out_dim, initial_mu=0., initial_sigma=1.,
                 min_sigma=0.):
        super().__init__(z.shape[0], out_dim, initial_mu, initial_sigma,
                         min_sigma)

        # Ensure inducing points are two-dimensional.
        if len(z.shape) == 1:
            z = z.unsqueeze(1)

        self.num_inducing = z.shape[0]
        self.z = nn.Parameter(z, requires_grad=True)

    def forward(self, *args, **kwargs):
        mu = self.mu
        sigma = F.softplus(self.raw_sigma)
        sigma = self.min_sigma + (1 - self.min_sigma) * sigma
        z = self.z

        return z, mu, sigma
