import numpy as np
import time
import torch
from torch.linalg import det as detto
from torch.linalg import norm
from torch.linalg import inv
from scipy.interpolate import griddata
import scipy.sparse as sparse
from fvgp import GP
from scipy.interpolate import Rbf
n_points = 4
x_grid = np.linspace(0, 1, n_points)
y_grid = np.linspace(0, 1, n_points)
from itertools import product
grid = np.asarray(list(product(x_grid, y_grid)))
from scipy.interpolate import griddata
import gc


mode = "4x25"
x_train = np.genfromtxt("./data/x_train_CAhousing.csv", delimiter=" ")
x_test = np.genfromtxt("./data/x_test_CAhousing.csv", delimiter=" ")
y_train = np.genfromtxt("./data/y_train_CAhousing.csv", delimiter=" ")
y_test = np.genfromtxt("./data/y_test_CAhousing.csv", delimiter=" ")
cuda_device = torch.device("cuda:0")
x_train_dev = torch.from_numpy(x_train).to(cuda_device, dtype=torch.float32)




space = np.array([[0.4999, 15.0001],
                    [1.0, 52.0],
                    [0.84, 141.909],
                    [0.333 ,34.066],
                    [5.0, 35682.0],
                    [0.6923, 1243.333],
                    [32.54, 41.95],
                    [-124.35, -114.31]])

def RBF_func_sigma(x, hps):
    offset = hps[0]
    slope = hps[1]
    if len(x) == 640: interpolator = griddata(x_train, y_train, x, method = "nearest", fill_value = 0.0) + 0.000001
    else: interpolator = griddata(x_train, y_train, x, method = "nearest", fill_value = 0.0) + 0.000001
    res = (slope.numpy() * interpolator) + offset.numpy()
    return torch.from_numpy(res)

def g(x, ampl, mean, sigma):
    return ampl * torch.exp(-((x-mean)**2 / ( 2.*sigma**2)))

def rbf_ls(x, hps, i):
    mean = torch.linspace(space[i,0], space[i,1], 3)
    sigma = (space[i,1] - space[i,0])/5.
    return g(x, hps[0], mean[0], sigma) + g(x, hps[1], mean[1], sigma) + g(x, hps[2], mean[2], sigma)

def lengthscales(x, hps, i):
    return rbf_ls(x , hps, i)


##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
def kernelG(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32) #could be a copy of x1 for training

    dist_dev = l1_dist(x1_dev,x2_dev, hps_dev[3:11])
    kk = delta_kernel(hps_dev, x1_dev, x2_dev, cuda_device) * wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev, dist_dev)
        
    torch.cuda.empty_cache()
    res = kk.cpu().numpy() 
    t1_stop = time.perf_counter()
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res

def g_vec(x, hps, device):
    res = torch.zeros((len(x), len(x_train))).to(device, dtype = torch.float32)
    d = l1_dist(x, x_train_dev, hps[3:11])
    res[d < hps[0]/2.] = 1.
    return res

def delta_kernel(hps, x1, x2, device):
    gs1 = g_vec(x1, hps, device).T ##has to be an array with len(x_data) columns and len(x1) rows
    gs2 = g_vec(x2, hps, device).T ##has to be an array with len(x_data) columns and len(x2) rows
    G = gs1.T @ gs2
    print("max G: ",    torch.max(G), flush = True)
    return hps[1] * G/float(len(x_train))


def wendland_anisotropic_gp2Scale_gpuDiff3(hps, d):  # pragma: no cover
    d = d.detach().clone()
    d[d > 1.] = 1.
    a = d
    kernel = hps[2] * (1.-a)**8 * (35.*a**3 + 25.*a**2 + 8.*a + 1.)
    return kernel


def l1_dist(x1, x2, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(x1.device, dtype=torch.float32)
    for i in range(8):
        d += (abs(x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[i])
    return d

##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
##############################################TEST#########################
def kernel(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32) #could be a copy of x1 for training
    x1 = torch.from_numpy(x1)
    x2 = torch.from_numpy(x2)

    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)


    k2 = core_kernel_prod(x1,x2, x1_dev, x2_dev, hps, hps_dev, cuda_device)
    kk = k2 #* k1

    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    res = kk.cpu().numpy()
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)

    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res

