#%%
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.datasets import make_blobs, make_moons
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans
from sklearn.datasets import fetch_openml
from collections import defaultdict
from time import time
from tqdm import tqdm
from sklearn import metrics
from datetime import datetime
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.preprocessing import StandardScaler
import pandas as pd
import time
import os
from copy import deepcopy
import argparse


# MARK: -lazy Kernel init
def initialize_kernel_kmeans_plusplus_lazy(X_data, k,kernel_function, random_gen=None):
    # fill K with Nans
    K = np.full((X_data.shape[0],X_data.shape[0]),np.nan)
    if random_gen is None:
        random_gen = np.random.default_rng()

    n = X.shape[0]
    X_indices = list(range(n))  # List of indices
    C = []  # Ids of the points that have already been selected

    # Add a uniformly chosen point from X to C without replacement using random_gen
    rand_idx = random_gen.choice(len(X_indices), 1, replace=False)
    # compute the corresponding row/column of the kernel matrix
    K[:,rand_idx] = kernel_function(X_data,X_data[rand_idx])
    K[rand_idx,:] = K[:,rand_idx].T
    C.append(rand_idx.item())

    # compute the diagonal of K
    np.fill_diagonal(K,np.array([kernel_function(X_data[i,:][np.newaxis,:],X_data[i,:][np.newaxis,:]).item() for i in range(n)]))

    # Distance squared: ||phi(x)-c_1||^2 = K(x,x) + K(c_1,c_1) - 2K(x,c_1)
    diag = K.diagonal()[:, np.newaxis]


    X_to_Cs_distance_squared_helper = None
    X_to_Cs_distance_squared = (diag + (K[rand_idx,rand_idx] - 2 * K[:,rand_idx])).flatten()
    label_vector = np.ones(n, dtype=int) * rand_idx.item()
    for i in range(1, k):
        cost = X_to_Cs_distance_squared.sum()
        if cost <= 0:
            raise ValueError("Sum of squared distances is non-positive, check kernel calculations.")

        probs = X_to_Cs_distance_squared / cost
        if np.any(probs < 0) or not np.isfinite(probs).all():
            raise ValueError("Computed probabilities contain non-positive or non-finite values.")
        #print(probs)
        rand_idx = random_gen.choice(len(X_indices), 1, p = probs)
        C.append(rand_idx.item())

        # compute the kernel entries for the new center:
        K[:,rand_idx] = kernel_function(X_data,X_data[rand_idx])
        K[rand_idx,:] = K[:,rand_idx].T

        # Update distances for the newly chosen center
        X_to_Cs_distance_squared_helper = (diag + K[rand_idx,rand_idx] - 2 * K[:,rand_idx]).flatten()

        # if the distance in the helper matrix is smaller, update the label and the distance
        mask = X_to_Cs_distance_squared_helper < X_to_Cs_distance_squared
        label_vector[mask] = rand_idx.item()
        X_to_Cs_distance_squared[mask] = X_to_Cs_distance_squared_helper[mask]

    # label vector currently contains the indices of the respective centers.
    # for ease of use, we map these to 0,1,...,k-1 instead
    label_vector = np.array([C.index(l) for l in label_vector])

    return label_vector, X_to_Cs_distance_squared, C, K


