import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice


class SimplexGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, order=1, min_noise=1e-4, use_ard=True, use_scale=True):
        likelihood = gp.likelihoods.GaussianLikelihood(noise_constraint=gp.constraints.GreaterThan(min_noise))
        super().__init__(train_x, train_y, likelihood)
        
        self.mean_module = gp.means.ConstantMean()
        if use_ard:
            self.base_covar_module = RBFLattice(ard_num_dims=train_x.size(-1), order=order)
        else:
            self.base_covar_module = RBFLattice(order=order)
        self.use_scale = use_scale
        if self.use_scale:
            self.covar_module = gp.kernels.ScaleKernel(self.base_covar_module)
        else:
            self.covar_module = self.base_covar_module

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)
    
    def get_noise(self) -> float:
        return self.likelihood.noise_covar.noise.cpu()

    def get_lengthscale(self) -> float:
        if self.use_scale:
            return self.covar_module.base_kernel.lengthscale.cpu()
        else:
            return self.covar_module.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        if self.use_scale:
            return self.covar_module.outputscale.cpu()
        else:
            return 1.
    