import torch
import numpy as np
import torch.nn as nn

from .base_networks import LinearGaussian
from .partial_networks import IndexNet, FactorNet, PointNet

from torch.nn import functional as F

__all__ = ['SparseNet', 'FixedSparseNet']

JITTER = 1e-5


class SparseNet(nn.Module):
    """A fully connected neural network for parameterising a multivariate
    Gaussian distribution over inducing points. Uses a single network which
    accepts the observation value and distance to inducing point.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param z: A Tensor, the initial inducing point locations.
    :param hidden_dims: A list, dimensions of the hidden layers.
    :param k: An int, get parameters for the nearest k inducing points. If
    None then the inference network parameterises all inducing points.
    :param min_sigma: A float, the minimum output sigma.
    :param nonlinearity: A function, the non-linearity to apply in between
    hidden layers.
    :param contains_nan: A bool, whether the observations potentially
    contains nans.
    :param fixed_inducing: A bool, whether to fix inducing points.
    """

    def __init__(self, in_dim, out_dim, z, hidden_dims=(64, 64), k=None,
                 min_sigma=0., nonlinearity=F.relu, contains_nan=False,
                 fixed_inducing=True, pinference_net='zeroimputation'):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.min_sigma = min_sigma
        self.fixed_inducing = fixed_inducing

        if k is None:
            self.k = len(z)
            self.all_inducing = True
        else:
            self.k = k
            self.all_inducing = False

        # TODO: Option to share inducing points across latent dimensions.
        # Keep inducing point locations fixed for now.
        if fixed_inducing:
            # Keep inducing point locations fixed.
            self.z_all = z
        else:
            self.z_all = nn.Parameter(z, requires_grad=True)

        # Takes the distance between inducing points and the observation value.
        if contains_nan:
            # Choose partial inference network to use.
            if pinference_net == 'indexnet':
                middle_idx = len(hidden_dims) // 2
                dims_1 = hidden_dims[:middle_idx]
                dims_2 = hidden_dims[middle_idx:]
                middle_dim = hidden_dims[middle_idx]

                self.network = IndexNet(
                    (in_dim + 1), out_dim, middle_dim, dims_1, dims_2,
                    min_sigma=min_sigma, nonlinearity=nonlinearity)

            elif pinference_net == 'factornet':
                self.network = FactorNet(
                    (in_dim + 1), out_dim, hidden_dims, min_sigma=min_sigma,
                    nonlinearity=nonlinearity)

            elif pinference_net == 'pointnet':
                middle_idx = len(hidden_dims) // 2
                dims_1 = hidden_dims[:middle_idx]
                dims_2 = hidden_dims[middle_idx:]
                middle_dim = hidden_dims[middle_idx]

                self.network = PointNet(
                    out_dim, middle_dim, dims_1, dims_2,
                    min_sigma=min_sigma, nonlinearity=nonlinearity)

            else:
                # Default to zero imputation.
                self.network = LinearGaussian(
                    (in_dim + 1), out_dim, hidden_dims,
                    nonlinearity=nonlinearity, min_sigma=min_sigma)

        else:
            self.network = LinearGaussian(
                (in_dim + 1), out_dim, hidden_dims,
                nonlinearity=nonlinearity, min_sigma=min_sigma)

    def forward(self, x, y, mask=None, kernel=None):
        """Returns parameters of a multivariate Gaussian distribution at the k
        nearest inducing locations.

        :param x: A Tensor, input locations of shape [M, x_dim].
        :param y: A Tensor, observations of shape [M, y_dim].
        :param mask: A Tensor, the mask to apply to the observation data.
        """
        assert len(y.shape) == 2, 'Inputs must be shape [M, y_dim]'

        if len(x.shape) == 1:
            x = x.unsqueeze(1)

        assert len(x.shape) == 2, 'Inputs must be shape [M, x_dim]'

        # Stores the indeces of inducing locations.
        z_indices = torch.zeros(y.shape[0] * self.k, dtype=torch.int64)
        # Stores distances to pass through network.
        dists = torch.zeros(y.shape[0], self.z_all.shape[0])

        # Find nearest k inducing locations for each datapoint.
        if self.all_inducing:
            if kernel is None:
                for idx, (x_, y_) in enumerate(zip(x, y)):
                    # Don't think I can get around this.
                    dist = torch.norm(self.z_all - x_, dim=1)
                    dists[idx, :] = dist
            else:
                dists = kernel(x, self.z_all)

            z_indices[:] = torch.arange(0, len(self.z_all), dtype=torch.int64)
            inputs = torch.cat(
                [y.repeat_interleave(self.k, dim=0),
                 dists.flatten().unsqueeze(1)], 1)
        else:
            if kernel is None:
                for idx, (x_, y_) in enumerate(zip(x, y)):
                    # Don't think I can get around this.
                    dist = torch.norm(self.z_all - x_, dim=1)
                    dists[idx, :] = dist
            else:
                dists = kernel(x, self.z_all)

            knn = dists.topk(self.k, dim=1, largest=False)

            # Store inducing point indices.
            z_indices = knn.indices.flatten()
            inputs = torch.cat(
                [y.repeat_interleave(self.k, dim=0),
                 knn.values.flatten().unsqueeze(1)], 1)

        # Get means and variances.
        if mask is None:
            mu, sigma = self.network(inputs)
        else:
            # Modify mask appropriately.
            mask = mask.repeat_interleave(self.k, dim=0)
            mask = torch.cat([mask, torch.ones(mask.shape[0], 1).fill_(True)],
                             dim=1)

            mu, sigma = self.network(inputs, mask)

        # Convert to natural parameters for easy summations.
        np_1 = mu / sigma ** 2
        np_2 = -1. / (2. * sigma ** 2)

        # Find unique inducing point indeces.
        z_indices_unique, mapping_idxs = z_indices.unique(return_inverse=True)
        z_unique = self.z_all[z_indices_unique]

        num_unique = len(z_unique)
        num_locations = len(mapping_idxs)

        # Sum the natural parameters for each unique inducing point.
        sum_matrix = torch.zeros(num_unique, num_locations)
        sum_matrix[mapping_idxs, np.arange(num_locations)] = 1

        np_1 = torch.matmul(sum_matrix, np_1)
        np_2 = torch.matmul(sum_matrix, np_2)

        # Convert back to mean and standard deviation.
        sigma = (- 1. / (2. * np_2)) ** 0.5
        mu = np_1 * sigma ** 2.

        if len(z_unique.shape) == 1:
            z_unique = z_unique.unsqueeze(1)

        return z_unique, mu, sigma


