import numpy as np
import time
import torch
x=torch.inverse(torch.ones((0, 0)))
from torch.linalg import det as detto
from torch.linalg import norm
from torch.linalg import inv
x=torch.inverse(torch.ones((0, 0)))
from scipy.interpolate import griddata
import scipy.sparse as sparse
from fvgp import GP
from scipy.interpolate import Rbf
n_points = 4
x_grid = np.linspace(0, 1, n_points)
y_grid = np.linspace(0, 1, n_points)
from itertools import product
grid = np.asarray(list(product(x_grid, y_grid)))

mode = "4x25"

def RBF_func(x, hps):
    rbf_interpolator = Rbf(grid[:,0], grid[:,1], hps.numpy(), function='multiquadric')
    rbf_np = rbf_interpolator(x[:,0].numpy(), x[:,1].numpy())
    rbf_np[rbf_np<0.001] = 0.001
    return torch.from_numpy(rbf_np)


def kernel(x1,x2,hps):
    t1_start = time.perf_counter()
    x1 = torch.from_numpy(x1)
    x2 = torch.from_numpy(x2)

    hps = torch.from_numpy(hps)

    if mode == "10x10": 
        k1 = kernel_sdisc10x10(x1,x2, hps)
    if mode == "4x25":  
        k1 = kernel_sdisc4x25(x1,x2, hps)
    if mode == "2x50":
        k1 = kernel_sdisc2x50(x1,x2, hps)

    if torch.all(k1==0.): 
        k = k1.numpy()
        t1_stop = time.perf_counter()
        print("ready to return early after ", t1_stop - t1_start, flush = True)
        return k
    
    k2 = core_kernel(x1,x2, hps)
    #if torch.isinf(torch.sum(abs(k2))): print("inf in core", flush = True)
    #if torch.isinf(torch.sum(abs(k1))): print("inf in core", flush = True)
    #if torch.isnan(torch.sum(k2)): print("NaN in core", flush = True)
    #if torch.isnan(torch.sum(k1)): print("NaN in core", flush = True)
 
    kk = k1 * k2

    t1_stop = time.perf_counter()
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return kk.numpy()

#def kernel_sdisc2x50(x1,x2, hps):
#    k = torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[0:50],  hps[100:150]]),hps[200:250],hps[300:350]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[0:50],  hps[100:150]]),hps[200:250],hps[300:350])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[50:100],hps[150:200]]),hps[250:300],hps[350:400]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[50:100],hps[150:200]]),hps[250:300],hps[350:400]))
#    k[time_distance != 0.] = 0.
#    kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:404])
#    return kk

def kernel_sdisc4x25(x1,x2, hps):
    #k = torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[0:25],  hps[100:125]]),hps[200:225],hps[300:325]),
    #                f_gpu(x2[:,0:2],torch.column_stack([hps[0:25],  hps[100:125]]),hps[200:225],hps[300:325])) + \
    #    torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[25:50], hps[125:150]]),hps[225:250],hps[325:350]),
    #                f_gpu(x2[:,0:2],torch.column_stack([hps[25:50], hps[125:150]]),hps[225:250],hps[325:350])) + \
    #    torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[50:75], hps[150:175]]),hps[250:275],hps[350:375]),
    #                f_gpu(x2[:,0:2],torch.column_stack([hps[50:75], hps[150:175]]),hps[250:275],hps[350:375])) + \
    #    torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[75:100],hps[175:200]]),hps[275:300],hps[375:400]),
    #                f_gpu(x2[:,0:2],torch.column_stack([hps[75:100],hps[175:200]]),hps[275:300],hps[375:400]))
    #kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:404])
    kk = wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[80:83])
    return kk

#def kernel_sdisc10x10(x1,x2, hps):
#    k = torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[0:10],  hps[100:110]]),hps[200:210],hps[300:310]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[0:10],  hps[100:110]]),hps[200:210],hps[300:310])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[10:20], hps[110:120]]),hps[210:220],hps[310:320]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[10:20], hps[110:120]]),hps[210:220],hps[310:320])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[20:30], hps[120:130]]),hps[220:230],hps[320:330]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[20:30], hps[120:130]]),hps[220:230],hps[320:330])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[30:40], hps[130:140]]),hps[230:240],hps[330:340]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[30:40], hps[130:140]]),hps[230:240],hps[330:340])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[40:50], hps[140:150]]),hps[240:250],hps[340:350]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[40:50], hps[140:150]]),hps[240:250],hps[340:350])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[50:60], hps[150:160]]),hps[250:260],hps[350:360]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[50:60], hps[150:160]]),hps[250:260],hps[350:360])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[60:70], hps[160:170]]),hps[260:270],hps[360:370]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[60:70], hps[160:170]]),hps[260:270],hps[360:370])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[70:80], hps[170:180]]),hps[270:280],hps[370:380]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[70:80], hps[170:180]]),hps[270:280],hps[370:380])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[80:90], hps[180:190]]),hps[280:290],hps[380:390]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[80:90], hps[180:190]]),hps[280:290],hps[380:390])) + \
#        torch.outer(f_gpu(x1[:,0:2],torch.column_stack([hps[90:100],hps[190:200]]),hps[290:300],hps[390:400]),
#                    f_gpu(x2[:,0:2],torch.column_stack([hps[90:100],hps[190:200]]),hps[290:300],hps[390:400]))
#    k[time_distance != 0.] = 0.
#    kk = k + wendland_anisotropic_gp2Scale_gpu(x1,x2, hps[400:404])
#    return kk

