import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_mutual_info_score, roc_auc_score
import scanpy
import pynndescent
from joblib import Parallel, delayed
from sklearn.neighbors import LocalOutlierFactor
from sklearn.metrics import cluster 
import numpy as np
import operator
from matplotlib import pyplot as plt
from sklearn.ensemble import IsolationForest
from pyod.models.ecod import ECOD 

"""
Basic utilites for compressibility analysis.
"""

def get_distance_matrix(X):
    """
    Compute the distance matrix for a set of points. 
    :param X: numpy array of shape (n_samples, n_features)
    :return: numpy array of shape (n_samples, n_samples)
    """
    n = X.shape[0]
    D = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            D[i, j] = np.linalg.norm(X[i] - X[j])
    return D

def get_compression_matrix(data, pca_dim, reduce_dim=False, D_pre=None):
    """
    Return the compression matrix for a dataset. 
    :param data: numpy array of shape (n_samples, n_features). 
    :param pca_dim: dimensionality of PCA
    :param reduce_dim: whether to reduce dimensionality before computing compressibility
    :param D_pre: precomputed distance matrix of data. If None, will be computed from data.

    :return: compression matrix
    """

    pre_data = data
    # For very large datasets, reduce dimensionality to speed up computation
    if reduce_dim: 
        red_dim_n = min(1000, data.shape[0])
        pca = PCA(n_components=red_dim_n)
        pre_data = pca.fit_transform(data)
    
    # Compute pre-PCA distance matrix
    if D_pre is None: 
         D_pre = get_distance_matrix(pre_data)

    pca = PCA(n_components=pca_dim)
    post_data = pca.fit_transform(data)
    D = None

    # Compute post-PCA distance matrix
    D = get_distance_matrix(post_data)

    # Compute pairwise compression ratios
    C = D_pre / D

    return C

def get_average_compression(C, cluster_sizes, k): 
    """
    Compute the average intracluster and intercluster compression for each cluster. 
    :param C: numpy array of pairwise compression ratios (n_samples, n_samples). Must be grouped by cluster (ex. all cluster 0, then 1, 2, etc.).
    :param cluster_sizes: array of cluster sizes. Order must match grouping of C.
    :param k: number of clusters 

    :return: array of average intercluster compression ratios, array of average intracluster compression
    """

    # Create empty arrays of size k to store intra/inter cluster compression for each cluster
    avg_intracluster_compression = np.zeros(k)
    avg_intercluster_compression = np.zeros(k)

    for i in range(k):
        # Index of first point in cluster i
        before = sum(cluster_sizes[:i])

        t1 = range(before, before + cluster_sizes[i])
        t2 = range(C.shape[0])

        # Set of points (indices) not in cluster i
        t3 = set(t2).difference(set(t1))

        for j in range(before, before + cluster_sizes[i]): 
            for l in range(before, before + cluster_sizes[i]):
                if j == l: 
                    continue 
                # Sum all intracluster compression ratios for cluster i
                avg_intracluster_compression[i] += C[j, l]

            for l in list(t3):
                # Sum all intercluster compression ratios for cluster i
                avg_intercluster_compression[i] += C[j, l]

    for i in range(k):
        # Divide by number of pairs of points in cluster i
        avg_intracluster_compression[i] /= (cluster_sizes[i] * (cluster_sizes[i] - 1))

        # Divide by number of pairs of points s.t. one point is in cluster i and the other is not
        avg_intercluster_compression[i] /= (cluster_sizes[i] * (len(C) - cluster_sizes[i]))
    
    return avg_intercluster_compression, avg_intracluster_compression

