import argparse
from collections import defaultdict
import json
import os
import time
import numpy as np
import hdf5storage
import warnings
from sklearn.neighbors import kneighbors_graph
import torch

from tqdm import tqdm
warnings.filterwarnings("ignore", message="Graph is not fully connected, spectral embedding may not work as expected.")
warnings.filterwarnings("ignore", message="Array is not symmetric, and will be converted to symmetric by average with its transpose.")

import sys
sys.path.append("rust/")
import fast_kernel_coreset_sampling as fcp

from sklearn.datasets import fetch_openml


from scipy import sparse
import math

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.datasets import make_blobs
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.metrics.pairwise import pairwise_kernels

import networkx as nx
import graph_tool.all as gt
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*adjacency_matrix will return a scipy.sparse array.*")
from sklearn.preprocessing import StandardScaler
import stag.random

# Friendster: V=65,608,366, E=1,806,067,135, d_avg = 27.5 #communities = 5000+
# LiveJournal: V=3,997,962, E=34,681,189, d_avg = 8.7 #communities = 5000+
# Wikipedia topcats: V=1,791,489, E=28,511,807, d_avg = 15.9 #communities = 17,364


def compute_normalized_cut(adjacency_matrix, labels,k):

    cluster_volumes = np.zeros(k)
    cluster_cuts = np.zeros(k)
    degrees = adjacency_matrix.sum(axis=1).A1
    n = adjacency_matrix.shape[0]
    for i in range(n):
        node_label = labels[i]
        cluster_volumes[node_label] += degrees[i]
        neighbour_indices = adjacency_matrix.indices[adjacency_matrix.indptr[i]:adjacency_matrix.indptr[i+1]]
        neighbour_values = adjacency_matrix.data[adjacency_matrix.indptr[i]:adjacency_matrix.indptr[i+1]]
        neighbour_labels = labels[neighbour_indices]
        mask = np.logical_and(neighbour_labels != node_label, neighbour_indices > i)
        cluster_cuts[node_label] += neighbour_values[mask].sum()
    ncuts = np.where(cluster_volumes > 0, cluster_cuts / cluster_volumes, 0)
    return ncuts.mean()

