import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from scipy.linalg import orthogonal_procrustes
from sklearn.neighbors import kneighbors_graph
from scipy.sparse import csr_matrix, diags
from scipy.sparse.linalg import eigsh
from scipy.linalg import eigh
from sklearn import datasets
import itertools
import matplotlib.pyplot as plt
import time
import pandas as pd
import os
from sklearn.manifold import SpectralEmbedding
import pacmap
from sklearn.neighbors import NearestNeighbors


import ot 


def pca_embedding(X, k=None):
    Xc = X - X.mean(axis=0)
    pca = PCA(n_components=k if k else X.shape[1])
    pca.fit(Xc)
    return pca.transform(Xc)  # returns n x k array, like diffusion_embedding


def pca_embedding_with_eigvals_and_evecs(X, k=None):
    Xc = X - X.mean(axis=0)
    k = k if k else X.shape[1]
    pca = PCA(n_components=k)
    pca.fit(Xc)
    AX = pca.transform(Xc)
    eigvals = pca.explained_variance_
    evecs = pca.components_.T  
    return AX, eigvals, evecs

def diffusion_embedding2(X, k=5, epsilon=None, t=1, symmetrize=True, dtype=np.float32):
    Xc = X - X.mean(axis=0)
    Xc = Xc.astype(dtype, copy=False)
    n = Xc.shape[0]
    k = min(k, n-1)
    n, d = Xc.shape

    n_neighbors = int(np.ceil(d * np.log(n)))
    n_neighbors = min(n_neighbors, n - 1)

    A = kneighbors_graph(
        Xc,       
        n_neighbors=n_neighbors,
        mode='distance',
        include_self=False,
        metric='euclidean',
        n_jobs=1
    ).tocsr()
    
    A = A.tocsr().astype(dtype)

    if epsilon is None:
        eps_base = np.median((A.data ** 2)) if A.nnz > 0 else 1.0
        epsilon = np.float32(eps_base if eps_base > 0 else 1.0)

    W = A.copy()
    W.data = np.exp(-(W.data ** 2) / epsilon).astype(dtype, copy=False)

    if symmetrize:
        W = (W + W.T) * (dtype(0.5))

    d = np.asarray(W.sum(axis=1)).ravel()
    dinv = 1.0 / np.sqrt(d + 1e-12)
    Dinv = diags(dinv.astype(dtype))
    P_tilde = Dinv @ W @ Dinv

    vals, vecs = eigsh(P_tilde, k=k+1, which='LM')
    order = np.argsort(vals)[::-1]
    vals = vals[order]
    vecs = vecs[:, order]

    lambdas = vals[1:k+1]
    phis = vecs[:, 1:k+1]

    emb = phis * (lambdas ** t)
    return emb