def initialize_kernel_kmeans_plusplus(K, k, random_gen=None):
    # Return the labels which correspond to the result of running KMeans++ initialization
    # in the kernel space.
    if random_gen is None:
        random_gen = np.random.default_rng()

    n = K.shape[0]
    X = list(range(n))  # List of indices
    C = []  # Ids of the points that have already been selected

    # Add a uniformly chosen point from X to C without replacement using random_gen
    rand_idx = random_gen.choice(len(X), 1, replace=False)
    C.append(rand_idx.item())

    # Distance squared: ||phi(x)-c_1||^2 = K(x,x) + K(c_1,c_1) - 2K(x,c_1)
    # Precompute diagonal terms K(x,x) and K(c_1,c_1)
    diagonal_K = K.diagonal()  # Extract the diagonal of K
    X_to_Cs_distance_squared = np.tile(diagonal_K[:, np.newaxis], (1, k))
    X_to_Cs_distance_squared[:, 0] += (K[rand_idx, rand_idx] - 2 * K[:, rand_idx].flatten())

    dist_squared_vector = X_to_Cs_distance_squared[:, 0]

    # Compute probabilities and choose new center points
    for i in range(1, k):
        cost = dist_squared_vector.sum()
        probs = dist_squared_vector / cost
        rand_idx = random_gen.choice(len(X), 1, p=probs)
        C.append(rand_idx.item())

        # Update distances for each newly chosen center
        X_to_Cs_distance_squared[:, i] += (K[rand_idx, rand_idx] - 2 * K[:, rand_idx].flatten())
        dist_squared_vector = np.minimum(dist_squared_vector, X_to_Cs_distance_squared[:, i])

    labels = X_to_Cs_distance_squared.argmin(axis=1)
    distances = dist_squared_vector
    
    return labels, distances, C

