import os
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.spatial.kdtree import distance_matrix

import paramrepulsor.utils.data as data_utils
from sklearn.svm import SVC, LinearSVC
from sklearn.model_selection import StratifiedKFold
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.preprocessing import scale
from sklearn.kernel_approximation import Nystroem
from numpy.random import default_rng

    
def knn_eval(X, y, n_neighbors=5, n_splits=10):
    '''
    This is a function that is used to evaluate the lower dimension embedding.
    An accuracy is calculated by an k-nearest neighbor classifier.
    Input:
        X: A numpy array with the shape [N, k]. The lower dimension embedding
           of some dataset. Expected to have some clusters.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset.
        n_neighbors: Number of neighbors considered by the classifier.
        n_splits: Number of splits used in the cross validation.
    Output:
        acc: The avg accuracy generated by the clf, using cross val.
    '''
    skf = StratifiedKFold(n_splits=n_splits)
    sum_acc = 0
    max_acc = n_splits
    for train_index, test_index in skf.split(X, y):
        clf = KNeighborsClassifier(n_neighbors=n_neighbors)
        clf.fit(X[train_index], y[train_index])
        acc = clf.score(X[test_index], y[test_index])
        sum_acc += acc
    avg_acc = sum_acc/max_acc
    return avg_acc


def knn_eval_large(X, y, n_neighbors=5, n_splits=10, seed=0, sample_size=10000):
    '''
    This is a function that is used to evaluate the lower dimension embedding.
    An accuracy is calculated by a k-nearest neighbor classifier, defined over
    a small **sample** of data.
    Input:
        X: A numpy array with the shape [N, k]. The lower dimension embedding
           of some dataset. Expected to have some clusters.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset.
        n_neighbors: Number of neighbors considered by the classifier.
        n_splits: Number of splits used in the cross validation.
    Output:
        acc: The avg accuracy generated by the clf, using cross val.
    '''
    rng = np.random.default_rng(seed=seed)
    skf = StratifiedKFold(n_splits=n_splits)
    correct_cnt = 0 # Counter for intersection
    cnt = 0
    for train_index, test_index in skf.split(X, y):
        X_train, y_train = X[train_index], y[train_index]
        X_test, y_test = X[test_index], y[test_index]
        sample_size = min(X_test.shape[0], sample_size) # prevent overflow
        indices = rng.choice(np.arange(X_test.shape[0]), size=sample_size, replace=False)
        for i in indices:
            index_list_neighbors = calculate_neighbors(X_train, X_test[i], n_neighbors - 1) # no self in train set
            y_neighbor = y_train[index_list_neighbors]
            # find the predicted value
            # if there's a tie, pick one of the pred randomly
            y_cnt = np.bincount(y_neighbor)
            mode = np.amax(y_cnt)
            y_pred = np.arange(y_cnt.shape[0])[y_cnt == mode]
            y_pred = rng.choice(y_pred)
            # compare the predicted with the label
            if y_pred == y_test[i]:
                correct_cnt += 1
            cnt += 1

    return correct_cnt / cnt


def knn_eval_series(X, y, n_neighbors_list=[1, 3, 5, 10, 15, 20, 25, 30]):
    '''
    This is a function that is used to evaluate the lower dimension embedding.
    An accuracy is calculated by an k-nearest neighbor classifier.
    A series of accuracy will be calculated for the given `n_neighbors`.
    Input:
        X: A numpy array with the shape [N, k]. The lower dimension embedding
           of some dataset. Expected to have some clusters.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset.
        n_neighbors_list: A list of int.
        kwargs: Any keyword argument that is send into the knn clf.
    Output:
        accs: The avg accuracy generated by the clf, using leave one out cross val.
    '''
    avg_accs = []
    for n_neighbors in n_neighbors_list:
        avg_acc = knn_eval(X, y, n_neighbors)
        avg_accs.append(avg_acc)
    return avg_accs