def diffusion_embedding(X, k=5, epsilon=None, t=1):
    Xc = X - X.mean(axis=0)
    n = Xc.shape[0]
    D_sq = cdist(Xc, Xc, 'sqeuclidean')
    if epsilon is None:
        epsilon = np.median(D_sq)
    K = np.exp(-D_sq / epsilon)
    d = np.sum(K, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d))
    P_tilde = D_inv_sqrt @ K @ D_inv_sqrt

    if k >= n - 1: # full evd
        vals, vecs = eigh(P_tilde)
        order = np.argsort(vals)[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        # remove the trivial eigenvector 
        lambdas = vals[1:k+1]
        phis = vecs[:, 1:k+1]
    else:
        vals, vecs = eigsh(P_tilde, k=k+1, which='LM')
        order = np.argsort(vals)[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        lambdas = vals[1:k+1]
        phis = vecs[:, 1:k+1]

    embedding_matrix = phis * (lambdas ** t)
    return embedding_matrix



def get_sorted_columns(A): # A: n x k array
    # returns list of length k; each is sorted 1D array of size n
    return [np.sort(A[:, i]) for i in range(A.shape[1])]

from ot.lp import emd_1d_sorted

# accessing internal method to save time
def emd2_1d_nosort(x_sorted, y_sorted, a_sorted=None, b_sorted=None, metric="sqeuclidean", p=2.0):
    n, m = len(x_sorted), len(y_sorted)
    if a_sorted is None:
        a_sorted = np.ones(n, dtype=float)/n
    if b_sorted is None:
        b_sorted = np.ones(m, dtype=float)/m

    Gs, indices, cost = emd_1d_sorted(
        a_sorted.astype(np.float64),
        b_sorted.astype(np.float64),
        x_sorted.astype(np.float64),
        y_sorted.astype(np.float64),
        metric=metric, p=p
    )
    return cost


def compute_1d_wasserstein_sorted(u_sorted, v_sorted): 
    n, m = len(u_sorted), len(v_sorted)
    if True and n == m: # uniform weights
        return np.mean((u_sorted - v_sorted) ** 2)
    else:
        a = np.ones(n) / n
        b = np.ones(m) / m
        return emd2_1d_nosort(u_sorted, v_sorted, a, b, metric="sqeuclidean", p=2.0)
    #return (ot.lp.emd2_1d(u_sorted, v_sorted, a, b)) # W2^2

def assignment_sliced_wasserstein(X, Y, AX, BY, eigvals_X=None, eigvals_Y=None, eigvecs_X=None, eigvecs_Y=None, print_pairings=False):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX) # save a linear factor by sorting early
    BY_sorted = get_sorted_columns(BY)
    C = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            # positive‐sign cost (both arrays already ascending)
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            # negative‐sign cost: reverse the sorted array before negating
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C[i, j] = min(d_pos, d_neg)
    row_ind, col_ind = linear_sum_assignment(C)
    total_cost = C[row_ind, col_ind].sum()

    return np.sqrt((1 / k) * total_cost) # matching units


# a weighted version - not used in experiments but left for completeness
def assignment_weighted_after(X, Y, AX, BY, eigvals_X, eigvals_Y, eigvecs_X=None, eigvecs_Y=None, weight_type='geometric', print_pairings=False):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX)
    BY_sorted = get_sorted_columns(BY)
    C = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C[i, j] = min(d_pos, d_neg)
    row_ind, col_ind = linear_sum_assignment(C)
    costs = C[row_ind, col_ind]
    if weight_type == 'geometric':
        weights = np.sqrt(eigvals_X[row_ind] * eigvals_Y[col_ind])
    elif weight_type == 'arithmetic':
        weights = 0.5 * (eigvals_X[row_ind] + eigvals_Y[col_ind])
    weights_sum = weights.sum()
    weights = weights / weights_sum if weights_sum > 0 else np.ones_like(weights)/k
    total_cost = (weights * costs).sum()
    return np.sqrt(total_cost)

def softmin_sign(C_pos, C_neg, beta=10.0):
    # soft-min over sign
    w = 1.0 / (1.0 + np.exp(beta * (C_pos - C_neg)))
    return w * C_pos + (1 - w) * C_neg

# "soft" riswie
def assignment_sliced_wasserstein_soft_sinkhorn(X, Y, AX, BY, epsilon=0.08, beta=5.0, n_iter=50000):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX)
    BY_sorted = get_sorted_columns(BY)
    C_pos = np.zeros((k, k))
    C_neg = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C_pos[i, j] = d_pos
            C_neg[i, j] = d_neg

    C_soft = softmin_sign(C_pos, C_neg, beta=beta)

    a = np.ones(k) / k
    b = np.ones(k) / k

    # use pot's sinkhorn algorithm to solve
    P = ot.sinkhorn(a, b, C_soft, reg=epsilon, numItermax=n_iter)
    total_cost = np.sum(P * C_soft)
    return np.sqrt(total_cost / k)

def gromov_wasserstein(X, Y):
    C1 = cdist(X, X)  # pairwise distance matrix for X
    C2 = cdist(Y, Y)  # pairwise distance matrix for Y
    p = np.ones(len(X)) / len(X)
    q = np.ones(len(Y)) / len(Y)
    gw_cost = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', max_iter=100000000)
    return np.sqrt(np.maximum(gw_cost, 0.0)) # some numerical details


