# import time
# import copy
# from math import factorial

import numpy as np
# import matplotlib
# matplotlib.use("Agg")
# import matplotlib.pyplot as plt
from scipy import integrate
from scipy.stats import multivariate_normal
from scipy.stats import mvn
from scipy.stats import norm
from scipy import optimize
import GPy


def get_kernel_info(kernel):
    """
    Extracts information from a GPy kernel object.

    Parameters:
    kernel (GPy.kern.Kern): The GPy kernel object.

    Returns:
    dict: A dictionary containing the kernel's name, input dimension, lengthscale, and variance.
    """
    if not isinstance(kernel, GPy.kern.Kern):
        raise ValueError("The kernel must be an instance of GPy.kern.Kern.")

    kernel_name = kernel.__class__.__name__
    if kernel_name not in ['RBF', 'Matern32', 'Matern52']:
        raise ValueError(f"Unsupported kernel type: {kernel_name}. Only RBF, Matern32, and Matern52 are supported.")

    input_dim = kernel.input_dim
    lengthscale = kernel.lengthscale.values
    variance = kernel.variance.values

    nu = np.inf
    # if kernel_name == 'RBF':
    #     nu = np.inf
    if kernel_name == 'Matern32':
        nu = 1.5
    elif kernel_name == 'Matern52':
        nu = 2.5
    return kernel_name, input_dim, lengthscale, variance, nu


def multivariate_t_rvs(loc, scale, df, size=1, rng=None):
    """
    Generate random samples from a multivariate t-distribution.

    Parameters:
    loc (array-like): d-dimensional Mean vector of the distribution.
    scale (float): sigma parameter where we assume diagonal and stationary covariance parameter.
    df (float): Degrees of freedom.
    size (int): Number of samples to generate.

    Returns:
    ndarray: Random samples from the multivariate t-distribution.
    """
    if rng is None:
        rng = np.random.default_rng()

    d = loc.shape[0]
    Z = rng.normal(loc=0, scale=scale, size=(d, size))
    chi2_samples = rng.chisquare(df=df, size=(1, size))
    return loc + Z * np.sqrt(df / chi2_samples)


class RFF(object):
    def __init__(self, kernel, rng, basis_dim=1000):
        self.kernel = kernel
        self.basis_dim = basis_dim

        self.kernel_name, self.input_dim, self.lengthscale, self.variance, self.nu = get_kernel_info(kernel)

        if 'RBF' in self.kernel_name:
            self.W = rng.normal(loc=0, scale=1, size=(self.input_dim, self.basis_dim // 2))
        elif 'Matern' in self.kernel_name:
            self.W = multivariate_t_rvs(loc=np.zeros((self.input_dim, 1)), scale=1., df=2*self.nu, size=self.basis_dim // 2, rng=rng)

        self.b = np.atleast_2d(rng.uniform(0, 2 * np.pi, size=self.basis_dim))

    def transform(self, X):
        linear_transformed_X = X / self.lengthscale @ self.W
        return np.sqrt(2.0 * self.variance / self.basis_dim) * np.c_[np.cos(linear_transformed_X), np.sin(linear_transformed_X)]

    def transform_grad(self, x: np.ndarray):
        if x.shape != (1, self.input_dim):
            raise ValueError("x must be a (1 times d) array, but got shape {} in transform_grad".format(x.shape))

        linear_transformed_x = x / self.lengthscale @ self.W
        transformed_x = (
            np.sqrt(2.0 * self.variance / self.basis_dim)
            * np.c_[-np.sin(linear_transformed_x), np.cos(linear_transformed_x)]
        ) # 1 \times basis_dim
        grad_coefficient = np.tile((self.W / np.c_[self.lengthscale]), (1, 2)) # input_dim \times basis_dim
        return transformed_x * grad_coefficient


    def _test_approximation_error(self, X):
        kernel_matrix = self.kernel.K(X)
        transformed_X = self.transform(X)
        kernel_matrix_approximation = transformed_X @ transformed_X.T
        return np.max(np.abs(kernel_matrix - kernel_matrix_approximation))

def main():
    pass

if __name__ == '__main__':
    main()