# MARK: -KMBKM class
class KernelMiniBatchKMeans:
    def __init__(self, n_clusters, batch_size, n_iterations,  new_lr=False, full_batch =False, use_coreset=False, random_state=None, lazy=False):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        print("batch size in constructor: ", batch_size)
        self.n_iterations = n_iterations
        self.full_batch = full_batch
        self.final_inertia = 0
        self.new_lr=new_lr #new learning rate flag
        self.random_state = random_state
        self.lazy = lazy
    

    def compute_new_labels_and_distances(
        self, K, labels, batch, X_to_C_inner_products, C_to_C_inner_products, X_to_implied_centers, weight_history=None):
        
        k = self.n_clusters
        # compute the distances squared between the input points and the new induced cluster centers
        # and update X_to_C_inner_products, C_to_C_inner_products, and X_to_implied_centers in place.
        B = len(batch)  # Since all weights are 1, this is just the count of items in batch
        labels_batch = labels.copy()

        # set all labels not in the batch to -1:
        labels_batch[np.isin(range(K.shape[0]), batch, invert=True)] = -1

        phi_a_phi_a_term = K.diagonal()
        for j in range(k):
            mask = labels_batch == j
            if not mask.any():
                continue
            
            learning_rate = 0
            b = mask.sum()  # Sum of weights where mask is true, which is just the count of True values in mask
            if self.new_lr:
                learning_rate = np.sqrt(b / B)
            else:
                weight_history[j] += b
                learning_rate = b / weight_history[j]
            
            if self.full_batch:
                learning_rate = 1
            
            denom = b
            denom_squared = denom * denom

            phi_a_cm_batch = (K[:, mask].sum(axis=1) / denom)

            # Update formulas simplify as weights are 1
            C_cm_batch_term = (X_to_C_inner_products[mask, j]).sum() / denom

            cm_batch_cm_batch_term = (K[mask][:, mask].sum() / denom_squared)

            # update the inner products between the data points and the new cluster center(s)
            X_to_C_inner_products[:, j] = (1 - learning_rate) * X_to_C_inner_products[:, j] + learning_rate * phi_a_cm_batch

            C_to_C_inner_products[j] = (1 - learning_rate) ** 2 * C_to_C_inner_products[j] \
                                       + 2 * (1 - learning_rate) * learning_rate * C_cm_batch_term \
                                       + learning_rate ** 2 * cm_batch_cm_batch_term
            
            X_to_implied_centers[:, j] = phi_a_phi_a_term - 2 * X_to_C_inner_products[:, j] + C_to_C_inner_products[j]

        new_labels = X_to_implied_centers.argmin(axis=1)
        X_to_closest_center = X_to_implied_centers.min(axis=1)

        return new_labels, X_to_closest_center


    def get_inertia(self,distances_squared):
        return distances_squared.sum()

    # MARK: -Lazy fit
    def fit_lazy(self, X, kernel_sigma):
        rng = np.random.default_rng(self.random_state)
        n_iterations = self.n_iterations
        batch_size = self.batch_size
        print(f"batch size in fit: {batch_size}")
        if self.full_batch:
            batch_size = X.shape[0]
        

        kernel_function = lambda x,y: pairwise_kernels(x,y,metric="rbf",gamma=1/(2*kernel_sigma**2))
        # first we compute the kernel kmeans++ initialization
        labels, distances_squared, C, K = initialize_kernel_kmeans_plusplus_lazy(X, self.n_clusters,kernel_function, rng)
        
        # track which parts of the kernel matrix have already been computed
        #computed_kernel_indices = np.zeros(X.shape[0], dtype=bool)
        #computed_kernel_indices[C] = True


        X_to_C_inner_products = K[:, C].copy()
        C_to_C_inner_products = K[C][:, C].copy().ravel()
        X_to_implied_centers = np.zeros((K.shape[0], self.n_clusters))

        weight_history = None
        if not self.new_lr:
            weight_history = np.zeros(self.n_clusters)

        #pbar = tqdm(range(n_iterations))
        #presample batches and compute kernel entries at once
        batches = []
        unique_indices = set()
        for i in range(n_iterations):
            batch = rng.choice(X.shape[0],batch_size,replace=False)
            batches.append(batch)
            unique_indices.update(set(batch)) 
        
        unique_indices = list(unique_indices)
        #K = kernel_function(X[unique_indices],X)
        K[:,unique_indices] = kernel_function(X,X[unique_indices])
        K[unique_indices,:]= K[:,unique_indices].T
        for batch in batches:
            labels,distances_squared = self.compute_new_labels_and_distances(
                K,labels,batch,X_to_C_inner_products,C_to_C_inner_products,
                X_to_implied_centers, weight_history
            )
            new_inertia = self.get_inertia(distances_squared)
        '''
        for i in pbar:

            batch = rng.choice(X.shape[0],batch_size,replace=False)
            #first identify the points in the batch we have not computed kernel entries for:
            batch_to_compute = batch[computed_kernel_indices[batch] == False]

            pbar.set_description(f"filling {batch_to_compute.shape[0]} kernel entries")
            if batch.shape[0] !=0:
                # compute the kernel entries for the batch
                K[:,batch_to_compute] = kernel_function(X,X[batch_to_compute])
                K[batch_to_compute,:]= K[:,batch_to_compute].T
                computed_kernel_indices[batch_to_compute] = True
            

            
            labels,distances_squared = self.compute_new_labels_and_distances(
                K,labels,batch,X_to_C_inner_products,C_to_C_inner_products,
                X_to_implied_centers, weight_history
            )
            new_inertia = self.get_inertia(distances_squared)
        '''
        self.final_inertia, self.labels_ = new_inertia, labels 
        

    def fit(self, X, K,labels, distances_squared, C):
        rng = np.random.default_rng(self.random_state)
        n_iterations = self.n_iterations
        batch_size = self.batch_size
        if self.full_batch:
            batch_size = X.shape[0]
        
        X_to_C_inner_products = K[:,C].copy()
        C_to_C_inner_products = K[C][:,C].copy().ravel()

        # scratch matrix we construct now to avoid reallocations. This is only used in 
        # compute_new_labels_and_distances_to_centers().
        X_to_implied_centers = np.zeros((K.shape[0],self.n_clusters))

        weight_history = None
        if not self.new_lr:
            weight_history = np.zeros(self.n_clusters)

        for i in range(n_iterations):
            
            batch = rng.choice(X.shape[0],batch_size,replace=False)
            
            labels,distances_squared = self.compute_new_labels_and_distances(
                K,labels,batch,X_to_C_inner_products,C_to_C_inner_products,
                X_to_implied_centers, weight_history
            )
            new_inertia = self.get_inertia(distances_squared)
            
        # Calculate final inertia snd cluster assignments for the entire dataset
        self.final_inertia, self.labels_ = new_inertia, labels 


def get_inertia(points, centers, normalized=True):
    nearest = np.argmin(euclidean_distances(points, centers), axis=1)
    inertia = np.linalg.norm(points-centers[nearest])**2 
    if normalized:
        inertia /=  points.shape[0]
    return inertia, nearest

