import numpy as np
import time
import torch
from torch.linalg import det as detto
from torch.linalg import norm
from torch.linalg import inv
from scipy.interpolate import griddata
import scipy.sparse as sparse
from fvgp import GP
from scipy.interpolate import Rbf
from itertools import product
from scipy.interpolate import griddata
import gc
import os


devs = torch.cuda.device_count()
print("worker has GPUs: ", devs, flush = True)


if devs == 1: 
    dist = np.load("distance_l1.npy")
    cuda_device = torch.device("cuda:0")
    try: 
        print(a)
    except:
        a = "initialized already"
        dist_dev =      torch.from_numpy(dist[0:60000, 0:60000]).to(cuda_device, dtype=torch.float32)
        #dist_dev_test = torch.from_numpy(dist[60000:70000, 60000:70000]).to(cuda_device, dtype=torch.float32)
        #dist_dev_trte = torch.from_numpy(dist[0:60000, 60000:70000]).to(cuda_device, dtype=torch.float32)
    #dist = None
    #gc.collect()
#elif devs == 4:
    #dist = np.load("distance_l1.npy")
    #cuda_device = torch.device("cuda:0")
    #try: 
    #    print(a)
    #except:
    #    a = "initialized already"
    #    dist_dev_test = torch.from_numpy(dist[60000:70000, 60000:70000]).to(cuda_device, dtype=torch.float32)
    #    dist_dev_trte = torch.from_numpy(dist[0:60000, 60000:70000]).to(cuda_device, dtype=torch.float32)
    #dist = None
    #gc.collect()
#else:
#    print("not 1 not 4", flush = True)

import pickle
with open("./data/x_train_MNIST.pkl", 'rb') as f:
    x_train = pickle.load(f)
    x_train = np.asarray(x_train).reshape(len(x_train), 28*28)[0:60000]

with open("./data/x_test_MNIST.pkl", 'rb') as f:
    x_test = pickle.load(f)
    x_test = np.asarray(x_test).reshape(len(x_test), 28*28)


def kernelDiff1(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    if len(x1) == len(x2) == 15000:
        for i in range(0, len(x_train), 15000):
            if np.all(x_train[i] == x1[0].reshape(28,28)): i_start = i
            #if np.all(x_train[i] == x1[-1].reshape(28,28)): i_end = i
            if np.all(x_train[i] == x2[0].reshape(28,28)): j_start = i
            #if np.all(x_train[i] == x2[-1].reshape(28,28)): j_end = i

        i_end = i_start + 15000
        j_end = j_start + 15000

        print(i_start, i_end, j_start, j_end, flush = True)

        kk =  wendland_anisotropic_gp2Scale_gpuDiff1(hps_dev, dist_dev[i_start:i_end,j_start:j_end])

    elif len(x1) == len(x2) == 10000:
        print("predicting kk", flush = True)
        kk =  wendland_anisotropic_gp2Scale_gpuDiff1(hps_dev, dist_dev_test)
        print("shape: ", kk.shape)
    elif len(x1) != len(x2):
        print("predicting k", flush = True)
        kk =  wendland_anisotropic_gp2Scale_gpuDiff1(hps_dev, dist_dev_trte)
        print("shape k: ", kk.shape, flush = True)
    else:
        raise Exception("NO KERNEL")
    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    res = kk.cpu().numpy() 
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res



def kernelDiff2(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    if len(x1) == len(x2) == 15000:
        for i in range(0, len(x_train), 15000):
            if np.all(x_train[i] == x1[0].reshape(28,28)): i_start = i
            #if np.all(x_train[i] == x1[-1].reshape(28,28)): i_end = i
            if np.all(x_train[i] == x2[0].reshape(28,28)): j_start = i
            #if np.all(x_train[i] == x2[-1].reshape(28,28)): j_end = i

        i_end = i_start + 15000
        j_end = j_start + 15000

        print(i_start, i_end, j_start, j_end, flush = True)

        kk =  wendland_anisotropic_gp2Scale_gpuDiff2(hps_dev, dist_dev[i_start:i_end,j_start:j_end])

    elif len(x1) == len(x2) == 10000:
        print("predicting kk", flush = True)
        kk =  wendland_anisotropic_gp2Scale_gpuDiff2(hps_dev, dist_dev_test)
        print("shape: ", kk.shape)
    elif len(x1) != len(x2):
        print("predicting k", flush = True)
        kk =  wendland_anisotropic_gp2Scale_gpuDiff2(hps_dev, dist_dev_trte)
        print("shape k: ", kk.shape, flush = True)
    else:
        raise Exception("NO KERNEL")
    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    res = kk.cpu().numpy() 
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res





def kernelDiff3(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    if len(x1) == len(x2) == 15000:
        for i in range(0, len(x_train), 15000):
            if np.all(x_train[i] == x1[0].reshape(28,28)): i_start = i
            #if np.all(x_train[i] == x1[-1].reshape(28,28)): i_end = i
            if np.all(x_train[i] == x2[0].reshape(28,28)): j_start = i
            #if np.all(x_train[i] == x2[-1].reshape(28,28)): j_end = i

        i_end = i_start + 15000
        j_end = j_start + 15000

        print(i_start, i_end, j_start, j_end, flush = True)

        kk =  wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev, dist_dev[i_start:i_end,j_start:j_end])

    elif len(x1) == len(x2) == 10000:
        print("predicting kk", flush = True)
        kk =  wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev, dist_dev_test)
        print("shape: ", kk.shape)
    elif len(x1) != len(x2):
        print("predicting k", flush = True)
        kk =  wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev, dist_dev_trte)
        print("shape k: ", kk.shape, flush = True)
    else:
        raise Exception("NO KERNEL")
    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    res = kk.cpu().numpy() 
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res


