import math
import torch
import gpytorch
from botorch.models.gpytorch import GPyTorchModel
from src.kernels import TemporalKernel, EmpiricalKernel, WienerKernel, TemporalKernelLearned


class BaseGPModel(gpytorch.models.ExactGP, GPyTorchModel):
    _num_outputs = 1

    def __init__(self,
                 train_x: torch.Tensor,  # dimensions of train_x: N x (D+1)
                 train_y: torch.Tensor,  # dimensions of train_y: N x 1
                 lengthscale_hyperprior: gpytorch.priors.Prior = None,
                 lengthscale_constraint: gpytorch.constraints.Interval = None,
                 type_of_forgetting=None,
                 empirical_kernel=None,
                 forgetting_factor: float = 0.03,
                 prior_mean: float = 0.):

        # check dimensions
        if train_y is not None:
            train_y = train_y.squeeze(-1)
        self.D = train_x.shape[-1] - 1

        # init likelihood
        noise_constraint = gpytorch.constraints.GreaterThan(
            1e-6)  # ensure lower bound on noise for numerical stability
        likelihood = gpytorch.likelihoods.GaussianLikelihood(
            noise_constraint=noise_constraint)

        super(BaseGPModel, self).__init__(train_x, train_y, likelihood)

        # specify model
        self.mean_module = gpytorch.means.ConstantMean()
        self.spatial_kernel = gpytorch.kernels.RBFKernel(ard_num_dims=self.D,
                                                         active_dims=tuple(range(self.D)),
                                                         lengthscale_prior=lengthscale_hyperprior,
                                                         lengthscale_constraint=lengthscale_constraint, )

        if empirical_kernel is not None:
            self.spatial_kernel = EmpiricalKernel(emp_kernel=empirical_kernel,
                                                  ard_num_dims=self.D,
                                                  active_dims=tuple(range(self.D)))

        if type_of_forgetting == 'UI':  # wiener process kernel
            self.forgetting_factor = forgetting_factor
            sigma_w_squared = self.forgetting_factor
            c0 = - 1 / sigma_w_squared
            self.temporal_kernel = WienerKernel(c0=c0,  # start at t= -10 to have higher flexibility in the mean
                                                sigma_hat_squared=sigma_w_squared,
                                                active_dims=self.D)
        elif type_of_forgetting == 'TVLearned':  # temporal kernel
            self.temporal_kernel = TemporalKernelLearned(epsilon=forgetting_factor,
                                                         active_dims=self.D)
        else:
            self.temporal_kernel = TemporalKernel(epsilon=forgetting_factor,
                                                  active_dims=self.D)

        # Initialize length scale and output scale to mean of priors.
        if lengthscale_hyperprior is not None and empirical_kernel is not None:
            self.spatial_kernel.lengthscale = lengthscale_hyperprior.mean

        if prior_mean != 0:
            self.mean_module.initialize(constant=prior_mean)
        self.mean_module.constant.requires_grad = False

        self.out_kernel = gpytorch.kernels.ScaleKernel(self.spatial_kernel)
        self.out_kernel.outputscale = 1.0

    def get_max_Kxx_dx2(self):
        """Computes the analytic second derivative of the kernel K(x,x) w.r.t. x.
        Args:
            x: (n x D) Test points.
        Returns:
            (n x D x D) The second derivative of K(x,x) w.r.t. x.
        """
        lengthscale = self.spatial_kernel.lengthscale[0].detach()
        return (torch.ones(self.D) / lengthscale ** 2) * 1

    def approximate_Lf(self, r, delta_L):
        max_dk = torch.sqrt(self.get_max_Kxx_dx2())
        dLks = self.get_dLk()

        # calculate upper bound on the Lipschitz constant of a sample from a zero mean GP (see Lederer et al. 2019)
        term1 = math.sqrt(2 * math.log(4 * self.D / delta_L))
        term2 = 12 * math.sqrt(6 * self.D) * torch.max(max_dk, torch.sqrt(r * dLks))
        return torch.linalg.vector_norm(term1 + term2)

    def get_dLk(self):
        lengthscales = self.spatial_kernel.lengthscale[0].clone().detach()
        return (math.sqrt(6) - 2) / (lengthscales ** 2) * (1 * math.exp(-1 / 2 * (3 - math.sqrt(6))))


class TimeInvariantGP(BaseGPModel):

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.out_kernel(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class TimeVariantGP(BaseGPModel):

    def forward(self, x):
        mean_x = self.mean_module(x)
        temp = self.temporal_kernel(x)
        spatio = self.out_kernel(x)
        covar_x = spatio * temp
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class ExactGPModel(gpytorch.models.ExactGP, GPyTorchModel):
    num_outputs = 1

    def __init__(self, train_x, train_y, ):
        likelihood = gpytorch.likelihoods.GaussianLikelihood()

        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)

        self.D = train_x.shape[-1]
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.spatial_kernel = self.covar_module.base_kernel

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    def get_max_Kxx_dx2(self):
        """Computes the analytic second derivative of the kernel K(x,x) w.r.t. x.
        Args:
            x: (n x D) Test points.
        Returns:
            (n x D x D) The second derivative of K(x,x) w.r.t. x.
        """
        lengthscale = self.spatial_kernel.lengthscale[0].detach()
        return (torch.ones(self.D) / lengthscale ** 2) * 1