def gpu_pairwise_kernels_batch(X, Y=None, metric='rbf', gamma=None, batch_size=1024):
    """
    Compute the pairwise kernel on GPU using PyTorch in batches.
    
    Args:
        X (np.ndarray): Input data array.
        Y (np.ndarray or None): Optional second input data array. If None, compute the kernel with X itself.
        metric (str): The kernel metric to use. Currently only 'rbf' is supported.
        gamma (float or None): Kernel coefficient for 'rbf'. If None, it defaults to 1 / n_features.
        batch_size (int): The batch size for GPU computation.

    Returns:
        np.ndarray: Kernel matrix.
    """
    device = torch.device('mps') if torch.backends.mps.is_available() else ( torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu'))

    with torch.no_grad():

        X = torch.tensor(X, dtype=torch.float32, device=device)
        if Y is not None:
            Y = torch.tensor(Y, device=device, dtype=torch.float32)
        if Y is None:
            Y = X

        if metric != 'rbf':
            raise ValueError(f"Unsupported metric '{metric}'. Currently only 'rbf' is supported.")
        
        if gamma is None:
            gamma = 1.0 / X.shape[1]

        n_samples_X = X.shape[0]
        n_samples_Y = Y.shape[0]

        # Initialize the kernel matrix
        K = np.zeros((n_samples_X, n_samples_Y), dtype=np.float32)

        # # initialize K on the GPU
        # K = torch.tensor(K, device=device)

        # Process in batches
        for i in range(0, n_samples_X, batch_size):
            end_i = min(i + batch_size, n_samples_X)
            X_batch = X[i:end_i]

            X_norm = (X_batch ** 2).sum(axis=-1).view(-1, 1)
            for j in range(0, n_samples_Y, batch_size):
                end_j = min(j + batch_size, n_samples_Y)
                Y_batch = Y[j:end_j]

                Y_norm = (Y_batch ** 2).sum(axis=-1).view(1, -1)
                K_batch = torch.exp(-gamma * (X_norm + Y_norm - 2 * torch.mm(X_batch, Y_batch.T)))

                K[i:end_i, j:end_j] = K_batch.cpu().numpy()
        
        # Copy the kernel matrix back to the CPU
        # K = K.cpu().numpy()
        # clean up
        del X
        del Y
        # torch.cuda.empty_cache() if device.type == 'cuda' else None

        return K



def compute_knn_graph_torch_cluster(data, n_neighbors=5, device='cpu'):
    # Convert data to torch tensor and send it to the desired device
    data_tensor = torch.tensor(data, dtype=torch.float32).to(device)
    
    # Calculate pairwise distances using torch
    dist_matrix = torch.cdist(data_tensor, data_tensor, p=2)
    
    # Get the k nearest neighbors using torch.topk
    distances, indices = torch.topk(dist_matrix, n_neighbors, largest=False)
    
    # Move indices and distances back to CPU for further processing
    indices = indices.cpu().numpy()
    distances = distances.cpu().numpy()
    
    # Create an adjacency matrix for the kNN graph
    n_samples = data.shape[0]
    rows, cols, vals = [], [], []
    
    # Build undirected graph with self-loops
    for i in range(n_samples):
        for j in range(n_neighbors):
            rows.append(i)
            cols.append(indices[i, j])
            vals.append(1)
            
            # Symmetrize the graph (make it undirected)
            rows.append(indices[i, j])
            cols.append(i)
            vals.append(1)
            
            # Add self-loop
            if i == indices[i, j]:
                rows.append(i)
                cols.append(i)
                vals.append(1)
    
    # Create sparse matrix (CSC format)
    adjacency_matrix = sparse.csc_matrix((vals, (rows, cols)), shape=(n_samples, n_samples))
    
    return adjacency_matrix


def signless_laplacian_and_D_dense(A):
    D = np.diag(A.sum(axis=1))
    D_inv_half = np.diag(1 / np.sqrt(D.diagonal()))
    L = D - A
    N = D_inv_half @ L @ D_inv_half
    M = np.eye(n) - (0.5)*N
    return M, D

def signless_laplacian_and_D_sparse(A,D=None):
    n = A.shape[0]
    if D is None:
        D = sparse.diags(A.sum(axis=1).A1)
    D_inv_half = sparse.diags(1 / np.sqrt(D.diagonal()))
    L = D - A
    N = D_inv_half @ L @ D_inv_half
    M = sparse.eye(n) - (0.5)*N
    return M, D

def fast_spectral_cluster(M,D, k: int):
    # M is the signless laplacian: I - (1/2) * D^(-1/2) * A * D^(-1/2)

    n = M.shape[0]
    l = min(k, math.ceil(math.log(k, 2)))
    t = 10 * math.ceil(math.log(n / k, 2))
    Y = np.random.normal(size=(n, l))

    # We know the top eigenvector of the normalised laplacian.
    # It doesn't help with clustering, so we will project our power method to
    # be orthogonal to it.
    top_eigvec = np.sqrt(D @ np.full((n,), 1))
    norm = np.linalg.norm(top_eigvec)
    if norm > 0:
        top_eigvec /= norm

    for _ in range(t):
        Y = M @ Y

        # Project Y to be orthogonal to the top eigenvector
        for i in range(l):
            Y[:, i] -= (top_eigvec.transpose() @ Y[:, i]) * top_eigvec

    kmeans = KMeans(n_clusters=k, n_init='auto')
    kmeans.fit(Y)
    return kmeans.labels_


def _k_init_metric(K,W,k,random_gen=None):

    # retun the labels which correspond to the result of running KMeans++ initialization
    # in the kernel space.
    if random_gen is None:
        random_gen = np.random.default_rng()
    n = K.shape[0]
    X = list(range(n))
    # Ids of the points that have already been selected
    C = []
    # add a uniformly chosen point from X to C without replacement using random_state
    rand_idx = random_gen.choice(len(X),1,replace=False)
    C.append(rand_idx.item())
    # distance squared: ||phi(x)-c_1||^2 = K(x,x) + K(c_1,c_1) - 2K(x,c_1) so we can precompute K(x,x) and K(c_1,c_1)
    X_to_Cs_distance_squared = np.tile(K.diagonal()[:,np.newaxis],(1,k))
    X_to_Cs_distance_squared[:,0] +=   (K[rand_idx,rand_idx] - 2*K[rand_idx,X]).A1
    dist_squared_vector = X_to_Cs_distance_squared[:,0]
    # if dist_squared_vector.min()<=0:

    for i in range(1,k):
        cost = (dist_squared_vector*W).sum()
        probs = dist_squared_vector*W/cost

        rand_idx = random_gen.choice(len(X),1,p=probs)
        C.append(rand_idx.item())
        # update X_to_Cs_distance_squared and dist_squared_vector
        X_to_Cs_distance_squared[:,i] +=  (K[rand_idx,rand_idx] - 2*K[rand_idx,X]).A1
        dist_squared_vector = np.minimum(dist_squared_vector,X_to_Cs_distance_squared[:,i])
        # if dist_squared_vector.min()<=0:
        #     print(f"negative distance squared detected. Shifting sigma by {dist_squared_vector.min()}")

    labels = X_to_Cs_distance_squared.argmin(axis=1)
    distances = dist_squared_vector

    return labels, distances, C


def get_inertia(distances_squared,W):
    # compute the inertia of the current clustering wrt the distances squared to centers and the weights.
    return (distances_squared*W).sum()


def compute_new_labels_and_distances_to_centers(K,W,labels,k,X_to_implied_centers):
    # compute the distances squared between the input points and the induced cluster centers, in the kernel space.
    

    for j in range(k):
        mask = labels == j
        if not mask.any():
            # warnings.warn(f"Empty cluster {j} detected")
            continue

        denom = W[mask].sum()
        denom_squared = denom* denom
        
        phi_a_phi_a_term = K.diagonal()
        phi_a_implied_c_term = ((K[:,mask].multiply(W[mask])).sum(axis=1)/denom).A1
        implied_c_implied_c_term = (K[mask][:,mask].multiply(np.outer(W[mask],W[mask]))).sum()/denom_squared
        X_to_implied_centers[:,j] = phi_a_phi_a_term - 2*phi_a_implied_c_term + implied_c_implied_c_term
    
    new_labels = X_to_implied_centers.argmin(axis=1)
    X_to_closest_center = X_to_implied_centers.min(axis=1)

    return new_labels, X_to_closest_center


def KKmeans_as_a_function(K,W,n_clusters,n_iters = 100, tol=1e-5, random_gen=None,verbose=False, early_stop=True):
    if random_gen is None:
        random_gen = np.random.default_rng()
    # We don't store centeres explicitly, they are implicitly defined by the labels.



    labels, distances, _C = _k_init_metric(K,W,n_clusters,random_gen)

        
    inertia = get_inertia(distances,W)

    # scratch matrix we construct now to avoid reallocations. This is only used in 
    # compute_new_labels_and_distances_to_centers().
    X_to_implied_centers = np.zeros((K.shape[0],n_clusters))

    for _iteration in range(n_iters):
        labels,distances = compute_new_labels_and_distances_to_centers(K,W,labels,n_clusters,X_to_implied_centers)
        new_inertia = get_inertia(distances,W)
        if verbose:
            print(f"inertia: {new_inertia:.12f}")
        if early_stop and new_inertia > (1-tol)*inertia:
            break
        inertia = new_inertia


    return labels


def labels_from_distances_to_implied_coreset_centers(big_k, coreset, coreset_labels, coreset_weights,k):
    distances_to_implied_centers = distances_to_centers(big_k, coreset, coreset_labels, coreset_weights,k)
    labels = np.argmin(distances_to_implied_centers, axis=1)
    # labels = fast_distances_to_centers(big_k,coreset,coreset_labels,coreset_weights,k)
    return labels


def fast_labels_from_distances_to_implied_coreset_centers(data,indices,indptr,nnz_per_col, coreset, coreset_labels, coreset_weights,k,n):
    # distances_to_implied_centers = distances_to_centers(big_k, coreset, coreset_labels, coreset_weights,k)
    # labels = np.argmin(distances_to_implied_centers, axis=1)
    labels = fcp.fast_distances_to_centers(data,indices,indptr,nnz_per_col,coreset,coreset_labels,coreset_weights,k,n)
    # raise Exception("Stop")
    return labels

def get_column_neighbours_and_values(csc_mat, x):
    start_idx = csc_mat.indptr[x]  # Start index in the indices array for column x
    end_idx = csc_mat.indptr[x + 1]  # End index in the indices array for column x
    
    # The row indices of the neighbours
    neighbours = csc_mat.indices[start_idx:end_idx]
    
    # The values corresponding to those neighbours
    values = csc_mat.data[start_idx:end_idx]
    
    return neighbours, values


def get_coreset_weights_labels_indices_and_values_of_neighbours(neighbours, neighbour_values, coreset, coreset_labels, coreset_weights):
    # Filter neighbours that are in the coreset
    neighbours_in_coreset_mask = np.isin(neighbours, coreset)
    neighbours_in_coreset = neighbours[neighbours_in_coreset_mask]
    # Get the indices of these neighbours in the coreset
    coreset_neighbour_indices = np.nonzero(np.isin(coreset, neighbours_in_coreset))[0]
    # Get the labels and weights of the neighbours in the coreset
    labels_of_neighbours = coreset_labels[coreset_neighbour_indices]
    weights_of_neighbours = coreset_weights[coreset_neighbour_indices]
    # Get the corresponding values of the neighbours in the coreset
    values_of_neighbours = neighbour_values[neighbours_in_coreset_mask]
    return weights_of_neighbours, labels_of_neighbours, coreset_neighbour_indices, values_of_neighbours



def fast_distances_to_centers(big_k,coreset,coreset_labels,coreset_weights,k):
    # Reimplemented this method in rust for speedup:

    # For each datapoint, we want to compute Delta(phi(x),c_i) where each c_i is the implied center of a cluster:
    # c_i = sum_{y in Pi_i} w(y) phi(y) / sum_{y in Pi_i} w(y)
    # We have

    # Delta(phi(x),c_i) = ||phi(x) - c_i||^2 = ||phi(x)||^2 + ||c_i||^2 - 2 <phi(x),c_i>
    # To compute the argmin over C we can ignore the ||phi(x)||^2 term since it is constant for all c_i.

    # Therefore we just need to compute <c_i,c_i> - 2 <phi(x),c_i> and take the argmin over c.
    # Note that K is sparse and we can compute <phi(x),c_i> = sum_{a in Pi_i and a ~ x} w(a) K(a,x) / sum_{a in Pi_i} w(a)
    # and <c_i,c_i> = (sum_{a in Pi_i} sum_{b in Pi_i} w(a)w(b) K(a,b)) / (sum_{a in Pi_i} w(a))^2

    n = big_k.shape[0]

    labels = np.ones(n,dtype=int) * -1
    # Steps: We first compute the norm squared of each implied center c_i:

    center_norms = np.zeros(k)
    center_denoms = np.zeros(k)
    t0 = time.time()
    for i in range(k):
        mask = coreset_labels == i
        if not mask.any():
            continue
        
        center_denoms[i] = coreset_weights[mask].sum()

        # center_norms[i] = (big_k[mask][:, mask].multiply(coreset_weights[mask][:, None] * coreset_weights[mask])).sum() / (center_denoms[i] ** 2)

        # Instead of np.outer and full sparse multiplication, iterate over non-zero elements
        coreset_indices = np.nonzero(mask)[0]  # Indices where mask is True
        big_k_submatrix = big_k[mask][:, mask]  # Submatrix of big_k corresponding to coreset
        row_indices, col_indices = big_k_submatrix.nonzero()  # Get non-zero element indices
        
        # Only iterate over non-zero elements
        center_norm_sum = 0.0
        for row_idx, col_idx in zip(row_indices, col_indices):
            center_norm_sum += big_k_submatrix[row_idx, col_idx] * coreset_weights[coreset_indices[row_idx]] * coreset_weights[coreset_indices[col_idx]]
        
        center_norms[i] = center_norm_sum / (center_denoms[i] ** 2)
    print(f"Time taken to compute norms: {time.time()-t0:.2f}")
    # print(center_norms)
    smallest_center_by_norm = center_norms.argmin()
    smallest_center_by_norm_value = center_norms[smallest_center_by_norm]
    # print(smallest_center_by_norm)
    # For each datapoint, we compute  <c_i,c_i> - 2 <phi(x),c_i>
    coreset_set = set(coreset)
    label_dict = {core: label for core, label in zip(coreset, coreset_labels)}
    weight_dict = {core: weight for core, weight in zip(coreset, coreset_weights)}

    # Iterate over each data point
    for i in range(n):
        x_to_c_i_inner_prod = defaultdict(float)

        # Get the neighbours and their values from the sparse matrix
        neighbours, values = get_column_neighbours_and_values(big_k, i)

        # Efficiently handle neighbours in the coreset using a set and dictionaries
        for idx, neighbour in enumerate(neighbours):
            if neighbour in coreset_set:
                label = label_dict[neighbour]  # Faster lookup using separate dictionaries
                weight = weight_dict[neighbour]
                x_to_c_i_inner_prod[label] += values[idx] * weight / center_denoms[label]

        # Find the best center by minimizing <c_i, c_i> - 2 <phi(x), c_i>
        best_center = smallest_center_by_norm
        best_center_value = smallest_center_by_norm_value
        
        for center, inner_prod in x_to_c_i_inner_prod.items():
            value = center_norms[center] - 2 * inner_prod
            if value < best_center_value:
                best_center = center
                best_center_value = value
        
        labels[i] = best_center

    return labels

def aggregate_coreset(coreset, coreset_weights):
    # Dictionary to store unique coreset points and their aggregated weights
    coreset_dict = defaultdict(float)
    
    # Aggregate weights for duplicate coreset points
    for idx, weight in zip(coreset, coreset_weights):
        coreset_dict[idx] += weight

    # Return the unique coreset points and their aggregated weights
    unique_coreset = np.array(list(coreset_dict.keys()))
    aggregated_weights = np.array(list(coreset_dict.values()))

    return unique_coreset, aggregated_weights

def labels_from_distance_to_implied_centers_lazy(coreset,X,D_inv,coreset_labels,coreset_weights,k,gamma):
    # compute the distances squared between the input points and the implied cluster centers, in the kernel space.
    # lazy version: only compute the distances for the points in the coreset.
    X_to_coreset_inner_product = gpu_pairwise_kernels_batch(X, X[coreset,:], metric='rbf', gamma=gamma)
    X_to_coreset_inner_product = D_inv[:,np.newaxis] * X_to_coreset_inner_product * D_inv[coreset][np.newaxis,:]
    X_to_implied_centers = np.zeros((X.shape[0],k))
    coreset_mask = np.ones(X.shape[0],dtype=bool) * -1
    coreset_mask[coreset] = coreset_labels
    for i in range(k):
        mask = coreset_labels == i
        full_mask = coreset_mask == i
        # if not mask.any():
        #     continue
        denom = coreset_weights[mask].sum()
        denom_squared = denom* denom
        phi_a_phi_a_term = D_inv**2
        phi_a_implied_c_term = ((X_to_coreset_inner_product[:,mask]*(coreset_weights[mask])).sum(axis=1)/denom)
        implied_c_implied_c_term = (X_to_coreset_inner_product[full_mask][:,mask]*np.outer(coreset_weights[mask],coreset_weights[mask])).sum()/denom_squared
        X_to_implied_centers[:,i] = phi_a_phi_a_term - 2*phi_a_implied_c_term + implied_c_implied_c_term
    
    labels = X_to_implied_centers.argmin(axis=1)
    return labels

def distances_to_centers(big_k, coreset, coreset_labels, coreset_weights,k):
    X_to_implied_centers = np.zeros((big_k.shape[0],k))
    coreset_mask = np.ones(big_k.shape[0],dtype=bool) * -1
    coreset_mask[coreset] = coreset_labels
    coreset_weights_full = np.zeros(big_k.shape[0])
    coreset_weights_full[coreset] = coreset_weights

    phi_a_phi_a_term = big_k.diagonal()
    for i in range(k):
        mask = coreset_mask == i

        if not mask.any():
            continue
        denom = coreset_weights_full[mask].sum()
        denom_squared = denom* denom

        phi_a_implied_c_term = ((big_k[:,mask].multiply(coreset_weights_full[mask])).sum(axis=1)/denom).A1
        implied_c_implied_c_term = (big_k[mask][:,mask].multiply(np.outer(coreset_weights_full[mask],coreset_weights_full[mask]))).sum()/denom_squared
        X_to_implied_centers[:,i] = phi_a_phi_a_term - 2*phi_a_implied_c_term + implied_c_implied_c_term
    
    return X_to_implied_centers




def load_K_W_from_mat_file(path):
    mat = hdf5storage.loadmat(path)
    # Extract the data, indices, indptr, and W/W_inv from the .mat file
    data = mat["data"][:].astype(np.float32)
    indices = mat["indices"][:].astype(np.uint64)  # Cast to uint64 for large matrices
    indptr = mat["indptr"][:].astype(np.uint64)  # Cast to uint64 for consistency
    W = np.asarray(mat["W"]).astype(np.float32).squeeze()  # Convert to array and ensure W is 1D
    W_inv = np.asarray(mat["W_inv"]).astype(np.float32).squeeze()  # Convert to array and ensure W_inv is 1D
    n = W.shape[0]
    nnz_per_col = np.diff(indptr).astype(np.uint64)  # Non-zero elements per column
    return (data, indices, indptr, nnz_per_col), W, W_inv, n

def coreset_runtime_experiment(
        dataset_name,output_dir,num_vertices,
        K_indices,K_indptr,K_data,K_nnz_per_col,W,
        xs, 
        coreset_size, rounds):
    path = os.path.join(output_dir, f"{dataset_name}.json")
    times_old = np.zeros((len(xs),rounds))
    times_improved = np.zeros((len(xs),rounds))
    pbar = tqdm(range(xs.shape[0]),ncols=140)
    for i in pbar:
        cluster = xs[i]
        for r in range(rounds):
            t0 = time.time()
            _ = fcp.old_coreset(cluster, num_vertices, coreset_size, K_data, K_indices, K_indptr, K_nnz_per_col, W, True)
            times_old[i,r] = time.time() - t0
            t1 = time.time()
            _ = fcp.improved_coreset(cluster, num_vertices, coreset_size, K_data, K_indices, K_indptr, K_nnz_per_col, W, True)
            times_improved[i,r] = time.time() - t1

            pbar.set_description(f"round {r+1}/{rounds}")
        
    # save the results to a json file

    results = {
        "dataset_name": dataset_name,
        "num_vertices": num_vertices,
        "cluster_start": exp_1_cluster_start,
        "cluster_end": exp1_cluster_end,
        "cluster_steps": exp1_cluster_steps,
        "coreset_size": coreset_size,
        "rounds": rounds,
        "times_old": times_old.tolist(),
        "times_improved": times_improved.tolist()
    }

    with open(path, "w") as f:
        json.dump(results, f)


def SC(sparse_adj_csc, k):
    return SpectralClustering(n_clusters=k, affinity='precomputed').fit_predict(sparse_adj_csc)


def CKKM(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, coreset_size,pbar,k_factor=1.0, initial_shift=0.0, coreset_labels=False):

    # construct the coreset using the old method
    (coreset, coreset_weights) = fcp.old_coreset(
        int(k*k_factor),
        n,
        coreset_size,
        K_data,
        K_indices,
        K_indptr,
        K_nnz_per_col,
        W,
        True
    )
    coreset, coreset_weights = aggregate_coreset(coreset, coreset_weights)

    coreset_K = K[coreset][:,coreset]

    # start with a negative shift. If it fails half it until we get within tol of 0.
    # Then we start by adding a positive shift of tol. If that fails we double it until it works.
    tol = 1e-12
    shift = initial_shift
    shifted_coreset_K = coreset_K + shift*sparse.eye(coreset_K.shape[0])
    while True:
        try:
            coreset_kkmeans_labels = KKmeans_as_a_function(shifted_coreset_K,coreset_weights,k, n_iters=100, tol=1e-12, verbose=False, early_stop=True)
            break
        except Exception as e:
            # add a small constant to the diagonal of K to avoid negative distances. Increase it exponentially until it works.
            pbar.set_description(f"negative distance squared detected: sigma: {shift}")
            
            if shift> -tol and shift<=0:
                shift = tol
            elif shift<=-tol:
                shift /= 2
            else:
                shift *= 2
            shifted_coreset_K = coreset_K + shift*sparse.eye(coreset_K.shape[0])
    if coreset_labels:
        return coreset_kkmeans_labels, coreset
    
    labels = fast_labels_from_distances_to_implied_coreset_centers(
        K_data, K_indices, K_indptr, K_nnz_per_col,
        coreset, coreset_kkmeans_labels.astype(np.uint64), coreset_weights, k, n
    )

    return labels

def CSC(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, coreset_size,k_factor=1.0, coreset_labels=False):
    # construct the coreset using the fast method
    i =0
    (coreset, coreset_weights) = (None, None)
    while i <3:
        try:
            (coreset, coreset_weights) = fcp.improved_coreset(
                int(k_factor*k),
                n,
                coreset_size,
                K_data,
                K_indices,
                K_indptr,
                K_nnz_per_col,
                W,
                True
            )
            # raise exception if coreset is a string:
            if isinstance(coreset,str):
                raise Exception(f"{coreset}: {coreset_weights}")
            break
        except Exception as e:
            i+=1
            print(f"Error: {e}. Retrying {i}")
    if coreset is None:
        raise Exception("Failed to construct coreset due to numerical instability")
    coreset, coreset_weights = aggregate_coreset(coreset, coreset_weights)
    coreset_D = coreset_weights
    coreset_A = K[coreset][:,coreset].multiply(coreset_D.reshape(-1,1)).multiply(coreset_D.reshape(1,-1))
    labels = SpectralClustering(n_clusters=k, affinity='precomputed').fit_predict(coreset_A).astype(np.uint64)

    if coreset_labels:
        return labels, coreset

    full_labels = fast_labels_from_distances_to_implied_coreset_centers(
        K_data, K_indices, K_indptr, K_nnz_per_col, 
        coreset, labels, coreset_weights, k,n
        )

    return full_labels


def CSC_fast(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, coreset_size,k_factor=1.0,coreset_labels=False):
    
    # construct the coreset using the fast method:
    (coreset, coreset_weights) = (None, None)

    i =0
    while i <3:
        try:
            (coreset, coreset_weights) = fcp.improved_coreset(
                int(k_factor*k),
                n,
                coreset_size,
                K_data,
                K_indices,
                K_indptr,
                K_nnz_per_col,
                W,
                True
            )
            # raise exception if coreset is a string:
            if isinstance(coreset,str):
                raise Exception(f"{coreset}: {coreset_weights}")
            break
        except Exception as e:
            i+=1
            print(f"Error: {e}. Retrying {i}")
    if coreset is None:
        raise Exception("Failed to construct coreset due to numerical instability")

    coreset, coreset_weights = aggregate_coreset(coreset, coreset_weights)
    coreset_D = coreset_weights
    coreset_A = K[coreset][:,coreset].multiply(coreset_D.reshape(-1,1)).multiply(coreset_D.reshape(1,-1))
    D = coreset_A.sum(axis=1).A1
    M_coreset, D_coreset =signless_laplacian_and_D_sparse(coreset_A,sparse.diags(D))
    labels = fast_spectral_cluster(M_coreset, D_coreset, k).astype(np.uint64)

    if coreset_labels:
        return labels,coreset

    full_labels = fast_labels_from_distances_to_implied_coreset_centers(
        K_data, K_indices, K_indptr, K_nnz_per_col, 
        coreset, labels, coreset_weights, k,n
        )

    return full_labels


def exp2_evaluate(
        dataset_name,alg_names,output_dir, knn, coreset_sizes, X=None, y=None, K=None, W=None, rounds= 10):
    
    path = os.path.join(output_dir, f"{dataset_name}.json")
    
    k = len(np.unique(y))
    n = X.shape[0]
    print()
    print(f"Running experiment for {dataset_name}")
    ckkmeans_times = np.zeros((len(coreset_sizes),rounds))
    ckkmeans_aris = np.zeros((len(coreset_sizes),rounds))
    ckkmeans_nmis = np.zeros((len(coreset_sizes),rounds))
    ckkmeans_ncs = np.zeros((len(coreset_sizes),rounds))

    csc_times = np.zeros((len(coreset_sizes),rounds))
    csc_aris = np.zeros((len(coreset_sizes),rounds))
    csc_nmis = np.zeros((len(coreset_sizes),rounds))
    csc_ncs = np.zeros((len(coreset_sizes),rounds))

    fast_csc_times = np.zeros((len(coreset_sizes),rounds))
    fast_csc_aris = np.zeros((len(coreset_sizes),rounds))
    fast_csc_nmis = np.zeros((len(coreset_sizes),rounds))
    fast_csc_ncs = np.zeros((len(coreset_sizes),rounds))


    sc_times = np.zeros((len(coreset_sizes),rounds))
    sc_aris = np.zeros((len(coreset_sizes),rounds))
    sc_nmis = np.zeros((len(coreset_sizes),rounds))
    sc_ncs = np.zeros((len(coreset_sizes),rounds))

    # construct the knn graph
    adj_csc = compute_knn_graph_torch_cluster(X, knn, device='cpu')
    D = adj_csc.sum(axis=1).A1
    W = D.astype(np.float64)
    K = adj_csc.multiply((1/W).reshape(-1,1)).multiply((1/W).reshape(1,-1)).tocsc()

    for alg_name in alg_names:
        print(f"Running {alg_name}")
        K_data = K.data
        K_indices = K.indices.astype(np.uint64)
        K_indptr = K.indptr.astype(np.uint64)
        K_nnz_per_col = np.diff(K_indptr)

        for j,coreset_size in enumerate(coreset_sizes):
            print(f"Running for coreset size {coreset_size}")
            pbar = tqdm(range(rounds),ncols=80)
            if j> 0 and alg_name == "SC":
                # skip the rest of the rounds for SC and duplicate the results from the first round
                sc_times[:] = sc_times[0]
                sc_aris[:] = sc_aris[0]
                sc_nmis[:] = sc_nmis[0]
                sc_ncs[:] = sc_ncs[0]
                break
            for i in pbar:
                # if i>0 and alg_name == "SC":
                #     # skip the rest of the rounds for SC and duplicate the results from the first round
                #     sc_times[j,:] = sc_times[j,0]
                #     sc_aris[j,:] = sc_aris[j,0]
                #     sc_nmis[j,:] = sc_nmis[j,0]
                #     break
                if alg_name == "coreset KKmeans":
                    # first estimate the degrees of the vertices
                    t0 = time.time()
                    labels = CKKM(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, coreset_size,pbar)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y, labels)
                    nmi = normalized_mutual_info_score(y, labels)
                    ckkmeans_times[j,i] = t1
                    ckkmeans_aris[j,i] = ari
                    ckkmeans_nmis[j,i] = nmi
                    ckkmeans_ncs[j,i] = compute_normalized_cut(adj_csc, labels, k)
                    
                elif alg_name == "SC":
                    if dataset_name == "mnist_784":
                        continue
                    # Construct the CPS graph 
                    t0 = time.time()
                    # sklearn spectral clustering
                    labels = SC(adj_csc, k)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y, labels)
                    nmi = normalized_mutual_info_score(y, labels)
                    sc_times[j,i] = t1
                    sc_aris[j,i] = ari
                    sc_nmis[j,i] = nmi
                    sc_ncs[j,i] = compute_normalized_cut(adj_csc, labels, k)
                elif alg_name == "CSC":
                    t0 = time.time()
                    labels = CSC(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, coreset_size, k_factor=2.0)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y, labels)
                    nmi = normalized_mutual_info_score(y, labels)
                    csc_times[j,i] = t1
                    csc_aris[j,i] = ari
                    csc_nmis[j,i] = nmi
                    csc_ncs[j,i] = compute_normalized_cut(adj_csc, labels, k)
                elif alg_name == "fast CSC":
                    t0 = time.time()
                    labels = CSC_fast(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, coreset_size, k_factor=2.0)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y, labels)
                    nmi = normalized_mutual_info_score(y, labels)
                    fast_csc_times[j,i] = t1
                    fast_csc_aris[j,i] = ari
                    fast_csc_nmis[j,i] = nmi
                    fast_csc_ncs[j,i] = compute_normalized_cut(adj_csc, labels, k)


                pbar.set_description(f"round {i+1}/{rounds}")
    
    # save the results to a json file
    results = {
        "dataset_name": dataset_name,
        "coreset_sizes": coreset_sizes.tolist(),
        "knn": knn,
        "rounds": rounds,
        "coreset KKmeans_times": ckkmeans_times.tolist(),
        "coreset KKmeans_aris": ckkmeans_aris.tolist(),
        "coreset KKmeans_nmis": ckkmeans_nmis.tolist(),
        "coreset KKmeans_ncs": ckkmeans_ncs.tolist(),
        "CSC_times": csc_times.tolist(),
        "CSC_aris": csc_aris.tolist(),
        "CSC_nmis": csc_nmis.tolist(),
        "CSC_ncs": csc_ncs.tolist(),
        "fast CSC_times": fast_csc_times.tolist(),
        "fast CSC_aris": fast_csc_aris.tolist(),
        "fast CSC_nmis": fast_csc_nmis.tolist(),
        "fast CSC_ncs": fast_csc_ncs.tolist(),
        "SC_times": sc_times.tolist(),
        "SC_aris": sc_aris.tolist(),
        "SC_nmis": sc_nmis.tolist(),
        "SC_ncs": sc_ncs.tolist(),
    }

    if dataset_name == "mnist_784":
        del results["SC_times"]
        del results["SC_aris"]
        del results["SC_nmis"]
        del results["SC_ncs"]

    with open(path, "w") as f:
        json.dump(results, f)
            


