import GPy
import numpy as np
import scipy.stats as stats
from sklearn.metrics import mean_squared_error

class GP:
    def __init__(self, X, y, kernel=None, noise=None):
        is_trained = kernel and noise
        if kernel is None:
            kernel = GPy.kern.RBF(input_dim=X.shape[1], variance=1.0, lengthscale=1.0, ARD=True)
        gpr = GPy.models.GPRegression(X, y, kernel)
        if noise is None:
            gpr.Gaussian_noise.variance = 1**2
        else:
            gpr.Gaussian_noise.variance = noise
            gpr.Gaussian_noise.variance.fix()
        if not is_trained:
            gpr.optimize()
        self.kern = gpr.kern
        self.noise = gpr.Gaussian_noise.variance.values[0]
        self.gpr = gpr
        self.X = X
        self.y = y

    def to_dict(self):
        gp_dict = {}
        gp_dict['kern'] = self.kern.to_dict()
        gp_dict['noise'] = self.noise
        gp_dict['X'] = self.X
        gp_dict['y'] = self.y
        return gp_dict

    @staticmethod
    def from_dict(gp_dict):
        kern = GPy.kern.Kern.from_dict(gp_dict['kern'])
        noise = gp_dict['noise']
        gp = GP(gp_dict['X'], gp_dict['y'], kern, noise)
        return gp
        
    def mi(self, X, noise=None):
        if noise is None: 
            noise = np.full((len(X), 1), self.noise)
        _, logdet = np.linalg.slogdet(np.eye(len(X)) + self.kern.K(X, X) @ np.diagflat(1./noise))
        return 0.5 * logdet
    
    def gen_gp(self, X, y, train_noise):
        gpr = GPy.models.GPHeteroscedasticRegression(X, y, self.kern)
        gpr.het_Gauss.variance = train_noise
        gpr.kern.fix()
        gpr.het_Gauss.variance.fix()
        return gpr
    
    def mse(self, X, y, train_noise, X_test, y_test, test_noise):
        if len(X) == 0 and len(y) == 0:
            return mean_squared_error(y_test, np.zeros_like(y_test))
        gpr = self.gen_gp(X, y, train_noise)
        y_pred, y_var = gpr.predict_noiseless(X_test)
        return ((y_pred - y_test)**2).mean(axis = None if y.shape[1]== 1 else 0)
    
    def mnlp(self, X, y, train_noise, X_test, y_test, test_noise):
        if len(X) == 0 and len(y) == 0:
            return -stats.norm.logpdf(y_test, np.zeros_like(y_test), np.sqrt(self.kern.K(X_test, X_test).diagonal() + test_noise)).mean()
        gpr = self.gen_gp(X, y, train_noise) 
        y_pred, y_var = gpr.predict_noiseless(X_test)
        return -stats.norm.logpdf(y_test, y_pred, np.sqrt(y_var + test_noise)).mean(axis = None if y.shape[1]== 1 else 0)
    
    def evaluation(self, X, y, train_noise, X_test, y_test, test_noise=None):
        results = {
            'mse': self.mse(X, y, train_noise, X_test, y_test, test_noise),
            'mnlp': self.mnlp(X, y, train_noise, X_test, y_test, test_noise)
        }
        return results
 