def compression_grouping(C, cluster_sizes): 
    """
    Return flattened, sorted array of all pairwise compression ratios grouped by same vs. different cluster. Used to generate the compression curve. 
    :param C: numpy array of pairwise compression ratios (n_samples, n_samples)
    :param cluster_sizes: array of cluster sizes 
    """

    # Create empty arrays of size k to store intra/inter cluster compression for each cluster
    size = (len(C) * (len(C) - 1)) / 2
    res = np.zeros((int(size), 2))
    idx = 0 
    curr_cluster = 0 
    curr_cluster_start = 0
    curr_cluster_end = cluster_sizes[0] - 1

    for i in range(len(C)): 
        # Move to next cluster if i is greater than the end of the current cluster
        if i > curr_cluster_end:
            curr_cluster += 1
            curr_cluster_start = curr_cluster_end + 1
            curr_cluster_end = curr_cluster_start + cluster_sizes[curr_cluster] - 1
        
        for j in range(i + 1, len(C)):
            # Record the compression ratio and whether the points are in the same cluster
            diff_clusters = 1 if j <= curr_cluster_end else 0
            res[idx] = [C[i, j], diff_clusters]
            idx += 1
   
    return res[np.argsort(res[:, 0])]

"""
Simulation utilities.
"""

def generate_points(x, d, n, k, noise_lower, noise_upper, p=0.3, equal=True): 
    """
    
    Generate a (n, d) array of points with k clusters generated from Bernoulli noise.
    :param x: (k, d) array of cluster centers
    :param d: number of dimensions
    :param n: number of points
    :param k: number of clusters
    :param noise_lower: Lower bound of noise (per point). 
    :param noise_upper: Upper bound of noise (per point).
    :param p: probability for Bernoulli noise
    :param equal: Whether to use equal noise for all clusters or have one cluster with elevated noise

    :return: (n, d) array of points

    """

    Y = np.zeros((0, d))
    for i in range(k): 

        cluster_noise_low = noise_lower
        cluster_noise_high = noise_upper

        if not equal and i == 0: 
            cluster_noise_low = noise_lower * 2
            cluster_noise_high = noise_upper * 2

        for j in range(n[i]):    

            # Generate noise
            m = np.random.uniform(cluster_noise_low, cluster_noise_high)
            noise = np.random.choice([m, 0, -m], size=(1, d), p=[(1-p)/2., p, (1-p)/2.])

            point = x[i] + noise
            Y = np.vstack((Y, point))
    
    return Y

def generate_mixture_outliers(d, o, num_clusters, Y, labels, centers, noise_lower, noise_upper, p=0.3): 

    """
    Generate outlier points by taking a weighted average of cluster centers and adding noise. 

    :param d: number of dimensions
    :param o: number of outliers
    :param num_clusters: number of clusters
    :param Y: (n, d) array of points
    :param labels: array of cluster labels
    :param centers: (k, d) array of cluster centers
    :param noise_lower: Lower bound of noise (per point).
    :param noise_upper: Upper bound of noise (per point). 
    :param p: probability for Bernoulli noise

    """

    for i in range(o): 
        # Generate random weights for cluster centers
        alpha = np.zeros(num_clusters)
        alpha[0] = np.random.uniform(0.5, 1)
        for j in range(1, num_clusters - 1):
            alpha[j] = np.random.uniform(0, 1 - alpha[0])
        alpha[num_clusters - 1] = 1 - np.sum(alpha)

        # Shuffle weights against cluster centers
        np.random.shuffle(alpha)
        np.testing.assert_approx_equal(np.sum(alpha), 1)

        # Generate point by taking weighted average of cluster centers
        point = alpha[0] * centers[0]
        for j in range(1, num_clusters):
            point += alpha[j] * centers[j]

        # Add noise
        m = np.random.uniform(noise_lower, noise_upper)
        noise = np.random.choice([m, 0, -m], size=(1, d), p=[(1-p)/2., p, (1-p)/2.])
        point = point + noise
        
        # Assign label to point
        largest = np.max(alpha)
        candidates = []
        for j in range(len(alpha)):
            if alpha[j] == largest:
                candidates.append(j)
        outlier_label = np.random.choice(candidates)

        # Append point to Y and label to labels
        Y = np.vstack((Y, point))
        labels = np.append(labels, outlier_label) 
        
    return Y, labels

