#Implement SUPER FALKON with sparse kernel implementation!
import copy
from sklearn.preprocessing import MinMaxScaler,StandardScaler
from sklearn import datasets
import numpy as np
import torch
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import falkon
from falkon import FalkonOptions
from falkon.kernels import Kernel, DiffKernel, KeopsKernelMixin,GaussianKernel
from sklearn.metrics import roc_auc_score
from falkon.hopt.objectives.exact_objectives.new_compreg import NystromCompReg
#from falkon.hopt.objectives.exact_objectives.utils import jittering_cholesky
import tqdm
import math
from torch.optim.lr_scheduler import ReduceLROnPlateau

def block_diagonal_matrix(block, num_blocks):
    return np.block([[block if i == j else np.zeros_like(block) for j in range(num_blocks)] for i in range(num_blocks)])


def transformed_features(l, r, dim):

    t = np.zeros((2,2))
    t[0, 1] = -1
    t[1, 0] = 1
    block = np.array([[0,1],[-1,0]])

    #print("dim = ", dim)
       
    A = block_diagonal_matrix(block, int(dim/2))

    if not isinstance(A, torch.Tensor):
       A = torch.tensor(A, dtype=torch.float32)

    device = l.device
    A = A.to(device)
   
    if not isinstance(l, torch.Tensor):
        l = l.clone().detach()
    if not isinstance(r, torch.Tensor):
        r = r.clone().detach()

    #print("left = ", l, l.shape)

    l2 = torch.matmul(l, A)

    #print("left = ", l2, l2.shape)

    #print("right = ", r, r.shape)

    x = torch.sum(l2 * r , dim = 1)
    
    A2 = np.zeros((dim,dim))
    #print("A2 shape = ", A2.shape)
    tf_length = math.comb(dim,2) #transformed feature length
    tf = np.zeros((l.shape[0],tf_length))
    tf = torch.tensor(tf, dtype=torch.float32)
    tf = tf.to(device)

    w = 0

    for i in range(dim):
        for j in range(i+1, dim):
            A2[i,j] = 1 
            A2[j,i] = -1 
            A2 = torch.tensor(A2, dtype=torch.float32)
            A2 = A2.to(device)
            l2 = torch.matmul(l, A2)
            x = torch.sum(l2 * r , dim = 1)
            #print("i, j = ", i, j)
            
            tf[:,w] = x

            w = w+1
            A2[i,j] = 0
            A2[j,i] = 0

    return tf

class NystromCompRegSHAP(NystromCompReg):

    def get_alpha(self):
        if self.x_train is None or self.y_train is None:
            raise RuntimeError("Call forward at least once before calling predict.")
        with torch.autograd.no_grad():
            L, A, AAT, LB, c = self._calc_intermediate(self.x_train, self.y_train)
            tmp1 = torch.linalg.solve_triangular(LB.T, c, upper=True)
            tmp2 = torch.linalg.solve_triangular(L.T, tmp1, upper=True)
            #print("alpha value (NystromCompRegSHAP) = ", tmp2)
            kms = self.kernel(self.centers, self.x_train)
            #print("tmp2 (alpha) shape = ", tmp2.shape)
            #print("self.centers shape = ", (self.centers).shape)
            weights = torch.matmul((self.centers).t(), tmp2)
            #print("function weights for the transformed features (d^2) = ", weights, weights.shape)
            items=self.centers
            dim = items.shape[1]
            #print("dim = ", dim)
            left = items[:,0:int(dim/2)]
            right = items[:,int(dim/2):dim]
            #print("left, right = ", left[0,:], right[0,:])
            final_weights = torch.matmul(transformed_features(left, right, int(dim/2)).t(), tmp2)

            print("final weights = ", final_weights, final_weights.shape)
            #print("kms=", kms)
            #print("self.centers = ", self.centers[0,:])
            #print("self.x_train = ", self.x_train[0,:])
            #print("kernel value (NystromCompRegSHAP) = ", kms)
            print("function prediction = ", kms.T @ tmp2)
            return tmp2

    def _calc_intermediate(self, X, Y):
       #print("calculate intermediate")
       variance = self.penalty * X.shape[0]
       #print("penalty = ", self.penalty)
       #print("variance = ", variance)
       
       kmn = self.kernel(self.centers, X)
       kmm = self.kernel(self.centers, self.centers)
       
       
       L = jittering_cholesky(kmm)
       #print("L = ", L)
       A = torch.linalg.solve_triangular(L, kmn, upper=False)
       #print("A = ", A)
       AAT = A @ A.T  
       B = AAT / variance + torch.eye(AAT.shape[0], device=X.device, dtype=X.dtype)
       #print("B = ", B.shape, B)
       LB = jittering_cholesky(B)  
       #print("LB = ", LB.shape, LB)

       AY = A @ Y  
       
       c = torch.linalg.solve_triangular(LB, AY, upper=False)  # m * p
       
       #print("c in calc intermediate = ", c)

       return L, A, AAT, LB, c

    def forward(self, X, Y):
        self.x_train, self.y_train = X.detach(), Y.detach()
        variance = self.penalty * X.shape[0]
        print("penalty = ", self.penalty)
        sqrt_var = torch.sqrt(variance)
        Kdiag = self.kernel(X, X, diag=True).sum()

        L, A, AAT, LB, c = self._calc_intermediate(X, Y)
        C = torch.linalg.solve_triangular(LB, A, upper=False)  # m * n

        datafit = torch.square(Y).sum() - torch.square(c / sqrt_var).sum()
        #datafit = torch.square(Y - (c / sqrt_var)).sum()
        ndeff = (C / sqrt_var).square().sum()
        trace = Kdiag - torch.trace(AAT)
        trace = trace * datafit / (variance * X.shape[0])
	
        '''
        print(f"RMSE: {torch.sqrt(datafit / X.shape[0]).item():.4f}")
        print(f"Data Fit: {datafit.item():.4f}")
        print(f"Effective Degrees (ndeff): {ndeff.item():.4f}")
        print(f"Trace Term: {trace.item():.4f}")
        print(f"Total Loss: {(ndeff + datafit + trace).item():.4f}\n")
	'''
	
        self._save_losses(ndeff, datafit, trace)
        #print("forward(loss function) = ",ndeff + datafit + trace)
        loss_forward = datafit + trace + ndeff
        return loss_forward, ndeff, datafit, trace