def kernelDiff4(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    if len(x1) == len(x2) == 15000:
        for i in range(0, len(x_train), 15000):
            if np.all(x_train[i] == x1[0].reshape(28,28)): i_start = i
            #if np.all(x_train[i] == x1[-1].reshape(28,28)): i_end = i
            if np.all(x_train[i] == x2[0].reshape(28,28)): j_start = i
            #if np.all(x_train[i] == x2[-1].reshape(28,28)): j_end = i

        i_end = i_start + 15000
        j_end = j_start + 15000

        print(i_start, i_end, j_start, j_end, flush = True)

        kk =  kernel_cos(hps_dev, dist_dev[i_start:i_end,j_start:j_end])

    elif len(x1) == len(x2) == 10000:
        print("predicting kk", flush = True)
        kk =  kernel_cos(hps_dev, dist_dev_test)
        print("shape: ", kk.shape)
    elif len(x1) != len(x2):
        print("predicting k", flush = True)
        kk =  kernel_cos(hps_dev, dist_dev_trte)
        print("shape k: ", kk.shape, flush = True)
    else:
        raise Exception("NO KERNEL")
    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    res = kk.cpu().numpy() 
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res

def kernelG(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)
    if len(x1) == len(x2) == 15000:
        for i in range(0, len(x_train), 15000):
            if np.all(x_train[i] == x1[0].reshape(28,28)): i_start = i
            #if np.all(x_train[i] == x1[-1].reshape(28,28)): i_end = i
            if np.all(x_train[i] == x2[0].reshape(28,28)): j_start = i
            #if np.all(x_train[i] == x2[-1].reshape(28,28)): j_end = i

        i_end = i_start + 15000
        j_end = j_start + 15000

        print(i_start, i_end, j_start, j_end, flush = True)

        kk = delta_kernel(hps_dev[0], i_start,i_end, dist_dev, cuda_device) * wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], dist_dev[i_start:i_end,j_start:j_end])

    elif len(x1) == len(x2) == 10000:
        #i_start = 0
        #i_end = dist_dev_test.shape[0]
        print("kk ....", flush = True)
        #kk =  delta_kernel(hps[0], i_start,i_end, torch.from_numpy(dist[0:60000, 0:60000]), "cpu").to(cuda_device, dtype = torch.float32) * wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], dist_dev_test)
        
        dk = delta_kernel_pred(hps[0], dist_dev_test.cpu().T, torch.from_numpy(dist[0:60000,0:60000]), "cpu").to(cuda_device, dtype = torch.float32) 
        wk = wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], dist_dev_test)
        print(dk.shape, flush = True)
        print(wk.shape, flush = True)
       
        #kk =  delta_kernel(hps_dev[0], i_start,i_end, dist_dev_test, cuda_device) * wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], dist_dev_test)
        print("DONE", flush = True)
    elif len(x1) != len(x2):
        #i_start = 0
        #i_end = 10000 #dist_dev_trte.shape[0]
        print("k ....", flush = True)
        #kk =  delta_kernel(hps[0], i_start,i_end, torch.from_numpy(dist[0:60000, 0:60000]), "cpu").to(cuda_device, dtype = torch.float32) * wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], dist_dev_trte)
        dk = delta_kernel_pred(hps[0], dist_dev_trte.cpu().T, torch.from_numpy(dist[0:60000,0:60000]), "cpu").to(cuda_device, dtype = torch.float32)
        wk = wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], dist_dev_trte)
        print(dk.shape, flush = True)
        print(wk.shape, flush = True)
        kk = dk * wk
        print("DONE", flush = True)
    else:
        raise Exception("NO KERNEL")
    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    res = kk.cpu().numpy() 
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res