def generate_variance_outliers(o, x, c, Y, noise_lower, noise_upper, p=0.3, equal=True): 
    """
    
    Generate a (o, d) array of outlier points and append them to Y. Outliers are generated by adding elevated noise to a random cluster center. 
    :param o: number of outliers
    :param x: (k, d) array of cluster centers
    :param c: scaling factor of noise. Array of size k if equal is False, else a single value.
    :param Y: (n, d) array of points
    :param noise_lower: Lower bound of noise (per point).
    :param noise_upper: Upper bound of noise (per point).
    :param p: probability for Bernoulli noise
    :param equal: Whether to use equal noise for all clusters or have one cluster with elevated noise

    :return: (n + o, d) array of points
    
    """

    for i in range(o): 
        center = np.random.randint(0, len(x))
        m = np.random.uniform(noise_lower, noise_upper)

        # Factor by which to multiply variance 
        mult = c
        if not equal: 
            mult = c[center]

        noise = np.random.choice([m * mult, 0, -m * mult], size=(1, len(x[0])), p=[(1-p)/2., p, (1-p)/2.])
        point = x[center] + noise
        Y = np.vstack((Y, point))
    
    return Y

def detection_auroc(pca_dim, non_outliers, total_points, Y): 

    """

    Compute the AUROC for each outlier detection approach on a dataset.

    :param pca_dim: dimensionality of PCA
    :param non_outliers: number of non-outliers
    :param total_points: total number of points
    :param Y: (n, d) array of points
    
    :return: AUROC for each outlier detection approach
    """

    # Outliers are appended to the end of the matrix
    outlier_idx = set([i for i in range(non_outliers, total_points)])
    
    # Create true labels
    true = np.ones(total_points)
    true[list(outlier_idx)] = 0

    # For each approach, get a sorted list of scores. Re-sort according to the original order of the points and compute the AUROC.

    # Compressibility approach
    C = get_compression_matrix(Y, pca_dim, False)
    var_list = variance_list(C)
    compression_scores = sorted(var_list, key=lambda x: x[1], reverse=False)
    compression_scores = [x[0] for x in compression_scores]
    compression_res = roc_auc_score(true, compression_scores)

    Y_pca = PCA(n_components=pca_dim).fit_transform(Y)
    # LOF approach
    lof_scores = lof(Y, 10)
    lof_scores = sorted(lof_scores, key=lambda x: x[1], reverse=False)
    lof_scores = np.array([x[0] for x in lof_scores])
    lof_res = roc_auc_score(true, lof_scores)

    # PCA + LOF approach
    pca_lof = lof(Y_pca, 10)
    pca_lof_scores = sorted(pca_lof, key=lambda x: x[1], reverse=False)
    pca_lof_scores = np.array([x[0] for x in pca_lof_scores])
    pca_lof_res = roc_auc_score(true, pca_lof_scores)

    # KNN-dist approach
    knn_dist = KNN_dist(Y, 20)
    knn_scores = sorted(knn_dist, key=lambda x: x[1], reverse=False)
    knn_scores = np.array([x[0] for x in knn_scores]) * -1 # Large KNN-dist score = more likely to be outlier
    knn_dist_res = roc_auc_score(true, knn_scores)
    
    # PCA + KNN-dist approach
    pca_knn_dist = KNN_dist(Y_pca, 20)
    pca_knn_dist_scores = sorted(pca_knn_dist, key=lambda x: x[1], reverse=False)
    pca_knn_dist_scores = np.array([x[0] for x in pca_knn_dist_scores]) * -1
    pca_knn_dist_res = roc_auc_score(true, pca_knn_dist_scores)

    # Isolation forest approach
    isolation_forest_scores = isolation_forest(Y)
    isolation_forest_scores = sorted(isolation_forest_scores, key=lambda x: x[1], reverse=False)
    isolation_forest_scores = np.array([x[0] for x in isolation_forest_scores])
    isolation_forest_res = roc_auc_score(true, isolation_forest_scores)

    # PCA + Isolation forest approach
    isolation_forest_pca = isolation_forest(Y_pca)
    isolation_forest_pca_scores = sorted(isolation_forest_pca, key=lambda x: x[1], reverse=False)
    isolation_forest_pca_scores = np.array([x[0] for x in isolation_forest_pca_scores])
    isolation_forest_pca_res = roc_auc_score(true, isolation_forest_pca_scores)

    # ECOD approach
    ecod_scores = ecod(Y)
    ecod_scores = sorted(ecod_scores, key=lambda x: x[1], reverse=False)
    ecod_scores = np.array([x[0] for x in ecod_scores]) * -1 # Large ECOD score = more likely to be outlier
    ecod_res = roc_auc_score(true, ecod_scores)

    # PCA + ECOD approach
    ecod_pca = ecod(Y_pca)
    ecod_pca_scores = sorted(ecod_pca, key=lambda x: x[1], reverse=False)
    ecod_pca_scores = np.array([x[0] for x in ecod_pca_scores]) * -1 
    ecod_pca_res = roc_auc_score(true, ecod_pca_scores)
    
    return (compression_res,lof_res,pca_lof_res,knn_dist_res,pca_knn_dist_res, isolation_forest_res, isolation_forest_pca_res, ecod_res, ecod_pca_res)