def svm_eval(X, y, n_splits=10, **kwargs):
    '''
    This is a function that is used to evaluate the lower dimension embedding.
    An accuracy is calculated by an SVM with rbf kernel.
    Input:
        X: A numpy array with the shape [N, k]. The lower dimension embedding
           of some dataset. Expected to have some clusters.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset.
        kwargs: Any keyword argument that is send into the SVM.
    Output:
        acc: The (avg) accuracy generated by an SVM with rbf kernel.
    '''
    X = scale(X)
    skf = StratifiedKFold(n_splits=n_splits)
    sum_acc = 0
    max_acc = n_splits
    for train_index, test_index in skf.split(X, y):
        clf = SVC(**kwargs)
        clf.fit(X[train_index], y[train_index])
        acc = clf.score(X[test_index], y[test_index])
        sum_acc += acc
    avg_acc = sum_acc/max_acc
    return avg_acc


def svm_eval_large(X, y, n_splits=10, sample_size=100000, 
                   seed=20200202, **kwargs):
    '''
    This is an accelerated version of the SVM function.
    Training an SVM is infeasible over huge dataset. We therefore only sample
    a portion of data to perform the training and testing.
    '''
    X = X.astype(float)
    X = scale(X)
    # Subsampling X and y
    rng = np.random.default_rng(seed=seed)
    sample_size = min(X.shape[0], sample_size) # prevent overflow
    indices = rng.choice(np.arange(X.shape[0]), size=sample_size, replace=False)
    X, y = X[indices], y[indices]

    # Perform standard evaluation
    skf = StratifiedKFold(n_splits=n_splits)
    sum_acc = 0
    max_acc = n_splits
    for train_index, test_index in skf.split(X, y):
        gamma_rec = (X.var()*X.shape[1] + 1e-3)
        feature_map_nystroem = Nystroem(gamma=1/gamma_rec, n_components=300)
        data_transformed = feature_map_nystroem.fit_transform(X[train_index])
        clf = LinearSVC(tol=1e-5, **kwargs)
        clf.fit(data_transformed, y[train_index])
        test_transformed = feature_map_nystroem.transform(X[test_index])
        acc = clf.score(test_transformed, y[test_index])
        sum_acc += acc
    avg_acc = sum_acc/max_acc
    return avg_acc


def faster_svm_eval(X, y, n_splits=10, **kwargs):
    '''
    This is an accelerated version of the svm_eval function.
    An accuracy is calculated by an SVM with rbf kernel.
    Input:
        X: A numpy array with the shape [N, k]. The lower dimension embedding
           of some dataset. Expected to have some clusters.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset.
        kwargs: Any keyword argument that is send into the SVM.
    Output:
        acc: The (avg) accuracy generated by an SVM with rbf kernel.
    '''

    X = X.astype(float)
    X = scale(X)
    skf = StratifiedKFold(n_splits=n_splits)
    sum_acc = 0
    max_acc = n_splits
    for train_index, test_index in skf.split(X, y):
        gamma_rec = (X.var()*X.shape[1] + 1e-3)
        feature_map_nystroem = Nystroem(gamma=1/gamma_rec, n_components=300)
        data_transformed = feature_map_nystroem.fit_transform(X[train_index])
        clf = LinearSVC(tol=1e-5, **kwargs)
        clf.fit(data_transformed, y[train_index])
        test_transformed = feature_map_nystroem.transform(X[test_index])
        acc = clf.score(test_transformed, y[test_index])
        sum_acc += acc
    avg_acc = sum_acc/max_acc
    return avg_acc


def random_triplet_eval(X, X_new, num_triplets=5):
    '''
    This is a function that is used to evaluate the lower dimension embedding.
    An triplet satisfaction score is calculated by evaluating how many randomly
    selected triplets have been violated. Each point will generate 5 triplets.
    Input:
        X: A numpy array with the shape [N, p]. The higher dimension embedding
           of some dataset. Expected to have some clusters.
        X_new: A numpy array with the shape [N, k]. The lower dimension embedding
               of some dataset. Expected to have some clusters as well.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset. Used to identify clusters
    Output:
        acc: The score generated by the algorithm.
    '''    
    # Sampling Triplets
    # Five triplet per point
    anchors = np.arange(X.shape[0])
    rng = default_rng()
    triplets = rng.choice(anchors, (X.shape[0], num_triplets, 2))
    triplet_labels = np.zeros((X.shape[0], num_triplets))
    anchors = anchors.reshape((-1, 1, 1))
    
    # Calculate the distances and generate labels
    b = np.broadcast(anchors, triplets)
    distances = np.empty(b.shape)
    distances.flat = [np.linalg.norm(X[u] - X[v]) for (u,v) in b]
    labels = distances[:, :, 0] < distances[: , :, 1]
    
    # Calculate distances for LD
    b = np.broadcast(anchors, triplets)
    distances_l = np.empty(b.shape)
    distances_l.flat = [np.linalg.norm(X_new[u] - X_new[v]) for (u,v) in b]
    pred_vals = distances_l[:, :, 0] < distances_l[:, :, 1]

    # Compare the labels and return the accuracy
    correct = np.sum(pred_vals == labels)
    acc = correct/X.shape[0]/num_triplets
    return acc


