import numpy as np
import time
import torch
x=torch.inverse(torch.ones((1,1), device="cuda:0"))
from torch.linalg import det as detto
from torch.linalg import norm
from torch.linalg import inv
x=torch.inverse(torch.ones((1,1), device="cuda:0"))
from scipy.interpolate import griddata
import scipy.sparse as sparse
from fvgp import GP
from scipy.interpolate import Rbf
n_points = 4
x_grid = np.linspace(0, 1, n_points)
y_grid = np.linspace(0, 1, n_points)
from itertools import product
grid = np.asarray(list(product(x_grid, y_grid)))
from scipy.interpolate import griddata
import gc


mode = "4x25"
x_train = np.genfromtxt("./data/x_train_2dtopo.csv", delimiter=" ")
x_test = np.genfromtxt("./data/x_test_2dtopo.csv", delimiter=" ")
y_train = np.genfromtxt("./data/y_train_2dtopo.csv", delimiter=" ")
y_test = np.genfromtxt("./data/y_test_2dtopo.csv", delimiter=" ")



def RBF_func_lambda(x, hps):
    if len(x) < 200: interpolator = griddata(x_test, y_test, x, method = "nearest", fill_value = 0.0) + 0.00001
    else: interpolator = griddata(x_train, y_train, x, method = "nearest", fill_value = 0.0) + 0.00001
    res = 1./((hps[1].numpy() * interpolator) + hps[0].numpy())
    return torch.from_numpy(res)


def RBF_func_gamma(x, hps):
    #rbf_interpolator = griddata(grid, hps.numpy(), x, method = "cubic", fill_value = 0.01)
    res = np.zeros((len(x))) + hps[0].numpy()
    return torch.from_numpy(res)



def RBF_func_sigma(x, hps):
    interpolator = griddata(x_train, y_train, x, method = "nearest", fill_value = 0.00001)
    res = hps[0].numpy() + hps[1].numpy() * interpolator
    return torch.from_numpy(res)



def kernel(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32) #could be a copy of x1 for training
    x1 = torch.from_numpy(x1)
    x2 = torch.from_numpy(x2)

    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)

    if mode == "10x10": 
        k1 = kernel_sdisc10x10(x1_dev,x2_dev, hps_dev, cuda_device)
    if mode == "4x25":  
        k1 = kernel_sdisc4x25(x1_dev,x2_dev, hps_dev, cuda_device)
    if mode == "2x50":
        k1 = kernel_sdisc2x50(x1_dev,x2_dev, hps_dev, cuda_device)

    if torch.all(k1==0.): 
        k = k1.cpu().numpy()
        t1_stop = time.perf_counter()
        print("ready to return early after ", t1_stop - t1_start, flush = True)
        return k
    
    k2 = core_kernel(x1,x2, x1_dev, x2_dev, hps, cuda_device)
    #if torch.isnan(torch.sum(k1)): print("NaN in core", flush = True)
 
    kk = k2 * k1

    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    #print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return kk.cpu().numpy()

#def kernel_sdisc2x50(x1,x2, hps, device):
#    k = torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[0:50],  hps[100:150]]),hps[200:250],hps[300:350], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[0:50],  hps[100:150]]),hps[200:250],hps[300:350], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[50:100],hps[150:200]]),hps[250:300],hps[350:400], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[50:100],hps[150:200]]),hps[250:300],hps[350:400], device))
#    k[time_distance != 0.] = 0.
#    kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:404], device)
#    return kk

def kernel_sdisc4x25(x1,x2, hps, device):
    k = torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[0:25],  hps[100:125]]),hps[200:225],hps[300:325], device),
                    f_gpu(x2[:,0:2],torch.column_stack([hps[0:25],  hps[100:125]]),hps[200:225],hps[300:325], device)) + \
        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[25:50], hps[125:150]]),hps[225:250],hps[325:350], device),
                    f_gpu(x2[:,0:2],torch.column_stack([hps[25:50], hps[125:150]]),hps[225:250],hps[325:350], device)) + \
        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[50:75], hps[150:175]]),hps[250:275],hps[350:375], device),
                    f_gpu(x2[:,0:2],torch.column_stack([hps[50:75], hps[150:175]]),hps[250:275],hps[350:375], device)) + \
        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[75:100],hps[175:200]]),hps[275:300],hps[375:400], device),
                    f_gpu(x2[:,0:2],torch.column_stack([hps[75:100],hps[175:200]]),hps[275:300],hps[375:400], device))
    #kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:404], device)
    #kk = wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[80:83], device)
    kk = wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[7:9], device)
    return kk