"""
Real-world experiment utilities.
"""

def parse_h5ad(anndf, cluster_label_obs): 
    """
    Return the data, cluster_sizes, and cluster labels from an AnnData object.
    :param anndf: AnnData object
    :param cluster_label_obs: name of the cluster label in the AnnData object

    :return: data, cluster_sizes, cluster_labels
    """
    # Get data and cluster labels
    data = anndf.to_df()
    cluster_labels = anndf.obs[cluster_label_obs].values

    # Sort data by cluster label
    combined = [(cluster_labels.codes[i], data.iloc[i]) for i in range(len(data))]
    combined.sort(key=lambda x: x[0])

    # Separate data
    data = np.array([x[1] for x in combined])

    # Get cluster sizes
    cluster_sizes = [0] * len(cluster_labels.categories)
    for i in range(len(combined)): 
        cluster_sizes[combined[i][0]] += 1

    # Separate cluster labels
    cluster_labels = np.array([x[0] for x in combined])

    return data, cluster_sizes, cluster_labels

def initiate(dsname, dspath, idx):
    """
    Initialize the data, cluster sizes, and cluster labels for a dataset from a list of dataset names and paths. 
    :param dsname: name of the dataset
    :param dspath: dictionary of dataset paths
    :param idx: index of the dataset in the list of dataset names

    :return: data, cluster_sizes, cluster_labels
    """

    # Read and log normalize data
    sce_data = scanpy.read_h5ad(dspath[idx])
    data, cs, labels = parse_h5ad(sce_data, 'phenoid')
    data= np.log1p(data)

    return data, cs, labels

def purity_score(y_true, y_pred):
    """
    Compute the purity of a clustering.
    :param y_true: true cluster labels
    :param y_pred: predicted cluster labels

    :return: purity
    """

    # Compute contingency/confusion matrix
    contingency_matrix = cluster.contingency_matrix(y_true, y_pred)
    # Compute and return purity 
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)

def kmeans_nmi_purity(data, n_clusters, labels): 
    """
    Compute the NMI and purity of a kmeans clustering on a dataset.
    :param data: numpy array of shape (n_samples, n_features)
    :param n_clusters: number of clusters
    :param labels: cluster labels

    :return: NMI and purity of kmeans clustering on data
    """

    kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(data)
    return adjusted_mutual_info_score(labels, kmeans.labels_), purity_score(labels, kmeans.labels_)