def standard_ot(X, Y):
    if X.shape[1] != Y.shape[1]: # need to live in same dimensional space
        raise ValueError(f"OT: Feature dimension mismatch: {X.shape[1]} vs {Y.shape[1]}")
    # same deal as GW
    C = cdist(X, Y)
    p = np.ones(len(X)) / len(X)
    q = np.ones(len(Y)) / len(Y)
    T = ot.emd(p, q, C, numItermax=100000000)
    ot_cost = (C * T).sum()
    return np.sqrt(ot_cost) 



def sliced_wasserstein(X, Y, n_proj=100, p=2, seed=None):
    if X.shape[1] != Y.shape[1]:
        raise ValueError(f"SW: Feature dimension mismatch: {X.shape[1]} vs {Y.shape[1]}")
    rng = np.random.default_rng(seed)
    n, d = X.shape
    m = Y.shape[0]
    dists = []
    for _ in range(n_proj):
        v = rng.normal(size=d)
        v /= np.linalg.norm(v)  # unit vector

        # project onto some multiple of the axes
        X_proj = np.dot(X, v)
        Y_proj = np.dot(Y, v)

        X_proj_sorted = np.sort(X_proj)
        Y_proj_sorted = np.sort(Y_proj)
        # we do equal sizes here, which isn't an issue because in our experiments we are always subsampling
        n_used = min(len(X_proj_sorted), len(Y_proj_sorted))
        Xs = X_proj_sorted[:n_used]
        Ys = Y_proj_sorted[:n_used]
        dist = np.mean(np.abs(Xs - Ys) ** p)
        dists.append(dist)

    return np.mean(dists)**(1/p)

## START NONE BOILERPLATE CODE HERE

import os
import numpy as np
from plyfile import PlyData
from scipy.optimize import linear_sum_assignment
from scipy.stats import wasserstein_distance
from sklearn.decomposition import PCA
from tqdm import tqdm
import plotly.graph_objects as go
import ot  


data_root = "MPI-FAUST/training/registrations"
file_template = "tr_reg_{:03d}.ply"

N_PEOPLE = 10
N_POSES = 10

# they are indexed nicely
mesh_paths = [os.path.join(data_root, file_template.format(p*10 + s)) for p in range(N_PEOPLE) for s in range(N_POSES)]


def load_mesh(ply_path):
    ply = PlyData.read(ply_path)
    verts = np.vstack([ply['vertex'][c] for c in ('x', 'y', 'z')]).T
    faces = np.vstack(ply['face']['vertex_indices'])
    return verts, faces

verts_list, faces_list = [], []
for path in tqdm(mesh_paths, desc="Loading meshes"):
    V, F = load_mesh(path)
    verts_list.append(V)
    faces_list.append(F)
verts_array = np.stack(verts_list, axis=0)  # (100, N_VERTS, 3)
faces = faces_list[0]  # all meshes have the same faces nicely

def deterministic_subsample(X, n, seed):
    rng = np.random.RandomState(seed)
    idx = rng.choice(X.shape[0], size=n, replace=False)
    return X[idx]

def ot_wasserstein_distance(X, Y):
    n = X.shape[0]
    m = Y.shape[0]
    a = np.ones(n) / n
    b = np.ones(m) / m
    M = ot.dist(X, Y, metric='euclidean') ** 2
    W2 = ot.emd2(a, b, M, numItermax=1_000_000)
    return np.sqrt(W2)

def ot_wasserstein_distance_subsampled(X, Y, n_subsample=500, seed_i=None, seed_j=None):
    n_i = min(n_subsample, X.shape[0])
    n_j = min(n_subsample, Y.shape[0])
    Xs = deterministic_subsample(X, n_i, seed=int(seed_i)) if seed_i is not None else X
    Ys = deterministic_subsample(Y, n_j, seed=int(seed_j)) if seed_j is not None else Y
    return ot_wasserstein_distance(Xs, Ys)