def kernelL1(x1,x2,hps):
    t1_start = time.perf_counter()
    cuda_device = torch.device("cuda:0")
    hps_dev = torch.from_numpy(hps).to(cuda_device, dtype=torch.float32)
    hps = torch.from_numpy(hps)
    x1_dev = torch.from_numpy(x1).to(cuda_device, dtype=torch.float32)
    x2_dev = torch.from_numpy(x2).to(cuda_device, dtype=torch.float32)

    d = pairwise_l1_distance_batched_torch(x1_dev, x2_dev, cuda_device)
    #kd = delta_kernel(hps_dev[0], d.T, d.T, cuda_device)  ###this gives the same matrix as doing it in the notebook on the first 10000 entries

    #kk = wendland_anisotropic_gp2Scale_gpuDiff3(hps_dev[1:3], d)
    #kk = kernel_cos(hps_dev[1:3], d)
    #res = kd.cpu().numpy() 
    #return res

    wendland = wendland_anisotropic_gp2Scale_gpuDiff3
    #wendland = RBF
    #wendland = EXP
    
    if len(x1) == len(x2) == 15000:
        x_train_dev = torch.from_numpy(x_train).to(cuda_device, dtype=torch.float32)
        if    torch.all(x1_dev[0] == x_train_dev[0]): d_full1 = dist_dev[0:15000,:]
        elif  torch.all(x1_dev[0] == x_train_dev[15000]): d_full1 = dist_dev[15000:30000,:]
        elif  torch.all(x1_dev[0] == x_train_dev[30000]): d_full1 = dist_dev[30000:45000,:]
        elif  torch.all(x1_dev[0] == x_train_dev[45000]): d_full1 = dist_dev[45000:60000,:]
        else: raise Exception("NO DISTANCE MATRIX")
        if    torch.all(x2_dev[0] == x_train_dev[0]): d_full2 = dist_dev[0:15000,:]
        elif  torch.all(x2_dev[0] == x_train_dev[15000]): d_full2 = dist_dev[15000:30000,:]
        elif  torch.all(x2_dev[0] == x_train_dev[30000]): d_full2 = dist_dev[30000:45000,:]
        elif  torch.all(x2_dev[0] == x_train_dev[45000]): d_full2 = dist_dev[45000:60000,:]
        else: raise Exception("NO DISTANCE MATRIX")

        ####d_full1 = pairwise_l1_distance_batched_torch(x_train_dev, x1_dev, cuda_device, batch_size = 500).T
        ####d_full2 = pairwise_l1_distance_batched_torch(x_train_dev, x2_dev, cuda_device, batch_size = 500).T
        kd = hps_dev[3] * delta_kernel(hps_dev[0], d_full1, d_full2, cuda_device)
        kk = wendland(hps_dev[1:3], d) * kd
        res = kk.cpu().numpy()
    elif len(x1) == len(x2) != 15000:
        print("kappa ...", flush = True)
        x_train_dev = torch.from_numpy(x_train).to(cuda_device, dtype=torch.float32)
        d_full2 = pairwise_l1_distance_batched_torch(x_train_dev, x2_dev, cuda_device).T
        d_full1 = pairwise_l1_distance_batched_torch(x_train_dev, x1_dev, cuda_device).T
        kk = wendland(hps_dev[1:3], d) * hps_dev[3] * delta_kernel(hps_dev[0],  d_full1, d_full2, cuda_device)
        res = kk.cpu().numpy() 
    elif len(x1) != len(x2):
        print("k ...", flush = True)
        x_train_dev1 = torch.from_numpy(x_train).to("cuda:1", dtype=torch.float32)
        x1_dev1 = torch.from_numpy(x1).to("cuda:1", dtype=torch.float32)
        x2_dev1 = torch.from_numpy(x2).to("cuda:1", dtype=torch.float32)
        #print("SHAPE OF x1, x2: ", x1_dev.shape, x2_dev.shape)
        dist = np.load("distance_l1.npy")
        d_full =      torch.from_numpy(dist[0:60000, 0:60000]).to("cuda:1", dtype=torch.float32)
        d2 =     pairwise_l1_distance_batched_torch(x_train_dev1, x2_dev1,      "cuda:1", batch_size = 1).T
        hps_dev1 = hps.to("cuda:1", dtype=torch.float32)

        kd = hps[3] * delta_kernel(hps_dev1[0],  d_full, d2, "cuda:1").cpu()
        kw = wendland(hps_dev[1:3], d).cpu()
        kk = kw * kd
        res = kk.numpy()
    else:
        print("No Kernel Selected")
        raise Exception("No Kernel Selected")



    t1_stop = time.perf_counter()
    torch.cuda.empty_cache()
    #res = kk.cpu().numpy() 
    if np.any(res != res): print("NAN  ###########################################", flush = True)
    if np.any(res > 1e16): print("INF  ###########################################", flush = True)
    if np.all(res == 0.):  print("ZERO ###########################################", flush = True)
    print("Elapsed kernel time:", t1_stop - t1_start, flush = True)
    return res