class MiniBatchKMeans:
    def __init__(self, n_clusters, batch_size, n_iterations, dynamic_batching=True, new_lr=False,lr_decay = False, use_var=False, use_beta=False, jl="None", random_state=None):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        self.n_iterations = n_iterations
        self.centroids = None
        self.final_inertia = 0
        self.new_lr=new_lr #new learning rate flag
        self.random_state = random_state
        
    #needed for consistency with sklrean (for running experiments)   
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self

    def fit(self, X, init_centroids=None):
        self.total_sampled = 0
        rng = np.random.default_rng(self.random_state)
        n_samples, n_features = X.shape
        self.centroids = init_centroids
        assert(init_centroids is not None)
        self.counts = np.zeros(self.n_clusters)
        weight_history = np.zeros(n_clusters)
        
        for i in range(self.n_iterations):
            #print iteration
            # print(f"iteration {i} out of {self.n_iterations}")


            minibatch_indices = rng.choice(X.shape[0],self.batch_size,replace=False)
            
            minibatch = X[minibatch_indices]
            
            #assign points to clusters and compute inertia
            old_inertia, cached_nearest = get_inertia(minibatch, self.centroids)
            new_centroids = np.copy(self.centroids)
            
            cluster_points = {c_idx: [] for c_idx in range(self.n_clusters)}

            sums = np.zeros((self.n_clusters, n_features))
            cluster_sizes = np.zeros(self.n_clusters)

            for idx, x in enumerate(minibatch):
                centroid_idx = cached_nearest[idx]
                sums[centroid_idx] += x
                cluster_sizes[centroid_idx] += 1
                cluster_points[centroid_idx].append(x)

            
            for c_idx in range(self.n_clusters):
                if cluster_sizes[c_idx] == 0:
                    continue
                center_of_mass = sums[c_idx] / cluster_sizes[c_idx]
                learning_rate = 0
                if self.new_lr:
                    learning_rate =  np.sqrt(cluster_sizes[c_idx] / (self.batch_size)) #new learning rate   
                else:
                    weight_history[c_idx] += cluster_sizes[c_idx]
                    learning_rate = cluster_sizes[c_idx]/weight_history[c_idx]
                
                learning_rate = min(1, learning_rate)
                new_centroids[c_idx] = (1 - learning_rate) * self.centroids[c_idx] + learning_rate * center_of_mass

            self.centroids = new_centroids

        # Calculate final inertia snd cluster assignments for the entire dataset
        self.final_inertia, self.labels_ = get_inertia(X,self.centroids)

    def predict(self, X):
        return np.argmin(euclidean_distances(X, self.centroids), axis=1)


def load_and_preprocess_data(dataset, n_sample=None):
    try:
        X,y = fetch_openml(name=dataset, version=1, as_frame=False, return_X_y=True,data_home="data/",cache=True,parser="liac-arff")
    except:
        X= None
    
    if dataset == "pendigits":
        kernel_sigma = 2.23606797749979#3.5#5.6496
        normalize = True

    if dataset == "har":
        #kernel_sigma = 5.5819
        kernel_sigma = 10.414010843544517 #unnormalized
        normalize = False
        #kernel_sigma = 33.4194

    if dataset == "moons":
        kernel_sigma = 1.3369
        #generate data set
        X, y = make_moons(n_samples=5000, noise=0.1, random_state=42)

    if dataset == "mnist_784":
        kernel_sigma = 5.8301
        X = X/255
        normalize = False

    if dataset == "letter":
        kernel_sigma = 2.23606797749979#2.3351
        kernel_sigma = 3.0
        normalize = True
    
    if n_sample is not None:
        shuffle = np.random.permutation(X.shape[0])
        X = X[shuffle]
        y = y[shuffle]
        X = X[:n_sample]
        y = y[:n_sample]
    if normalize:
        X = StandardScaler().fit_transform(X)
    return X,y, kernel_sigma