def riswie_distance_subsampled(X, Y, k=3, n_subsample=500, seed_i=None, seed_j=None):
    n_i = min(n_subsample, X.shape[0])
    n_j = min(n_subsample, Y.shape[0])
    Xs = deterministic_subsample(X, n_i, seed=int(seed_i)) if seed_i is not None else X
    Ys = deterministic_subsample(Y, n_j, seed=int(seed_j)) if seed_j is not None else Y

    AX = pca_embedding(Xs, k=k)
    BY = pca_embedding(Ys, k=k)

    return assignment_sliced_wasserstein(Xs, Ys, AX, BY)


def gromov_wasserstein_distance(X, Y, n_subsample=500, i=None, j=None, loss_fun='square_loss', numItermax=50000, epsilon=1e-9):

    X_sub = deterministic_subsample(X, n_subsample, seed=int(i))
    Y_sub = deterministic_subsample(Y, n_subsample, seed=int(j))
    nx, ny = X_sub.shape[0], Y_sub.shape[0]
    a = np.ones(nx) / nx
    b = np.ones(ny) / ny
    C1 = ot.dist(X_sub, X_sub, metric='euclidean') ** 2
    C2 = ot.dist(Y_sub, Y_sub, metric='euclidean') ** 2
    gw_dist = ot.gromov.gromov_wasserstein2(
        C1, C2, a, b, 'square_loss',
        max_iter=numItermax, epsilon=epsilon, verbose=False
    )
    return np.sqrt(gw_dist)

def euclidean_random_pairing_distance_subsampled(
    X, Y, n_subsample=1000, seed_i=None, seed_j=None, pair_seed=None, use_rms=True
):
    Xi = deterministic_subsample(X, min(n_subsample, X.shape[0]), seed=int(seed_i)) if seed_i is not None else X
    Yj = deterministic_subsample(Y, min(n_subsample, Y.shape[0]), seed=int(seed_j)) if seed_j is not None else Y
    m = min(Xi.shape[0], Yj.shape[0])

    base = 42
    rng = np.random.RandomState(base)

    idx_i = rng.choice(Xi.shape[0], size=m, replace=False)
    idx_j = rng.choice(Yj.shape[0], size=m, replace=False)
    diffs = Xi[idx_i] - Yj[idx_j]

    if use_rms:
        return float(np.sqrt(np.mean(np.sum(diffs**2, axis=1))))
    else:
        return float(np.mean(np.linalg.norm(diffs, axis=1)))

def sliced_wasserstein_distance_subsampled(
    X, Y, n_subsample=1000, n_slices=64, seed=None, seed_i=None, seed_j=None
):
    Xi = deterministic_subsample(X, min(n_subsample, X.shape[0]), seed=int(seed_i)) if seed_i is not None else X
    Yj = deterministic_subsample(Y, min(n_subsample, Y.shape[0]), seed=int(seed_j)) if seed_j is not None else Y

    rng = np.random.RandomState(42)

    d = Xi.shape[1]
    total = 0.0
    for _ in range(n_slices):
        theta = rng.normal(size=d)
        nrm = np.linalg.norm(theta)
        if nrm == 0: # needed?
            continue
        theta /= nrm

        u = Xi @ theta
        v = Yj @ theta
        u.sort()
        v.sort()

        if u.shape[0] == v.shape[0]:
            total += float(np.mean((u - v) ** 2))  # equal size
        else:
            total += float(emd2_1d_nosort(u, v))   # we support this even though it isn't needed here

    return float(np.sqrt(total / max(1, n_slices)))

N = N_PEOPLE * N_POSES
D = np.zeros((N, N))
E = np.zeros((N, N))
F = np.zeros((N, N))
G = np.zeros((N, N))
H =  np.zeros((N, N))
for i in tqdm(range(N), desc="RISWIE distance matrix"):
    for j in range(i+1, N):
        D[i, j] = assignment_sliced_wasserstein(verts_array[i], verts_array[j], pca_embedding(verts_array[i], k=3), pca_embedding(verts_array[j], k=3))
        D[j, i] = D[i, j]
        # do other distances as you wish