def wendland_anisotropic_gp2Scale_gpu(x1, x2, hps):  # pragma: no cover
    d = _get_distance_matrix(x1, x2, hps)
    d[d > 1.] = 1.
    #kernel = hps[0] * (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    kernel = (1. - d) ** 8 * (35. * d ** 3 + 25. * d ** 2 + 8. * d + 1.)
    return kernel


def _get_distance_matrix(x1, x2, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2)))
    for i in range(2):
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[1 + i]) ** 2
    return torch.sqrt(d)



def b_gpu(x,x0, r, ampl):
    x_new = x - x0
    d = norm(x_new, axis = 1)
    a = torch.zeros(d.shape)
    a = 1.0 - (d**2/r**2)
    i = torch.where(a > 0.0)
    bump = torch.zeros(a.shape)
    e = torch.exp((-1.0/a[i])+1.)
    bump[i] = ampl * e
    return bump

def f_gpu(x,x0, radii, amplts):
    b = 0.
    dim = x.shape[1]
    for i in range(len(radii)):
        b += b_gpu(x, x0[i], radii[i], amplts[i])
    return b


def Lambda(x, hps):
    l = RBF_func(x, hps)
    return l

def gamma(x, hps):
    ee = RBF_func(x, hps)
    return ee

def Q(x1,x2,outer_sum):
    ost = outer_sum/2.0
    M = inv(ost)
    diff_matrix = torch.zeros((len(x1),len(x2),2))
    for i in range(2): diff_matrix[:,:,i] = (x1[:, i].reshape(-1, 1) - x2[:, i])
    Qres = torch.einsum('ijk,ijkl,ijl->ij', diff_matrix,M,diff_matrix)
    return Qres

def S(x, gamma_hps, lambda_hps1, lambda_hps2):
    GG = G(x, gamma_hps)
    return GG @ L(x, lambda_hps1, lambda_hps2) @ GG.permute(0,2,1)

def L(x, hps1, hps2):
    res = torch.zeros((len(x),len(x[0]),len(x[0])))
    res[:,0,0] = Lambda(x, hps1)
    res[:,1,1] = Lambda(x, hps2)
    return res


def G(x, hps):
    res = torch.zeros((len(x),len(x[0]),len(x[0])))    
    res[:,0,0] = torch.cos(gamma(x, hps))
    res[:,0,1] =-torch.sin(gamma(x, hps))
    res[:,1,0] = torch.sin(gamma(x, hps))
    res[:,1,1] = torch.cos(gamma(x, hps))
    return res



def outer_sum_S(S1,S2):
    ###S1,S2 \in R^(len(x),2,2)
    S1list = [S1 for i in range(len(S2))]
    S2list = [S2 for i in range(len(S1))]
    return (torch.stack(S1list, dim=0) + torch.stack(S2list, dim=0).permute(1,0,2,3)).permute(1,0,2,3)


def outer(S1,S2):
    return torch.einsum('ijk,ljk->iljk',S1,S2)


def matern_kernel_diffGPU(distance):
    length = 1.
    kernel = (1.0 + ((torch.sqrt(torch.tensor(3.0)) * distance) / length)) * torch.exp(
    -(torch.sqrt(torch.tensor(3.0)) * distance) / length)
    return kernel


def exp_kernel_diffGPU(distance):
    kernel = torch.exp(-distance)
    return kernel


def core_kernel(x1,x2,hps):
    #The kernel follows the mathematical definition of a kernel. This
    #means there is no limit to the variety of kernels you can define.
    #st = time.time()
    
    lamda1_hps = hps[0:16]
    lamda2_hps = hps[16:32]
    gamma_hps  = hps[32:48]

    S1 = S(x1,gamma_hps,lamda1_hps,lamda2_hps)
    S2 = S(x2,gamma_hps,lamda1_hps,lamda2_hps)


    hps_signal1 = hps[48:64]
    hps_signal2 = hps[64:80]

    signal_variance1 = RBF_func(x1, hps_signal1)
    signal_variance2 = RBF_func(x2, hps_signal2)


    A = torch.outer(signal_variance1,signal_variance2)
    outer_sum = outer_sum_S(S1,S2)
    
    det1 = detto(S1)**(1/4)
    det2 = detto(S2)**(1/4)
    det3 = detto(outer_sum/2.0)
    B = torch.outer(det1,det2)/torch.sqrt(det3)
    C = matern_kernel_diffGPU(torch.sqrt(Q(x1,x2,outer_sum)))
    #C = exp_kernel_diffGPU(torch.sqrt(Q(x1,x2,outer_sum)))

    k = A * B * C
    #if torch.isnan(torch.sum(A)): print("A nan", flush = True)
    #if torch.isnan(torch.sum(B)): print("B nan", flush = True)
    #if torch.isnan(torch.sum(C)): print("C nan", flush = True)
    return k


