#%%
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.datasets import make_blobs, make_moons
import matplotlib.pyplot as plt
# from sklearn.cluster import MiniBatchKMeans
from sklearn.datasets import fetch_openml
from collections import defaultdict
from time import time
from tqdm import tqdm
from sklearn import metrics
from datetime import datetime
from sklearn.preprocessing import StandardScaler
#import mini batch k-means from sklrean
from sklearn.cluster import MiniBatchKMeans as MiniBatchKMeans_sklearn
import pandas as pd
import time
import os
from itertools import product
from copy import deepcopy
import argparse
from kernel_kmeans import KernelMiniBatchKMeans, initialize_kernel_kmeans_plusplus, BaseSketchKernelKMeans, SubGaussianKernelKMeans, ROSKernelKMeans, NystromKernelKMeans
from kmeans import MiniBatchKMeans
from weighted_tmbkkm  import WTKernelMiniBatchKMeans
import json
from sklearn.metrics import accuracy_score
from scipy.optimize import linear_sum_assignment
import torch

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def clustering_accuracy(y_true, y_pred):
    y_true = y_true.astype(np.int64)
    # Build the confusion matrix
    D = max(y_pred.max(), y_true.max()) + 1
    confusion_matrix = np.zeros((D, D), dtype=np.int64)
    for i in range(len(y_pred)):
        confusion_matrix[y_pred[i], y_true[i]] += 1

    # Hungarian algorithm to find the best label mapping
    row_ind, col_ind = linear_sum_assignment(-confusion_matrix)

    # Calculate accuracy
    return confusion_matrix[row_ind, col_ind].sum() / len(y_pred)


def compute_normalized_laplacian_inplace(adjacency_matrix, device='cpu'):
    """
    Computes the normalized Laplacian matrix for a given graph using efficient in-place operations.

    :param adjacency_matrix: Adjacency matrix of the graph
    :param device: 'cpu' or 'cuda'
    :return: Normalized Laplacian matrix (updated in place in the adjacency_matrix)
    """
    # Convert adjacency matrix to a PyTorch tensor and move to the device
    with torch.no_grad():
        A = torch.tensor(adjacency_matrix, dtype=torch.float32, device=device)
        
        # Compute the degree vector
        D = torch.sum(A, dim=1)
        
        # Compute D^(-1/2) in-place using broadcasting
        D_inv_sqrt = torch.pow(D, -0.5)
        D_inv_sqrt[D_inv_sqrt == float('inf')] = 0  # Handle divide by zero
        
        # Normalize adjacency matrix in-place using broadcasting
        A *= D_inv_sqrt[:, None]
        A *= D_inv_sqrt[None, :]
        
        # Subtract the normalized adjacency matrix from the identity matrix to get the Laplacian
        A = -A  # Negate the matrix
        A.diagonal().add_(1)  # Add 1 to the diagonal elements in-place
        
        return A

def compute_heat_kernel_normalized_inplace(adjacency_matrix, t, device='cpu'):
    """
    Computes the exact heat kernel matrix using the normalized Laplacian for a given graph and diffusion time t.

    :param adjacency_matrix: Adjacency matrix of the graph
    :param t: Diffusion time
    :param device: 'cpu' or 'cuda'
    :return: Heat kernel matrix
    """
    with torch.no_grad():
        # Compute the normalized Laplacian matrix in place
        L_norm = compute_normalized_laplacian_inplace(adjacency_matrix, device)
        # set dype to float
        L_norm = L_norm.to(torch.float32)
        
        # drop the adjacency matrix
        del adjacency_matrix


        # Compute the heat kernel matrix using the matrix exponential
        heat_kernel_matrix = torch.matrix_exp(-t * L_norm)
        
        return heat_kernel_matrix.cpu().numpy()