def initialize_kmeans_plusplus(X, k, rng):
    n_samples, n_features = X.shape
    centroids = np.empty((k, n_features), dtype=X.dtype)

    # Step 1: Randomly choose the first centroid using the provided rng
    centroids[0] = X[rng.integers(n_samples)]

    # Step 2 and 3: Choose the remaining centroids
    for c in range(1, k):
        squared_diff = np.square(X[:, np.newaxis, :] - centroids[np.newaxis, :c, :])
        distances = np.min(squared_diff, axis=1)
        total_distances = np.sum(distances, axis=1)
        probabilities = total_distances / np.sum(total_distances)
        centroids[c] = X[rng.choice(n_samples, p=probabilities)]

    return centroids


# MARK: -Evaluation
def evaluate(kms, X,K, labels, num_iters, n_clusters, batch_size, n_runs=50,kernel_construction_time=0):
    
    evaluations = []
                    
    for name, km in kms.items():
        train_times = []
        print(f"Evaluating {name}")
        scores = defaultdict(list)
        for seed in tqdm(range(n_runs)):
            rng = np.random.default_rng(seed) #set random seed
            
        
            km.random_state=seed
            if isinstance(km, KernelMiniBatchKMeans):
                if km.lazy == False:
                    t0 = time.time()
                    init_labels, init_distances_squared, init_C = initialize_kernel_kmeans_plusplus(K,n_clusters,rng)
                    km.fit(X, K, init_labels, init_distances_squared, init_C)
                else:
                    t0 = time.time()
                    km.fit_lazy(X, kernel_sigma)
            else:
                t0 = time.time()
                init_centroids = initialize_kmeans_plusplus(X,n_clusters,rng)
                km.fit(X, init_centroids)
            
            # include the time it took to construct the kernel matrix in the training time
            if isinstance(km, KernelMiniBatchKMeans) and km.lazy == False:
                train_times.append(time.time() - t0 + kernel_construction_time)
            else:
                train_times.append(time.time() - t0)
            scores["NMI"].append(metrics.normalized_mutual_info_score(labels, km.labels_))
            #scores["Homogeneity"].append(metrics.homogeneity_score(labels, km.labels_))
            #scores["Completeness"].append(metrics.completeness_score(labels, km.labels_))
            #scores["V-measure"].append(metrics.v_measure_score(labels, km.labels_))
            scores["ARI"].append(metrics.adjusted_rand_score(labels, km.labels_))
            #scores["Silhouette Coefficient"].append(metrics.silhouette_score(X, km.labels_, sample_size=2000))
    
        train_times = np.asarray(train_times)

        evaluation = {
            "estimator": name,
            "num_iters": num_iters,
            "n_clusters": n_clusters,
            "batch_size": batch_size,
            "train_time_mean": train_times.mean(),
            "train_time_std": train_times.std()
        }

        for score_name, score_values in scores.items():
            mean_score, std_score = np.mean(score_values), np.std(score_values)
            evaluation[score_name + "_mean"] = mean_score
            evaluation[score_name + "_std"] = std_score
    
        evaluations.append(evaluation)

        print(f"\n {name}, num_iters: {num_iters}, n_clusters: {n_clusters}, batch size: {batch_size} ")
        for score_name, score_values in scores.items():
            mean_score, std_score = np.mean(score_values), np.std(score_values)
            print(f"{score_name}: {mean_score:.3f} ± {std_score:.3f}")
        

    return evaluations

def wang_heuristic(X, sample = None, **kwargs): 
    """
    heuristic from
    @article{JMLR:v20:17-517,
        author  = {Shusen Wang and Alex Gittens and Michael W. Mahoney},
        title   = {Scalable Kernel K-Means Clustering with Nystrom Approximation: Relative-Error Bounds},
        journal = {Journal of Machine Learning Research},
        year    = {2019},
        volume  = {20},
        number  = {12},
        pages   = {1--49},
        url     = {http://jmlr.org/papers/v20/17-517.html}
    }
    """
    indices = range(X.shape[0])
    if sample is not None:
        indices = np.random.choice(indices,sample,replace=False)

    n = len(indices)

    # compute the sum of all pairwise distances squared from the sample to the rest of the data
    data = X[indices]

    #distances = np.zeros((n,1))
    Y = np.zeros((n,1))
    for i in tqdm(range(n)):
        Y[i] = (np.sum((X-data[i])**2,axis=1)).mean()    

    sigma = np.sqrt(Y.mean())

    return sigma  






colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#ff7f0e',  # safety orange
    '#9467bd',  # muted purple
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
    # Add more as needed
]

hatches = [
    '','//','','//',''  
]

def plot_results(to_plot):
    plt.rcParams.update({'font.size': 24}) 
    plt.figure(figsize=(10, 6))

    num_res = len(to_plot)  # Number of rows in the grid
    #assume all DFs have the same batch sizes
    batch_sizes = to_plot[0]['batch_size'].unique()
    num_batches = len(batch_sizes)
    fig, axes = plt.subplots(num_batches, num_res, figsize=(7*num_res, 6*num_batches))
    #fig, axes = plt.subplots(num_batches, num_res, figsize=(7*num_res, 6))
    
    for j in range(num_batches):
        for i, df1 in enumerate(to_plot):
            b = batch_sizes[j]
            name = df1['dataset'].iloc[0]
            df = df1[df1['batch_size'] == batch_sizes[j]]

            
            if num_res == 1:
                ax = axes[j]
            else:
                ax = axes[j][i]
            ax1,ax2 = plot_results_bars(df, ax,i==0 and j==0)
            if i == 0:
                ax1.set_ylabel('Score')
            if i == num_res - 1:
                ax2.set_ylabel('Time (s) log scale')
            ax.set_title(f"{name} (batch size: {b})")
    
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.04),ncol=5, fontsize=34)
       
    plt.tight_layout()
    # write to results directory
    current_datetime = datetime.now()
    formatted_datetime = current_datetime.strftime('%Y-%m-%d %H:%M:%S')
    plt.savefig(f"results/{formatted_datetime}_results.png", bbox_inches='tight')



def plot_results_bars(df, ax1, set_labels=True):
    metric_names = ["ARI", "NMI"]
    time_metric = "train_time"
    sorted(df['estimator'].unique())
    ax2 = ax1.twinx()  # Create a second y-axis to plot the train_time
    n_metrics = len(metric_names) + 1  # Including train_time
    
    bar_width = 0.4
    positions = np.arange(n_metrics) * (len(df['estimator'].unique()) * bar_width + 0.5)

    df_comb = df
    #ax2.set_ylabel('Time (s) log scale')
    #set ax2 to log scale
    ax2.set_yscale('log')
    #ax1.set_ylabel('Score')

    
    for i, metric in enumerate(metric_names + [time_metric]):
        metric_mean = metric + "_mean"
        metric_std = metric + "_std"
        for j, name in enumerate(sorted(df['estimator'].unique())):
            position = positions[i] + j * bar_width-0.5
            ax = ax1
            if metric == time_metric:
                ax = ax2
            alg_name = name[2:]
            ax.bar(position, df_comb[df_comb['estimator'] == name][metric_mean].iloc[0], bar_width,
                    color=colors[j], label=(alg_name) if i == 0 and set_labels else "", yerr=df_comb[df_comb['estimator'] == name][metric_std].iloc[0],
                    capsize=5, hatch=hatches[j], edgecolor='black', linewidth=1)
            '''
            ax.bar(position,df["kernel_construction_time"], bottom=df_comb[df_comb['estimator'] == name][metric_mean].iloc[0], width=bar_width,
                    color="black", label=(alg_name) if i == 0 and set_labels else "",
                    capsize=5, hatch=hatches[j], edgecolor='black', linewidth=1)
            '''
            
    ax1.set_xticks(positions + bar_width / 2)
    ax1.set_xticklabels(metric_names + ["runtime"])
    return ax1,ax2
    