def exp3_evaluate(alg_names,output_dir, ks, coreset_fraction, n_per_cluster,p,q, runs= 10):
    
    dataset_name =  "synthetic"

    path = os.path.join(output_dir, f"{dataset_name}.json")
    

    print(f"Running synthetic experiment")
    ckkmeans_times = np.zeros((len(ks),runs))
    ckkmeans_aris = np.zeros((len(ks),runs))
    ckkmeans_nmis = np.zeros((len(ks),runs))

    csc_times = np.zeros((len(ks),runs))
    csc_aris = np.zeros((len(ks),runs))
    csc_nmis = np.zeros((len(ks),runs))

    fast_csc_times = np.zeros((len(ks),runs))
    fast_csc_aris = np.zeros((len(ks),runs))
    fast_csc_nmis = np.zeros((len(ks),runs))


    
    for j,k in enumerate(ks):
        q_k = q/k
        print(f"Running for k={k} and q={q_k:.3e}")
        # construct the stochastic block model
        A, y = stochastic_block_model(n_per_cluster, k, p, q)
        n = A.shape[0]

        adj_csc = A
        D = adj_csc.sum(axis=1).A1
        W = D.astype(np.float64)
        K = adj_csc.multiply((1/W).reshape(-1,1)).multiply((1/W).reshape(1,-1)).tocsc()

        K_data = K.data
        K_indices = K.indices.astype(np.uint64)
        K_indptr = K.indptr.astype(np.uint64)
        K_nnz_per_col = np.diff(K_indptr)

        for alg_name in alg_names:
            print(f"Running {alg_name}")
            pbar = tqdm(range(runs),ncols=120)
            for i in pbar:
                if alg_name == "coreset KKmeans":
                    # first estimate the degrees of the vertices
                    t0 = time.time()
                    labels,coreset = CKKM(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, int(coreset_fraction*n),pbar, initial_shift = 1/(n),coreset_labels=True)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y[coreset], labels)
                    nmi = normalized_mutual_info_score(y[coreset], labels)
                    ckkmeans_times[j,i] = t1
                    ckkmeans_aris[j,i] = ari
                    ckkmeans_nmis[j,i] = nmi
                elif alg_name == "CSC":
                    t0 = time.time()
                    labels,coreset = CSC(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, int(coreset_fraction*n), k_factor=2.0, coreset_labels=True)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y[coreset], labels)
                    nmi = normalized_mutual_info_score(y[coreset], labels)
                    csc_times[j,i] = t1
                    csc_aris[j,i] = ari
                    csc_nmis[j,i] = nmi
                elif alg_name == "fast CSC":
                    t0 = time.time()
                    labels,coreset = CSC_fast(K,K_data, K_indices, K_indptr, K_nnz_per_col, W, n, k, int(coreset_fraction*n), k_factor=5.0, coreset_labels=True)
                    t1 = time.time() - t0
                    ari = adjusted_rand_score(y[coreset], labels)
                    nmi = normalized_mutual_info_score(y[coreset], labels)
                    fast_csc_times[j,i] = t1
                    fast_csc_aris[j,i] = ari
                    fast_csc_nmis[j,i] = nmi


    
    # save the results to a json file
    results = {
        "dataset_name": dataset_name,
        "coreset_fraction": coreset_fraction,
        "n_per_cluster": n_per_cluster,
        "p": p,
        "q": q,
        "ks": ks.tolist(),
        "rounds": runs,
        "coreset KKmeans_times": ckkmeans_times.tolist(),
        "coreset KKmeans_aris": ckkmeans_aris.tolist(),
        "coreset KKmeans_nmis": ckkmeans_nmis.tolist(),
        "CSC_times": csc_times.tolist(),
        "CSC_aris": csc_aris.tolist(),
        "CSC_nmis": csc_nmis.tolist(),
        "fast CSC_times": fast_csc_times.tolist(),
        "fast CSC_aris": fast_csc_aris.tolist(),
        "fast CSC_nmis": fast_csc_nmis.tolist(),
    }

    with open(path, "w") as f:
        json.dump(results, f)