def neighbor_kept_ratio_eval(X, X_new, n_neighbors=30):
    '''
    This is a function that evaluates the local structure preservation.
    A nearest neighbor set is constructed on both the high dimensional space and
    the low dimensional space.
    Input:
        X: A numpy array with the shape [N, p]. The higher dimension embedding
           of some dataset. Expected to have some clusters.
        X_new: A numpy array with the shape [N, k]. The lower dimension embedding
               of some dataset. Expected to have some clusters as well.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset. Used to identify clusters
    Output:
        acc: The score generated by the algorithm.

    '''
    nn_hd = NearestNeighbors(n_neighbors=n_neighbors + 1)
    nn_ld = NearestNeighbors(n_neighbors=n_neighbors + 1)
    nn_hd.fit(X)
    nn_ld.fit(X_new)
    # Construct a k-neighbors graph, where 1 indicates a neighbor relationship
    # and 0 means otherwise, resulting in a graph of the shape n * n
    graph_hd = nn_hd.kneighbors_graph(X).toarray()
    graph_hd -= np.eye(X.shape[0]) # Removing diagonal
    graph_ld = nn_ld.kneighbors_graph(X_new).toarray()
    graph_ld -= np.eye(X.shape[0]) # Removing diagonal
    neighbor_kept = np.sum(graph_hd * graph_ld).astype(float)

    # Use a non-graph based approach
    neighbor_kept_ratio = neighbor_kept / n_neighbors / X.shape[0]
    return neighbor_kept_ratio


def neighbor_kept_ratio_eval_new(X, X_new, n_neighbors=30):
    '''
    This is a function that evaluates the local structure preservation.
    A nearest neighbor set is constructed on both the high dimensional space and
    the low dimensional space.
    Input:
        X: A numpy array with the shape [N, p]. The higher dimension embedding
           of some dataset. Expected to have some clusters.
        X_new: A numpy array with the shape [N, k]. The lower dimension embedding
               of some dataset. Expected to have some clusters as well.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset. Used to identify clusters
    Output:
        acc: The score generated by the algorithm.

    '''
    nn_hd = NearestNeighbors(n_neighbors=n_neighbors + 1)
    nn_ld = NearestNeighbors(n_neighbors=n_neighbors + 1)
    nn_hd.fit(X)
    nn_ld.fit(X_new)

    # Use a non-graph based approach
    nn_graph_hd = nn_hd.kneighbors(X, return_distance=False)[:, 1:]
    nn_graph_ld = nn_ld.kneighbors(X_new, return_distance=False)[:, 1:]
    neighbor_kept = sum([len(np.intersect1d(nn_graph_hd[i], nn_graph_ld[i])) for i in range(X.shape[0])])
    neighbor_kept_ratio = neighbor_kept / n_neighbors / X.shape[0]
    return neighbor_kept_ratio


def neighbor_kept_ratio_eval_large(X, X_new, n_neighbors=30, sample_size=10000, seed=0):
    '''
    This is a function that evaluates the local structure preservation.
    In a large dataset, keeping a neighbor graph is infeasible as it will lead
    to OOM error. Therefore, we evaluate the neighborhood using a small portion 
    of points as samples.
    Input:
        X: A numpy array with the shape [N, p]. The higher dimension embedding
           of some dataset. Expected to have some clusters.
        X_new: A numpy array with the shape [N, k]. The lower dimension embedding
               of some dataset. Expected to have some clusters as well.
        n_neighbors: Number of neighbors considered by the algorithm
        samples: Number of samples considered by the algorithm. 
        seed: The random seed used by the random number generator.
    Output:
        acc: The score generated by the algorithm.
    '''
    rng = np.random.default_rng(seed=seed)
    sample_size = min(X.shape[0], sample_size) # prevent overflow
    indices = rng.choice(np.arange(X.shape[0], dtype=int), size=sample_size, replace=False)
    correct_cnt = 0 # Counter for intersection
    for i in indices:
        # Calculate the neighbors
        index_list_high = calculate_neighbors(X, i, n_neighbors)
        index_list_low = calculate_neighbors(X_new, i, n_neighbors)
        # Calculate the intersection
        correct_cnt += intersection(index_list_high, index_list_low)
    correct_cnt -= sample_size # Remove self
    neighbor_kept_ratio = correct_cnt / n_neighbors / sample_size
    return neighbor_kept_ratio