def cholesky(M, upper=False, check_errors=True):
    #print("NaN in M:", torch.isnan(M).any())
    #print("Inf in M:", torch.isinf(M).any())
    if upper:
       eigenvalues = torch.linalg.eigvals(M)
       print("eigenvalue", torch.all(eigenvalues.real > 0))  # Check if all eigenvalues are positive

       U, info = torch.linalg.cholesky_ex(M.transpose(-2, -1).conj())
       if check_errors:
          if info > 0:
             raise RuntimeError("Cholesky failed on row %d" % (info))
       return U.transpose(-2, -1).conj()
    else:
        #eigenvalues = torch.linalg.eigvals(M)
        #print("eigenvalues : ", eigenvalues)
        #print("eigenvalue", torch.any(eigenvalues.real < 0))  # Check if all eigenvalues are positive

        L, info = torch.linalg.cholesky_ex(M, check_errors=False)
        if check_errors:
           if info > 0:
              raise RuntimeError("Cholesky failed on row %d" % (info))
        return L
def jittering_cholesky(mat, upper=False):
    #print("jittering cholesky overloaded")
    #print("NaN in mat:", torch.isnan(mat).any())
    #print("Inf in mat:", torch.isinf(mat).any())
    eye = torch.eye(mat.shape[0], device=mat.device, dtype=mat.dtype)
    #print("maximum element of mat", torch.max(mat))
    epsilons = [1e-4] #[1e-1] #[1e-8, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0]
    last_exception = None
    for eps in epsilons:
        try:
            #print("epsilon = ", eps)
            return cholesky(mat + eye * eps, upper=upper, check_errors=True)
        except RuntimeError as e:  # noqa: PERF203
            last_exception = e
    raise last_exception

def accuracy(true,pred):
    true=true.cpu().numpy()
    pred=pred.cpu().numpy()
    true_zero_one = np.clip(true,0,1)
    pred_zero_one = pred>0
    acc = (true_zero_one==pred_zero_one).sum()/true_zero_one.shape[0]
    return acc

def auc(true,pred):
    true=true.cpu().numpy()
    pred=pred.cpu().numpy()
    true_zero_one = np.clip(true,0,1)
    pred_zero_one = pred>0
    auc = roc_auc_score(true_zero_one,pred_zero_one)
    return auc

def rmse(true, pred):
    #print("true.reshape(-1, 1) ", true.reshape(-1, 1))
    #print("pred.reshape(-1, 1) ", pred.reshape(-1, 1))
    #print("true.reshape(-1, 1) - pred.reshape(-1, 1) = ", true.reshape(-1, 1) - pred.reshape(-1, 1))
    return torch.sqrt(torch.mean((true.reshape(-1, 1) - pred.reshape(-1, 1))**2))

def calc_2(true, pred):
    res = (true-pred)**2
    return  (1- res.mean()/true.mean()).item()

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x ** 2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    # Ensure diagonal is zero if x=y
    # if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist, 0.0, np.inf)

