# System/Library imports
from typing import *

# Common data science imports
import torch

# GPytorch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, GridInterpolationKernel
from gpytorch.distributions import MultivariateNormal


class SKIP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, grid_size=-1, use_ard=True):
        super().__init__(train_x, train_y, likelihood)
        
        self.mean_module = ConstantMean()
        
        # Define a grid for inducing points
        if grid_size == -1:
            grid_size = gpytorch.utils.grid.choose_grid_size(train_x)
        print("SHAPE", train_x.shape[-1], "GRID_SIZE", grid_size)
        if grid_size <= 2:
            raise ValueError(f"Degenerate grid size {grid_size} ...")
        
        if use_ard:
            # base_kernel = RBFKernel(batch_shape=torch.Size([train_x.size(-1)]), ard_num_dims=train_x.shape[-1])
            self.base_covar_module = RBFKernel(ard_num_dims=train_x.shape[-1])
        else:
            # base_kernel = RBFKernel(batch_shape=torch.Size([train_x.size(-1)]))
            self.base_covar_module = RBFKernel()
        
        # self.covar_module = ScaleKernel(
        #     GridInterpolationKernel(
        #         base_kernel,
        #         grid_size=grid_size,
        #         # grid_bounds=[(-0.1, 1.1) for _ in range(train_x.shape[-1])],
        #         num_dims=1,
        #     )
        # )
        self.covar_module = gpytorch.kernels.ProductStructureKernel(
            gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.GridInterpolationKernel(self.base_covar_module, grid_size=grid_size, num_dims=1)
            ), num_dims=train_x.size(-1)
        )

        self.get_lengthscale()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    # def forward(self, x):
    #     mean_x = self.mean_module(x)
    #     univariate_covars = self.covar_module(x.mT.unsqueeze(-1))
    #     covar_x = univariate_covars.prod(dim=-3)
    #     return MultivariateNormal(mean_x, covar_x)
    
    def get_noise(self) -> float:
        return self.likelihood.noise_covar.noise.cpu()

    def get_lengthscale(self) -> float:
        # return self.covar_module.base_kernel.base_kernel.base_kernel.lengthscale.cpu()
        return self.base_covar_module.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        return self.covar_module.base_kernel.outputscale.cpu()