def wendland_anisotropic_gp2Scale_gpuDiff1(hps, d):  # pragma: no cover
    d = d.detach().clone()
    d[d > hps[1]] = hps[1]
    kernel = hps[0] * (1.0 - (d/hps[1])) ** 2
    #kernel[kernel < 1e-4] = 0.0
    return kernel


def RBF(hps, d):  # pragma: no cover
    kernel = hps[0] * torch.exp(-d**2 / hps[1])
    return kernel

def EXP(hps, d):  # pragma: no cover
    kernel = hps[0] * torch.exp(-d / hps[1])
    return kernel


def wendland_anisotropic_gp2Scale_gpuDiff2(hps, d):  # pragma: no cover
    d = d.detach().clone()
    d[d > hps[1]] = hps[1]
    kernel =  hps[0] * ((1. - (d/hps[1]))**4) * ((4. * (d/hps[1]))  + 1.)
    kernel[kernel < 1e-4] = 0.0
    return kernel

def wendland_anisotropic_gp2Scale_gpuDiff3(hps, d):  # pragma: no cover
    d = d.detach().clone()
    d[d > hps[1]] = hps[1]
    a = d/hps[1]
    kernel = hps[0] * (1.-a)**8 * (35.*a**3 + 25.*a**2 + 8.*a + 1.)
    #kernel[kernel < 1e-2] = 0.0
    return kernel