# Extract knn graph from dense kernel matrix
def construct_knn_graph(X, k, device):
    """
    Constructs a k-nearest neighbors graph as a dense matrix with 0,1 entries,
    including self-distances, using NumPy arrays.

    Parameters:
    X (np.ndarray): Input data array of shape (n_samples, n_features).
    k (int): Number of nearest neighbors to consider, including self-distance.
    device (str): Device to perform computation ('cuda' for GPU or 'cpu').

    Returns:
    np.ndarray: Adjacency matrix of the k-nearest neighbors graph.
    """
    # Convert the NumPy array to a PyTorch tensor and move it to the specified device
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    
    # Compute pairwise distances
    distances = torch.cdist(X_tensor, X_tensor)

    assert distances.size() == (X.shape[0], X.shape[0])
    
    # Find the indices of the k nearest neighbors (including self-distances)
    _, indices = torch.topk(distances, k=k, largest=False)
    
    del distances
    # Initialize the adjacency matrix on the specified device
    n_samples = X_tensor.size(0)
    torch.cuda.empty_cache()
    adjacency_matrix = torch.zeros(n_samples, n_samples, dtype=torch.float32, device=device)
    
    # Fill the adjacency matrix with 1s for k-nearest neighbors, including self-distances
    adjacency_matrix.scatter_(1, indices, 1)
    
    # Convert the adjacency matrix back to a NumPy array
    adjacency_matrix_numpy = adjacency_matrix.cpu().numpy()
    
    # assert adjacency_matrix_numpy.sum(axis=1).min() == k
    # assert adjacency_matrix_numpy.sum(axis=1).max() == k

    return adjacency_matrix_numpy

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')


# MARK: -Load and preprocess data
def load_and_preprocess_data(dataset, n_sample=None):
    if dataset.endswith("_knn"):
        dataset = dataset[:-4]
    if dataset.endswith("_heat"):
        dataset = dataset[:-5]
    try:
        X,y = fetch_openml(name=dataset, version=1, as_frame=False, return_X_y=True,data_home="data/",cache=True,parser="auto")
    except Exception as e:
        print(e)
        X= None
    
    normalize = False
    knn = 2000
    heat_knn = 250
    if dataset == "pendigits":
        kernel_sigma = 2.23606797749979#3.5#5.6496
        normalize = True
        knn = 1000
        heat_knn = 500

    if dataset == "har":
        #kernel_sigma = 5.5819
        kernel_sigma = 10.414010843544517 #unnormalized
        kernel_sigma = 4.25
        knn = 2000
        heat_knn = 500
        
    if dataset == "moons":
        kernel_sigma = 1.3369
        #generate data set
        X, y = make_moons(n_samples=5000, noise=0.1, random_state=42)

    if dataset == "mnist_784":
        kernel_sigma = 5.8301
        X = X/255
        normalize = False
        knn = 2000
        heat_knn = 200

    if dataset == "EMNIST_Balanced":
        #kernel_sigma = 11.510648727416992 #wang
        kernel_sigma = 3.5 #good for n=100000
        X = X/255
        normalize = False

    if dataset == "letter":
        kernel_sigma = 2.23606797749979#2.3351
        kernel_sigma = 3.0
        normalize = True
        knn = 1000
        heat_knn = 500

        # convert y from letters to numbers
        unique_labels = np.unique(y)
        label_map = {label: i for i, label in enumerate(unique_labels)}
        y = np.array([label_map[label] for label in y])
        
    
    if n_sample is not None:
        shuffle = np.random.permutation(X.shape[0])
        X = X[shuffle]
        y = y[shuffle]
        X = X[:min(X.shape[0],n_sample)]
        y = y[:min(X.shape[0],n_sample)]


    if normalize:
        X = StandardScaler().fit_transform(X)
    return X,y, kernel_sigma, knn, heat_knn


def initialize_kmeans_plusplus(X, k, rng):
    n_samples, n_features = X.shape
    centroids = np.empty((k, n_features), dtype=X.dtype)

    # Step 1: Randomly choose the first centroid using the provided rng
    centroids[0] = X[rng.integers(n_samples)]

    # Step 2 and 3: Choose the remaining centroids
    for c in range(1, k):
        squared_diff = np.square(X[:, np.newaxis, :] - centroids[np.newaxis, :c, :])
        distances = np.min(squared_diff, axis=1)
        total_distances = np.sum(distances, axis=1)
        probabilities = total_distances / np.sum(total_distances)
        centroids[c] = X[rng.choice(n_samples, p=probabilities)]

    return centroids