def wendland_anisotropic_gp2Scale_gpu(x1, x2, hps, cuda_device):  # pragma: no cover
    d = _get_distance_matrix_gpu(x1, x2, cuda_device, hps)
    d[d > 1.] = 1.
    kernel = (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel



def wendland_anisotropic_gp2Scale_gpu13(x1, x2, hps, cuda_device):  # pragma: no cover
    d = _get_distance_matrix_gpu(x1, x2, cuda_device, hps)
    d[d > 1.] = 1.
    kernel = (1. - d) ** 3 * ((3. * d)  + 1.)
    return kernel


def _get_distance_matrix_gpu(x1, x2, device, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(device, dtype=torch.float32)
    for i in range(8):
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[i]) ** 2
    return torch.sqrt(d)



def b_gpu(x,x0, r, ampl, device):
    x_new = x - x0
    d = norm(x_new, axis = 1)
    a = torch.zeros(d.shape).to(device, dtype = torch.float32)
    a = 1.0 - (d**2/r**2)
    i = torch.where(a > 0.0)
    bump = torch.zeros(a.shape).to(device, dtype = torch.float32)
    e = torch.exp((-1.0/a[i])+1.).to(device, dtype = torch.float32)
    bump[i] = ampl * e
    return bump

def f_gpu(x,x0, radii, amplts, device):
    b = 0.
    dim = x.shape[1]
    for i in range(len(radii)):
        b += b_gpu(x, x0[i], radii[i], amplts[i], device)
    return b



def wendland_q(d):
    d[d > 1.] = 1.
    kernel = (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel



def core_kernel(x1,x2,x1_dev,x2_dev, hps, hps_dev, cuda_device):
    #The kernel follows the mathematical definition of a kernel. This
    #means there is no limit to the variety of kernels you can define.
    #st = time.time()

    signal_hps1 = hps[0:2]
    signal_hps2 = hps[2:4]
    length_scales = hps_dev[4:12]

    signal_variance1 = RBF_func_sigma(x1, signal_hps1).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, signal_hps1).to(cuda_device, dtype=torch.float32)
    A = torch.outer(signal_variance1,signal_variance2) 
   

    signal_variance1 = RBF_func_sigma(x1, signal_hps2).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, signal_hps2).to(cuda_device, dtype=torch.float32)
    B = torch.outer(signal_variance1,signal_variance2) 
    
    D = A + B

    C = wendland_anisotropic_gp2Scale_gpu(x1_dev,x2_dev, length_scales, cuda_device)
    #C = wendland_anisotropic_gp2Scale_gpu13(x1_dev,x2_dev, length_scalesW, cuda_device)
    #d = _get_distance_matrix_gpu(x1_dev, x2_dev, cuda_device, length_scalesM)
    #E = matern_kernel_diffGPU(d)
    k = D * C
    
    return k

def core_kernel58(x1,x2,x1_dev,x2_dev, hps, hps_dev, cuda_device):


    signal_hps = hps_dev[0:58]
    length_scales = hps_dev[58:66]
    signal_variance1 = RBF_func_sigma58(x1_dev, signal_hps)
    signal_variance2 = RBF_func_sigma58(x2_dev, signal_hps)
    A = torch.outer(signal_variance1,signal_variance2)
    C = wendland_anisotropic_gp2Scale_gpu(x1_dev,x2_dev, length_scales, cuda_device)
    k = A * C
    
    return k


def core_kernel_prod(x1,x2,x1_dev,x2_dev, hps, hps_dev, cuda_device):

    #signal_var = hps_dev[0]
    signal_hps1 = hps[0:2]
    signal_hps2 = hps[2:4]
    signal_variance1 = RBF_func_sigma(x1, signal_hps1).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, signal_hps1).to(cuda_device, dtype=torch.float32)
    S1 = torch.outer(signal_variance1,signal_variance2) 
   

    signal_variance1 = RBF_func_sigma(x1, signal_hps2).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, signal_hps2).to(cuda_device, dtype=torch.float32)
    S2 = torch.outer(signal_variance1,signal_variance2) 
    S = S1 + S2


    lhps = hps_dev[4:]

     
    k = torch.ones((len(x1), len(x2))).to(cuda_device, dtype=torch.float32)
    for i in range(x1.shape[1]):
        l1 = lengthscales(x1_dev[:,i], lhps[i*3 : (i+1)*3], i)
        l2 = lengthscales(x2_dev[:,i], lhps[i*3 : (i+1)*3], i)
        outer_sum = l1[:,None]**2 + l2[None,:]**2
        A = torch.sqrt(2.0 * torch.outer(l1,l2) / outer_sum)
        d = abs((x1_dev[:, i].reshape(-1, 1) - x2_dev[:, i]))
        B = wendland_q(d/outer_sum)
        k = k * B * A
    
    return S * k