def train_krr(kernel,Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,block,fold,pen,m_fac=1.0):

    #np.savetxt(f"scatterplot_train_KRR_train_{block}.txt", Ytrain)
    #print("train krr kernel=", kernel)
    penalty_init = torch.tensor(pen, dtype=torch.float32)
    M = int(round(Xtrain.shape[0] ** 0.5 * m_fac)) #Xtrain.shape[0] #int(round(Xtrain.shape[0] ** 0.5 * m_fac))
    print("M , Xtrain.shape[0] = ", M, Xtrain.shape[0])
    centers_init = Xtrain[np.random.choice(Xtrain.shape[0], size=(M,), replace=False)].clone()
    #print("centers_init = ", centers_init.shape)
    #centers_init = Xtrain#.clone()
    model = NystromCompRegSHAP(
        kernel=kernel, penalty_init=penalty_init, centers_init=centers_init,  # The initial hp values
        opt_penalty=True, opt_centers=True,  # Whether the various hps are to be optimized
    )
    
    Xtrain, Ytrain = Xtrain.to('cuda:0'), Ytrain.to('cuda:0')
    Xval, Yval = Xval.to('cuda:0'), Yval.to('cuda:0')
    Xtest, Ytest = Xtest.to('cuda:0'), Ytest.to('cuda:0')
    
    model = model.to('cuda:0')
    opt_hp = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4) #betas=(0.9, 0.98))
    #opt_hp = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    patience = 10
    best_val = np.inf
    best_test = np.inf
    best_tr = np.inf
    best_model = copy.deepcopy(model)
    best_loop = 0
    counter = 0
    num_epochs = 300
    pbar = tqdm.tqdm(list(range(num_epochs)))
    forward_loss = []
    training_rmse = []
    val_rmse = []
    test_rmse = []
    ndeff_vals = []
    datafit_vals = []
    trace_vals = []
    train_prediction = []
    val_prediction = []
    test_prediction = []
        

    scheduler = ReduceLROnPlateau(opt_hp, mode='min', factor=0.8, patience=10, verbose=True)
    

    for i, j in enumerate(pbar):
        print(i)
        opt_hp.zero_grad()

        
        loss, ndeff, datafit, trace = model(Xtrain, Ytrain) 

        forward_loss.append(loss.item())
        ndeff_vals.append(ndeff.item())
        datafit_vals.append(datafit.item())
        trace_vals.append(trace.item())
        loss.backward()
        opt_hp.step()
        try:
            tr_err = auc(Ytrain, model.predict(Xtrain))
            val_err = auc(Yval, model.predict(Xval))
            ts_err = auc(Ytest, model.predict(Xtest))
        except Exception:
            #print('auc not working')
            p_train = model.predict(Xtrain)
            tr_err = rmse(Ytrain, p_train)
            print("Xtrain, Ytrain, model.predict(Xtrain) = ", Xtrain.t(), Ytrain.t(), model.predict(Xtrain).t())
            #print("error = ", tr_err)
            p_val = model.predict(Xval)
            val_err = rmse(Yval, p_val)
            p_test = model.predict(Xtest)
            ts_err = rmse(Ytest, p_test)
            training_rmse.append(tr_err.item())
            val_rmse.append(val_err.item())
            test_rmse.append(ts_err.item())
            
            train_prediction.append(p_train.cpu().numpy())
            val_prediction.append(p_val.cpu().numpy())
            test_prediction.append(p_test.cpu().numpy())
            
            #opt_hp.step()  # Perform optimization
            scheduler.step(val_err.item())  # Decay LR when step_size is reached
            
        if val_err < best_val:
            best_model = copy.deepcopy(model)
            best_val = val_err
            best_test = ts_err
            best_tr = tr_err
            best_loop = i
            counter = 0
        pbar.set_description(f"Tr err: {tr_err} Val err: {val_err} Test err: {ts_err}")
        print("best val, best loop = ", best_val, best_loop)
        best_model.get_alpha()
        counter += 1
        #if counter > patience:
        #    break
     
    
    y_pred_best_train = best_model.predict(Xtrain)
    
    print("Ytrain.cpu().detach().numpy() variance = ", np.var(Ytrain.cpu().detach().numpy()), Ytrain.cpu().detach().numpy())
    print("y_pred_best_train.cpu().detach().numpy() variance = ", np.var(y_pred_best_train.cpu().detach().numpy()), y_pred_best_train.cpu().detach().numpy())
    
    print("train correlation = ", np.corrcoef(Ytrain.cpu().detach().numpy().flatten(), y_pred_best_train.cpu().detach().numpy().flatten())[0,1])
    correlation = np.corrcoef(Ytrain.cpu().detach().numpy().flatten(), y_pred_best_train.cpu().detach().numpy().flatten())[0,1]
    
    plt.figure(figsize=(6, 6))
    plt.scatter(Ytrain.cpu().detach().numpy(), y_pred_best_train.cpu().detach().numpy(), alpha=0.5, label="Predicted Values", color="blue")
    #plt.scatter(Ytrain.cpu().detach().numpy(), Ytrain.cpu().detach().numpy(), color="red", marker="x", label="True Labels")
    #plt.plot([min(Ytrain.cpu().detach().numpy()), max(Ytrain.cpu().detach().numpy())], [min(Ytrain.cpu().detach().numpy()), max(Ytrain.cpu().detach().numpy())], 'r--', label="Perfect Fit")  # Reference line
    plt.xlabel("True Values")
    plt.ylabel("Predicted Values")
    plt.title(f"Kernel Ridge Regression: Predictions vs. True Values_{correlation}")
    plt.legend()
    plt.savefig(f"scatterplot_train_{block}.png")
    #plt.show()
    
    y_pred_best_val = best_model.predict(Xval)
    correlation = np.corrcoef(Yval.cpu().detach().numpy().flatten(), y_pred_best_val.cpu().detach().numpy().flatten())[0,1]
    
    plt.figure(figsize=(6, 6))
    plt.scatter(Yval.cpu().detach().numpy(), y_pred_best_val.cpu().detach().numpy(), alpha=0.5, label="Predicted Values", color="blue")
    #plt.scatter(Yval.cpu().detach().numpy(), Yval.cpu().detach().numpy(), color="red", marker="x", label="True Labels")
    #plt.plot([min(Ytrain.cpu().detach().numpy()), max(Ytrain.cpu().detach().numpy())], [min(Ytrain.cpu().detach().numpy()), max(Ytrain.cpu().detach().numpy())], 'r--', label="Perfect Fit")  # Reference line
    plt.xlabel("True Values")
    plt.ylabel("Predicted Values")
    plt.title(f"Kernel Ridge Regression: Predictions vs. True Values_{correlation}")
    plt.legend()
    plt.savefig(f"scatterplot_val_{block}.png")
    
    y_pred_best_test = best_model.predict(Xtest)
    correlation = np.corrcoef(Ytest.cpu().detach().numpy().flatten(), y_pred_best_test.cpu().detach().numpy().flatten())[0,1]
    
    plt.figure(figsize=(6, 6))
    plt.scatter(Ytest.cpu().detach().numpy(), y_pred_best_test.cpu().detach().numpy(), alpha=0.5, label="Predicted Values", color="blue")
    #plt.scatter(Ytest.cpu().detach().numpy(), Ytest.cpu().detach().numpy(), color="red", marker="x", label="True Labels")
    #plt.plot([min(Ytrain.cpu().detach().numpy()), max(Ytrain.cpu().detach().numpy())], [min(Ytrain.cpu().detach().numpy()), max(Ytrain.cpu().detach().numpy())], 'r--', label="Perfect Fit")  # Reference line
    plt.xlabel("True Values")
    plt.ylabel("Predicted Values")
    plt.title(f"Kernel Ridge Regression: Predictions vs. True Values_{correlation}")
    plt.legend()
    plt.savefig(f"scatterplot_test_{block}.png")
    
    np.savetxt(f"Ytest_{block}_fold_{fold}.txt", Ytest.cpu().detach().numpy())
    np.savetxt(f"Ytrain_{block}_fold_{fold}.txt", Ytrain.cpu().detach().numpy())
    np.savetxt(f"Yval_{block}_fold_{fold}.txt", Yval.cpu().detach().numpy())
    
    np.savetxt(f"y_pred_best_test_{block}_fold_{fold}.txt", y_pred_best_test.cpu().detach().numpy())
    np.savetxt(f"y_pred_best_val_{block}_fold_{fold}.txt", y_pred_best_val.cpu().detach().numpy())
    np.savetxt(f"y_pred_best_train_{block}_fold_{fold}.txt", y_pred_best_train.cpu().detach().numpy())
    
    np.save(f"train_prediction_{block}_fold_{fold}.npy", np.array(train_prediction))
    np.save(f"val_prediction_{block}_fold_{fold}.npy", np.array(val_prediction))
    np.save(f"test_prediction_{block}_fold_{fold}.npy", np.array(test_prediction))

    print("np.array(test_prediction) shape = ", np.array(test_prediction).shape)

    
    
    '''   
    print("forward loss = ", forward_loss)
    print("training_rmse = ", training_rmse)
    print("val_rmse = ", val_rmse)
    print("test_rmse = ", test_rmse)
    '''
    
    
    fig, axes = plt.subplots(3, 2, figsize=(15, 12))  # 3 rows × 2 cols = 6 plots
    axes = axes.flatten()  # Flatten to index easily

    # Plot RMSEs
    axes[0].plot(range(num_epochs), training_rmse, label="Train RMSE")
    axes[0].plot(range(num_epochs), val_rmse, label="Val RMSE")
    axes[0].plot(range(num_epochs), test_rmse, label="Test RMSE")
    axes[0].set_title("RMSE")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("RMSE")
    axes[0].legend()
    axes[0].grid(True)

    # Plot Forward Loss
    axes[1].plot(range(num_epochs), forward_loss, label="Forward Loss")
    axes[1].set_title("Forward Loss")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].legend()
    axes[1].grid(True)

    # Plot Ndeff
    axes[2].plot(range(num_epochs), ndeff_vals, label="Ndeff", color="purple")
    axes[2].set_title("Effective Dimensionality (Ndeff)")
    axes[2].set_xlabel("Epoch")
    axes[2].set_ylabel("Ndeff")
    axes[2].legend()
    axes[2].grid(True)

    # Plot Trace
    axes[3].plot(range(num_epochs), trace_vals, label="Trace", color="green")
    axes[3].set_title("Trace of Kernel Matrix")
    axes[3].set_xlabel("Epoch")
    axes[3].set_ylabel("Trace")
    axes[3].legend()
    axes[3].grid(True)

    # Plot Datafit
    axes[4].plot(range(num_epochs), datafit_vals, label="Datafit", color="darkorange")
    axes[4].set_title("Datafit Term")
    axes[4].set_xlabel("Epoch")
    axes[4].set_ylabel("Datafit")
    axes[4].legend()
    axes[4].grid(True)

    # Leave the 6th subplot empty or add summary info
    axes[5].axis('off')  # Empty subplot
    axes[5].text(0.5, 0.5, f"Block: {block}", ha='center', va='center', fontsize=12)

    # Final layout
    plt.tight_layout()
    plt.savefig(f"combined_training_metrics_block_{block}.png")
    plt.close()

    
    return best_model, best_tr,best_val, best_test