# MARK: -Evaluation
def evaluate(kms, X,K, labels, num_iters, n_clusters, batch_size, tau, n_runs=50,kernel_construction_time=0, K_graph=None,W=None):
    
    evaluations = []
                    
    for name, km in kms.items():
        train_times = []
        print(f"Evaluating {name}")
        scores = defaultdict(list)
        for seed in tqdm(range(n_runs)):
            rng = np.random.default_rng(seed) #set random seed
            
        
            km.random_state=seed
            if isinstance(km, KernelMiniBatchKMeans) or isinstance(km, WTKernelMiniBatchKMeans):  
                if isinstance(km, WTKernelMiniBatchKMeans) and K_graph is not None and W is not None:
                    t0 = time.time()
                    init_labels, init_distances_squared, init_C = initialize_kernel_kmeans_plusplus(K_graph,n_clusters,rng)
                    km.fit(X, K_graph, init_labels, init_distances_squared, init_C, W)
                else:                  
                    t0 = time.time()
                    # assert K has no nans:
                    assert not np.isnan(K).any()
                    init_labels, init_distances_squared, init_C = initialize_kernel_kmeans_plusplus(K,n_clusters,rng)
                    km.fit(X, K, init_labels, init_distances_squared, init_C)
            elif isinstance(km, BaseSketchKernelKMeans):
                t0 = time.time()
                km.fit(K)
            else:
                t0 = time.time()
                init_centroids = None
                if not isinstance(km, MiniBatchKMeans_sklearn):
                    #second param for sklrean mbk is simply ignored. so it does the initialization anyway.
                    init_centroids = initialize_kmeans_plusplus(X,n_clusters,rng)
                km.fit(X, init_centroids)
            
            # include the time it took to construct the kernel matrix in the training time
            # if (isinstance(km, KernelMiniBatchKMeans) or isinstance(km, WTKernelMiniBatchKMeans)) and km.lazy == False:
            #     train_times.append(time.time() - t0 + kernel_construction_time)
            # else:
            #     train_times.append(time.time() - t0)
            train_times.append(time.time() - t0)
            scores["NMI"].append(metrics.normalized_mutual_info_score(labels, km.labels_))
            #scores["Homogeneity"].append(metrics.homogeneity_score(labels, km.labels_))
            #scores["Completeness"].append(metrics.completeness_score(labels, km.labels_))
            #scores["V-measure"].append(metrics.v_measure_score(labels, km.labels_))
            scores["ARI"].append(metrics.adjusted_rand_score(labels, km.labels_))
            #scores["Silhouette Coefficient"].append(metrics.silhouette_score(X, km.labels_, sample_size=2000))
            scores["ACC"].append(clustering_accuracy(labels, km.labels_))
    
        train_times = np.asarray(train_times)

        evaluation = {
            "estimator": name,
            "num_iters": num_iters,
            "n_clusters": n_clusters,
            "batch_size": batch_size,
            "tau": tau,
            "train_time_mean": train_times.mean(),
            "train_time_std": train_times.std()
        }

        for score_name, score_values in scores.items():
            mean_score, std_score = np.mean(score_values), np.std(score_values)
            evaluation[score_name + "_mean"] = mean_score
            evaluation[score_name + "_std"] = std_score
    
        evaluations.append(evaluation)

        print(f"\n {name}, num_iters: {num_iters}, n_clusters: {n_clusters}, batch size: {batch_size}, tau: {tau}")
        for score_name, score_values in scores.items():
            mean_score, std_score = np.mean(score_values), np.std(score_values)
            print(f"{score_name}: {mean_score:.3f} ± {std_score:.3f}")
        

    return evaluations



def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

def wang_heuristic_gpu(X, sample=None, **kwargs):
    device = get_device()
    print(f"Using device: {device}")

    # Convert numpy array to torch tensor and move to GPU
    X_gpu = torch.tensor(X, device=device, dtype=torch.float32)
    
    indices = range(X_gpu.shape[0])
    if sample is not None:
        indices = np.random.choice(indices, sample, replace=False)

    n = len(indices)

    # Compute the sum of all pairwise distances squared from the sample to the rest of the data
    data_gpu = X_gpu[indices]

    Y_gpu = torch.zeros((n, 1), device=device)
    for i in tqdm(range(n)):
        Y_gpu[i] = torch.mean(torch.sum((X_gpu - data_gpu[i])**2, axis=1))

    sigma = torch.sqrt(torch.mean(Y_gpu))

    # Convert result back to numpy
    sigma_cpu = sigma.cpu().item()

    return sigma_cpu


colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#9467bd',  # muted purple
    '#17becf',   # blue-teal
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#8c564b',  # chestnut brown
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green

    # Add more as needed
]

hatches = [ 
    "x","","//","//","","//"
]

hatches_map = {
    "Kernel": "x",
    "MiniBatch Kernel": "",
    "$\\beta$-MiniBatch": "//",
    "Truncated $\\beta$-MiniBatch Kernel": "//",
    "MiniBatch": "",
    "$\\beta$-MiniBatch Kernel": "//",
    "ROS Kernel": "",
    "Nystrom": "",
    "SubGaussian": ""
}

colors_map = {
    "Kernel": colors[0],
    "Truncated $\\beta$-MiniBatch Kernel": colors[3],
    "$\\beta$-MiniBatch Kernel": colors[1],
    "MiniBatch Kernel": colors[1],
    "$\\beta$-MiniBatch": colors[2],
    "MiniBatch": colors[2],
    "ROS Kernel": colors[-3],
    "Nystrom": colors[-2],
    "SubGaussian": colors[-1]
}

def plot_results(to_plot, dataset_sizes):

    plt.rcParams.update({'font.size': 24}) 
    
    # Extract unique batch sizes and taus for the grid
    batch_sizes = sorted(to_plot[0]['batch_size'].unique())
    taus = sorted(to_plot[0]['tau'].unique())
    num_batches = len(batch_sizes)
    num_taus = len(taus)


    # Now plot the figure that will go in the body of the paper. This will
    # be one plot for batch size 1024 and tau 200 with one column per dataset.
    # It will not include the _knn or _heat datasets.

    # Create a figure with a single row and as many columns as datasets
    datasets = [df['dataset'].iloc[0] for df in to_plot]
    datasets = [d for d in datasets if not d.endswith("_knn") and not d.endswith("_heat")]
    num_datasets = len(datasets)

    fig, axes = plt.subplots(2, 2, figsize=(20,15),sharey=True)

    axes = axes.flatten() if num_datasets > 1 else [axes]

    # Adjust the title size and space below the title
    # fig.suptitle("Results for Batch size 1024 and $\\tau$ 100", fontsize=40, y=0.90)

    # Iterate over datasets that are not _knn or _heat

    for i, data_set_group in enumerate([d for d in to_plot if (not d['dataset'].iloc[0].endswith("_knn")) and (not d['dataset'].iloc[0].endswith("_heat"))]):
        print(i, data_set_group)
        
        dataset_name = data_set_group['dataset'].iloc[0]

        df_tau_bs = data_set_group[(data_set_group['tau'] == 100) & (data_set_group['batch_size'] == 1024)]


        ax1, ax2 = plot_results_bars(df_tau_bs, axes[i], dataset_sizes,  i == 0)

        # Label y-axes
        if i == 0 or i == 2:
            ax1.set_ylabel('Score')
        if i == 1 or i == 3: 
            ax2.set_ylabel('Time (s) log scale')
    
        ax1.set_title(f"{dataset_name}")

    # Position the legend outside of the plotting area, reducing space between the plots and legend
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, 0.02), ncol=3, fontsize=24, frameon=False)

    # Adjust layout: Reduce space between title and plots, and between plots and the legend
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust rect to control the space left for title/legend

    current_datetime = datetime.now()
    formatted_datetime = current_datetime.strftime('%Y_%m_%d__%H_%M_%S')

    # Save the figure for each dataset
    plt.savefig(f"results/paper_body_results.png", bbox_inches='tight')
    plt.close()



    
    print(f"num batches: {num_batches}, num taus: {num_taus}")
    
    # Iterate over datasets
    for dataset_group in to_plot:
        dataset_name = dataset_group['dataset'].iloc[0]
        
        # Create a figure for each dataset
        fig, axes = plt.subplots(num_taus, num_batches, figsize=(7*num_batches, 6*num_taus),sharey=True)
        
        # Adjust the title size and space below the title
        fig.suptitle(f"Dataset: {dataset_name}", fontsize=40, y=0.99)


        maximum_time = dataset_group['train_time_mean'].max()*1.08

        # Iterate over taus and batch sizes
        for i, tau in enumerate(taus):
            for j, batch_size in enumerate(batch_sizes):
                df_tau_bs = dataset_group[(dataset_group['tau'] == tau) & (dataset_group['batch_size'] == batch_size)]

                if num_taus == 1 and num_batches == 1:
                    ax = axes
                elif num_taus == 1:
                    ax = axes[j]
                elif num_batches == 1:
                    ax = axes[i]
                else:
                    ax = axes[i, j]

                ax1, ax2 = plot_results_bars(df_tau_bs, ax, dataset_sizes,  i == 0 and j == 0)

                # Add batch size label on the top of each column
                if i == 0:
                    ax.set_title(f"Batch size: {batch_size}")
                
                # Add tau label to first column of each row with increased font size and vertically aligned
                if j == 0:
                    # Increase font size of Tau labels
                    ax1.text(-0.33, 0.5, f"$\\tau$: {tau}", va='center', ha='center', 
                             rotation=0, transform=ax1.transAxes, fontsize=26)

                # Label y-axes
                if j == 0:
                    ax1.set_ylabel('Score')
                if j == num_batches - 1:  # Apply time label to all rows
                    ax2.set_ylabel('Time (s) log scale')

                ax2.set_ylim(0.1, maximum_time)

        # Position the legend outside of the plotting area, reducing space between the plots and legend
        fig.legend(loc='upper center', bbox_to_anchor=(0.5, 0.0), ncol=5, fontsize=28, frameon=False)

        # Adjust layout: Reduce space between title and plots, and between plots and the legend
        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust rect to control the space left for title/legend

        current_datetime = datetime.now()
        formatted_datetime = current_datetime.strftime('%Y_%m_%d__%H_%M_%S')

        # Save the figure for each dataset
        plt.savefig(f"results/{dataset_name}_results.png", bbox_inches='tight')
        plt.close()