#def kernel_sdisc10x10(x1,x2, hps, device):
#    k = torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[0:10],  hps[100:110]]),hps[200:210],hps[300:310], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[0:10],  hps[100:110]]),hps[200:210],hps[300:310], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[10:20], hps[110:120]]),hps[210:220],hps[310:320], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[10:20], hps[110:120]]),hps[210:220],hps[310:320], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[20:30], hps[120:130]]),hps[220:230],hps[320:330], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[20:30], hps[120:130]]),hps[220:230],hps[320:330], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[30:40], hps[130:140]]),hps[230:240],hps[330:340], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[30:40], hps[130:140]]),hps[230:240],hps[330:340], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[40:50], hps[140:150]]),hps[240:250],hps[340:350], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[40:50], hps[140:150]]),hps[240:250],hps[340:350], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[50:60], hps[150:160]]),hps[250:260],hps[350:360], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[50:60], hps[150:160]]),hps[250:260],hps[350:360], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[60:70], hps[160:170]]),hps[260:270],hps[360:370], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[60:70], hps[160:170]]),hps[260:270],hps[360:370], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[70:80], hps[170:180]]),hps[270:280],hps[370:380], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[70:80], hps[170:180]]),hps[270:280],hps[370:380], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[80:90], hps[180:190]]),hps[280:290],hps[380:390], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[80:90], hps[180:190]]),hps[280:290],hps[380:390], device)) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[90:100],hps[190:200]]),hps[290:300],hps[390:400], device),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[90:100],hps[190:200]]),hps[290:300],hps[390:400], device))
#    k[time_distance != 0.] = 0.
#    kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:404], device)
#    return kk

def wendland_anisotropic_gp2Scale_gpu(x1, x2, hps, cuda_device):  # pragma: no cover
    d = _get_distance_matrix_gpu(x1, x2, cuda_device, hps)
    d[d > 1.] = 1.
    #kernel = hps[0] * (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    kernel = (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel


def _get_distance_matrix_gpu(x1, x2, device, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(device, dtype=torch.float32)
    for i in range(2):
        #d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[1 + i]) ** 2
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[i]) ** 2
    return torch.sqrt(d)



def b_gpu(x,x0, r, ampl, device):
    x_new = x - x0
    d = norm(x_new, axis = 1)
    a = torch.zeros(d.shape).to(device, dtype = torch.float32)
    a = 1.0 - (d**2/r**2)
    i = torch.where(a > 0.0)
    bump = torch.zeros(a.shape).to(device, dtype = torch.float32)
    e = torch.exp((-1.0/a[i])+1.).to(device, dtype = torch.float32)
    bump[i] = ampl * e
    return bump

def f_gpu(x,x0, radii, amplts, device):
    b = 0.
    dim = x.shape[1]
    for i in range(len(radii)):
        b += b_gpu(x, x0[i], radii[i], amplts[i], device)
    return b


def Lambda(x, hps):
    l = RBF_func_lambda(x, hps)
    return l

def gamma(x, hps):
    ee = RBF_func_gamma(x, hps)
    return ee

def Q(x1,x2,outer_sum, device):
    ost = outer_sum/2.0
    M = inv(ost)
    diff_matrix = torch.zeros((len(x1),len(x2),2)).to(device, dtype=torch.float32)
    for i in range(2): diff_matrix[:,:,i] = (x1[:, i].reshape(-1, 1) - x2[:, i])
    Qres = torch.einsum('ijk,ijkl,ijl->ij', diff_matrix,M,diff_matrix)
    return Qres