def SGD_KRR_base(Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,block,ls,pen=1e-5,m_fac=1.0):
    lengthscale_init = torch.tensor([ls]*(Xtrain.shape[1])).requires_grad_()
    kernel = GaussianKernel(sigma=lengthscale_init,opt=FalkonOptions(keops_active="yes"))
    return train_krr(kernel,Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,block,pen,m_fac)
def SGD_KRR(Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,block,fold,ls,pen=1e-5,m_fac=1.0):
    #print("ls = ", ls)
    lengthscale_init = torch.tensor([ls]*(Xtrain.shape[1]//2)).requires_grad_()
    #print("lengthscale_init = ", lengthscale_init)
    kernel = diffrentiable_FALKON_GPGP(lengthscale=lengthscale_init, options=falkon.FalkonOptions())
    #print("sgd krr kernel", kernel)
    return train_krr(kernel,Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,block,fold,pen,m_fac)

def SGD_KRR_PGP(Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,ls,pen=1e-5,m_fac=1.0):
    lengthscale_init = torch.tensor([ls]*(Xtrain.shape[1]//2)).requires_grad_()
    kernel = diffrentiable_FALKON_PGP(lengthscale=lengthscale_init, options=falkon.FalkonOptions())
    return train_krr(kernel,Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,pen,m_fac)

def SGD_UKRR(Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,dat,pen=1e-5,m_fac=1.0):
    ls_u, ls_i, user_dim = dat
    lengthscale_i = torch.tensor([ls_i]*((Xtrain.shape[1]-user_dim)//2)).requires_grad_()
    lengthscale_u = torch.tensor([ls_u]*user_dim).requires_grad_()
    kernel = diffrentiable_FALKON_UGPGP(lengthscale_items=lengthscale_i,lengthscale_users=lengthscale_u,user_dim=user_dim, options=falkon.FalkonOptions())
    return train_krr(kernel,Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,pen,m_fac)


def SGD_UKRR_PGP(Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,dat,pen=1e-5,m_fac=1.0):
    ls_u, ls_i, user_dim = dat
    lengthscale_i = torch.tensor([ls_i]*((Xtrain.shape[1]-user_dim)//2)).requires_grad_()
    lengthscale_u = torch.tensor([ls_u]*user_dim).requires_grad_()
    kernel = diffrentiable_FALKON_UPGP(lengthscale_items=lengthscale_i,lengthscale_users=lengthscale_u,user_dim=user_dim, options=falkon.FalkonOptions())
    return train_krr(kernel,Xtrain,Ytrain,Xval,Yval,Xtest,Ytest,pen,m_fac)

class diffrentiable_FALKON_GPGP(DiffKernel):
    def __init__(self, lengthscale, options):
        # Super-class constructor call. We do not specify core_fn
        # but we must specify the hyperparameter of this kernel (lengthscale)
        super().__init__("diffrentiable_FALKON_GPGP", options,core_fn=None,
                         lengthscale=lengthscale)

    def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):
        #print("hello compute")
        ls = self.lengthscale
        #print("ls = ", ls)
        #ls = 1
        xa,xb = torch.chunk(X1,dim=1,chunks=2)
        xc,xd = torch.chunk(X2,dim=1,chunks=2)
        
        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)
        
        '''
        xa_ = xa/(torch.linalg.norm(xa))
        xb_ = xb/(torch.linalg.norm(xb))
        xc_ = xc/(torch.linalg.norm(xc))
        xd_ = xd/(torch.linalg.norm(xd))
        '''

        #print("ls = ", ls)
        #print("xa = ", xa)
        #print("X1 shape = ", X1.shape)
        #print("X2 shape = ", X2.shape)
        #print("diag = ", diag)
        
        #print("xa_ = ", xa_)
        #print("xc_ = ", xc_)
        
        
        #if torch.all(X1 == X2):  # Check if X1 and X2 are identical
           #X2 = X2 + 1e-6 * torch.randn_like(X2)  # Add small random noise
            
        #diag = False
        if diag:
            # xa_ = xa_
            # xb_ = xb_
            # xc_ = xc_
            # xd_ = xd_
            
            K = (-((xa_ - xc_) ** 2 + (xb_ - xd_) ** 2) / 2).sum(-1).exp() - (-((xa_ - xd_) ** 2 + (xb_ - xc_) ** 2) / 2).sum(-1).exp()
            
            #print("K shape = ", K.shape)
            #print("K = ", K)
                   
            #K = torch.sum((xa_ * xc_ + 1)**2, dim = 1) * torch.sum((xb_ * xd_ + 1)**2, dim = 1) - torch.sum((xa_ * xd_ + 1)**2, dim = 1) * torch.sum((xb_ * xc_ + 1)**2, dim = 1)
            #K = torch.sum(xa_ * xc_, dim = 1) * torch.sum(xb_ * xd_, dim = 1) - torch.sum(xa_ * xd_, dim = 1) * torch.sum(xb_ * xc_, dim = 1)
            
            out.copy_(K)
            #return K
        else:
            #print("now diag = ", diag)
            
            K = (-(pairwise_distances(xa_,xc_)+ pairwise_distances(xb_,xd_)) / 2).exp() - (-(pairwise_distances(xa_,xd_) + pairwise_distances(xb_,xc_)) / 2).exp()            
            #print("K = ", K)
            #print("K shape = ", K.shape)
            
            
            #K = ((torch.matmul(xa, xc.t()) + 1) ** 2) * ((torch.matmul(xb, xd.t()) + 1) ** 2) - ((torch.matmul(xa, xd.t()) + 1) ** 2) * ((torch.matmul(xb, xc.t()) + 1) ** 2)

            #K = torch.mul(torch.matmul(xa_ , xc_.t()), torch.matmul(xb_ , xd_.t())) - torch.mul(torch.matmul(xa_ , xd_.t()), torch.matmul(xb_ , xc_.t()))
            
            
            out.copy_(K)
        return out

    def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):
        # The implementation here is similar to `compute` without in-place operations.
        #print("hello compute_diff")
        ls = self.lengthscale.to(device=X1.device, dtype=X1.dtype)
        #print("ls = ", ls)
        #ls = 1
        xa,xb = torch.chunk(X1,dim=1,chunks=2)
        xc,xd = torch.chunk(X2,dim=1,chunks=2)
        
        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)
        
        '''
        xa_ = xa/(torch.linalg.norm(xa))
        xb_ = xb/(torch.linalg.norm(xb))
        xc_ = xc/(torch.linalg.norm(xc))
        xd_ = xd/(torch.linalg.norm(xd))
        '''
                    
        if diag:
            # xa_ = xa_.unsqueeze(1)
            # xb_ = xb_.unsqueeze(1)
            # xc_ = xc_.unsqueeze(1)
            # xd_ = xd_.unsqueeze(1)
            
            K = (-((xa_ - xc_) ** 2 + (xb_ - xd_) ** 2) / 2).sum(-1).exp() - (-((xa_ - xd_) ** 2 + (xb_ - xc_) ** 2) / 2).sum(-1).exp()
            #K = torch.sum((xa_ * xc_ + 1)**2, dim = 1) * torch.sum((xb_ * xd_ + 1)**2, dim = 1) - torch.sum((xa_ * xd_ + 1)**2, dim = 1) * torch.sum((xb_ * xc_ + 1)**2, dim = 1)
            #K = torch.sum(xa_ * xc_, dim = 1) * torch.sum(xb_ * xd_, dim = 1) - torch.sum(xa_ * xd_, dim = 1) * torch.sum(xb_ * xc_, dim = 1)
        
            return K
            
        
        K = (-(pairwise_distances(xa_, xc_) + pairwise_distances(xb_, xd_)) / 2).exp() - (-(pairwise_distances(xa_, xd_) + pairwise_distances(xb_, xc_)) / 2).exp()
        
        #K = ((torch.matmul(xa, xc.t()) + 1) ** 2) * ((torch.matmul(xb, xd.t()) + 1) ** 2) - ((torch.matmul(xa, xd.t()) + 1) ** 2) * ((torch.matmul(xb, xc.t()) + 1) ** 2)

        #K = torch.mul(torch.matmul(xa_ , xc_.t()), torch.matmul(xb_ , xd_.t())) - torch.mul(torch.matmul(xa_ , xd_.t()), torch.matmul(xb_ , xc_.t()))
                
        return K

    def detach(self):
        # Clones the class with detached hyperparameters
        return diffrentiable_FALKON_GPGP(
            lengthscale=self.lengthscale.detach(),
            options=self.params
        )

    def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:
        raise NotImplementedError("Sparse not implemented")

class diffrentiable_FALKON_UGPGP(DiffKernel):
    def __init__(self, lengthscale_items,lengthscale_users,user_dim, options):
        # Super-class constructor call. We do not specify core_fn
        # but we must specify the hyperparameter of this kernel (lengthscale)
        super().__init__("diffrentiable_FALKON_UGPGP", options,core_fn=None,
                         lengthscale_items=lengthscale_items,
                         lengthscale_users=lengthscale_users
                         )
        # self.lengthscale = lengthscale_items
        # self.lengthscale_users = lengthscale_users
        self.user_dim = user_dim

    def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):

        ls = self.lengthscale_items#.to(device=X1.device, dtype=X1.dtype)
        ls_u = self.lengthscale_users#".to(device=X1.device, dtype=X1.dtype)
        u_1 = X1[:,:self.user_dim]
        u_2 = X2[:,:self.user_dim]

        xa,xb = torch.chunk(X1[:,self.user_dim:],dim=1,chunks=2)
        xc,xd = torch.chunk(X2[:,self.user_dim:],dim=1,chunks=2)

        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)

        u_1_ = u_1.div(ls_u)
        u_2_ = u_2.div(ls_u)

        if diag:
            K = (-((xa_ - xc_) ** 2 + (xb_ - xd_) ** 2) / 2).sum(-1).exp() - (
                    -((xa_ - xd_) ** 2 + (xb_ - xc_) ** 2) / 2).sum(-1).exp()
            L = K * (-(u_1_ - u_2_) ** 2 ).sum(-1).exp()
            out.copy_(L)
        else:
            K = (-(pairwise_distances(xa_,xc_)+ pairwise_distances(xb_,xd_)) / 2).exp() - (
                    -(pairwise_distances(xa_,xd_) + pairwise_distances(xb_,xc_)) / 2).exp()
            L = K * (-pairwise_distances(u_1_,u_2_) ).exp()
            out.copy_(L)
        return out

    def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):
        # The implementation here is similar to `compute` without in-place operations.
        ls = self.lengthscale_items#.to(device=X1.device, dtype=X1.dtype)
        ls_u = self.lengthscale_users#".to(device=X1.device, dtype=X1.dtype)
        u_1 = X1[:,:self.user_dim]
        u_2 = X2[:,:self.user_dim]

        xa,xb = torch.chunk(X1[:,self.user_dim:],dim=1,chunks=2)
        xc,xd = torch.chunk(X2[:,self.user_dim:],dim=1,chunks=2)

        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)

        u_1_ = u_1.div(ls_u)
        u_2_ = u_2.div(ls_u)


        if diag:

            K = (-((xa_ - xc_) ** 2 + (xb_ - xd_) ** 2) / 2).sum(-1).exp() - (
                    -((xa_ - xd_) ** 2 + (xb_ - xc_) ** 2) / 2).sum(-1).exp()
            L = K * (-(u_1_ - u_2_) ** 2).sum(-1).exp()
            return L
        else:
            K = (-(pairwise_distances(xa_, xc_) + pairwise_distances(xb_, xd_)) / 2).exp() - (
                    -(pairwise_distances(xa_, xd_) + pairwise_distances(xb_, xc_)) / 2).exp()
            L = K * (-pairwise_distances(u_1_, u_2_)).exp()
        return L

    def detach(self):
        # Clones the class with detached hyperparameters
        return diffrentiable_FALKON_GPGP(
            lengthscale=self.lengthscale.detach(),
            options=self.params
        )

    def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:
        raise NotImplementedError("Sparse not implemented")