def plot_results_bars(df, ax1, dataset_sizes, set_labels=True):

    print(df)

    metric_names = ["ARI", "NMI", "ACC"]
    time_metric = "train_time"
    
    # Sort estimators explicitly
    estimators_ordered = df['estimator'][::-1]
    estimators_ordered = [
        "1.Kernel",
        "9.Truncated $\\beta$-MiniBatch Kernel",
        "2.$\\beta$-MiniBatch Kernel",
        "3.MiniBatch Kernel",
        "4.$\\beta$-MiniBatch",
        "5.MiniBatch",
        "6.ROS Kernel",
        "7.Nystrom",
        "8.SubGaussian",
    ]
    
    ax2 = ax1.twinx()  # Create a second y-axis to plot the train_time
    n_metrics = len(metric_names) + 1  # Including train_time
    
    bar_width = 0.4
    positions = np.arange(n_metrics) * (len(estimators_ordered) * bar_width + 0.5)
    positions[-1] += 0.1

    df_comb = df
    ax2.set_yscale('log')

    # Iterate over metrics and estimators in the same order
    for i, metric in enumerate(metric_names + [time_metric]):
        metric_mean = metric + "_mean"
        metric_std = metric + "_std"
        
        for j, name in enumerate(estimators_ordered):

            # skip if estimator is not in the dataframe
            if name not in df_comb['estimator'].values:
                continue

            position = positions[i] + j * bar_width - 0.5
            ax = ax1
            if metric == time_metric:
                ax = ax2
            
            alg_name = name[2:]
            kernel_construction_time = 0

            # Add kernel construction time only for relevant methods
            if alg_name in ["Truncated $\\beta$-MiniBatch Kernel",
                            "$\\beta$-MiniBatch Kernel", 
                            "Kernel",
                            "MiniBatch Kernel",
                            "SubGaussian",
                            "ROS Kernel",
                            "Nystrom"] and metric == "train_time":
                kernel_construction_time = df["kernel_construction_time"].iloc[0]


            # if using a sketching method, multiply the kernel construction time by 150/n where n is the number of datapoints
            dataset = df["dataset"].iloc[0]
            d_flag = True
            if dataset[-4:] == "_knn":
                dataset = dataset[:-4]
                d_flag = False
            if dataset[-5:] == "_heat":
                dataset = dataset[:-5]
                d_flag = False


            if alg_name in ["SubGaussian", "Nystrom", "ROS Kernel"] and d_flag:
                kernel_construction_time = kernel_construction_time * 150 / dataset_sizes[dataset]

            # Plot the kernel construction time (black part of the bars)
            if kernel_construction_time > 0:
                ax.bar(position, kernel_construction_time, bar_width, color="black", edgecolor='black', linewidth=1)

            # Plot the main bar with color and hatch
            hatch = hatches_map[alg_name]
            color = colors_map[alg_name]
            bar_height = df_comb[df_comb['estimator'] == name][metric_mean].iloc[0]


            if alg_name == "SubGaussian":
                alg_name = "Kernel sketching (Sub-Gaussian)"
            if alg_name == "Nystrom":
                alg_name = "Kernel sketching (Nystrom)"
            if alg_name == "ROS Kernel":
                alg_name = "Kernel sketching (ROS)"

            ax.bar(position, bar_height, bar_width, color=color, 
                   label=(alg_name) if i == 0 and set_labels else None, 
                   yerr=df_comb[df_comb['estimator'] == name][metric_std].iloc[0],
                   capsize=5, hatch=hatch, bottom=kernel_construction_time, edgecolor='black', linewidth=1)

            # # Add text annotation above bars
            # total_height = bar_height + kernel_construction_time
            # ax.text(position, total_height + df_comb[df_comb['estimator'] == name][metric_std].iloc[0] + 0.01,  
            #         f'{total_height:.2f}', ha='center', va='bottom', fontsize=10)
            

    # Set x-ticks and labels
    ax1.set_xticks(positions + bar_width / 2)
    ax1.set_xticklabels(metric_names + ["time"])
    
    return ax1, ax2

    
