import numpy as np
import time
import torch
x=torch.inverse(torch.ones((1,1), device="cuda:0"))
from torch.linalg import det as detto
from torch.linalg import norm
from torch.linalg import inv
x=torch.inverse(torch.ones((1,1), device="cuda:0"))
from scipy.interpolate import griddata
import scipy.sparse as sparse
from fvgp import GP
from scipy.interpolate import Rbf
n_points = 4
x_grid = np.linspace(0, 1, n_points)
y_grid = np.linspace(0, 1, n_points)
from itertools import product
grid = np.asarray(list(product(x_grid, y_grid)))
from scipy.interpolate import griddata
import gc


x_train = np.genfromtxt("./data/x_train_2dtopo.csv", delimiter=" ")
x_test = np.genfromtxt("./data/x_test_2dtopo.csv", delimiter=" ")
y_train = np.genfromtxt("./data/y_train_2dtopo.csv", delimiter=" ")
y_test = np.genfromtxt("./data/y_test_2dtopo.csv", delimiter=" ")



def RBF_func_lambda(x, hps):
    if len(x) == len(x_train): interpolator = griddata(x_train, y_train, x, method = "nearest", fill_value = 0.0) + 0.00001
    else: interpolator = griddata(x_train, y_train, x, method = "linear", fill_value = 0.0) + 0.00001
    res = 1./((hps[1].numpy() * interpolator) + hps[0].numpy())
    return torch.from_numpy(res)


def RBF_func_gamma(x, hps):
    res = np.zeros((len(x))) + hps[0].numpy()
    return torch.from_numpy(res)



def RBF_func_sigma(x, hps):
    interpolator = griddata(x_train, y_train, x, method = "linear", fill_value = 0.00001)
    res = hps[0].numpy() + (hps[1].numpy() * interpolator)
    return torch.from_numpy(res)



def kernel(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32) #could be a copy of x1 for training
    x1 = torch.from_numpy(x1)
    x2 = torch.from_numpy(x2)

    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)

    k2 = core_kernel(x1,x2, x1_dev, x2_dev, hps, cuda_device)
 
    kk = k2

    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return kk.cpu().numpy()


def kernel_bump(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32) #could be a copy of x1 for training
    x1 = torch.from_numpy(x1)
    x2 = torch.from_numpy(x2)

    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)

    k1 = kernel_sdisc4x25(x1_dev,x2_dev, hps_dev[7:409], cuda_device)

    if torch.all(k1==0.): 
        k = k1.cpu().numpy()
        t1_stop = time.perf_counter()
        print("ready to return early after ", t1_stop - t1_start, flush = True)
        return k
    
    #k2 = core_kernel_bumps(x1,x2, x1_dev, x2_dev, hps[0:7], cuda_device)
    k2 = core_kernel(x1,x2, x1_dev, x2_dev, hps[0:7], cuda_device)
 
    kk = k2 + hps[409] * k1

    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    #print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return kk.cpu().numpy()


def kernel_sdisc4x25(x1,x2, hps, device):
    k = torch.outer(f_gpu(x1,torch.column_stack([hps[0:25],  hps[100:125]]),hps[200:225],hps[300:325], device),
                    f_gpu(x2,torch.column_stack([hps[0:25],  hps[100:125]]),hps[200:225],hps[300:325], device)) + \
        torch.outer(f_gpu(x1,torch.column_stack([hps[25:50], hps[125:150]]),hps[225:250],hps[325:350], device),
                    f_gpu(x2,torch.column_stack([hps[25:50], hps[125:150]]),hps[225:250],hps[325:350], device)) + \
        torch.outer(f_gpu(x1,torch.column_stack([hps[50:75], hps[150:175]]),hps[250:275],hps[350:375], device),
                    f_gpu(x2,torch.column_stack([hps[50:75], hps[150:175]]),hps[250:275],hps[350:375], device)) + \
        torch.outer(f_gpu(x1,torch.column_stack([hps[75:100],hps[175:200]]),hps[275:300],hps[375:400], device),
                    f_gpu(x2,torch.column_stack([hps[75:100],hps[175:200]]),hps[275:300],hps[375:400], device))
    kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:402], device)
    return kk