class FixedSparseNet(nn.Module):
    """A fully connected neural network for parameterising a Gaussian
    distribution over the nearest k inducing locations, which are at fixed
    intervals.

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param inducing_spacing: A float, the spacing of inducing locations in
    latent space.
    :param: k: An int, the number of inducing locations the each observation
    'touches'.
    :param min_sigma: A float, the minimum output sigma.
    :param nonlinearity: A function, the non-linearity to apply inbetween
    hidden layers.
    :param contains_nan: A bool, whether the observations potentially
    contains nans.
     """

    def __init__(self, in_dim, out_dim, hidden_dims=(64, 64),
                 inducing_spacing=1., k=4, min_sigma=0.,
                 nonlinearity=F.relu, contains_nan=False):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.inducing_spacing = inducing_spacing

        # Round up k to the nearest odd number... why?
        self.k = (k // 2) * 2 + 1

        # Takes the distance between inducing points and the observation value.
        if contains_nan:
            # Split hidden_dims for defining a IndexNet class.
            middle_idx = len(hidden_dims) // 2
            dims_1 = hidden_dims[:middle_idx]
            dims_2 = hidden_dims[middle_idx:]
            middle_dim = hidden_dims[middle_idx]

            self.network = IndexNet(
                (in_dim + 1), out_dim, middle_dim, dims_1, dims_2,
                nonlinearity=nonlinearity)

        else:
            self.network = LinearGaussian(
                (in_dim + 1), out_dim, hidden_dims,
                nonlinearity=nonlinearity)

    def forward(self, x, y, mask=None):
        """Returns parameters of a multivariate Gaussian distribution at the k
        nearest inducing locations.

        :param x: A Tensor, input locations of shape [M, 1].
        :param y: A Tensor, observations of shape [M, in_dim].
        :param mask: A Tensor, the mask to apply to the observation data.
        """
        assert len(y.shape) == 2, 'Inputs must be shape [M, y_dim]'
        assert x.shape[1] == 1, 'Only works on 1D inputs for now.'

        # Store the indices of inducing locations.
        z_indices = torch.zeros(y.shape[0] * self.k, dtype=torch.int64)
        # Stores the [observation, distance] pairs to pass through the network.
        inputs = torch.zeros(y.shape[0] * self.k, y.shape[1] + 1)

        # Find nearest k inducing locations for each datapoint.
        # Can exploit the fact that inducing locations are at fixed intervals.
        # Indices are relative to minimum inducing location.
        z_min = ((x.min(0)[0] / self.inducing_spacing).round()
                 - (self.k - 1) / 2) * self.inducing_spacing

        z_mid = (x / self.inducing_spacing).round() * self.inducing_spacing
        idx_mid = ((z_mid - z_min) / self.inducing_spacing).round().squeeze(1)

        for i in range(self.k):
            j = i - self.k // 2
            # Store inducing point indices and distances.
            z_indices[i::self.k] = idx_mid + j
            inputs[i::self.k, :] = torch.cat(
                [y, (z_mid - x) + j * self.inducing_spacing], dim=1)

        # Get means and variances.
        if mask is None:
            mu, sigma = self.network(inputs)
        else:
            # Modify mask appropriately.
            mask = mask.repeat_interleave(self.k, dim=0)
            mask = torch.cat([mask, torch.ones(mask.shape[0], 1).fill_(True)],
                             dim=1)

            mu, sigma = self.network(inputs, mask)

        # Convert to natural parameters for easy summations.
        np_1 = mu / sigma ** 2
        np_2 = -1. / (2. * sigma ** 2)

        # Find unique inducing point indeces.
        z_indices_unique, mapping_idxs = z_indices.unique(
            return_inverse=True)
        z_unique = z_indices_unique * self.inducing_spacing + z_min

        num_unique = len(z_unique)
        num_locations = len(mapping_idxs)

        # Sum the natural parameters for each unique inducing point.
        sum_matrix = torch.zeros(num_unique, num_locations)
        sum_matrix[mapping_idxs, np.arange(num_locations)] = 1

        np_1 = torch.matmul(sum_matrix, np_1)
        np_2 = torch.matmul(sum_matrix, np_2)

        # Convert back to mean and standard deviation.
        sigma = (- 1. / (2. * np_2)) ** 0.5
        mu = np_1 * sigma ** 2.

        if len(z_unique.shape) == 1:
            z_unique = z_unique.unsqueeze(1)

        return z_unique, mu, sigma


class AlternativeFixedSparseNet(nn.Module):
    """A fully connected neural network for partameterising a Gaussian
    distribution over the nearest k inducing locations l(u|y).

    :param in_dim: An int, dimension of the input variable.
    :param out_dim: An int, dimension of the output variable.
    :param hidden_dims: A list, dimensions of hidden layers.
    :param inducing_spacing: A float, the spacing of inducing locations in
    latent space.
    :param: k: An int, the number of inducing locations the each observation
    'touches'.
    :param initial_sigma: A float, (approximately) sets the initial output
    variance.
    :param initial_mu: A float, (approximately) sets the initial output mean.
    :param min_sigma: A float, the minimum output sigma.
    :param nonlinearity: A function, the non-linearity to apply inbetween
    hidden layers.
    :param contains_nan: A bool, whether the observations potentially
    contains nans.
     """
    def __init__(self, in_dim, out_dim, hidden_dims=(64, 64),
                 inducing_spacing=.1, k=4, initial_sigma=1.,
                 initial_mu=0., min_sigma=0., nonlinearity=F.relu,
                 contains_nan=False):
        super().__init__()

        # Round up k to the nearest odd number.
        self.k = (k // 2) * 2 + 1
        self.out_dim = out_dim
        self.inducing_spacing = inducing_spacing

        if contains_nan:
            # Use IndexNet for inputs that potentially
            # contain nans.
            middle_idx = len(hidden_dims) // 2
            dims_1 = hidden_dims[:middle_idx]
            dims_2 = hidden_dims[middle_idx:]
            middle_dim = hidden_dims[middle_idx]
            self.network = IndexNet(
                in_dim, 2 * self.k * out_dim, middle_dim, dims_1, dims_2,
                initial_sigma, initial_mu, min_sigma=min_sigma,
                nonlinearity=nonlinearity)
        else:
            self.network = LinearGaussian(
                in_dim, 2 * self.k * self.out_dim, hidden_dims,
                initial_sigma, initial_mu, min_sigma=min_sigma,
                nonlinearity=nonlinearity)

    def forward(self, x, y, mask=None):
        """Returns parameters of a multivariate Gaussian distribution at the k
        nearest inducing locations.

        :param x: A Tensor, input locations of shape [M, 1].
        :param y: A Tensor, observations of shape [M, in_dim].
        :param mask: A Tensor, the mask to apply to the observation data.
        """
        assert len(y.shape) == 2, 'Inputs must be shape [M, in_dim]'
        assert x.shape[1] == 1, 'Only works on 1D inputs for now.'

        # Round input locations to the nearest inducing point.
        x = (x / self.inducing_spacing).round() * self.inducing_spacing
        x = x.squeeze()

        if mask is None:
            out = self.network(y)
        else:
            out = self.network(y, mask)

        # Natural parameters for the nearest k inducing points.
        np_1 = torch.zeros(x.shape[0] * self.k, self.out_dim)
        np_2 = torch.zeros(x.shape[0] * self.k, self.out_dim)
        z = torch.zeros(x.shape[0] * self.k)
        for i in range(self.k):
            # Inducing point location = j * self.inducing_spacing.
            j = i - (self.k - 1) / 2
            z[i::self.k] = x - j * self.inducing_spacing

            # Mean and standard deviation of Gaussian.
            mu = out[0][:, (i * self.out_dim):((i + 1) * self.out_dim)]
            sigma = out[1][:, (i * self.out_dim):((i + 1) * self.out_dim)]

            # Convert to natural parameters.
            np_1[i::self.k, :] = mu / sigma ** 2
            np_2[i::self.k, :] = -1. / (2. * sigma ** 2)

        # Find unique inducing points.
        z, z_idxs = z.unique(sorted=True, return_inverse=True)
        num_unique = len(z)
        num_locations = len(z_idxs)

        # Sum the natural parameters for each unique inducing point.
        sum_matrix = torch.zeros(num_unique, num_locations)
        sum_matrix[z_idxs, np.arange(num_locations)] = 1

        np_1 = torch.matmul(sum_matrix, np_1)
        np_2 = torch.matmul(sum_matrix, np_2)

        # Convert back to mean and standard deviation.
        sigma = (- 1. / (2. * np_2)) ** 0.5
        mu = np_1 * sigma ** 2.

        return z.unsqueeze(1), mu, sigma