#%%
# MARK: -Main



def gpu_pairwise_kernels_batch(X, Y=None, metric='rbf', gamma=None, batch_size=1024):
    """
    Compute the pairwise kernel on GPU using PyTorch in batches.
    
    Args:
        X (np.ndarray): Input data array.
        Y (np.ndarray or None): Optional second input data array. If None, compute the kernel with X itself.
        metric (str): The kernel metric to use. Currently only 'rbf' is supported.
        gamma (float or None): Kernel coefficient for 'rbf'. If None, it defaults to 1 / n_features.
        batch_size (int): The batch size for GPU computation.

    Returns:
        np.ndarray: Kernel matrix.
    """
    device = torch.device('mps') if torch.backends.mps.is_available() else ( torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu'))

    with torch.no_grad():

        X = torch.tensor(X, dtype=torch.float32, device=device)
        if Y is not None:
            Y = torch.tensor(Y, device=device, dtype=torch.float32)
        if Y is None:
            Y = X

        if metric != 'rbf':
            raise ValueError(f"Unsupported metric '{metric}'. Currently only 'rbf' is supported.")
        
        if gamma is None:
            gamma = 1.0 / X.shape[1]

        n_samples_X = X.shape[0]
        n_samples_Y = Y.shape[0]

        # Initialize the kernel matrix
        K = np.zeros((n_samples_X, n_samples_Y), dtype=np.float32)

        # # initialize K on the GPU
        # K = torch.tensor(K, device=device)

        # Process in batches
        for i in range(0, n_samples_X, batch_size):
            end_i = min(i + batch_size, n_samples_X)
            X_batch = X[i:end_i]

            X_norm = (X_batch ** 2).sum(axis=-1).view(-1, 1)
            for j in range(0, n_samples_Y, batch_size):
                end_j = min(j + batch_size, n_samples_Y)
                Y_batch = Y[j:end_j]

                Y_norm = (Y_batch ** 2).sum(axis=-1).view(1, -1)
                K_batch = torch.exp(-gamma * (X_norm + Y_norm - 2 * torch.mm(X_batch, Y_batch.T)))

                K[i:end_i, j:end_j] = K_batch.cpu().numpy()
        
        # Copy the kernel matrix back to the CPU
        # K = K.cpu().numpy()
        # clean up
        del X
        del Y
        # torch.cuda.empty_cache() if device.type == 'cuda' else None

        return K



if __name__ == "__main__":

    # read arguments using argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("mode", type=str, help="Mode of operation: 'run' or 'plot'", choices=["run","plot"])
    
    # add an optional list of result files to plot
    parser.add_argument("-f", "--files", nargs="+", help="List of result files to plot", default=[])

    args = parser.parse_args()

    mode = args.mode
    result_files = args.files
    # mode ="run"
    #mode = "plot"

    # make parameters global

    n_runs = 10
    n_iters = [200]
    # Define parameter ranges
    batch_size_values = [256,512,1024,2048]
    taus = [50, 100, 200, 300]

    # n_runs = 1
    # n_iters = [200]
    # # Define parameter ranges
    # batch_size_values = [1024]
    # taus = [200]

    dataset_gammas ={}

    if mode == "run":

        skip_full = False
        dataset_names = [
             "pendigits",
              "har",
             "mnist_784",
             "letter"
            ]

        # add knn flag to dataset names
        dataset_names_with_knn_flag = []
        for dataset_name in dataset_names:
            dataset_names_with_knn_flag.append(dataset_name + "_knn")
            dataset_names_with_knn_flag.append(dataset_name+ "_heat")
            dataset_names_with_knn_flag.append(dataset_name)
        
        dataset_names = dataset_names_with_knn_flag

        
        for dataset_name in dataset_names:
            n = None
            X, Y, kernel_sigma, knn, heat_knn = load_and_preprocess_data(dataset_name, n_sample = n)
            if n is None:
                n = X.shape[0]
            
            print(dataset_name)
            print("dataset size", X.shape, "kernel sigma", kernel_sigma)
            num_clusters = np.unique(Y).shape[0]
            print(f"num clusters: {num_clusters}")

            n_clusters = np.unique(Y).shape[0]
            n_clusters_values = [n_clusters]
            
            evaluations = []
            current_datetime = datetime.now()

            # Format the date and time
            formatted_datetime = current_datetime.strftime('%Y_%m_%d__%H_%M_%S')
            print(formatted_datetime)
            print(f"dataset: {dataset_name}")
            evaluations = []
            gamma =1/(2*kernel_sigma**2)
            
            # time how long it takes to compute the kernel matrix 
            K = None
            kernel_construction_time = 0.0
        
            if dataset_name.endswith("_heat"):
                t0 = time.time()
                # downsample mnist to fit in memory
                if dataset_name == "mnist_784_heat":
                    n = 20000
                    X = X[:n]
                    Y = Y[:n]
                    n = X.shape[0]
                K = construct_knn_graph(X, heat_knn, get_device())
                torch.cuda.empty_cache()
                K = compute_heat_kernel_normalized_inplace(K, t=8.0, device=get_device())
                torch.cuda.empty_cache()
                K = (K + K.T) / 2

                kernel_construction_time = time.time() - t0

            elif dataset_name.endswith("_knn"):
                t0 = time.time()
                K = construct_knn_graph(X, knn, get_device())
                # Construct the k-nearest neighbour graph using D^(-1)AD^(-1) normalization with shift
                D_inv = 1 / K.sum(axis=1)
                K *= D_inv[:, None]
                K *= D_inv[None, :]


                # subtract a small value to the diagonal to avoid numerical issues for mnist_784_knn
                if dataset_name == "mnist_784_knn":
                    K[np.diag_indices_from(K)] -= 0.000005
                kernel_construction_time =  time.time() - t0
            else:
                t0 = time.time()
                K = gpu_pairwise_kernels_batch(X, metric='rbf', gamma=gamma, batch_size=7000)
                kernel_construction_time = time.time() - t0
                print(f"Computed kernel matrix in {kernel_construction_time:.3f} s")


            kernel_gamma = np.sqrt((K.diagonal().max())).item()
            dataset_gammas[dataset_name] = kernel_gamma
            print(f"Kernel gamma: {kernel_gamma}")

            for num_iters, n_clusters, batch_size in product(n_iters, n_clusters_values, batch_size_values):
                print("#"*20)
                mbkk_newlr = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=True)
                mbkk_oldlr = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=False)
                mbk_newlr = MiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, dynamic_batching=False, new_lr=True)
                mbk_oldlr = MiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, dynamic_batching=False, new_lr=False)
                nykkm = NystromKernelKMeans(n_clusters=n_clusters, n_iterations=num_iters)
                sgkkm = SubGaussianKernelKMeans(n_clusters=n_clusters, n_iterations=num_iters)
                roskkm = ROSKernelKMeans(n_clusters=n_clusters, n_iterations=num_iters)
                for tau in taus:
                    wtmbkk_newlr = WTKernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, tau=tau, lazy=False)
                    evaluations += evaluate({"9.Truncated $\\beta$-MiniBatch Kernel": wtmbkk_newlr},X,K, Y, num_iters, n_clusters, batch_size,tau, n_runs, kernel_construction_time, K_graph=None, W=None)

                mbks = {
                        "4.$\\beta$-MiniBatch": mbk_newlr,
                        "5.MiniBatch": mbk_oldlr,
                        "2.$\\beta$-MiniBatch Kernel": mbkk_newlr,
                        "3.MiniBatch Kernel": mbkk_oldlr,
                        "6.ROS Kernel": roskkm,
                        "7.Nystrom": nykkm,
                        "8.SubGaussian": sgkkm,
                        }
                
                temp_evals = evaluate(mbks,X,K, Y, num_iters, n_clusters, batch_size,taus[0], n_runs, kernel_construction_time, K_graph=None, W=None)
                
                for eval in temp_evals:
                    for tau in taus:
                        eval1 = deepcopy(eval)
                        eval1["tau"] = tau
                        evaluations+= [eval1]


            
            if not skip_full:      
                kk = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=True, full_batch=True)
                kk_eval=evaluate({"1.Kernel": kk},X,K, Y, 200, n_clusters, batch_size=None, tau=None, n_runs=3, kernel_construction_time=kernel_construction_time)
                
                # duplicate the kernel results for all batch sizes and taus (since they are not relevant)
                for b in batch_size_values:
                    kk_eval1 = deepcopy(kk_eval)
                    kk_eval1[0]["batch_size"] = b
                    for tau in taus:
                        kk_eval2 = deepcopy(kk_eval1)
                        kk_eval2[0]["tau"] = tau
                        evaluations+=kk_eval2
            
            # Convert evaluations to DataFrame
            df = pd.DataFrame(evaluations)
            metric_names = ["Homogeneity", "Completeness", "V-measure", "ARI", "Silhouette Coefficient", "NMI", "ACC"]
            param_vals = {'num_iters': n_iters, 'n_clusters': n_clusters_values, 'batch_size': batch_size_values, 'n_runs': n_runs, 'n': n, "tau": tau}
            

            if not os.path.exists("results"):
                os.makedirs("results")


            #add dataset name to df
            df['dataset'] = dataset_name
            #add kernel construction time to df
            df['kernel_construction_time'] = kernel_construction_time
            
            if result_files != []:
                path = os.path.join("results",f"{result_files[0]}")
            else:
                path = os.path.join("results",f"{dataset_name}_{formatted_datetime}_results.csv")
            df.to_csv(path, index=False)
        
        # save the gammas for each dataset to a json file
        with open("results/dataset_gammas.json", "w") as f:
            json.dump(dataset_gammas, f)

    elif mode == "plot":
        # Set the directory for results

        datasets = ["pendigits", "har", "mnist_784", "letter"]

        dataset_sizes = {
            d: load_and_preprocess_data(d, n_sample=None)[0].shape[0] for d in datasets
        }



        filepaths = []
        if result_files != []:
            filepaths = result_files
        else:
            directory = "results/"
            for filename in os.listdir(directory):
                if filename.endswith(".csv"):
                    filepaths.append(directory + filename)


        to_plot = []
        # Assume num_epochs_values contains a single value
        num_iters = n_iters[0]  # Extract the single value of num_epochs

        # Iterate over CSV files in the directory
        for filename in filepaths:
            df = pd.read_csv(filename)
            to_plot.append(df)

        # Prepare a list of dataframes filtered for the single num_epochs value
        to_plot_filtered = []

        for df in to_plot:
            # Filter dataframe by num_iters only
            df_filtered = df[df['num_iters'] == num_iters]

            # If there is data for this num_epochs value, add it to the list
            if not df_filtered.empty:
                to_plot_filtered.append(df_filtered)

        # Only plot if there is data to plot
        if to_plot_filtered:
            plot_results(to_plot_filtered, dataset_sizes)

