# System/Library imports
from typing import *

# Common data science imports
import torch

# GPytorch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, GridInterpolationKernel
from gpytorch.distributions import MultivariateNormal


class SKIGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, grid_size=-1, use_ard=True):
        super().__init__(train_x, train_y, likelihood)
        
        # Define a grid for inducing points
        # grid_size = gpytorch.utils.grid.choose_grid_size(train_x, 1.0)
        if grid_size == -1:
            grid_size = gpytorch.utils.grid.choose_grid_size(train_x)

        # self.grid_size = grid_size
        # self.register_buffer("grid_bounds", torch.tensor([[0.0, 1.0]]))
        # self.grid = torch.linspace(0.0, 1.0, grid_size).unsqueeze(-1)
        
        self.mean_module = gpytorch.means.ConstantMean()
        if grid_size <= 2:
            raise ValueError(f"Degenerate grid size {grid_size} ...")
        print("SHAPE", train_x.shape[-1], "GRID_SIZE", grid_size)
        # Kernel with Grid Interpolation
        if use_ard:
            base_kernel = RBFKernel(ard_num_dims=train_x.shape[-1])
        else:
            base_kernel = RBFKernel()
        self.covar_module = ScaleKernel(
            GridInterpolationKernel(
                base_kernel,
                grid_size=grid_size,
                grid_bounds=[(-0.1, 1.1) for _ in range(train_x.shape[-1])],
                num_dims=train_x.shape[-1],
            )
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
    
    def get_noise(self) -> float:
        return self.likelihood.noise_covar.noise.cpu()

    def get_lengthscale(self) -> float:
        return self.base_covar_module.base_kernel.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        return self.base_covar_module.outputscale.cpu()