def variance_list(C): 

    """
    Compute the variance of compressibility for each point in C
    :param C: compressibility matrix 
    :return: Array of (variance, index in data) sorted by variance in asscending order. 
    """

    # Remove NaNs along the diagonal
    np.nan_to_num(C, copy=False, nan=0.0)

    comp_var = [0] * len(C)
    for i in range(len(C)): 
        # Compute variance ignoring the diagonal
        comp_var[i] = np.var(C[i], where=[i != j for j in range(len(C))])

    # Combine variance with indices and sort by variance
    combined_var = [(comp_var[i], i) for i in range(len(C))]
    combined_var.sort(key=lambda x: x[0])
    return combined_var

def lof(PX, ng=20):

    """
    Calculate the LOF score for each point in PX and return the points sorted by LOF score.

    :param PX: numpy array of shape (n_samples, n_features)
    :param ng: number of neighbors to use for LOF
    :return: array of (LOF score, index in PX) sorted by LOF score in ascending order
    """   

    n=PX.shape[0]

    # Calculate LOF scores
    clf = LocalOutlierFactor(n_neighbors=ng)
    clf.fit_predict(PX)
    # Closer to 0 is less likely to be outlier (NOF score is close to -1)
    X_scores = clf.negative_outlier_factor_

    # Combine LOF scores with indices
    xaxis=[i for i in range(n)]
    LOF_order=np.zeros((n,2))
    LOF_order[:,0]=X_scores
    LOF_order[:,1]=xaxis 

    # Sort by LOF score
    LOF_order=sorted(LOF_order, key=lambda x: x[0]) 
    LOF_order=np.array(LOF_order)

    return LOF_order

def KNN_dist(PX,kchoice):

    """
    Calculate the KNN distance for each point in PX and return the points sorted by KNN distance.

    :param PX: numpy array of shape (n_samples, n_features)
    :param kchoice: number of neighbors to use for KNN
    :return: array of (KNN distance, index in PX) sorted by KNN distance in descending order
    """

    n=PX.shape[0]
    index = pynndescent.NNDescent(PX)
    index.prepare()
    kchoice1=kchoice+1
    neighbors = index.query(PX,k=kchoice1)
    indices = neighbors[0]
    knn_list=indices[:,1:]
    knn_list=np.array(knn_list)

    knn_order=np.zeros((n,2))

    for i in range(n):
        knn_order[i,0]=min([np.linalg.norm(PX[i,:]-PX[knn_list[i,j],:]) for j in range(kchoice)])
        knn_order[i,1]=i

    knn_order=sorted(knn_order, key=operator.itemgetter(0),reverse=True)
    knn_order=np.array(knn_order)

    return knn_order

def isolation_forest(data, n_estimators=100, contamination=0.1):
    """
    Compute the isolation forest scores for a dataset. 
    :param data: numpy array of shape (n_samples, n_features)
    :param n_estimators: number of trees in the forest
    :param contamination: proportion of outliers in the dataset

    :return: array of isolation forest scores
    """

    clf = IsolationForest(n_estimators=n_estimators, contamination=contamination).fit(data)

    # Combine scores with indices
    scores = clf.decision_function(data)
    combined = [(scores[i], i) for i in range(len(scores))]
    combined.sort(key=lambda x: x[0])
    return np.array(combined)

def ecod(PX):

    """
    Calculate the ECOD score for each point in PX and return the points sorted by ECOD score.

    :param PX: numpy array of shape (n_samples, n_features)
    :return: array of (ECOD score, index in PX) sorted by ECOD score in descending order
    """

    # Calculate ECOD scores
    clf = ECOD()
    clf.fit(PX)
    y_train_scores = clf.decision_scores_

    n=PX.shape[0]

    # Combine ECOD scores with indices
    xaxis=[i for i in range(n)]
    ECOD_order=np.zeros((n,2))
    ECOD_order[:,0]=y_train_scores
    ECOD_order[:,1]=xaxis 

    # Sort by ECOD score (reverse order, higher score is more likely to be outlier)
    ECOD_order=sorted(ECOD_order, key=lambda x: x[0],reverse=True)
    ECOD_order=np.array(ECOD_order)

    return ECOD_order