def S(x, gamma_hps, lambda_hps1, lambda_hps2, device):
    GG = G(x, gamma_hps, device)
    return GG @ L(x, lambda_hps1, lambda_hps2, device) @ GG.permute(0,2,1)

def L(x, hps1, hps2, device):
    res = torch.zeros((len(x),len(x[0]),len(x[0])))
    res[:,0,0] = Lambda(x, hps1)
    res[:,1,1] = Lambda(x, hps2)
    return res.to(device, dtype=torch.float32)


def G(x, hps, device):
    res = torch.zeros((len(x),len(x[0]),len(x[0])))    
    res[:,0,0] = torch.cos(gamma(x, hps))
    res[:,0,1] =-torch.sin(gamma(x, hps))
    res[:,1,0] = torch.sin(gamma(x, hps))
    res[:,1,1] = torch.cos(gamma(x, hps))
    return res.to(device, dtype=torch.float32)



def outer_sum_S(S1,S2):
    ###S1,S2 \in R^(len(x),2,2)
    S1list = [S1 for i in range(len(S2))]
    S2list = [S2 for i in range(len(S1))]
    res = (torch.stack(S1list, dim=0) + torch.stack(S2list, dim=0).permute(1,0,2,3)).permute(1,0,2,3)
    del S1list
    del S2list
    gc.collect()
    torch.cuda.empty_cache()
    return res


def outer(S1,S2):
    return torch.einsum('ijk,ljk->iljk',S1,S2)



def matern_kernel_diffGPU(distance):
    length = 1.
    kernel = (1.0 + ((torch.sqrt(torch.tensor(3.0)) * distance) / length)) * torch.exp(
    -(torch.sqrt(torch.tensor(3.0)) * distance) / length)
    return kernel


def exp_kernel_diffGPU(distance):
    kernel = torch.exp(-distance)
    return kernel

def wendland_q(d):
    d[d > 1.] = 1.
    kernel = (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel



def core_kernel(x1,x2,x1_dev,x2_dev,hps, cuda_device):
    #The kernel follows the mathematical definition of a kernel. This
    #means there is no limit to the variety of kernels you can define.
    #st = time.time()
    
    #lamda1_hps = hps[0:16]
    #lamda2_hps = hps[16:32]
    #gamma_hps  = hps[32:48]
    #hps_signal1 = hps[48:64]
    #hps_signal2 = hps[64:80]

    lamda1_hps = hps[0:2]
    lamda2_hps = hps[2:4]
    gamma_hps  = hps[4:5]
    hps_signal = hps[5:7]
    #hps_signal2 = hps[7:9]


    #lamda1_hps = hps[0]
    #lamda2_hps = hps[1]
    #gamma_hps  = hps[2]
    #hps_signal1 = hps[3]
    #hps_signal2 = hps[4]



    S1 = S(x1,gamma_hps,lamda1_hps,lamda2_hps, cuda_device)
    S2 = S(x2,gamma_hps,lamda1_hps,lamda2_hps, cuda_device)



    signal_variance1 = RBF_func_sigma(x1, hps_signal).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, hps_signal).to(cuda_device, dtype=torch.float32)


    A = torch.outer(signal_variance1,signal_variance2)
    outer_sum = outer_sum_S(S1,S2)
    
    det1 = detto(S1)**(1/4)
    det2 = detto(S2)**(1/4)
    det3 = detto(outer_sum/2.0)
    B = torch.outer(det1,det2)/torch.sqrt(det3)
    #C = matern_kernel_diffGPU(torch.sqrt(Q(x1_dev,x2_dev,outer_sum, cuda_device)))
    C = wendland_q(torch.sqrt(Q(x1_dev,x2_dev,outer_sum, cuda_device)))

    k = A * B * C
    del S1
    del S2
    del A
    del B
    del C
    gc.collect()
    torch.cuda.empty_cache()
    S1 = 0
    S2 = 0
    A = 0
    B = 0
    C = 0
    #if torch.isnan(torch.sum(A)): print("A nan", flush = True)
    #if torch.isnan(torch.sum(B)): print("B nan", flush = True)
    #if torch.isnan(torch.sum(C)): print("C nan", flush = True)
    return k