def wendland_anisotropic_gp2Scale_gpu(x1, x2, hps, cuda_device):  # pragma: no cover
    d = _get_an_distance_matrix_gpu(x1, x2, cuda_device, hps)
    d[d > 1.] = 1.
    kernel = (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel


def _get_distance_matrix_gpu(x1, x2, device, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(device, dtype=torch.float32)
    for i in range(2):
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[i]) ** 2
    return torch.sqrt(d)

def _get_an_distance_matrix_gpu(x1, x2, device, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(device, dtype=torch.float32)
    for i in range(2):
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[i]) ** 2
    return torch.sqrt(d)



def b_gpu(x,x0, r, ampl, device):
    x_new = x - x0
    d = norm(x_new, axis = 1)
    a = torch.zeros(d.shape).to(device, dtype = torch.float32)
    a = 1.0 - (d**2/r**2)
    i = torch.where(a > 0.0)
    bump = torch.zeros(a.shape).to(device, dtype = torch.float32)
    e = torch.exp((-1.0/a[i])+1.).to(device, dtype = torch.float32)
    bump[i] = ampl * e
    return bump

def f_gpu(x,x0, radii, amplts, device):
    b = 0.
    dim = x.shape[1]
    for i in range(len(radii)):
        b += b_gpu(x, x0[i], radii[i], amplts[i], device)
    return b


def Lambda(x, hps):
    l = RBF_func_lambda(x, hps)
    return l

def gamma(x, hps):
    ee = RBF_func_gamma(x, hps)
    return ee

def Q(x1,x2,outer_sum, device):
    ost = outer_sum/2.0
    M = inv(ost)
    diff_matrix = torch.zeros((len(x1),len(x2),2)).to(device, dtype=torch.float32)
    for i in range(2): diff_matrix[:,:,i] = (x1[:, i].reshape(-1, 1) - x2[:, i])
    Qres = torch.einsum('ijk,ijkl,ijl->ij', diff_matrix,M,diff_matrix)
    return Qres

def S(x, gamma_hps, lambda_hps1, lambda_hps2, device):
    GG = G(x, gamma_hps, device)
    return GG @ L(x, lambda_hps1, lambda_hps2, device) @ GG.permute(0,2,1)

def L(x, hps1, hps2, device):
    res = torch.zeros((len(x),len(x[0]),len(x[0])))
    res[:,0,0] = Lambda(x, hps1)
    res[:,1,1] = Lambda(x, hps2)
    return res.to(device, dtype=torch.float32)


def G(x, hps, device):
    res = torch.zeros((len(x),len(x[0]),len(x[0])))    
    res[:,0,0] = torch.cos(gamma(x, hps))
    res[:,0,1] =-torch.sin(gamma(x, hps))
    res[:,1,0] = torch.sin(gamma(x, hps))
    res[:,1,1] = torch.cos(gamma(x, hps))
    return res.to(device, dtype=torch.float32)



def outer_sum_S(S1,S2):
    ###S1,S2 \in R^(len(x),2,2)
    S1list = [S1 for i in range(len(S2))]
    S2list = [S2 for i in range(len(S1))]
    res = (torch.stack(S1list, dim=0) + torch.stack(S2list, dim=0).permute(1,0,2,3)).permute(1,0,2,3)
    del S1list
    del S2list
    gc.collect()
    torch.cuda.empty_cache()
    return res


def outer(S1,S2):
    return torch.einsum('ijk,ljk->iljk',S1,S2)



def matern_kernel_diffGPU(distance):
    length = 1.
    kernel = (1.0 + ((torch.sqrt(torch.tensor(3.0)) * distance) / length)) * torch.exp(
    -(torch.sqrt(torch.tensor(3.0)) * distance) / length)
    return kernel


def exp_kernel_diffGPU(distance):
    kernel = torch.exp(-distance)
    return kernel

def wendland_q(d):
    d[d > 1.0] = 1.0
    #kernel = (1.0 - d) ** 2
    #kernel =  ((1. - d)**4) * ((4. * d)  + 1.)
    kernel = (1.0 - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel


def core_kernel(x1,x2,x1_dev,x2_dev,hps, cuda_device):
    #The kernel follows the mathematical definition of a kernel. This
    #means there is no limit to the variety of kernels you can define.
    #st = time.time()
    
    lamda1_hps = hps[0:2]
    lamda2_hps = hps[2:4]
    gamma_hps  = hps[4:5]
    hps_signal = hps[5:7]


    S1 = S(x1,gamma_hps,lamda1_hps,lamda2_hps, cuda_device)
    S2 = S(x2,gamma_hps,lamda1_hps,lamda2_hps, cuda_device)

    signal_variance1 = RBF_func_sigma(x1, hps_signal).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, hps_signal).to(cuda_device, dtype=torch.float32)


    A = torch.outer(signal_variance1,signal_variance2)
    outer_sum = outer_sum_S(S1,S2)
    
    det1 = detto(S1)**(1/4)
    det2 = detto(S2)**(1/4)
    det3 = detto(outer_sum/2.0)
    B = torch.outer(det1,det2)/torch.sqrt(det3)
    C = wendland_q(torch.sqrt(Q(x1_dev,x2_dev,outer_sum, cuda_device)))

    k = A * B * C
    del A
    del B
    del C
    del S1
    del S2
    gc.collect()
    torch.cuda.empty_cache()
    return k

def core_kernel_bumps(x1,x2,x1_dev,x2_dev,hps, cuda_device):
    #The kernel follows the mathematical definition of a kernel. This
    #means there is no limit to the variety of kernels you can define.
    #st = time.time()
    
    lamda1_hps = hps[0:2]
    lamda2_hps = hps[2:4]
    gamma_hps  = hps[4:5]
    hps_signal = hps[5:7]


    S1 = S(x1,gamma_hps,lamda1_hps,lamda2_hps, cuda_device)
    S2 = S(x2,gamma_hps,lamda1_hps,lamda2_hps, cuda_device)

    signal_variance1 = RBF_func_sigma(x1, hps_signal).to(cuda_device, dtype=torch.float32)
    signal_variance2 = RBF_func_sigma(x2, hps_signal).to(cuda_device, dtype=torch.float32)


    A = torch.outer(signal_variance1,signal_variance2)
    outer_sum = outer_sum_S(S1,S2)
    
    det1 = detto(S1)**(1/4)
    det2 = detto(S2)**(1/4)
    det3 = detto(outer_sum/2.0)
    B = torch.outer(det1,det2)/torch.sqrt(det3)
    C = matern_kernel_diffGPU(torch.sqrt(Q(x1_dev,x2_dev,outer_sum, cuda_device)))

    k = A * B * C
    return k

#
