import gpytorch
from models.base_fbgp import BaseFBGP_gpytorch
import torch


class FBGP_gpytorch(BaseFBGP_gpytorch):
    def __init__(self, args, train_x, train_y, likelihood):
        super(FBGP_gpytorch, self).__init__(train_x, train_y, likelihood)

        self.train_x = train_x
        self.train_y = train_y
        self.args = args

        # self.mean_module = gpytorch.means.ConstantMean()
        self.mean_module = gpytorch.means.ZeroMean()

        # value of priors:
        # https://botorch.org/api/_modules/botorch/models/gp_regression.html#SingleTaskGP
        # https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Priors
        # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
        self.length_prior = gpytorch.priors.LogNormalPrior(0, torch.sqrt(torch.tensor([3])))

        # lengthscale_prior = gpytorch.priors.GammaPrior(1, 1)
        # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)

        # self.covar_module = gpytorch.kernels.ScaleKernel(
        #    gpytorch.kernels.RBFKernel(
        #        lengthscale_prior=lengthscale_prior,
        #    )#,
        #    #outputscale_prior=outputscale_prior
        # )
        self.covar_module = gpytorch.kernels.RBFKernel(lengthscale_prior=self.length_prior,
                                                       ard_num_dims=train_x.shape[1])


        # Initialize lengthscale and outputscale to mean of priors
        # self.covar_module.base_kernel.lengthscale = lengthscale_prior.mean
        # self.covar_module.outputscale = 1 #outputscale_prior.mean
        # self.covar_module.lengthscale = 1

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