def kernel_cos(hps, d):  # pragma: no cover
    d = d.detach().clone()
    d[d > hps[1]] = hps[1]
    a = d/hps[1]
    kernel =  hps[0] * 0.5 * (torch.cos((a**(2/2.5))*torch.pi) + 1.)
    #kernel[kernel < 1e-4] = 0.0
    return kernel


def g_vec(d, hps, device):
    res = torch.zeros(d.shape).to(device, dtype = torch.float32)
    res[d < hps/2.] = 1.
    #res[1::2] = 0.
    return res

def delta_kernel(hps, d1, d2, device):  
    gs1 = g_vec(d1, hps, device) ##has to be an array with len(x_data) columns and len(x1) rows
    gs2 = g_vec(d2, hps, device) ##has to be an array with len(x_data) columns and len(x2) rows
    G = gs1 @ gs2.T
    print("max G: ",    torch.max(G/float(d1.shape[1])), d1.shape[1], flush = True)
    return G/float(d1.shape[1])


def _get_distance_matrix_gpu(x1, x2, device, hps):  # pragma: no cover
    d = torch.zeros((len(x1), len(x2))).to(device, dtype=torch.float32)
    for i in range(x1.shape):
        #d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[1 + i]) ** 2
        d += ((x1[:, i].reshape(-1, 1) - x2[:, i]) / hps[i]) ** 2
    return torch.sqrt(d)


def matern_kernel_diffGPU(distance):
    length = 1.
    kernel = (1.0 + ((torch.sqrt(torch.tensor(3.0)) * distance) / length)) * torch.exp(
    -(torch.sqrt(torch.tensor(3.0)) * distance) / length)
    return kernel


def exp_kernel_diffGPU(distance):
    kernel = torch.exp(-distance)
    return kernel

def pairwise_l1_distance(A, B):
    """
    Compute the pairwise L1 distance matrix between two sets of vectors.

    Parameters:
    A: numpy array of shape (n, d)
    B: numpy array of shape (m, d)

    Returns:
    D: numpy array of shape (n, m) where D[i, j] = L1 distance between A[i] and B[j]
    """

    #A = A.reshape(len(A), len(A[0])**2)
    #B = B.reshape(len(B), len(B[0])**2)
    
    A = A[:, torch.newaxis, :]  # shape (n, 1, d)
    B = B[torch.newaxis, :, :]  # shape (1, m, d)
    D = torch.sum(torch.abs(A - B), axis=2)  # shape (n, m)
    return D



def pairwise_l1_distance_lowmem(A, B, device):
    """
    Compute pairwise L1 distances between A and B with reduced memory usage.

    Parameters:
    A: numpy array of shape (n, d)
    B: numpy array of shape (m, d)

    Returns:
    D: numpy array of shape (n, m) where D[i, j] = L1 distance between A[i] and B[j]
    """
    n, d = A.shape
    m = B.shape[0]
    D = torch.empty((n, m)).to(device, dtype=torch.float32)


    for i in range(n):
        D[i, :] = torch.sum(torch.abs(B - A[i]), axis=1)

    return D


def pairwise_l1_distance_batched_torch(A, B, device=None,  batch_size=100):
    """
    Compute pairwise L1 distances between sets A and B using PyTorch with batching.

    A: torch.Tensor of shape (n, d)
    B: torch.Tensor of shape (m, d)
    batch_size: number of rows in A to process at once
    device: torch.device (e.g., 'cuda' or 'cpu')

    Returns:
    D: torch.Tensor of shape (n, m) on the specified device
    """
    if device is None:
        device = A.device

    n, m = A.shape[0], B.shape[0]
    D = torch.empty((n, m), dtype=torch.float32, device=device)

    for i in range(0, n, batch_size):
        end = min(i + batch_size, n)
        A_batch = A[i:end]  # (batch_size, d)
        # Broadcasting: (batch_size, 1, d) - (1, m, d) → (batch_size, m, d)
        D[i:end] = torch.sum(torch.abs(A_batch[:, None, :] - B[None, :, :]), dim=2)

    return D