def calculate_neighbors(X, i, n_neighbors):
    '''A helper function that calculates the neighbor of a sample in a dataset.
    '''
    if isinstance(i, np.ndarray):
        diff_mat = X - i
    else:
        diff_mat = X - X[i] # In this case, i is an instance of sample
    diff_mat = np.linalg.norm(diff_mat, axis=1)
    diff_mat = diff_mat.reshape(-1)
    # Find the top n_neighbors + 1 entries
    index_list = np.argpartition(diff_mat, n_neighbors + 1)[:n_neighbors+1]
    return index_list


def intersection(index_list1, index_list2):
    '''A helper function that calculates the intersection between two different
    list of indices, with O(n) complexity. This could be done with set instead.'''
    index_dict = {}
    for i in range(len(index_list1)):
        index_dict[index_list1[i]] = 1
    cnt = 0
    for i in range(len(index_list2)):
        if index_list2[i] in index_dict:
            cnt += 1
    return cnt


def neighbor_kept_ratio_series_eval(X, X_news, n_neighbors=30):
    nn_hd = NearestNeighbors(n_neighbors=n_neighbors+1)
    nn_hd.fit(X)
    graph_hd = nn_hd.kneighbors_graph(X).toarray()
    graph_hd -= np.eye(X.shape[0]) # Removing diagonal
    nk_ratios = []
    for X_new in X_news:
        nn_ld = NearestNeighbors(n_neighbors=n_neighbors+1)
        nn_ld.fit(X_new)
        graph_ld = nn_ld.kneighbors_graph(X_new).toarray()
        graph_ld -= np.eye(X.shape[0]) # Removing diagonal
        neighbor_kept = np.sum(graph_hd * graph_ld).astype(float)
        neighbor_kept_ratio = neighbor_kept / n_neighbors / X.shape[0]
        nk_ratios.append(neighbor_kept_ratio)
    return nk_ratios


def neighbor_kept_ratio_series_eval_fast(X, X_news, n_neighbors=30):
    nn_hd = NearestNeighbors(n_neighbors=n_neighbors+1)
    nn_hd.fit(X)
    graph_hd = nn_hd.kneighbors(X, return_distance=False)
    graph_hd = graph_hd[:, 1:] # Remove itself
    nk_ratios = []
    for X_new in X_news:
        nn_ld = NearestNeighbors(n_neighbors=n_neighbors+1)
        nn_ld.fit(X_new)
        graph_ld = nn_ld.kneighbors(X_new, return_distance=False)
        graph_ld = graph_ld[:, 1:] # Remove itself
        neighbor_kept = 0
        for i in range(graph_hd.shape[0]):
            neighbor_kept += len(np.intersect1d(graph_hd[i], graph_ld[i]))
        neighbor_kept_ratio = neighbor_kept / n_neighbors / X.shape[0]
        nk_ratios.append(neighbor_kept_ratio)
    return nk_ratios


