import sys
import os

# Get the directory of the current script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get the parent directory
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
# get "new" directory  
new_dir = os.path.join(parent_dir, "new")

# Add the parent directory to sys.path
sys.path.insert(0, new_dir)
from algorithms.algorithms import labels_from_distances_to_implied_coreset_centers

import warnings
warnings.filterwarnings("ignore", message="Graph is not fully connected, spectral embedding may not work as expected.")


import numpy as np
import fast_kernel_coreset_sampling as fcp
from scipy import sparse

from sklearn.datasets import make_blobs, make_moons
from sklearn.neighbors import kneighbors_graph
from sklearn.cluster import SpectralClustering

import hdf5storage
from sklearn.metrics import adjusted_rand_score
from sklearn.preprocessing import StandardScaler
from sklearn import datasets
from sklearn.datasets import fetch_openml
import faiss
from scipy.sparse import csc_matrix
import scipy.io
import scipy.sparse
import time
import datetime
import os
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from tqdm import tqdm



def load_and_preprocess_data(dataset_name,downsampled_data_set_size,dataset_size=None,dataset_clusters=None, dataset_dimension=None, noise=None, factor=0.5, center_box=(-10.0, 10.0), **kwargs):
    X,y = None, None
    if dataset_name == "blobs":
        n_samples = min(downsampled_data_set_size,dataset_size) if dataset_size is not None else downsampled_data_set_size
        X,y = datasets.make_blobs(n_samples=n_samples, n_features=dataset_dimension, centers=dataset_clusters, center_box=center_box, shuffle=True)
    elif dataset_name == "moons":
        n_samples = min(downsampled_data_set_size,dataset_size) if dataset_size is not None else downsampled_data_set_size
        X,y = datasets.make_moons(n_samples=n_samples, noise=noise)
    elif dataset_name == "circles":
        n_samples = min(downsampled_data_set_size,dataset_size) if dataset_size is not None else downsampled_data_set_size
        X,y = datasets.make_circles(n_samples=n_samples, noise=noise, factor=factor)
    elif dataset_name == "fashion_mnist":
        openml_id = 40996
        X,y = fetch_openml(data_id=openml_id, as_frame=False, return_X_y=True,data_home="data/",cache=True,parser="liac-arff")
        num_points_to_sample = min(downsampled_data_set_size,X.shape[0])
        random_state = np.random.RandomState(0)
        indices = random_state.choice(X.shape[0], int(num_points_to_sample), replace=False)
        X = X[indices]
        y = y[indices]
    else:
        X,y = fetch_openml(name=dataset_name, version=1, as_frame=False, return_X_y=True,data_home="data/",cache=True,parser="liac-arff")
        num_points_to_sample = min(downsampled_data_set_size,X.shape[0])
        random_state = np.random.RandomState(0)
        indices = random_state.choice(X.shape[0], int(num_points_to_sample), replace=False)
        X = X[indices]
        y = y[indices]

    k = len(np.unique(y))

    return StandardScaler().fit_transform(X),y, k

def construct_knn_graph_naive(X, k, weighted=False, gamma=None):
    """
    Construct a k-nearest neighbour graph using sklearn's kneighbors_graph
    """
    n = X.shape[0]
    adj_matrix = kneighbors_graph(X, k, mode='connectivity', include_self=True).tocsc()
    adj_matrix = adj_matrix + adj_matrix.T
    adj_matrix.data = np.ones(adj_matrix.nnz)
    if weighted:
        assert gamma is not None
        adj_matrix.data = np.exp(-gamma * adj_matrix.data)
    return adj_matrix