def remove_pca_kmeans(data, cluster_sizes, labels, pca_dim, removal_rate=0.2, C=None, method='compression'):
    
    """
    Compute the NMI and ARI of a kmeans clustering on a dataset after removing likely outlier points and performing PCA.

    :param data: numpy array of shape (n_samples, n_features), must be sorted by cluster
    :param cluster_sizes: array of cluster sizes
    :param labels: cluster labels, must be sorted by cluster
    :param pca_dim: dimensionality of PCA
    :param removal_rate: percentage of points to remove
    :param reduce_dim: whether to reduce dimensionality before computing compressibility
    :param C: compressibility matrix
    :param method: method of outlier removal. Options are 'compression', 'lof', 'lof_pca', 'knn', and 'knn_pca'

    :return: 50 trials of the NMI and purity of kmeans clustering on data after outlier removal
    """

    n_clusters = len(cluster_sizes)

    # Array of (outlier metric, index in data) sorted by outlier metric in descending order (i.e. most likely outlier first)
    combined = None

    # Variance of compressibility
    if method == 'compression': 
        if C is None:
            raise ValueError('C must be provided for compression method')
        combined = variance_list(C)
    # Local outlier factor
    elif method == 'lof': 
        combined = lof(data, 10)
    # Local outlier factor after PCA
    elif method == 'lof_pca': 
        data_pca = PCA(n_components=pca_dim, random_state=1).fit_transform(data)
        combined = lof(data_pca, 10)
    # KNN distance
    elif method == 'knn': 
        combined = KNN_dist(data,20)
    # KNN distance after PCA
    elif method == 'knn_pca': 
        data_pca = PCA(n_components=pca_dim, random_state=1).fit_transform(data)
        combined = KNN_dist(data_pca,20)
    elif method == 'isolation_forest':
        combined = isolation_forest(data)
    elif method == 'isolation_forest_pca':
        data_pca = PCA(n_components=pca_dim, random_state=1).fit_transform(data)
        combined = isolation_forest(data_pca, contamination=removal_rate)
    elif method == 'ecod': 
        combined = ecod(data)
    elif method == 'ecod_pca': 
        data_pca = PCA(n_components=pca_dim, random_state=1).fit_transform(data)
        combined = ecod(data_pca)
    else:
        raise ValueError('Invalid outlier removal method')

    # Remove points from data and labels
    num_to_remove = int(removal_rate * sum(cluster_sizes))
    mask = np.ones(len(data), dtype=bool)
    mask[[int(combined[i][1]) for i in range(num_to_remove)]] = False
    data_removed = data[mask, :]
    labels_removed = labels[mask]

    # Perform kmeans clustering on post-PCA data; repeat 5 times and take max
    data_removed_pca = PCA(n_components=pca_dim, random_state=1).fit_transform(data_removed)
    arr = Parallel(n_jobs=5)(delayed(kmeans_nmi_purity)(data_removed_pca, n_clusters, labels_removed) for i in range(50))

    nmi = np.array([x[0] for x in arr])
    purity = np.array([x[1] for x in arr])

    return nmi, purity

"""
Visualization and Miscellanous utilities.
"""