import numpy as np
import pandas as pd

from sklearn.cluster import KMeans, SpectralClustering, AgglomerativeClustering
from sklearn.metrics import v_measure_score, adjusted_rand_score, normalized_mutual_info_score
from sklearn.manifold import MDS, TSNE
from sklearn.metrics.cluster import contingency_matrix
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import StandardScaler
from scipy.optimize import linear_sum_assignment


from sklearn_extra.cluster import KMedoids

def clustering_accuracy(y_true, y_pred):
    C = contingency_matrix(y_true, y_pred)
    row_ind, col_ind = linear_sum_assignment(-C)
    matched = C[row_ind, col_ind].sum()
    return matched / C.sum()

def run_kmeans_trials(X, n_clusters, trials=50, n_init_per_trial=10, seed=12345):
    best_inertia = np.inf
    best_labels = None
    rng = np.random.default_rng(seed)
    seeds = rng.integers(0, 2**31 - 1, size=trials)
    for rs in seeds:
        km = KMeans(n_clusters=n_clusters, n_init=n_init_per_trial, random_state=int(rs))
        labels = km.fit_predict(X)
        if km.inertia_ < best_inertia:
            best_inertia = km.inertia_
            best_labels = labels
    return best_labels, best_inertia

def run_kmedoids_trials(X_or_D, n_clusters, metric, trials=20, seed=12345):

    best_inertia = np.inf
    best_labels = None
    rng = np.random.default_rng(seed)
    seeds = rng.integers(0, 2**31 - 1, size=trials)
    for rs in seeds:
        km = KMedoids(
            n_clusters=n_clusters,
            metric=metric,
            init='k-medoids++',
            method='pam',
            random_state=int(rs),
        )
        labels = km.fit_predict(X_or_D)
        if km.inertia_ < best_inertia:
            best_inertia = km.inertia_
            best_labels = labels
    return best_labels, best_inertia

def add_row(results, method, inertia, y_true, y_pred):
    results.append((
        method,
        inertia,
        clustering_accuracy(y_true, y_pred),
        v_measure_score(y_true, y_pred),
        adjusted_rand_score(y_true, y_pred),
        normalized_mutual_info_score(y_true, y_pred),
    ))

def rbf_affinity_from_dist(D):
    nz = D[D > 0]
    sigma = float(np.median(nz)) if nz.size else 1.0
    sigma = max(sigma, 1e-12) # numerical handling
    A = np.exp(-(D ** 2) / (2 * sigma ** 2))
    np.fill_diagonal(A, 1.0)
    return A

all_results = {}