def construct_knn_graph(X, k, nlist=1024, nprobe=20, weighted=False, gamma=None):
    """
    Construct an approximate k-nearest neighbour graph using 
    """

    data = X.astype(np.float32)
    d = data.shape[1]
    n = data.shape[0]

    nlist = nlist
    m = 8
    k = k

    try:
        res = faiss.StandardGpuResources() if faiss.get_num_gpus() > 0 else faiss.StandardResources()

        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
        # index = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, nlist, faiss.METRIC_L2)
        # index = faiss.IndexHNSWFlat(d,32)
        gpu_index = faiss.index_cpu_to_gpu(res, 0, index) if faiss.get_num_gpus() > 0 else index
    except:
        print("No GPU available")
        # cpu only
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
        # index = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, nlist, faiss.METRIC_L2)
        gpu_index = index

    gpu_index.train(data)
    gpu_index.add(data)

    gpu_index.nprobe = nprobe

    D, I = gpu_index.search(data, k)
    row_indices, col_indices, data_values = [], [], []

    if weighted:
        assert gamma is not None
        for i in range(n):
            row_indices.append(i)
            col_indices.append(i)
            data_values.append(1)
            for j in range(k):
                if I[i,j] <0:
                    print("negative index", i, j, I[i,j])
                    break
                val = np.exp(-gamma * D[i,j] ** 2)
                row_indices.append(i)
                col_indices.append(I[i,j])
                data_values.append(val)

                row_indices.append(I[i,j])
                col_indices.append(i)
                data_values.append(val)

                
    else:
        for i in range(n):
            row_indices.append(i)
            col_indices.append(i)
            data_values.append(1)
            for j in range(k):
                if I[i,j] <0:
                    print("negative index", i, j, I[i,j])
                    break
                row_indices.append(i)
                col_indices.append(I[i,j])
                data_values.append(1)

                row_indices.append(I[i,j])
                col_indices.append(i)
                data_values.append(1)
    
    adj_matrix_csc = csc_matrix((data_values, (row_indices, col_indices)), shape=(n, n))

    return adj_matrix_csc

def coreset_and_save(c,K,W,coreset_size=1000, save=False, new=True):
    data, indices, indptr = K.data, K.indices, K.indptr
    n = K.shape[0]
    data = data.astype(np.float32)
    W = W.astype(np.float32)
    indices = indices.astype(np.uint64)
    indptr = indptr.astype(np.uint64)
    nnz_per_col = np.diff(indptr).astype(np.uint64)


    # save the sparse matrix to a file
    if save:
        np.savez('sparse_matrix.npz', data=data, indices=indices, indptr=indptr, nnz_per_col=nnz_per_col, W=W, shape=np.array([n,n],dtype=np.uint64), num_clusters=np.array([c],dtype=np.uint64))
    
    t0 = time.time()
    if new:    
        result = fcp.coreset(c,n, coreset_size, data, indices, indptr, nnz_per_col, W)
    else:
        result = fcp.old_coreset(c,n, coreset_size, data, indices, indptr, nnz_per_col, W)
    t1 = time.time()
    elapsed = datetime.timedelta(milliseconds=(t1-t0)*1000)
    return result, elapsed


def census_data(c=100,save=True):
    k = 200

    # load dataset from US census data (1990)
    path = os.path.join("..", "coreset_paper", "data.txt")
    X = np.loadtxt(path, delimiter=',', skiprows=1)
    
    adj_matrix = construct_knn_graph(X, k, nlist=1000, nprobe=20, weighted=False, gamma=None)

    if save:
        np.savez('sparse_matrix.npz', data=adj_matrix.data, indices=adj_matrix.indices, indptr=adj_matrix.indptr)

    # adj_matrix = construct_knn_graph_naive(X, k, weighted=False, gamma=None)

    # write the adjacency matrix to a file as 3 arrays:

    indices, weights = coreset_and_save(c,adj_matrix, np.ones(adj_matrix.nnz), coreset_size=C,save=False)