#%%
# MARK: -Main
if __name__ == "__main__":

    # read arguments using argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("mode", type=str, help="Mode of operation: 'run' or 'plot'", choices=["run","plot"])
    
    args = parser.parse_args()

    mode = args.mode
    if mode == "run":

        skip_full = False
        dataset_names = [
             "pendigits",
             "har",
             "mnist_784",
            "letter"
            ]
        #dataset_names = ["pendigits"] 
        for dataset_name in dataset_names:
            n = None #10000
            X, Y, kernel_sigma = load_and_preprocess_data(dataset_name, n_sample = n) #get_dataset(dataset_name)
            if n is None:
                n = X.shape[0]
            
            print(dataset_name)
            print("dataset size", X.shape, "kernel sigma", kernel_sigma)
            #kernel_sigma = wang_heuristic(X, sample=None)
            #print("wang gamma", kernel_sigma)

            n_clusters = np.unique(Y).shape[0]
            n_runs = 10
            # Define parameter ranges
            num_iters_values = [200]
            n_clusters_values = [n_clusters]
            batch_size_values = [1024, 256, 64, 16]
            
            evaluations = []
            current_datetime = datetime.now()

            # Format the date and time
            formatted_datetime = current_datetime.strftime('%Y-%m-%d %H:%M:%S')
            print(formatted_datetime)
            print(f"dataset: {dataset_name}")
            evaluations = []
            gamma =1/(2*kernel_sigma**2)


            # time how long it takes to compute the kernel matrix completely
            t0 = time.time()
            kernel_matrix = pairwise_kernels(X,metric="rbf",gamma=gamma)#rbf_kernel(X, gamma=gamma)
            kernel_construction_time = time.time() - t0

            K = kernel_matrix
            for num_iters in num_iters_values:
                for n_clusters in n_clusters_values:
                    for batch_size in batch_size_values:
                        print("#"*20)
                        print(f"batch size should be {batch_size}")
                        mbkk_newlr = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=True)
                        mbkk_oldlr = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=False)
                        mbkk_newlr_lazy = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=True,lazy=True)
                        mbkk_oldlr_lazy = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=False, lazy=True)

                        kk = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=True, full_batch=True)
                        mbk_newlr = MiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, dynamic_batching=False, new_lr=True)
                        mbk_oldlr = MiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, dynamic_batching=False, new_lr=False)

                        mbks = {
                                "4.$\\beta$-MiniBatch": mbk_newlr,
                                 "5.MiniBatch": mbk_oldlr,
                                #"6. Lazy $\\beta$-MiniBatch KKm": mbkk_newlr_lazy,
                                # "7. Lazy MiniBatch Kkm": mbkk_oldlr_lazy,
                                "2.$\\beta$-MiniBatch Kernel": mbkk_newlr,
                                 "3.MiniBatch Kernel": mbkk_oldlr,
                                #"1.Kernel": kk
                                }
                        evaluations+=evaluate(mbks,X,K, Y, num_iters, n_clusters, batch_size, n_runs, kernel_construction_time)
            if not skip_full:      
                kk = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=num_iters, new_lr=True, full_batch=True)
                kk_eval=evaluate({"1.Kernel": kk},X,K, Y, num_iters, n_clusters, None, n_runs, kernel_construction_time)
                for b in batch_size_values:
                    kk_eval1 = deepcopy(kk_eval)
                    kk_eval1[0]["batch_size"] = b
                    evaluations+=kk_eval1
            # Convert evaluations to DataFrame
            df = pd.DataFrame(evaluations)
            metric_names = ["Homogeneity", "Completeness", "V-measure", "ARI", "Silhouette Coefficient", "NMI"]
            param_vals = {'num_iters': num_iters_values, 'n_clusters': n_clusters_values, 'batch_size': batch_size_values, 'n_runs': n_runs, 'n': n}
            

            if not os.path.exists("results"):
                os.makedirs("results")

            #  clear the directory
            #for filename in os.listdir("results"):
            #    if filename.endswith(".csv"):
            #        os.remove(f"results/{filename}")

            #add dataset name to df
            df['dataset'] = dataset_name
            #add kernel construction time to df
            df['kernel_construction_time'] = kernel_construction_time
            df.to_csv(f"results/{dataset_name}_{formatted_datetime}_{str(param_vals)}results.csv", index=False)

    elif mode == "plot":
        #plot_results_bars(df)
        #directory = "results/final/"
        directory = "results/"
        to_plot = []
        for filename in os.listdir(directory):
            if filename.endswith(".csv"):
                df = pd.read_csv(directory+filename)
                to_plot.append(df)
        
        plot_results(to_plot)

# %%