# MARK:- Bock model

def stochastic_block_model(n_per_cluster, k, p, q): 

    N = int(n_per_cluster * k)
    k = int(k)
    p = np.float64(p)
    q = np.float64(q)
    g = stag.random.sbm(N,k,p,q,False)
    adj = g.adjacency().to_scipy()

    # add self loops:
    adj = adj + sparse.eye(N)

    labels = stag.random.sbm_gt_labels(N,k)
    return adj, labels


# MARK:- Dataloader
def real_dataset_loader(dataset_name):
    try:
        X,y = fetch_openml(name=dataset_name, version=1, as_frame=False, return_X_y=True,data_home="data/",cache=True,parser="auto")
    except Exception as e:
        print(e)
        print(f"Failed to load dataset {dataset_name}")
        return None
    knn = None
    if dataset_name == "pendigits":
        knn = 250
    elif dataset_name == "mnist_784":
        knn = 500
    elif dataset_name == "har":
        knn = 200
    elif dataset_name == "letter":
        knn = 300
    
    return X,y,knn


        
dataset_file_map = {
    "friendster": "com-Friendster_preprocessed.mat",
    "wikipedia": "wiki-topcats_preprocessed.mat",
    "livejournal": "com-LiveJournal_preprocessed.mat",
}




pretty_names = {
    "CSC": "Coreset Spectral Clustering",
    "fast CSC": "Fast Coreset Spectral Clustering",
    "coreset KKmeans": "Coreset Kernel k-means",
    "SC": "Spectral Clustering",
}