def blobs_test_num_clusters(d=100, cluster_start=2, cluster_end=1000, cluster_steps=20, coreset_size=10000, n=1_000_000, k=250, rounds=5):
    """
    Test the time taken to compute the coreset while varing the number of clusters
    """

    # Generate a sparse ANN graph from the blobs dataset


    xs = np.linspace(cluster_start, cluster_end, cluster_steps,dtype=int)
    ys = np.zeros((len(xs),rounds))

    X, _ = make_blobs(n_samples=n, n_features=d, centers=cluster_end, random_state=42)
    adj_matrix = construct_knn_graph(X, k, nlist=100, nprobe=20, weighted=True, gamma=0.25).tocsc()
    for i in tqdm(range(xs.shape[0])):
        c = xs[i]
        for r in range(rounds):
            result, elapsed = coreset_and_save(c, adj_matrix, np.ones(adj_matrix.shape[0]), coreset_size=coreset_size, save=False)
            ys[i,r] = elapsed.total_seconds()

    # plot means and stds (shaded region)
    means = np.mean(ys, axis=1)
    stds = np.std(ys, axis=1)

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=xs, y=means, mode='lines', name='Mean'))

    fig.add_trace(go.Scatter(x=xs, y=means+stds, mode='lines', fillcolor='rgba(0,100,80,0.2)'))
    fig.add_trace(go.Scatter(x=xs, y=means-stds, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)'))

    fig.update_layout(title=f"N: {n}, d: {d}, k: {k}, coreset_size: {coreset_size}, ",
                    xaxis_title='Number of clusters',
                    yaxis_title='Time (s)')



    fig.write_html('results/blobs_test_num_clusters.html')


def blobs_test_coreset_size(d=100, cluster=100, coreset_start=1000, coreset_end=10000, coreset_steps=10, n=1_000_000, k=250, rounds=5):
    """
    Test the time taken to compute the coreset while varing the coreset size
    """

    # Generate a sparse ANN graph from the blobs dataset

    xs = np.linspace(coreset_start, coreset_end, coreset_steps,dtype=int)
    times = np.zeros((len(xs),rounds))
    aris = np.zeros((len(xs),rounds)) 
    # e-mnist:
    X, y, cluster = load_and_preprocess_data("fashion_mnist", n)
    d = X.shape[1]
    n = X.shape[0]

    print(f"{X.shape[0]} samples, {X.shape[1]} features, {len(np.unique(y))} clusters")
    adj_matrix = construct_knn_graph(X, k, nlist=128, nprobe=20, weighted=False)
    W = adj_matrix.sum(axis=0).A1
    K = adj_matrix.multiply((1/W).reshape(-1,1)).multiply((1/W).reshape(1,-1)).tocsc()
    # add 1/W to the diagonal of K
    K = K -0.001*sparse.diags(1/W, 0, format='csc')
    data, indices, indptr = K.data, K.indices, K.indptr

    W = W.astype(np.float32)
    data = data.astype(np.float32)
    indices = indices.astype(np.uint64)
    indptr = indptr.astype(np.uint64)
    nnz_per_col = np.diff(indptr).astype(np.uint64)

    pbar = tqdm(range(xs.shape[0]), ncols=160)
    for i in pbar:
        coreset_size = xs[i]
        for r in range(rounds):
            t0 = time.time()
            (coreset,coreset_weights) = fcp.improved_coreset(cluster,n, coreset_size, data, indices, indptr, nnz_per_col, W)
            coreset_submatrix = K[coreset, :][:, coreset]
            # reweight the coreset submatrix by the coreset weights (each row and column is multiplied by the corresponding weight)
            coreset_graph = coreset_submatrix.multiply(coreset_weights.reshape(-1,1)).multiply(coreset_weights.reshape(1,-1)).tocsc()
            coreset_labels = SpectralClustering(n_clusters=cluster, affinity='precomputed').fit_predict(coreset_graph)
            
            full_labels = labels_from_distances_to_implied_coreset_centers(K,coreset, coreset_labels,coreset_weights,cluster)
            t1 = time.time()
            times[i,r] = (t1-t0)
            
            ari = adjusted_rand_score(y, full_labels)
            aris[i,r] = ari
        formatted_aris = "/".join([f"{ari:.2f}" for ari in aris[i,:]])
        formatted_times = "/".join([f"{time:.2f}" for time in times[i,:]])
        pbar.set_description(f"ARIs: {formatted_aris} avg({np.mean(aris[i,:]):.2f}), Times: {formatted_times}  avg({np.mean(times[i,:]):.1f})")

    # get the time for spectral clustering on the full data:
    t0 = time.time()
    full_labels = SpectralClustering(n_clusters=cluster, affinity='precomputed').fit_predict(adj_matrix.tocsr())
    t1 = time.time()
    full_time = t1-t0
    full_ari = adjusted_rand_score(y, full_labels)
    print(f"Full data time: {full_time:.2f}, ARI: {full_ari:.2f}")

    # plot means and stds (shaded region) times and aris on the same plot using two y-axes
    mean_times = np.mean(times, axis=1)
    std_times = np.std(times, axis=1)
    mean_aris = np.mean(aris, axis=1)
    std_aris = np.std(aris, axis=1)

    fig = make_subplots(specs=[[{"secondary_y": True}]])
    # add the full data time and ari as horizontal lines. Ari dashed, time solid
    # fig.add_trace(go.Scatter(x=xs, y=mean_times, mode='lines', name='Mean Time', yaxis='y1'))
    # fig.add_trace(go.Scatter(x=xs, y=mean_aris, mode='lines', name='Mean ARI', yaxis='y2'))

    fig.add_trace(go.Scatter(x=xs, y=mean_times, mode='lines', name='Mean Time', line=dict(color='blue', dash="solid")), secondary_y=False)
    fig.add_trace(go.Scatter(x=xs, y=mean_aris, mode='lines', name='Mean ARI', line=dict(color='red', dash="dash")), secondary_y=True)


    # don't show the legend or line colour for the shaded regions
    fig.add_trace(go.Scatter(x=xs, y=mean_times+std_times, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')), secondary_y=False)
    fig.add_trace(go.Scatter(x=xs, y=mean_times-std_times, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')), secondary_y=False)

    fig.add_trace(go.Scatter(x=xs, y=mean_aris+std_aris, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')), secondary_y=True)
    fig.add_trace(go.Scatter(x=xs, y=mean_aris-std_aris, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')), secondary_y=True)

    # add the full data time and ari as horizontal lines. Ari dashed, time solid, both black
    fig.add_shape(type="line", x0=xs[0], y0=full_time, x1=xs[-1], y1=full_time, line=dict(color="black", width=1, dash="solid"), yref="y1")
    fig.add_shape(type="line", x0=xs[0], y0=full_ari, x1=xs[-1], y1=full_ari, line=dict(color="black", width=1, dash="dash"), yref="y2")

    # make the time y-axis log scale
    fig.update_yaxes(type="log", secondary_y=False)


    fig.update_layout(title=f"N: {n}, d: {d}, k: {k}, cluster: {cluster}, ",
                    xaxis_title='Coreset Size',
                    yaxis_title='Time (s)',
                    yaxis2_title='ARI',
                    yaxis2=dict(overlaying='y', side='right'))
    


    fig.write_html('results/blobs_test_coreset_size.html')


def test_old_and_new(d=100, coreset_size=1000, cluster_start=10, cluster_end=1000, cluster_steps=10, n=1_000_000, k=250, rounds=5):
    """
    Test the time taken to compute the coreset while varing the number of clusters
    """

    # Generate a sparse ANN graph from the blobs dataset

    xs = np.linspace(cluster_start, cluster_end, cluster_steps,dtype=int)
    times_old = np.zeros((len(xs),rounds))
    times_new = np.zeros((len(xs),rounds))

    X, y = make_blobs(n_samples=n, n_features=d, centers=cluster_end, random_state=42)
    adj_matrix = construct_knn_graph(X, k, nlist=100, nprobe=20, weighted=False)
    W = adj_matrix.sum(axis=0).A1
    K = adj_matrix.multiply((1/W).reshape(-1,1)).multiply((1/W).reshape(1,-1)).tocsc()
    print("constructed Kernel matrix")
    pbar = tqdm(range(xs.shape[0]))
    for i in pbar:
        cluster = xs[i]
        for r in range(rounds):

            _, new_elapsed = coreset_and_save(cluster, K, W, coreset_size=coreset_size, save=False, new=True)
            _, old_elapsed = coreset_and_save(cluster, K, W, coreset_size=coreset_size, save=False, new=False)
            times_old[i,r] = old_elapsed.total_seconds()
            times_new[i,r] = new_elapsed.total_seconds()

        pbar.set_description(f"Old: {np.mean(times_old[i,:]):.2f}, New: {np.mean(times_new[i,:]):.2f}")

    # plot means and stds (shaded region) times and aris on the same plot. Don't show the legend or line for the shaded regions. Plot the ratio of new to old times on a separate plot
    old_means = np.mean(times_old, axis=1)
    old_stds = np.std(times_old, axis=1)
    new_means = np.mean(times_new, axis=1)
    new_stds = np.std(times_new, axis=1)

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=xs, y=old_means, mode='lines', name='Old'))

    fig.add_trace(go.Scatter(x=xs, y=old_means+old_stds, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=old_means-old_stds, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))

    fig.add_trace(go.Scatter(x=xs, y=new_means, mode='lines', name='New'))
    
    fig.add_trace(go.Scatter(x=xs, y=new_means+new_stds, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=new_means-new_stds, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))


    fig.update_layout(title=f"N: {n}, d: {d}, k: {k}, coreset_size: {coreset_size}, ",
                    xaxis_title='Number of clusters',
                    yaxis_title='Time (s)')

    fig.write_html('results/old_vs_new.html')

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=xs, y=old_means/new_means, mode='lines', name='New/Old'))

    fig.add_trace(go.Scatter(x=xs, y=(old_means+old_stds)/(new_means-new_stds), mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=(old_means-old_stds)/(new_means+new_stds), mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))

    fig.update_layout(title=f"N: {n}, d: {d}, k: {k}, coreset_size: {coreset_size}, ",
                    xaxis_title='Number of clusters',
                    yaxis_title='Ratio of New to Old Time')

    fig.write_html('results/old_vs_new_ratio.html')


def test_friendster_clusters(cluster_start,cluster_end, cluster_steps, coreset_size,rounds=10):
    print("loading friendster graph")
    t0 = time.time()
    mat = hdf5storage.loadmat("friendster_A.mat")
    data, indices, indptr, W, W_inv = mat["data"][:], mat["indices"][:], mat["indptr"][:], mat["W"][:], mat["W_inv"][:]

    W = np.array(W,dtype=np.float32).flatten()
    n = W.shape[0]
    W_inv = np.array(W_inv,dtype=np.float32).flatten()

    data = data.astype(np.float32)
    W = W.astype(np.float32)
    indices = indices.astype(np.uint64)
    indptr = indptr.astype(np.uint64)
    nnz_per_col = np.diff(indptr).astype(np.uint64)


    elapsed = time.time() - t0
    print(f"loaded friendster graph in {elapsed:.2f} seconds")

    xs = np.linspace(cluster_start, cluster_end, cluster_steps,dtype=int)
    times_old = np.zeros((len(xs),rounds))
    times_improved = np.zeros((len(xs),rounds))


    pbar = tqdm(range(xs.shape[0]), ncols=160)
    for i in pbar:
        cluster = xs[i]
        for r in range(rounds):
            t0 = time.time()
            _ = fcp.old_coreset(cluster,n, coreset_size, data, indices, indptr, nnz_per_col, W, True)
            t1 = time.time()
            elapsed = datetime.timedelta(milliseconds=(t1-t0)*1000)
            times_old[i,r] = elapsed.total_seconds()

            t0 = time.time()
            _ = fcp.improved_coreset(cluster,n, coreset_size, data, indices, indptr, nnz_per_col, W, True)
            t1 = time.time()
            elapsed = datetime.timedelta(milliseconds=(t1-t0)*1000)
            times_improved[i,r] = elapsed.total_seconds()

            pbar.set_description(f"round {r+1} of {rounds}, Old: {times_old[i,r]:.2f}, Improved: {elapsed.total_seconds():.2f}\t\t")
        
        pbar.set_description(f"Old avg: {np.mean(times_old[i,:]):.2f}, Improved avg: {np.mean(times_improved[i,:]):.2f}\t\t\t")

    # write the data to /results/friendster_data.npz
    np.savez('results/friendster_data_small.npz', times_old=times_old, times_improved=times_improved, xs=xs)


    # plot means and stds (shaded region) times and aris on the same plot. Don't show the legend or line for the shaded regions. Plot the ratio of new to old times on a separate plot
    old_means = np.mean(times_old, axis=1)
    old_stds = np.std(times_old, axis=1)
    improved_means = np.mean(times_improved, axis=1)
    improved_stds = np.std(times_improved, axis=1)

    # plot means and stds (shaded region) times and aris on the same plot. Don't show the legend or line for the shaded regions. 
    # Plot the ratio of new to old times and improved to old times on a separate plot

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=xs, y=old_means, mode='lines', name='Old'))

    fig.add_trace(go.Scatter(x=xs, y=old_means+old_stds, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=old_means-old_stds, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))


    fig.add_trace(go.Scatter(x=xs, y=improved_means, mode='lines', name='Improved'))

    fig.add_trace(go.Scatter(x=xs, y=improved_means+improved_stds, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=improved_means-improved_stds, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))


    fig.update_layout(title=f"Friendster graph, coreset_size: {coreset_size}, ",
                    xaxis_title='Number of clusters',
                    yaxis_title='Time (s)')

    fig.write_html('results/friendster_old_vs_new_small.html')


    fig = go.Figure()

    fig.add_trace(go.Scatter(x=xs, y=old_means/improved_means, mode='lines', name='Improved/Old'))

    fig.add_trace(go.Scatter(x=xs, y=(old_means+old_stds)/(improved_means-improved_stds), mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=(old_means-old_stds)/(improved_means+improved_stds), mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))

    fig.update_layout(title=f"Friendster graph, coreset_size: {coreset_size}, ",
                    xaxis_title='Number of clusters',
                    yaxis_title='Time ratio')

    fig.write_html('results/friendster_old_vs_new_ratio_small.html')

def save_friendster_graph_as_npz():
    mat = hdf5storage.loadmat("friendster_A.mat")
    data, indices, indptr, W, W_inv = mat["data"][:], mat["indices"][:], mat["indptr"][:], mat["W"][:], mat["W_inv"][:]
    

    W = np.array(W,dtype=np.float32).flatten()
    data = data.astype(np.float32)
    indices = indices.astype(np.uint64)
    indptr = indptr.astype(np.uint64)
    nnz_per_col = np.diff(indptr).astype(np.uint64)

    num_clusters = np.array([5000],dtype=np.uint64)
    shape = np.array([W.shape[0],W.shape[0]],dtype=np.uint64)

    # save data, indices, indptr, nnz_per_col, W, shape, num_clusters
    np.savez('friendster_graph.npz', data=data, indices=indices, indptr=indptr, nnz_per_col=nnz_per_col, W=W, shape=shape, num_clusters=num_clusters)


def test_friendster_clusters_just_improved(cluster_start,cluster_end, cluster_steps, coreset_size,rounds=10):
    print("loading friendster graph")
    t0 = time.time()
    mat = hdf5storage.loadmat("friendster_A.mat")
    data, indices, indptr, W, W_inv = mat["data"][:], mat["indices"][:], mat["indptr"][:], mat["W"][:], mat["W_inv"][:]

    W = np.array(W,dtype=np.float32).flatten()
    n = W.shape[0]
    W_inv = np.array(W_inv,dtype=np.float32).flatten()

    data = data.astype(np.float32)
    W = W.astype(np.float32)
    indices = indices.astype(np.uint64)
    indptr = indptr.astype(np.uint64)
    nnz_per_col = np.diff(indptr).astype(np.uint64)


    elapsed = time.time() - t0
    print(f"loaded friendster graph in {elapsed:.2f} seconds")

    xs = np.linspace(cluster_start, cluster_end, cluster_steps,dtype=int)
    times_improved = np.zeros((len(xs),rounds))
    pbar = tqdm(range(xs.shape[0]), ncols=160)
    for i in pbar:
        cluster = xs[i]
        for r in range(rounds):
            t0 = time.time()
            _ = fcp.improved_coreset(cluster,n, coreset_size, data, indices, indptr, nnz_per_col, W, True)
            t1 = time.time()
            elapsed = datetime.timedelta(milliseconds=(t1-t0)*1000)
            times_improved[i,r] = elapsed.total_seconds()


            pbar.set_description(f"round {r+1} of {rounds}, time: \t{times_improved[i,r]:.2f}")

        pbar.set_description(f"avg: \t{np.mean(times_improved[i,:]):.2f}")

    # write the data to /results/friendster_data.npz
    np.savez('results/friendster_data_improved.npz', times_improved=times_improved, xs=xs)


    # plot means and stds (shaded region) times and aris on the same plot. Don't show the legend or line for the shaded regions. Plot the ratio of new to old times on a separate plot
    improved_means = np.mean(times_improved, axis=1)
    improved_stds = np.std(times_improved, axis=1)


    fig = go.Figure()
    fig.add_trace(go.Scatter(x=xs, y=improved_means, mode='lines', name='improved'))

    fig.add_trace(go.Scatter(x=xs, y=improved_means+improved_stds, mode='lines', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))
    fig.add_trace(go.Scatter(x=xs, y=improved_means-improved_stds, mode='lines', fill='tonexty', fillcolor='rgba(0,100,80,0.2)', showlegend=False, line=dict(color='rgba(0,100,80,0)')))


    fig.update_layout(title=f"Friendster graph, coreset_size: {coreset_size}, ",
                    xaxis_title='Number of clusters',
                    yaxis_title='Time (s)')

    fig.write_html('results/friendster_improved.html')


if __name__ == '__main__':
    # blobs_test_num_clusters(d=5,cluster_start=2, cluster_end=1000, cluster_steps=200, coreset_size=2000, n=5000, k=100, rounds=20)
    
    
    blobs_test_coreset_size(d=50, cluster=10, coreset_start=100, coreset_end=3000, coreset_steps=20, n=100_000, k=100, rounds=5)
    
    
    # test_old_and_new(d=16,cluster_start=2, cluster_end=2000, cluster_steps=20, coreset_size=40_000, n=1_000_000, k=200, rounds=10)
    # test_friendster_clusters_just_improved(10,5000,30,1_000_000,3)
    # test_friendster_clusters(2,50,48,1_000_000,20)
    # save_friendster_graph_as_npz()
