import numpy as np
import time
import torch
from scipy.interpolate import griddata
import scipy.sparse as sparse



train_station_elevation = np.load("train_station_elevation.npy")
train_station_dist2coast = np.load("train_station_dist2coast.npy")

def kernel(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32)
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)

    #k1 = wendland_time_gpu(x1_dev, x2_dev, hps_dev[13])
    #if torch.all(k1==0.): 
    #    k = k1.cpu().numpy()
    #    t1_stop = time.perf_counter()
    #    print("ready to return early after ", t1_stop - t1_start, flush = True)
    #    return k
    
    k2 = core_kernel(x1,x2, x1_dev, x2_dev, hps_dev[0:13], cuda_device)

    #if np.isinf(np.sum(abs(kk))): print("inf in kk", flush = True)
    k2 = k2.cpu().numpy()
    t1_stop = time.perf_counter()
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return k2


def wendland_anisotropic_gp2Scale_gpu(x1, x2, hps, cuda_device):  # pragma: no cover
    d = _get_distance_matrix_gpu(x1, x2, cuda_device, hps)
    d[d > 1.] = 1.
    kernel = hps[0] * (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel


def _get_distance_matrix_gpu(x1, x2, device, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(device, dtype=torch.float32)
    for i in range(3):
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[1 + i]) ** 2
    return torch.sqrt(d)

def _get_time_distance_gpu(x1, x2):  # pragma: no cover
    d = x1[:, 2].reshape(-1, 1) - x2[:, 2]
    return torch.sqrt(d)


def wendland_time_gpu(x1, x2, r):  # pragma: no cover
    d = _get_time_distance_gpu(x1, x2)
    #r = 10.
    d[d > r] = r
    kernel = (1. - d/r) ** 8 * (35. * (d/r) ** 3 + 25. * (d/r) ** 2 + 8. * (d/r) + 1.)
    return kernel


def b_gpu(x,x0, r, ampl, device):
    x_new = x - x0
    d = torch.linalg.norm(x_new, axis = 1)
    a = torch.zeros(d.shape).to(device, dtype = torch.float32)
    a = 1.0 - (d**2/r**2)
    i = torch.where(a > 0.0)
    bump = torch.zeros(a.shape).to(device, dtype = torch.float32)
    e = torch.exp((-1.0/a[i])+1.).to(device, dtype = torch.float32)
    bump[i] = ampl * e
    return bump

def f_gpu(x,x0, radii, amplts, device):
    b = 0.
    dim = x.shape[1]
    for i in range(len(radii)):
        b += b_gpu(x, x0[i], radii[i], amplts[i], device)
    return b

def lin(x,w, elev, dist):
    return w[0] + w[1] * elev + w[2] * dist

def Lambda(x,wl, elev, dist):
    lin1 = lin(x,wl, elev, dist)
    return torch.exp(lin1)

def gamma(x,wg, elev, dist):
    ee = torch.exp(lin(x,wg, elev, dist))
    return ee/((2./torch.pi) * (1.+ee))

def Q(x1,x2,outer_sum, wl3, device):
    ost = outer_sum/2.0
    M = torch.linalg.inv(ost)
    diff_matrix = torch.zeros((len(x1),len(x2),2)).to(device, dtype=torch.float32)
    for i in range(2):
        diff_matrix[:,:,i] = (x1[:, i].reshape(-1, 1) - x2[:, i])
    t_diff = (x1[:, 2].reshape(-1, 1) - x2[:, 2]) ** 2
    Qres = torch.einsum('ijk,ijkl,ijl->ij', diff_matrix,M,diff_matrix) + t_diff/wl3
    return Qres

def S(x,wg,wl1,wl2, elev, dist, device):
    GG = G(x, wg, elev, dist, device)
    return GG @ L(x,wl1,wl2, elev, dist, device) @ GG.permute(0,2,1)

def L(x,wl1,wl2, elev, dist, device):
    res = torch.zeros((len(x),len(x[0]),len(x[0]))).to(device, dtype=torch.float32)
    res[:,0,0] = Lambda(x,wl1, elev, dist)
    res[:,1,1] = Lambda(x,wl2, elev, dist)
    return res

def G(x,wg, elev, dist, device):
    res = torch.zeros((len(x),len(x[0]),len(x[0]))).to(device, dtype=torch.float32)
    res[:,0,0] = torch.cos(gamma(x,wg, elev, dist))
    res[:,0,1] =-torch.sin(gamma(x,wg, elev, dist))
    res[:,1,0] = torch.sin(gamma(x,wg, elev, dist))
    res[:,1,1] = torch.cos(gamma(x,wg, elev, dist))
    return res


def outer_sum_S(S1,S2):
    ###S1,S2 \in R^(len(x),2,2)
    S1list = [S1 for i in range(len(S2))]
    S2list = [S2 for i in range(len(S1))]
    return (torch.stack(S1list, dim=0) + torch.stack(S2list, dim=0).permute(1,0,2,3)).permute(1,0,2,3)

def outer(S1,S2):
    return torch.einsum('ijk,ljk->iljk',S1,S2)

def matern_kernel_diffGPU(distance):
    length = 1.
    kernel = (1.0 + ((torch.sqrt(torch.tensor(3.0)) * distance) / length)) * torch.exp(
    -(torch.sqrt(torch.tensor(3.0)) * distance) / length)
    return kernel

def exp_kernel_diffGPU(distance):
    kernel = torch.exp(-distance)
    return kernel

def wendland_q(d):
    d[d > 1.0] = 1.0
    kernel = (1.0 - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel

def core_kernel(x1,x2,x1_dev,x2_dev,hps_dev, cuda_device):
    #The kernel follows the mathematical definition of a kernel. This
    #means there is no limit to the variety of kernels you can define.
    #st = time.time()
    
    lamda1_hps = hps_dev[0:3]
    lamda2_hps = hps_dev[3:6]
    gamma_hps  = hps_dev[6:9]

    elev1 = torch.from_numpy(griddata(train_station_elevation[:,0:2] , train_station_elevation[:,2],  x1[:,0:2], method='nearest', fill_value=0.)).to(cuda_device, dtype=torch.float32)
    dist1 = torch.from_numpy(griddata(train_station_dist2coast[:,0:2], train_station_dist2coast[:,2], x1[:,0:2], method='nearest', fill_value=0.)).to(cuda_device, dtype=torch.float32)
    elev2 = torch.from_numpy(griddata(train_station_elevation[:,0:2] , train_station_elevation[:,2],  x2[:,0:2], method='nearest', fill_value=0.)).to(cuda_device, dtype=torch.float32)
    dist2 = torch.from_numpy(griddata(train_station_dist2coast[:,0:2], train_station_dist2coast[:,2], x2[:,0:2], method='nearest', fill_value=0.)).to(cuda_device, dtype=torch.float32)
    S1 = S(x1_dev[:, 0:2],gamma_hps,lamda1_hps,lamda2_hps, elev1, dist1, cuda_device)
    S2 = S(x2_dev[:, 0:2],gamma_hps,lamda1_hps,lamda2_hps, elev2, dist2, cuda_device)

    e0 = hps_dev[9]
    e1 = hps_dev[10]
    e2 = hps_dev[11]
    hps_time = hps_dev[12]

    signal_variance1 = torch.exp(e0 + 
                                 e1 * elev1 + 
                                 e2 * dist1)
    signal_variance2 = torch.exp(e0 + 
                                 e1 * elev2 + 
                                 e2 * dist2)

    A = torch.outer(signal_variance1,signal_variance2)
    outer_sum = outer_sum_S(S1,S2)
    B = torch.outer(torch.linalg.det(S1)**(1/4),torch.linalg.det(S2)**(1/4))/\
        torch.sqrt(torch.linalg.det(outer_sum/2.0))
    C = wendland_q(torch.sqrt(Q(x1_dev,x2_dev,outer_sum, hps_time, cuda_device)))
    
    k = A * B * C
    return k

def my_noise(x,hps):
    return np.ones((len(x))) * hps[13]

