# Gaussian Processes w/ exact inference
import gpytorch
import torch
from models.base_exact_gp import BaseExactGPModel


class ExactGPModel(BaseExactGPModel):
    """
    A single task (output) GP model w/ exact inference.

    """
    def __init__(self, train_x, train_y, kernel, likelihood):
        super().__init__(train_x, train_y, likelihood) #ExactGPModel, self
        #self.mean_module = gpytorch.means.ConstantMean()
        self.mean_module = gpytorch.means.ZeroMean()
        #self.covar_module = gpytorch.kernels.ScaleKernel(kernel)
        #self.covar_module.outputscale = torch.tensor(1.)
        self.covar_module = kernel

        #self.mean_module.constant = torch.nn.Parameter(torch.mean(train_y))
        #self.covar_module.outputscale = torch.std(train_y) #torch.nn.Parameter(torch.std(train_y))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class FancyGPWithPriors(BaseExactGPModel):
    def __init__(self, train_x, train_y, likelihood):
        super(FancyGPWithPriors, self).__init__(train_x, train_y, likelihood)
        #self.mean_module = gpytorch.means.ConstantMean()
        self.mean_module = gpytorch.means.ZeroMean()

        # value of priors:
        # https://botorch.org/api/_modules/botorch/models/gp_regression.html#SingleTaskGP
        # https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Priors
        #lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
        lengthscale_prior = gpytorch.priors.LogNormalPrior(0, torch.sqrt(torch.tensor([3])))
        #lengthscale_prior = gpytorch.priors.GammaPrior(1, 1)
        #outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)

        #self.covar_module = gpytorch.kernels.ScaleKernel(
        #    gpytorch.kernels.RBFKernel(
        #        lengthscale_prior=lengthscale_prior,
        #    )#,
        #    #outputscale_prior=outputscale_prior
        #)
        self.covar_module = gpytorch.kernels.RBFKernel(lengthscale_prior=lengthscale_prior,
                                                       ard_num_dims=train_x.shape[1])

        # Initialize lengthscale and outputscale to mean of priors
        #self.covar_module.base_kernel.lengthscale = lengthscale_prior.mean
        #self.covar_module.outputscale = 1 #outputscale_prior.mean
        #self.covar_module.lengthscale = 1

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class MultitaskGPModel(BaseExactGPModel):
    """
    A multitask (multi-output) GP model w/ exact inference.

    """
    def __init__(self, train_x, train_y, kernel, likelihood, num_tasks=2):
        super().__init__(train_x, train_y, likelihood) #MultitaskGPModel, self
        self.mean_module = gpytorch.means.MultitaskMean(gpytorch.means.ConstantMean(), num_tasks=num_tasks)
        self.covar_module = gpytorch.kernels.MultitaskKernel(kernel, num_tasks=num_tasks, rank=1)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)


class BatchIndependentMultitaskGPModel(BaseExactGPModel):
    def __init__(self, train_x, train_y, likelihood, batch_size=2):
        super().__init__(train_x, train_y, likelihood)

        lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
        outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)


        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([batch_size]))
        #self.mean_module = gpytorch.means.LinearMean(input_size=train_x.shape[1], batch_shape=torch.Size([batch_size]))
        #self.mean_module = gpytorch.means.ZeroMean(batch_shape=torch.Size([batch_size]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                batch_shape=torch.Size([batch_size],),
                lengthscale_prior=lengthscale_prior,
                ard_num_dims=train_x.shape[1]
            ),
            #+
            #gpytorch.kernels.LinearKernel(
            #    batch_shape=torch.Size([batch_size], ),
            #    ard_num_dims=train_x.shape[1]
            #),
            batch_shape=torch.Size([batch_size]),
            outputscale_prior=outputscale_prior
        )

        # Initialize lengthscale and outputscale to mean of priors
        self.covar_module.base_kernel.lengthscale = .1 #lengthscale_prior.mean
        self.covar_module.outputscale = 1 #outputscale_prior.mean

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(
            gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        )