def eval_reduction(dataset_name, methods):
    print(f'Evaluating {dataset_name}')
    X, y = data_utils.data_prep(dataset_name)
    X = X.astype(float)
    for method in methods:
        # Check if the file exists
        print(method)
        if not os.path.exists(f'../outputs/npys/{method}_{dataset_name}.npy'):
            print("Results Not Exist")
            continue
        X_lows = np.load(
            f'../outputs/npys/{method}_{dataset_name}.npy', allow_pickle=True)
        # Unsupervised eval
        # NK Ratio
        print('Nearest Neighbor Kept')
        nk_ratios = neighbor_kept_ratio_series_eval_fast(X, X_lows)
        nk_ratios = np.array(nk_ratios)
        np.save(f'../results/{dataset_name}_{method}_nkratios.npy', nk_ratios)

        # RT Ratio
        print('Random Triplet Accuracy')
        rte_ratios = []
        for X_low in X_lows:
            rte_ratio = random_triplet_eval(X, X_low)
            rte_ratios.append(rte_ratio)
        rte_ratios = np.array(rte_ratios)
        np.save(f'../results/{dataset_name}_{method}_rteratios.npy', rte_ratios)

        # Supervised eval
        # KNN Acc
        print('KNN Accuracy')
        knn_accs = []
        for X_low in X_lows:
            knn_acc = knn_eval(X_low, y)
            knn_accs.append(knn_acc)
        np.save(f'../results/{dataset_name}_{method}_knnaccs.npy', knn_accs)

        # SVM Acc
        print('SVM Accuracy')
        svm_accs = []
        for X_low in X_lows:
            svm_acc = svm_eval(X_low, y)
            svm_accs.append(svm_acc)
        np.save(f'../results/{dataset_name}_{method}_svmaccs.npy', svm_accs)
        print('---------')
    print('Finished Successfully')


def eval_reduction_unsupervised(dataset_name, methods):
    print(f'Evaluating {dataset_name}')
    X, y = data_utils.data_prep(dataset_name)
    for method in methods:
        # Check if the file exists
        print(method)
        if not os.path.exists(f'../outputs/npys/{method}_{dataset_name}.npy'):
            print("Results Not Exist")
            continue
        X_lows = np.load(
            f'../outputs/npys/{method}_{dataset_name}.npy', allow_pickle=True)
        # Unsupervised eval
        # NK Ratio
        print('Nearest Neighbor Kept')
        nk_ratios = neighbor_kept_ratio_series_eval_fast(X, X_lows)
        nk_ratios = np.array(nk_ratios)
        np.save(f'../results/{dataset_name}_{method}_nkratios.npy', nk_ratios)

        # Random Triplet Accuracy
        print('Random Triplet Accuracy')
        rte_ratios = []
        for X_low in X_lows:
            rte_ratio = random_triplet_eval(X, X_low)
            rte_ratios.append(rte_ratio)
        rte_ratios = np.array(rte_ratios)
        np.save(f'../results/{dataset_name}_{method}_rteratios.npy', rte_ratios)

        print('---------')
    print('Finished Successfully')


def eval_reduction_onepercent(dataset_name, methods):
    print(f'Evaluating {dataset_name}')
    X, y = data_utils.data_prep(dataset_name)
    for method in methods:
        # Check if the file exists
        print(method)
        if not os.path.exists(f'../outputs/npys/{method}_{dataset_name}.npy'):
            print("Results Not Exist")
            continue
        X_lows = np.load(
            f'../outputs/npys/{method}_{dataset_name}.npy', allow_pickle=True)
        # Unsupervised eval
        # NK Ratio
        print('Nearest Neighbor Kept')
        n_neighbors = int(X.shape[0] * 0.01) # 1% of the data
        nk_ratios = neighbor_kept_ratio_series_eval_fast(X, X_lows, n_neighbors=n_neighbors)
        nk_ratios = np.array(nk_ratios)
        np.save(f'../results/{dataset_name}_{method}_nkratios_onepercent.npy', nk_ratios)

    print('Finished Successfully')


def spearman_correlation_eval(X, X_new, n_points=1000, random_seed=100):
    '''Evaluate the global structure of an embedding via spearman correlation in
    distance matrix, following https://www.nature.com/articles/s41467-019-13056-x
    '''
    # Fix the random seed to ensure reproducability
    rng = np.random.default_rng(seed=random_seed)
    dataset_size = X.shape[0]

    # Sample n_points points from the dataset randomly
    sample_index = rng.choice(np.arange(dataset_size), size=n_points, replace=False)

    # Generate the distance matrix in high dim and low dim
    dist_high = distance_matrix(X[sample_index], X[sample_index])
    dist_low = distance_matrix(X_new[sample_index], X_new[sample_index])
    dist_high = dist_high.reshape([-1])
    dist_low = dist_low.reshape([-1])

    # Calculate the correlation
    corr, pval = scipy.stats.spearmanr(dist_high, dist_low)
    return dist_high, dist_low, corr, pval