class diffrentiable_FALKON_UPGP(DiffKernel):
    def __init__(self, lengthscale_items,lengthscale_users,user_dim, options):
        # Super-class constructor call. We do not specify core_fn
        # but we must specify the hyperparameter of this kernel (lengthscale)
        super().__init__("diffrentiable_FALKON_UPGP", options,core_fn=None,
                         lengthscale_items=lengthscale_items,
                         lengthscale_users=lengthscale_users
                         )
        # self.lengthscale = lengthscale_items
        # self.lengthscale_users = lengthscale_users
        self.user_dim = user_dim

    def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):

        ls = self.lengthscale_items#.to(device=X1.device, dtype=X1.dtype)
        ls_u = self.lengthscale_users#".to(device=X1.device, dtype=X1.dtype)
        u_1 = X1[:,:self.user_dim]
        u_2 = X2[:,:self.user_dim]

        xa,xb = torch.chunk(X1[:,self.user_dim:],dim=1,chunks=2)
        xc,xd = torch.chunk(X2[:,self.user_dim:],dim=1,chunks=2)

        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)

        u_1_ = u_1.div(ls_u)
        u_2_ = u_2.div(ls_u)

        if diag:
            K = (-(xa_-xc_)**2/2).sum(-1).exp() + (-(xb_-xd_)**2/2).sum(-1).exp() - (-(xa_-xd_)**2/2).sum(-1).exp() - (-(xb_-xc_)**2/2).sum(-1).exp()
            L = K * (-(u_1_ - u_2_) ** 2 ).sum(-1).exp()
            out.copy_(L)
        else:
            K = (-pairwise_distances(xa_,xc_)/2).exp()+(-pairwise_distances(xb_,xd_)/2).exp()-(-pairwise_distances(xa_,xd_)/2).exp()-(-pairwise_distances(xb_,xc_)/2).exp()
            L = K * (-pairwise_distances(u_1_,u_2_) ).exp()
            out.copy_(L)
        return out

    def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):
        # The implementation here is similar to `compute` without in-place operations.
        ls = self.lengthscale_items#.to(device=X1.device, dtype=X1.dtype)
        ls_u = self.lengthscale_users#".to(device=X1.device, dtype=X1.dtype)
        u_1 = X1[:,:self.user_dim]
        u_2 = X2[:,:self.user_dim]

        xa,xb = torch.chunk(X1[:,self.user_dim:],dim=1,chunks=2)
        xc,xd = torch.chunk(X2[:,self.user_dim:],dim=1,chunks=2)

        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)

        u_1_ = u_1.div(ls_u)
        u_2_ = u_2.div(ls_u)


        if diag:
            K = (-(xa_-xc_)**2/2).sum(-1).exp() + (-(xb_-xd_)**2/2).sum(-1).exp() - (-(xa_-xd_)**2/2).sum(-1).exp() - (-(xb_-xc_)**2/2).sum(-1).exp()
            L = K * (-(u_1_ - u_2_) ** 2).sum(-1).exp()
            return L
        else:
            K = (-pairwise_distances(xa_,xc_)/2).exp()+(-pairwise_distances(xb_,xd_)/2).exp()-(-pairwise_distances(xa_,xd_)/2).exp()-(-pairwise_distances(xb_,xc_)/2).exp()
            L = K * (-pairwise_distances(u_1_, u_2_)).exp()
        return L

    def detach(self):
        # Clones the class with detached hyperparameters
        return diffrentiable_FALKON_GPGP(
            lengthscale=self.lengthscale.detach(),
            options=self.params
        )

    def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:
        raise NotImplementedError("Sparse not implemented")


