# Gaussian Processes
import gpytorch
import torch
import numpy as np

from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution,\
    NaturalVariationalDistribution, TrilNaturalVariationalDistribution
from models.base_approx_gp import BaseApproximateGPModel


class ApproximateGPModel(BaseApproximateGPModel):
    def __init__(self, train_x, train_y, kernel, likelihood, inducing_points):
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        self.likelihood = likelihood
        self.train_x = train_x
        self.train_y = train_y
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(kernel)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
    
class ApproximateGPModel_NGD(BaseApproximateGPModel):
    def __init__(self, train_x, train_y, kernel, likelihood, inducing_points):
        variational_distribution = TrilNaturalVariationalDistribution(inducing_points.size(0)) # produces nan w/ logPred
        # variational_distribution =  NaturalVariationalDistribution(inducing_points.size(0), mean_init_std=0.001)
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        self.likelihood = likelihood
        self.train_x = train_x
        self.train_y = train_y
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(kernel)
        self.likelihood = likelihood

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


# The most general purpose multitask model is the Linear Model of Coregionalization (LMC),
# which assumes that each output dimension (task) is the linear combination of some latent functions ?
# https://docs.gpytorch.ai/en/latest/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.html
class MultitaskApproximateGPModel(BaseApproximateGPModel):
    def __init__(self, train_x, train_y, kernel, likelihood, inducing_points, num_latents, num_tasks):

        # Learn a different set of inducing points for each output dimension ([num_latents, num_inducing_points, 1])
        inducing_points = torch.rand(num_latents, inducing_points.size(0), 1)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        # (implementation-related solution)
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_latents])
        )

        # We have to wrap the VariationalStrategy in a LMCVariationalStrategy
        # so that the output will be a MultitaskMultivariateNormal rather than a batch output
        # (implementation-related solution)
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
            num_latents=num_latents,
            latent_dim=-1
        )

        super().__init__(variational_strategy)

        self.likelihood = likelihood
        self.train_x = train_x
        self.train_y = train_y

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
            batch_shape=torch.Size([num_latents])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


"""
Not working..
class HeteroskedasticGPModel(ExactGPModel):
    ""
    An Exact GP model using heteroskedastic noise.

    This model internally wraps another GP (an ExactGPModel) to model the
    observation noise. This allows the likelihood to make out-of-sample
    predictions for the observation noise levels. (Inspiration: BoTorch)
    ""
    def __init__(self, train_x, train_y, train_y_var, kernel, likelihood=None, num_tasks=1):
        #self._validate_tensor_args(x=train_x, y=train_y, yvar=train_yvar)
        #validate_input_scaling(train_x=train_x, train_y=train_y, train_yvar=train_yvar)
        #self._set_dimensions(train_x=train_x, train_y=train_y)
        noise_likelihood = gpytorch.likelihoods.GaussianLikelihood(
            noise_prior=gpytorch.priors.SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log),
            noise_constraint=gpytorch.constraints.GreaterThan(1e-4, transform=None, initial_value=1.0)
        )
        noise_model = ExactGPModel(
            train_x=train_x,
            train_y=torch.log(train_y_var),  # ensure positive noise
            kernel=kernel,
            likelihood=noise_likelihood,
            num_tasks=num_tasks
        )
        # This is here all the magic happens
        likelihood = gpytorch.likelihoods.gaussian_likelihood._GaussianLikelihoodBase(
            gpytorch.likelihoods.noise_models.HeteroskedasticNoise(noise_model))
        super().__init__(
            train_x=train_x,
            train_y=train_y,
            kernel=kernel,
            likelihood=likelihood,
            num_tasks=num_tasks
        )
        self.register_added_loss_term("noise_added_loss")
        self.update_added_loss_term(
            "noise_added_loss", gpytorch.mlls.NoiseModelAddedLossTerm(noise_model)
        )
        self.to(train_x)
"""