def spearman_correlation_series_eval(X, X_news, n_points=1000, random_seed=100):
    corrs = []
    pvals = []
    dist_highs = []
    dist_lows = []    
    for i in range(len(X_news)):
        X_new = X_news[i]
        dist_high, dist_low, corr, pval = spearman_correlation_eval(X, X_new, n_points, random_seed)
        corrs.append(corr)
        pvals.append(pval)
        dist_highs.append(dist_high)
        dist_lows.append(dist_low)
    corrs = np.array(corrs)
    pvals = np.array(pvals)
    dist_highs = np.array(dist_highs)
    dist_lows = np.array(dist_lows)
    return corrs, pvals, dist_highs, dist_lows


def centroid_knn_eval(X, X_new, y, k):
    '''Evaluate the global structure of an embedding via the KNC metric:
    neighborhood preservation for cluster centroids, following 
    https://www.nature.com/articles/s41467-019-13056-x
    '''
    # Calculating the cluster centers
    cluster_mean_ori, cluster_mean_new = [], []
    categories = np.unique(y)
    num_cat = len(categories)
    cluster_mean_ori = np.zeros((num_cat, X.shape[1]))
    cluster_mean_new = np.zeros((num_cat, X_new.shape[1]))
    cnt_ori = np.zeros(num_cat) # number of instances for each class

    # Only loop through the whole dataset once
    for i in range(X.shape[0]):
        ylabel = int(y[i])
        cluster_mean_ori[ylabel] += X[i]
        cluster_mean_new[ylabel] += X_new[i]
        cnt_ori[ylabel] += 1
    cluster_mean_ori = ((cluster_mean_ori.T)/cnt_ori).T
    cluster_mean_new = ((cluster_mean_new.T)/cnt_ori).T

    # Generate the nearest neighbor list in the high dimension
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(cluster_mean_ori)
    _, indices = nbrs.kneighbors(cluster_mean_ori)
    indices = indices[:,1:] # Remove the center itself

    # Now for the low dimension
    nbr_low = NearestNeighbors(n_neighbors=k+1).fit(cluster_mean_new)
    _, indices_low = nbr_low.kneighbors(cluster_mean_new)
    indices_low = indices_low[:,1:] # Remove the center itself

    # Calculate the intersection of both lists
    len_neighbor_list = k * num_cat
    both_nbrs = 0

    # for each category, check each of its indices
    for i in range(num_cat):
        for j in range(k):
            if indices[i, j] in indices_low[i, :]:
                both_nbrs += 1
    # Compare both lists and generate the accuracy
    return both_nbrs/len_neighbor_list


def centroid_knn_series_eval(X, X_news, y, k):
    accs = []
    for i in range(len(X_news)):
        X_new = X_news[i]
        acc = centroid_knn_eval(X, X_new, y, k)
        accs.append(acc)
    accs = np.array(accs)
    return accs


def centroid_corr_eval(X, X_new, y, k):
    '''Evaluate the global structure of an embedding via the KNC metric:
    neighborhood preservation for cluster centroids, following 
    https://www.nature.com/articles/s41467-019-13056-x
    '''
    # Calculating the cluster centers
    cluster_mean_ori, cluster_mean_new = [], []
    categories = np.unique(y)
    num_cat = len(categories)
    cluster_mean_ori = np.zeros((num_cat, X.shape[1]))
    cluster_mean_new = np.zeros((num_cat, X_new.shape[1]))
    cnt_ori = np.zeros(num_cat) # number of instances for each class

    # Only loop through the whole dataset once
    for i in range(X.shape[0]):
        ylabel = int(y[i])
        cluster_mean_ori[ylabel] += X[i]
        cluster_mean_new[ylabel] += X_new[i]
        cnt_ori[ylabel] += 1
    cluster_mean_ori = ((cluster_mean_ori.T)/cnt_ori).T
    cluster_mean_new = ((cluster_mean_new.T)/cnt_ori).T
    # Generate the distance matrix in high dim and low dim
    dist_high = distance_matrix(cluster_mean_ori, cluster_mean_ori)
    dist_low = distance_matrix(cluster_mean_new, cluster_mean_new)
    dist_high = dist_high.reshape([-1])
    dist_low = dist_low.reshape([-1])

    # Calculate the correlation
    corr, pval = scipy.stats.spearmanr(dist_high, dist_low)
    return dist_high, dist_low, corr, pval