for name, mat in [
    ("RISWIE", D.copy()),
    ("Gromov", E.copy()),
    ("OT", F.copy()),
    ("Euclidean", G.copy()),
    ("Sliced", H.copy()),
]:
    print(f"\n===== Running clustering for {name} =====")
    X = mat
    N = X.shape[0]
    n_clusters = 10
    true_labels = np.array([i % n_clusters for i in range(N)])

    results = []

    pred, inertia = run_kmeans_trials(X, n_clusters, trials=50, n_init_per_trial=50)
    add_row(results, "KMeans (dist rows)", inertia, true_labels, pred)

    pred, inertia = run_kmedoids_trials(X, n_clusters, metric='precomputed', trials=20)
    add_row(results, "KMedoids (precomputed dist)", inertia, true_labels, pred)

    pred = SpectralClustering(
        n_clusters=n_clusters, affinity='precomputed', n_init=50, random_state=42
    ).fit_predict(rbf_affinity_from_dist(X))
    add_row(results, "Spectral (RBF of dist)", np.nan, true_labels, pred)

    pred = AgglomerativeClustering(
        n_clusters=n_clusters, linkage='average', metric='precomputed'
    ).fit_predict(X)
    add_row(results, "Agglomerative (avg, precomp)", np.nan, true_labels, pred)

    for d in (2, 3):
        X_mds = MDS(n_components=d, dissimilarity='precomputed', random_state=0).fit_transform(X)
        X_mds = StandardScaler().fit_transform(X_mds)

        pred, inertia = run_kmeans_trials(X_mds, n_clusters, trials=50, n_init_per_trial=50)
        add_row(results, f"MDS-{d}D + KMeans", inertia, true_labels, pred)

        pred, inertia = run_kmedoids_trials(X_mds, n_clusters, metric='euclidean', trials=20)
        add_row(results, f"MDS-{d}D + KMedoids", inertia, true_labels, pred)

        D_mds = pairwise_distances(X_mds, metric='euclidean')
        A_mds = rbf_affinity_from_dist(D_mds)
        pred = SpectralClustering(
            n_clusters=n_clusters, affinity='precomputed', n_init=50, random_state=42
        ).fit_predict(A_mds)
        add_row(results, f"MDS-{d}D + Spectral", np.nan, true_labels, pred)

    for d in (2, 3):
        tsne = TSNE(
            n_components=d,
            perplexity=10,
            metric='precomputed',
            init='random',
            random_state=0,
            method="barnes_hut",
        )
        X_tsne = tsne.fit_transform(X)
        X_tsne = StandardScaler().fit_transform(X_tsne)

        pred, inertia = run_kmeans_trials(X_tsne, n_clusters, trials=50, n_init_per_trial=50)
        add_row(results, f"t-SNE-{d}D + KMeans", inertia, true_labels, pred)

        pred, inertia = run_kmedoids_trials(X_tsne, n_clusters, metric='euclidean', trials=20)
        add_row(results, f"t-SNE-{d}D + KMedoids", inertia, true_labels, pred)

        D_tsne = pairwise_distances(X_tsne, metric='euclidean')
        A_tsne = rbf_affinity_from_dist(D_tsne)
        pred = SpectralClustering(
            n_clusters=n_clusters, affinity='precomputed', n_init=50, random_state=42
        ).fit_predict(A_tsne)
        add_row(results, f"t-SNE-{d}D + Spectral", np.nan, true_labels, pred)

    df = pd.DataFrame(results, columns=["Method", "Inertia", "Accuracy", "V-measure", "ARI", "NMI"])
    method_order = [
        "KMeans (dist rows)",
        "KMedoids (precomputed dist)",
        "Spectral (RBF of dist)",
        "Agglomerative (avg, precomp)",
        "MDS-2D + KMeans", "MDS-2D + KMedoids", "MDS-2D + Spectral",
        "MDS-3D + KMeans", "MDS-3D + KMedoids", "MDS-3D + Spectral",
        "t-SNE-2D + KMeans", "t-SNE-2D + KMedoids", "t-SNE-2D + Spectral",
        "t-SNE-3D + KMeans", "t-SNE-3D + KMedoids", "t-SNE-3D + Spectral",
    ]
    df["__order"] = df["Method"].apply(lambda m: method_order.index(m) if m in method_order else 999)
    df = df.sort_values("__order").drop(columns="__order")
    print(df.to_string(index=False, float_format="%.4f"))

    all_results[name] = df.copy()

metrics = ["Accuracy", "V-measure", "ARI", "NMI"]
base_name = "RISWIE"

if True:
    df_base = all_results[base_name][["Method"] + metrics].set_index("Method")
    comparisons = []
    for other_name, df_other in all_results.items():
        if other_name == base_name:
            continue
        df_o = df_other[["Method"] + metrics].set_index("Method")
        common = df_base.index.intersection(df_o.index)
        if len(common) == 0:
            continue
        row = {"Against": other_name}
        for m in metrics:
            better = (df_base.loc[common, m] >= df_o.loc[common, m]).mean()
            row[m] = float(better)
        comparisons.append(row)

    if comparisons:
        comp_df = pd.DataFrame(comparisons, columns=["Against"] + metrics)
        print("\n===== Proportion of configs where RISWIE is better-or-equal =====")
        disp = comp_df.copy()
        for m in metrics:
            disp[m] = (disp[m] * 100).map(lambda x: f"{x:.1f}%")
        print(disp.to_string(index=False))