class diffrentiable_FALKON_PGP(DiffKernel):
    def __init__(self, lengthscale, options):
        # Super-class constructor call. We do not specify core_fn
        # but we must specify the hyperparameter of this kernel (lengthscale)
        super().__init__("diffrentiable_FALKON_PGP", options,core_fn=None,
                         lengthscale=lengthscale)
        # super().__init__("diffrentiable_FALKON_GPGP", options)
        # self.lengthscale = lengthscale

    def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):
        ls = self.lengthscale
        xa,xb = torch.chunk(X1,dim=1,chunks=2)
        xc,xd = torch.chunk(X2,dim=1,chunks=2)
        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)

        if diag:
            K = (-(xa_-xc_)**2/2).sum(-1).exp() + (-(xb_-xd_)**2/2).sum(-1).exp() - (-(xa_-xd_)**2/2).sum(-1).exp() - (-(xb_-xc_)**2/2).sum(-1).exp()

            out.copy_(K)
        else:
            K = (-pairwise_distances(xa_,xc_)/2).exp()+(-pairwise_distances(xb_,xd_)/2).exp()-(-pairwise_distances(xa_,xd_)/2).exp()-(-pairwise_distances(xb_,xc_)/2).exp()

            out.copy_(K)
        return out

    def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):
        ls = self.lengthscale
        xa,xb = torch.chunk(X1,dim=1,chunks=2)
        xc,xd = torch.chunk(X2,dim=1,chunks=2)
        xa_ = xa.div(ls)
        xb_ = xb.div(ls)
        xc_ = xc.div(ls)
        xd_ = xd.div(ls)
        if diag:
            K = (-(xa_-xc_)**2/2).sum(-1).exp() + (-(xb_-xd_)**2/2).sum(-1).exp() - (-(xa_-xd_)**2/2).sum(-1).exp() - (-(xb_-xc_)**2/2).sum(-1).exp()
            return K
        K = (-pairwise_distances(xa_, xc_) / 2).exp() + (-pairwise_distances(xb_, xd_) / 2).exp() - (
                    -pairwise_distances(xa_, xd_) / 2).exp() - (-pairwise_distances(xb_, xc_) / 2).exp()

        return K

    def detach(self):
        # Clones the class with detached hyperparameters
        return diffrentiable_FALKON_PGP(
            lengthscale=self.lengthscale.detach(),
            options=self.params
        )

    def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:
        raise NotImplementedError("Sparse not implemented")