colours = {
    "CSC": "blue",
    "fast CSC": "orange",
    "coreset KKmeans": "green",
    "SC": "black",

}

# MARK:- Main
if __name__ == "__main__":




    #MARK:- exp1 config:

    exp1_datasets= [
        "wikipedia",
        "livejournal",
        "friendster",
        ]
    exp1_n_runs = 5
    exp_1_cluster_start = 2
    exp1_cluster_end = 5000
    exp1_cluster_steps = 50
    exp1_coreset_size = 100_000
    exp1_xs = np.unique(np.geomspace(exp_1_cluster_start, exp1_cluster_end, exp1_cluster_steps, dtype=int))


    #MARK:- exp2 config:


    exp2_dataset_names = [
        "pendigits",
        "mnist_784",
        "har",
        "letter",
    ]
    exp2_alg_names = [
        "CSC", 
        "fast CSC",
        "coreset KKmeans",
        "SC",
        ]
    
    exp2_n_runs = 20
    exp2_coreset_size_start = 100
    exp2_coreset_size_end = 3000
    exp2_coreset_size_steps = 50
    exp2_coreset_sizes = np.unique(np.linspace(exp2_coreset_size_start, exp2_coreset_size_end, exp2_coreset_size_steps, dtype=int))


    #MARK:- exp3 config:

    exp3_alg_names = [
        # "CSC", 
        "fast CSC",
        "coreset KKmeans",
        # "SC",
        ]
    
    exp3_p = 0.5
    exp3_q = 0.001
    exp3_n_runs = 20
    exp3_coreset_fraction = 0.01
    n_per_cluster = 1000
    exp3_k_start = 10
    exp3_k_end = 250
    exp3_k_steps = 50
    exp3_ks = np.unique(np.linspace(exp3_k_start, exp3_k_end, exp3_k_steps, dtype=int))



    parser = argparse.ArgumentParser()
    parser.add_argument("mode", type=str, help="Mode of operation: 'run' or 'plot'", choices=["exp1","exp2","exp3","plotexp1","plotexp2","plotexp3"])
    # add optional dataset argument for exp2
    parser.add_argument("--dataset", type=str, help="Name of the dataset to run the experiment on", choices=exp2_dataset_names)
    args = parser.parse_args()

    mode = args.mode



    if mode == "exp1":
        # experiment 1: compare coreset algorithm running times on three large graphs
        for dataset in exp1_datasets:
            print(f"Running experiment for {dataset}")
            
            path = os.path.join("graph_datasets", dataset_file_map[dataset])
            (K_data, K_indices, K_indptr, K_nnz_per_col), W, W_inv, n = load_K_W_from_mat_file(path)
            coreset_runtime_experiment(
                dataset, "results", n, K_indices, K_indptr, K_data, K_nnz_per_col, W,
                exp1_xs, exp1_coreset_size, exp1_n_runs
            )


    if mode == "exp2":
        
        
        for dataset_name in exp2_dataset_names:
            X,y,knn = real_dataset_loader(dataset_name)
            exp2_evaluate(
                dataset_name,exp2_alg_names,"results", knn, exp2_coreset_sizes, X=X, y=y, rounds= exp2_n_runs
            )

    if mode == "exp3":
        exp3_evaluate(
            exp3_alg_names,"results", exp3_ks, exp3_coreset_fraction, n_per_cluster, exp3_p, exp3_q, runs= exp3_n_runs
            )


    elif mode == "plotexp1":

        fontsize = 16
        fig, ax = plt.subplots(1, len(exp1_datasets), figsize=(len(exp1_datasets)*6, 6), sharey=True)

        for i, dataset in enumerate(exp1_datasets):
            # Load the json results
            path = os.path.join("results", f"{dataset}.json")
            with open(path, "r") as f:
                results = json.load(f)
            
            # Prepare the x values and times data
            times_old = np.array(results["times_old"])
            times_improved = np.array(results["times_improved"])
            
            # Calculate means and stds
            times_old_means = times_old.mean(axis=1)
            times_old_stds = times_old.std(axis=1)
            times_improved_means = times_improved.mean(axis=1)
            times_improved_stds = times_improved.std(axis=1)
            
            # Plot mean lines for both methods
            ax[i].plot(exp1_xs, times_old_means, label="Jiang et al. coreset construction", color="blue")
            ax[i].plot(exp1_xs, times_improved_means, label="Our coreset construction", color="green")
            
            # Fill between mean ± std for shaded areas
            ax[i].fill_between(exp1_xs, times_old_means - times_old_stds, times_old_means + times_old_stds, color="blue", alpha=0.3)
            ax[i].fill_between(exp1_xs, times_improved_means - times_improved_stds, times_improved_means + times_improved_stds, color="green", alpha=0.3)
            
            # Customize plot for each dataset
            ax[i].set_title(f"{dataset}", fontsize=fontsize)
            ax[i].set_xlabel("Clusters", fontsize=fontsize)
            ax[i].set_ylabel("seconds (log scale)", fontsize=fontsize)
            ax[i].tick_params(axis='y', which='both', labelleft=True)
            
            # Use log scale for y axis
            ax[i].set_yscale("log")

        # Create a shared legend below the plots
        handles, labels = ax[0].get_legend_handles_labels()  # Get the legend info from the first subplot
        fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=fontsize)

        # Adjust layout to leave space for the shared legend
        plt.tight_layout(rect=[0, 0.1, 1, 1])

        # Save the plot
        plt.savefig("results/coreset_runtime_comparison.png")



    elif mode == "plotexp2":
        fontsize = 16
        dataset_names = exp2_dataset_names if args.dataset is None else [args.dataset]

        for dataset_name in dataset_names:
            # Load the json results
            path = os.path.join("results", f"{dataset_name}.json")
            with open(path, "r") as f:
                results = json.load(f)
            coreset_sizes = np.array(results["coreset_sizes"])
            knn = results["knn"]
            rounds = results["rounds"]
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            # one ax for time, one for ARI
            for i, alg_name in enumerate(exp2_alg_names):
                if dataset_name == "mnist_784" and alg_name == "SC":
                    continue
                times = np.array(results[f"{alg_name}_times"])
                aris = np.array(results[f"{alg_name}_aris"])
                ncs = np.array(results[f"{alg_name}_ncs"])
                times_means = times.mean(axis=1)
                aris_means = aris.mean(axis=1)
                times_stds = times.std(axis=1)
                aris_stds = aris.std(axis=1)
                ncs_means = ncs.mean(axis=1)
                ncs_stds = ncs.std(axis=1)

                axes[0].plot(coreset_sizes, times_means, label=pretty_names[alg_name], color=colours[alg_name])
                axes[0].fill_between(coreset_sizes, times_means - times_stds, times_means + times_stds, alpha=0.3, color=colours[alg_name])
                axes[1].plot(coreset_sizes, aris_means, label=pretty_names[alg_name], color=colours[alg_name])
                axes[1].fill_between(coreset_sizes, aris_means - aris_stds, aris_means + aris_stds, alpha=0.3, color=colours[alg_name])
                axes[0].set_title(f"Running time",fontsize=fontsize)
                axes[1].set_title(f"ARI",fontsize=fontsize)
                axes[0].set_xlabel("Coreset size",fontsize=fontsize)
                axes[1].set_xlabel("Coreset size", fontsize=fontsize)
                axes[0].set_ylabel("seconds", fontsize=fontsize)
                axes[1].set_ylabel("ARI", fontsize=fontsize)

                # Add a third plot for normalized cut
                axes[2].plot(coreset_sizes, ncs_means, label=pretty_names[alg_name], color=colours[alg_name])
                axes[2].fill_between(coreset_sizes, ncs_means - ncs_stds, ncs_means + ncs_stds, alpha=0.3, color=colours[alg_name])
                axes[2].set_title(f"Normalised cut", fontsize=fontsize)
                axes[2].set_xlabel("Coreset size", fontsize=fontsize)
                axes[2].set_ylabel("Normalised cut", fontsize=fontsize)

                # time log scale if not mnist:
                if dataset_name != "mnist_784":
                    axes[0].set_yscale("log")

            # Add a shared legend below the plots
            handles, labels = axes[0].get_legend_handles_labels()  # Get legend handles and labels from one axis
            # set legend font size

            fig.legend(handles, labels, loc="lower center", ncol=len(exp2_alg_names),fontsize=fontsize-1)  # Create a shared legend


            plt.tight_layout(rect=[0, 0.1, 1, 1])  # Adjust layout to leave space for legend

            # Save the plot
            plt.savefig(f"results/exp2_{dataset_name}_experiment_results.png")

    elif mode == "plotexp3":
        fontsize = 16
        path = os.path.join("results", "synthetic.json")
        with open(path, "r") as f:
            results = json.load(f)

        ks = np.array(results["ks"])
        coreset_fraction = results["coreset_fraction"]
        n_per_cluster = results["n_per_cluster"]
        p = results["p"]
        q = results["q"]
        rounds = results["rounds"]
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        for i, alg_name in enumerate(exp3_alg_names):
            times = np.array(results[f"{alg_name}_times"])
            aris = np.array(results[f"{alg_name}_aris"])
            times_means = times.mean(axis=1)
            aris_means = aris.mean(axis=1)
            times_stds = times.std(axis=1)
            aris_stds = aris.std(axis=1)
            axes[0].plot(ks, times_means, label=pretty_names[alg_name], color=colours[alg_name])
            axes[0].fill_between(ks, times_means - times_stds, times_means + times_stds, alpha=0.3, color=colours[alg_name])
            axes[1].plot(ks, aris_means, label=pretty_names[alg_name], color=colours[alg_name])
            axes[1].fill_between(ks, aris_means - aris_stds, aris_means + aris_stds, alpha=0.3, color=colours[alg_name])
            axes[0].set_title(f"Running time", fontsize=fontsize)
            axes[1].set_title(f"ARI", fontsize=fontsize)
            axes[0].set_xlabel("Number of clusters", fontsize=fontsize)
            axes[1].set_xlabel("Number of clusters", fontsize=fontsize)
            axes[0].set_ylabel("seconds", fontsize=fontsize)
            axes[1].set_ylabel("ARI", fontsize=fontsize)

        # Create a shared legend below both plots
        handles, labels = axes[0].get_legend_handles_labels()  # Get legend info from one axis
        fig.legend(handles, labels, loc="lower center", ncol=len(exp3_alg_names), fontsize=fontsize)

        # Adjust layout to leave space for the shared legend
        plt.tight_layout(rect=[0, 0.1, 1, 1])

        # Add a super title
        # fig.suptitle(f"Synthetic experiment results with coreset fraction {coreset_fraction}", fontsize=20)

        # Save the plot
        plt.savefig(f"results/exp3_synthetic_experiment_results.png")

    else:
        print("Invalid mode")
        exit(1)