def multi_bar_graph_error_bars(xlabels,data,title,ylabel,ncols=3,legend_loc='upper right', fig_width=5, fig_height=5):
    """
    Plot a multi-bar graph with error bars for the min and max values of the data.

    :param xlabels: list of x labels
    :param data: dictionary of values for the bars in each group  
    :param title: title of graph
    :param ylabel: y axis label
    :param ncols: number of columns for the legend; same as the number of keys in data
    :param legend_loc: location of the legend
    """

    x = np.arange(len(xlabels))  # the label locations
    width = 0.75 / ncols  # the width of the bars
    multiplier = 0

    fig, ax = plt.subplots(layout='constrained')
    colors = ['darkorange', 'red', 'brown', 'gray', 'darkturquoise', 'pink', 'olive', 'purple', 'blue']

    for attribute, measurement in data.items():
        # Use the offset to move the bar
        offset = width * multiplier

        # Calculate the median, lower, and upper values for the error bars
        med = np.array([np.median(m) for m in measurement])
        lower = np.array([np.min(m) for m in measurement])
        upper = np.array([np.max(m) for m in measurement])

        yerr = [[max(lower[i], med[i]) - min(lower[i], med[i]) for i in range(len(med))], [max(upper[i], med[i]) - min(upper[i], med[i]) for i in range(len(med))]]
        ax.bar(x + offset, med, width, label=attribute)
        ax.errorbar(x + offset, med, yerr=yerr, fmt='none', ecolor='black', capsize=2)

        # Plot the bar
        ax.bar(x + offset, med, width, label=attribute, yerr=yerr, color=colors[multiplier])
        multiplier += 1

    # Add text for labels, title and custom x-axis tick labels, etc.
    ax.set_title(title)
    ax.set_ylabel(ylabel, fontsize="15")
    ax.set_xticks(x + width, xlabels, fontsize="15")

    # Create the legend and filter out duplicates (one for the bar and one for the error bars)
    handles, labels = ax.get_legend_handles_labels()
    handles = [handles[i] for i in range(len(handles)) if i % 2 == 1]
    labels = [labels[i] for i in range(len(labels)) if i % 2 == 1]
    if legend_loc == 'outside': 
        ax.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.55, 1), fontsize="15")
    else:
        ax.legend(handles, labels, loc=legend_loc, borderaxespad=0.)

    # Set figure size
    fig.set_size_inches((fig_width, fig_height))
    # plt.rcParams.update({'font.size': 14})

    # Set the y-axis limits
    min_val = np.inf
    max_val = -np.inf
    for attribute, measurement in data.items():
        min_val = min(min_val, min([np.min(m) for m in measurement]))
        max_val = max(max_val, max([np.max(m) for m in measurement]))
    ax.set_ylim(min(-0.005, min_val - (abs(min_val) * 0.15)), max_val + (abs(max_val) * 0.15))

    plt.show()

def multi_bar_graph(xlabels,data,title,ylabel,ncols=3,legend_loc='upper right', fig_width=5, fig_height=5): 

    """
    Plot a multi-bar graph for the data.

    :param xlabels: list of dataset names
    :param data: dictionary of values for the bars in each group
    :param title: title of graph
    :param ylabel: y axis label
    :param ncols: number of columns for the legend; same as the number of keys in data
    :param legend_loc: location of the legend
    """

    x = np.arange(len(xlabels))  # the label locations
    width = 1 / ncols  # the width of the bars
    multiplier = 0

    fig, ax = plt.subplots(layout='constrained')
    colors = ['darkorange', 'red', 'brown', 'gray', 'darkturquoise', 'pink', 'olive', 'purple', 'blue']

    for attribute, measurement in data.items():
        # Use the offset to move the bar
        offset = (width - (width * 0.4)) * multiplier

        # Plot the bar
        ax.bar(x + offset, measurement, width - (width * 0.4), label=attribute, color=colors[multiplier])
        multiplier += 1

    # Add text for labels, title and custom x-axis tick labels, etc.
    ax.set_title(title, fontsize="12")
    ax.set_ylabel(ylabel, fontsize="12")
    ax.set_xticks(x + width, xlabels, fontsize="12")

    # Create the legend
    ax.legend(loc=legend_loc, borderaxespad=0., fontsize="12", bbox_to_anchor=(1.53, 1))

    # Set the y-axis limits
    min_val = min([y for x in data.values() for y in x])
    max_val = max([y for x in data.values() for y in x])
    ax.set_ylim(min(-0.005, min_val - (abs(min_val) * 0.15)), max(0.23, max_val + (abs(max_val) * 0.15)))
    
    fig.set_size_inches((fig_width, fig_height))

    plt.show()