from joblib import Parallel, delayed
import glest
import itertools
import re
from pathlib import Path
from typing import Any, Dict, List, Tuple

import matplotlib.colors as colors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import pytest
import scienceplots
import scipy
import seaborn as sns
import statsmodels.api as sm
from cycler import cycler
from matplotlib.ticker import (
    AutoMinorLocator,
    FuncFormatter,
    LogFormatter,
    LogFormatterSciNotation,
)
from sklearn.ensemble import HistGradientBoostingClassifier, RandomForestClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import (
    DecisionTreeClassifier,
    DecisionTreeRegressor,
)
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from tueplots import bundles

from decision.xp.common import (
    # fit_predict_clf,
    get_constant_utilty,
    # get_optimal_thresholds_norecal,
    # get_optimal_thresholds_recal,
    get_threshold_from_utility,
    # recalibrate_scores,
    u_emp_from_score,
)
from decision.xp.data.hate import filter_ds
from decision.xp.model.hate import getattr_nested
from decision.xp.data.base import ForwardedMixin, ds_registry, ds_rename
from decision.xp.model.base import PretrainedMixin, model_registry, model_rename
from decision.xp.post_training import (
    GLAR,
    FineTuning,
    GLARThresholded,
    HistogramBinningRecalibration,
    Identity,
    PartitionerDict,
    PlattBinnerRecalibration,
    PlattRecalibration,
    PostTraining,
    PostTrainingDict,
    SigmoidFineTuning,
    SklearnRecalibration,
    Stacking,
)
from decision.xp.regrets import (
    compute_accuracy,
    compute_regret_CL,
    compute_regret_CL_normal,
    compute_regret_GL_LB,
    compute_regret_GL_UB,
)
from glest.helpers import calibration_curve as glest_calibration_curve
from utils.contexts import Timer
from utils.io import save_fig, save_path
from utils.plot import (
    KeepLim,
    add_axis,
    add_colorbar,
    add_horizontal_bands,
    add_label_band,
    add_legend,
    set_latex_font,
)

plt.style.use(["science"])
plt.rcParams.update(bundles.neurips2024())


def compute_one(
    rule,
    ds: ForwardedMixin,
    model: PretrainedMixin,
    recalibrate: bool,
    finetuned: bool,
    n_utilities: int = 100,
    val_prop: float = 1,
) -> List[dict]:
    rs = 0
    use_ts = None
    # print("\n\n\n")
    # rule_name, rule = rule if rule is not None else (None, None)
    ds_name, ds = ds
    model_name, model = model

    print(
        f"ds_name: {ds.ds_name}, rule: {rule}, recalibrate: {recalibrate}, use_ts: {use_ts}, finetuned: {finetuned}"  # , n_utilities: {n_utilities}, utility_mode: {utility_mode}, use_extra_features: {use_extra_features}"
    )
    ds_name2 = ds_rename.get(ds_name, ds_name)
    model_name2 = model_rename.get(model_name, model_name)

    extra_features = None

    (X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

    def subsample(array: np.ndarray, prop: float):
        assert 0 <= prop <= 1
        n = array.shape[0]
        n_sub = int(n * prop)
        return array[:n_sub]

    n_init_val1 = idx_val1.shape[0]
    n_init_val2 = idx_val2.shape[0]

    idx_val1 = subsample(idx_val1, val_prop)
    idx_val2 = subsample(idx_val2, val_prop)

    n_val1 = idx_val1.shape[0]
    n_val2 = idx_val2.shape[0]

    # X_train = X[idx_train]
    # y_train = y[idx_train]
    X_val1 = X[idx_val1]
    S_val1 = S[idx_val1]
    y_val1 = y[idx_val1]
    X_val2 = X[idx_val2]
    S_val2 = S[idx_val2]
    y_val2 = y[idx_val2]
    X_test = X[idx_test]
    S_test = S[idx_test]
    y_test = y[idx_test]

    U = get_constant_utilty(n_utilities)  # (n_utilities, 2, 2)
    t = get_threshold_from_utility(U)  # (n_utilities,)

    print(X.shape, y.shape, S.shape)

    gle = glest.GLEstimator(S_val1, partitioner=rule, train_size=0.5, random_state=0)
    gle.fit(X_val1, y_val1, test_data=(X_val2, y_val2, S_val2))
    fig = gle.plot(fig_kw=dict(figsize=(2.5, 2.5)))

    V = gle.GL(reduce_bin=False) / 2

    C, H = glest_calibration_curve(
        gle.frac_pos_, gle.counts_, gle.mean_scores_, remove_empty=False
    )

    bins = gle.partitioner.bins_
    binids = np.searchsorted(bins[1:-1], S_test)

    a = S_test[:, None] >= t[None, :]
    RCL = compute_regret_CL(C[binids], t, a)  # (n, k)

    bin_counts = gle.counts_.sum(axis=1)

    RGL_LB = compute_regret_GL_LB(C, V, t, bin_counts=bin_counts)  # (k,)
    RGL_UB = compute_regret_GL_UB(C, V, t, bin_counts=bin_counts)  # (k,)

    print(RCL.shape)

    acc = compute_accuracy(y, S, t)  # (k,)
    auc = roc_auc_score(y, S)

    fig2, ax = plt.subplots(figsize=(3, 2))
    ax.plot(t, RCL.mean(axis=0), label=r"$t^{\star}R_{CL}$")
    ax.set(xlabel=r"Threshold $t^{\star}$", ylabel=r"Normalized regret")
    add_legend(ax, ncol=3)

    fig3, ax = plt.subplots(figsize=(3, 2))
    ax.plot(t, RGL_LB, label=r"$t^{\star}\mathrm{LB}$")
    ax.plot(t, RGL_UB, label=r"$t^{\star}\mathrm{UB}$")
    ax.set(xlabel=r"Threshold $t^{\star}$", ylabel=r"Normalized regret")
    add_legend(ax, ncol=3)

    fig4, ax = plt.subplots(figsize=(3, 2))
    ax.plot(t, RCL.mean(axis=0), label=r"$t^{\star}R_{CL}$")
    ax.plot(t, RGL_LB, label=r"$t^{\star}\mathrm{LB}$")
    ax.plot(t, RGL_UB, label=r"$t^{\star}\mathrm{UB}$")
    ax.set(xlabel=r"Threshold $t^{\star}$", ylabel=r"Normalized regret")
    ax.text(0.95, 0.95, f"AUC={auc:.3f}", ha="right", va="top", transform=ax.transAxes)
    ax.set_title(f"{ds_name2} — {model_name2}", y=1.15)
    add_legend(ax, ncol=3)

    fig5, ax = plt.subplots(figsize=(3, 2))
    ax.plot(t, acc, label=rf"Accuracy (AUC={auc:.4f})")
    ax.set(xlabel=r"Threshold $t^{\star}$")
    ax.set_title(f"{ds_name2} — {model_name2}", y=1.15)
    add_legend(ax, ncol=3)

    return fig, fig2, fig3, fig4, fig5

    if recalibrate:
        with Timer() as timer_recalibrate:
            S_val1, S_val2, S_test = recalibrate_scores(
                S_val1, y_val1, [S_val1, S_val2, S_test]
            )
    else:
        timer_recalibrate = None

    if use_ts == "recal_isotonic":
        ts = get_optimal_thresholds_recal(S_val2, y_val2, t, method="isotonic")

    elif use_ts == "norecal":
        ts = get_optimal_thresholds_norecal(S_val2, y_val2, U)

    print("Fit predict")
    # u_test_emp_base = u_emp_from_score(S_test, y_test, t, U, return_action=False)
    S_test, timers = fit_predict_clf(
        None,
        None,
        X_val1,
        y_val1,
        X_val2,
        y_val2,
        X_test,
        (S_val1, S_val2, S_test),
        rule,
        extra_features=extra_features,
        return_timers=True,
    )
    print("Emp scores")
    u_test_emp = u_emp_from_score(S_test, y_test, t, U, return_action=False)

    print("Bounds")

    if rule is not None:
        print("GLE metrics")
        with Timer() as timer_rule_gl:
            gle = glest.GLEstimator(
                S_val1,
                partitioner=rule,  # rule is already fit
                random_state=0,
                use_fitted=True,  # avoid refitting the partitioner
            ).fit(X_val1, y_val1, test_data=(X_val2, y_val2, S_val2))


def compute_metrics(partitioner: glest.Partitioner, S, y, X, t):
    gle = glest.GLEstimator(S, partitioner=partitioner, train_size=0.5, random_state=0)
    gle.fit(X, y)
    # fig = gle.plot(fig_kw=dict(figsize=(2.5, 2.5)))

    V = gle.GL(reduce_bin=False) / 2

    C, H = glest_calibration_curve(
        gle.frac_pos_, gle.counts_, gle.mean_scores_, remove_empty=False
    )

    bins = gle.partitioner.bins_
    binids = np.searchsorted(bins[1:-1], S)

    a = (S[:, None] >= t[None, :]).astype(int)
    RCL = compute_regret_CL(C[binids], t, a)  # (n, k)
    RCL = RCL.mean(axis=0)

    bin_counts = gle.counts_.sum(axis=1)

    # When bin_counts is 0, C is NaN, but its ok since no weight is given to it
    C[bin_counts == 0] = 0
    assert not np.isnan(C).any()

    RGL_LB = compute_regret_GL_LB(C, V, t, bin_counts=bin_counts)  # (k,)
    RGL_UB = compute_regret_GL_UB(C, V, t, bin_counts=bin_counts)  # (k,)

    assert not np.isnan(RGL_LB).any()
    assert not np.isnan(RGL_UB).any()

    print(RCL.shape)

    acc = compute_accuracy(y, S, t)  # (k,)
    auc = roc_auc_score(y, S)

    metrics_t = {
        "RCL": RCL,
        "RGL_LB": RGL_LB,
        "RGL_UB": RGL_UB,
        "acc": acc,
    }
    metrics_bins = {
        "V": V,
        "C": C,
        "H": H,
    }
    metrics_one = {
        "auc": auc,
        "E[Y]": np.mean(y),
        "E[C^2]": np.average(np.square(C), weights=bin_counts),
    }
    metrics_one.update(gle.metrics(reduce_bin=True))

    return gle, metrics_t, metrics_bins, metrics_one


def compute_metrics_residuals(partitioner: glest.Partitioner, S, y, X, t):

    # S = S.reshape(-1, 1)  # Ensure S is a 2D array for sklearn
    calibrated_classifier = LogisticRegression()
    X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
        X, y, S, test_size=0.5, random_state=0
    )

    X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
        X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=0
    )

    calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

    c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
    c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

    residuals_train = y_train - c_hat_train
    residuals_test = y_test - c_hat_test
    dt = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)
    dt.fit(X_train, residuals_train)
    leaf_ids = dt.apply(X_test)


    gle = glest.core.GLEstimatorResiduals(None, None)
    gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)
    # fig = gle.plot(fig_kw=dict(figsize=(2.5, 2.5)))

    r_hat = gle.honest_tree_pred
    
    # C, H = glest_calibration_curve(
    #     gle.frac_pos_, gle.counts_, gle.mean_scores_, remove_empty=False
    # )

    # bins = gle.partitioner.bins_
    # binids = np.searchsorted(bins[1:-1], S)
    a = (S_test[:, None] >= t[None, :]).astype(int)
    RCL = compute_regret_CL(c_hat_test, t, a)  # (n, k)
    RCL = RCL.mean(axis=0)


    a = (c_hat_test[:, None] >= t[None, :]).astype(int)
    RGL = compute_regret_CL(c_hat_test + r_hat, t, a)  # (n, k)
    RGL = RGL.mean(axis=0)

    # bin_counts = gle.counts_.sum(axis=1)

    # When bin_counts is 0, C is NaN, but its ok since no weight is given to it
    # C[bin_counts == 0] = 0
    # assert not np.isnan(C).any()

    # RGL_LB = compute_regret_GL_LB(C, V, t, bin_counts=bin_counts)  # (k,)
    # RGL_UB = compute_regret_GL_UB(C, V, t, bin_counts=bin_counts)  # (k,)

    # assert not np.isnan(RGL_LB).any()
    # assert not np.isnan(RGL_UB).any()

    # print(RCL.shape)

    acc = compute_accuracy(y, S, t)  # (k,)
    auc = roc_auc_score(y, S)

    metrics_t = {
        "RCL": RCL,
        "RGL_residuals": RGL,
        # "RGL_LB": RGL_LB,
        # "RGL_UB": RGL_UB,
        "acc": acc,
    }
    # metrics_bins = {
    #     "V": V,
    #     "C": C,
    #     "H": H,
    # }
    metrics_one = {
        "auc": auc,
        "E[Y]": np.mean(y),
        "E[C^2]": np.average(np.square(c_hat_test)),
    }
    metrics_one.update(gle.metrics())

    return gle, metrics_t, metrics_one


def compute_metrics_residuals_normal(partitioner: glest.Partitioner, S, y, X, t):

    # S = S.reshape(-1, 1)  # Ensure S is a 2D array for sklearn
    calibrated_classifier = LogisticRegression()
    X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
        X, y, S, test_size=0.5, random_state=0
    )

    X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
        X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=0
    )

    calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

    c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
    c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

    residuals_train = y_train - c_hat_train
    residuals_test = y_test - c_hat_test
    dt = DecisionTreeRegressor(max_depth = 10, min_samples_leaf= 10)
    dt.fit(X_train, residuals_train)
    leaf_ids = dt.apply(X_test)


    gle = glest.core.GLEstimatorResiduals(None, None)
    gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)
    # fig = gle.plot(fig_kw=dict(figsize=(2.5, 2.5)))

    r_hat = gle.honest_tree_pred
    var_x = gle.var_x
    # C, H = glest_calibration_curve(
    #     gle.frac_pos_, gle.counts_, gle.mean_scores_, remove_empty=False
    # )

    # bins = gle.partitioner.bins_
    # binids = np.searchsorted(bins[1:-1], S)

    a = c_hat_test[:, None] >= t[None, :]
    RCL = compute_regret_CL(c_hat_test, t, a)  # (n, k)
    RCL = RCL.mean(axis=0)
    var_bis = np.zeros_like(c_hat_test)
    var_bis[:] = 0.01
    RGL = compute_regret_CL_normal(c_hat_test, t, a, var_bis, r_hat) # (n, k)
    print("Variance ", var_x) 
    RGL = RGL.mean(axis=0)

    # bin_counts = gle.counts_.sum(axis=1)

    # When bin_counts is 0, C is NaN, but its ok since no weight is given to it
    # C[bin_counts == 0] = 0
    # assert not np.isnan(C).any()

    # RGL_LB = compute_regret_GL_LB(C, V, t, bin_counts=bin_counts)  # (k,)
    # RGL_UB = compute_regret_GL_UB(C, V, t, bin_counts=bin_counts)  # (k,)

    # assert not np.isnan(RGL_LB).any()
    # assert not np.isnan(RGL_UB).any()

    # print(RCL.shape)

    acc = compute_accuracy(y, S, t)  # (k,)
    auc = roc_auc_score(y, S)

    metrics_t = {
        "RCL": RCL,
        "RGL_residuals": RGL,
        # "RGL_LB": RGL_LB,
        # "RGL_UB": RGL_UB,
        "acc": acc,
    }
    # metrics_bins = {
    #     "V": V,
    #     "C": C,
    #     "H": H,
    # }
    metrics_one = {
        "auc": auc,
        "E[Y]": np.mean(y),
        "E[C^2]": np.average(np.square(c_hat_test)),
    }
    metrics_one.update(gle.metrics())

    return gle, metrics_t, metrics_one

def metrics_to_df(metrics_t: dict, metrics_one: dict, t: np.ndarray) -> pd.DataFrame:
    df_t = pd.DataFrame(metrics_t, index=t)
    df_t.index.name = "t"

    df_one = pd.DataFrame(metrics_one, index=[0])

    return df_t, df_one


def compute_regrets(
    ds: ForwardedMixin,
    model: PretrainedMixin,
    rules: Dict[str, glest.Partitioner],
    n_utilities: int = 100,
) -> pd.DataFrame:
    rs = 0
    finetuned = False

    print(f"ds_name: {ds.ds_name}, rule: {rules}")

    (X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

    idx_val = np.concatenate([idx_val1, idx_val2])

    X_val = X[idx_val]
    S_val = S[idx_val]
    y_val = y[idx_val]

    t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
    U = get_constant_utilty(n_utilities, t_target)  # (n_utilities, 2, 2)
    t = get_threshold_from_utility(U)  # (n_utilities,)

    print(X.shape, y.shape, S.shape)

    # Compute the residual metrics before post training
    dfs_t = {}
    dfs_one = {}
    for rule_name, rule in rules.items():
        gle, metrics_t, metrics_one = compute_metrics_residuals(
            rule, S_val, y_val, X_val, t
        )
        df_t, df_one = metrics_to_df(metrics_t, metrics_one, t)
        dfs_t[rule_name] = df_t
        dfs_one[rule_name] = df_one

    return dfs_t, dfs_one

    # gle, metrics_t, metrics_bins, metrics_one = compute_metrics(
    #     rule, S_val, y_val, X_val, t
    # )
    # fig = gle.plot(fig_kw=dict(figsize=(2.5, 2.5)))

    # rows = {"t": t}
    # rows.update(metrics_t)
    # df = pd.DataFrame(rows)

    # for col, val in metrics_one.items():
    #     df[col] = val

    # return df


# model_names = [
#     "cnerg1",
#     "cnerg2",
#     "cnerg3",
#     "cnerg4",
#     "cnerg5",
#     "fb_roberta1",
#     "fb_roberta2",
#     # "mistral_instruct",
# ]


# ds_names = [
#     "hate",
#     "hate_merged_en",
#     "hate_merged_no_en",
#     "merged_hate_check",
#     "hate_merged_en2",
#     "hate_dyn_gen",
#     "hate_merged_no_en2",
#     "hate_merged_large_en",
#     "hate_merged_large_no_en",
#     "hate_merged_large",
# ]


@pytest.mark.parametrize(
    "model_name",
    [
        "cnerg1",
        "cnerg2",
        "cnerg3",
        "cnerg4",
        "cnerg5",
        # "fb_roberta1",
        "fb_roberta2",
        # "mistral_instruct",
    ],
)
@pytest.mark.parametrize(
    # "ds_name",
    # [
    #     # "hate_merged_en",
    #     # "hate_merged_en2",
    #     # "hate_merged_large_en",
    # ],
    "ds_name",
    [
        # "hate",
        # "hate_merged_en",
        # "hate_merged_no_en",
        # "merged_hate_check",
        # "hate_merged_en2",
        # "hate_dyn_gen",
        # "hate_merged_no_en2",
        # "hate_merged_large_en",
        # "hate_merged_large_no_en",
        # "hate_merged_large",
        "hate_en_tweets",
        "hate_en_speech18",
        "hate_en_speech_off",
        "hate_en_davidson",
        "hate_en_gender",
        "hate_en_frenk",
        "hate_en_check",
        "hate_en_twitter",
        "hate_en_open",
    ],
)
def test_compute_one(out, model_name, ds_name):
    # ds_name = "hate"
    # model_name = "cnerg1"
    ds = ds_registry[ds_name]()
    model = model_registry[model_name]()
    ds_name2 = ds_rename[ds_name]
    model_name2 = model_rename[model_name]

    partitioner_est = DecisionTreeRegressor(
        random_state=0,
        max_leaf_nodes=5,
    )
    rule = glest.Partitioner(
        estimator=partitioner_est,
        predict_method="apply",
        n_bins=15,
        strategy="quantile",
        binwise_fit=True,
        verbose=10,
        raise_on_fit_error=True,
    )

    print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})")
    set_latex_font()
    fig, fig2, fig3, fig4, fig5 = compute_one(
        rule, (ds_name, ds), (model_name, model), False, False
    )

    save_fig(fig, out, _name="diagram", ds=ds_name, m=model_name)
    save_fig(fig2, out, _name="regret_cl", ds=ds_name, m=model_name)
    save_fig(fig3, out, _name="regret_gl", ds=ds_name, m=model_name)
    save_fig(fig4, out, _name="regret_both", ds=ds_name, m=model_name)
    save_fig(fig5, out, _name="acc", ds=ds_name, m=model_name)


def get_acc_auc(
    ds: ForwardedMixin,
    model: PretrainedMixin,
    n_utilities: int = 100,
    finetuned: bool = False,
    return_acc: bool = True,
):
    rs = 0
    print(f"ds_name: {ds.ds_name}")

    U = get_constant_utilty(n_utilities)  # (n_utilities, 2, 2)
    t = get_threshold_from_utility(U)  # (n_utilities,)

    (X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

    auc = roc_auc_score(y, S)

    if return_acc:
        acc = compute_accuracy(y, S, t)  # (k,)
        return t, acc, auc

    return auc


def plot_accuracy(ds, model):
    ds_name, ds = ds
    model_name, model = model

    ds_name2 = ds_rename.get(ds_name, ds_name)
    model_name2 = model_rename.get(model_name, model_name)

    t, acc, auc = get_acc_auc(ds, model)

    fig, ax = plt.subplots(figsize=(3, 2))
    ax.plot(t, acc, label=rf"Accuracy (AUC={auc:.4f})")
    ax.set(xlabel=r"Threshold $t^{\star}$")
    ax.set_title(f"{ds_name2} — {model_name2}", y=1.15)
    add_legend(ax, ncol=3)

    print(auc)

    return fig


@pytest.mark.parametrize(
    "model_name",
    [
        "cnerg1",
        "cnerg2",
        "cnerg3",
        # # "cnerg4",
        "cnerg5",
        "fb_roberta1",
        # "fb_roberta2",
        # "mistral_instruct",
    ],
)
@pytest.mark.parametrize(
    "ds_name",
    [
        "hate_merged_en",
        "hate_merged_en2",
        "hate_merged_large_en",
    ],
)
def test_plot_accuracy(out, model_name, ds_name):
    ds = ds_registry[ds_name]()
    model = model_registry[model_name]()
    ds_name2 = ds_rename[ds_name]
    model_name2 = model_rename[model_name]
    set_latex_font()
    print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})")
    fig = plot_accuracy((ds_name, ds), (model_name, model))
    save_fig(fig, out, ds=ds_name, m=model_name)


def compute_auc_table(ds_names: List[str] = None, model_names: List[str] = None):
    if ds_names is None:
        ds_names = ds_registry.keys()
    if model_names is None:
        model_names = model_registry.keys()

    table = np.full((len(ds_names), len(model_names)), np.nan)

    for (i, ds_name), (j, model_name) in itertools.product(
        enumerate(ds_names), enumerate(model_names)
    ):
        model = model_registry[model_name]()
        ds = ds_registry[ds_name]()

        if not isinstance(ds, ForwardedMixin):
            print(f"{ds_name} is not a ForwardedMixin")
            continue

        ds_name2 = ds_rename.get(ds_name, ds_name)
        model_name2 = model_rename.get(model_name, model_name)
        print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})")

        try:
            auc = get_acc_auc(ds, model, return_acc=False)
        except FileNotFoundError:  # Not forwarded yet
            print(f"File not found for {ds_name} — {model_name}")
            continue
        table[i, j] = auc
        print(auc)

    table = pd.DataFrame(table, index=ds_names, columns=model_names)
    return table


def test_auc_all(out):
    # model_names = [
    #     "cnerg1",
    #     "cnerg2",
    #     "cnerg3",
    #     "cnerg5",
    #     "fb_roberta1",
    # ]
    table = compute_auc_table()
    path = save_path(out, "csv", "auc")
    table.to_csv(path)


# def plot_auc_all(df: pd.DataFrame):
#     # Flattening the DataFrame
#     df_flat = df.reset_index().melt(
#         id_vars="index", var_name="Model", value_name="Value"
#     )
#     df_flat["Dataset_Model"] = df_flat["index"] + " x " + df_flat["Model"]

#     df_flat.sort_values(by=["index", "Model"], inplace=True)

#     # unique_datasets = df_flat["index"].unique()
#     # color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
#     # colors = {
#     #     dataset: color_cycle[i % len(color_cycle)]
#     #     for i, dataset in enumerate(unique_datasets)
#     # }
#     # df_flat["Color"] = df_flat["index"].map(colors)

#     # Plotting
#     # fig, ax = plt.subplots(figsize=(2, 8))
#     # ax.axvline(0.5, color="black", lw=0.5)
#     # ax.scatter(df_flat["Value"], df_flat["Dataset_Model"], c=df_flat["Color"])
#     # ax.set_xlabel("ROC AUC")
#     # ax.set_ylabel("Dataset x Model")
#     # # ax.set_title("Scatter Plot of Values by Dataset and Model")

#     # Prepare data for plotting
#     x_values = df.values.flatten()  # Flatten the DataFrame values for x-axis
#     y_models = np.repeat(df.columns, len(df.index))  # Repeat model names
#     y_datasets = np.tile(
#         df.index, len(df.columns)
#     )  # Repeat dataset names for each model

#     # Create plot
#     fig, ax1 = plt.subplots(figsize=(4, 10))

#     # Scatter plot for models
#     ax1.scatter(df_flat["Value"], df_flat["Dataset_Model"], color="b", label="Models")
#     ax1.set_yticks(range(len(y_models)))
#     ax1.set_yticklabels(y_models)
#     ax1.set_ylabel("Model Names")
#     ax1.set_xlabel("Values")

#     # Create twinx for datasets
#     ax2 = ax1.twinx()
#     ax2.scatter(
#         x_values, range(len(x_values)), color="r", label="Datasets", alpha=0
#     )  # Alpha=0 makes them invisible
#     ax2.set_yticks(range(len(y_datasets)))
#     ax2.set_yticklabels(y_datasets)
#     ax2.set_ylabel("Dataset Names")

#     return fig


def plot_auc_all(df: pd.DataFrame, rev: bool = False):
    # if rev:
    #     df = df.T

    # Flattening the DataFrame
    df_flat = df.reset_index().melt(
        id_vars="index", var_name="Model", value_name="Value"
    )
    df_flat["Dataset_Model"] = df_flat["index"] + " x " + df_flat["Model"]

    if rev:
        df_flat.sort_values(by=["Model", "index"], inplace=True)
        hue = "Model"
    else:
        df_flat.sort_values(by=["index", "Model"], inplace=True)
        hue = "index"

    # Drop NaN values
    df_flat = df_flat.dropna()

    fig, ax = plt.subplots(figsize=(4, 10))

    # enable grid
    ax.grid(True, axis="both")

    sns.scatterplot(
        data=df_flat,
        x="Value",
        y="Dataset_Model",
        ax=ax,
        hue=hue,
        legend=False,
        palette="tab10",
        zorder=10,
    )
    ax.axvline(0.5, color="black", lw=0.5)

    ax.set_ylim(-0.5, len(df_flat["Dataset_Model"]) - 0.5)
    ax.set_yticklabels(df_flat["index"])
    ax2 = ax.twinx()
    ax2.set_yticks(ax.get_yticks())
    ax2.set_ylim(ax.get_ylim())
    ax2.set_yticklabels(df_flat["Model"])

    ax.set(xlabel="ROC AUC", ylabel="Dataset")
    ax2.set(ylabel="Model")

    ax.invert_yaxis()
    ax2.invert_yaxis()

    return fig


@pytest.mark.parametrize("rev", [False, True])
def test_plot_auc_all(out, inp, rev):
    model_names = [
        "cnerg1",
        "cnerg2",
        "cnerg3",
        "cnerg5",
        "fb_roberta1",
    ]
    table = pd.read_csv(Path(inp).parent / "auc_all/forwarded.csv", index_col=0)
    # Filter out the columns of table that are included in model_names
    table = table[model_names]

    # Rename each column based on the model_rename dictionary
    table.columns = [
        model_rename.get(model_name, model_name) for model_name in table.columns
    ]
    # Idem for the index
    table.index = [ds_rename.get(ds_name, ds_name) for ds_name in table.index]

    set_latex_font()
    fig = plot_auc_all(table, rev=rev)
    save_fig(fig, out, _name="auc", r=rev)


@pytest.mark.parametrize(
    "model_name",
    [
        "cnerg1",
        "cnerg2",
        "cnerg3",
        "cnerg4",
        "cnerg5",
        "fb_roberta2",
    ],
)
@pytest.mark.parametrize(
    "ds_name",
    [
        "hate_en_tweets",
        "hate_en_speech18",
        "hate_en_speech_off",
        "hate_en_davidson",
        "hate_en_gender",
        "hate_en_frenk",
        "hate_en_check",
        "hate_en_twitter",
        "hate_en_open",
        "hate",
        "hate_merged_en",
        # "hate_merged_no_en",
        # "merged_hate_check",
        "hate_merged_en2",
        "hate_dyn_gen",
        # "hate_merged_no_en2",
        "hate_merged_large_en",
        # "hate_merged_large_no_en",
        # "hate_merged_large",
    ],
)
def test_regret_table(out, inp, model_name, ds_name, skip_existing=False):
    ds = ds_registry[ds_name]()
    model = model_registry[model_name]()
    ds_name2 = ds_rename[ds_name]
    model_name2 = model_rename[model_name]

    print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})")

    kwargs = dict(ds=ds_name, m=model_name)

    partitioners_dict = PartitionerDict()
    partitioner_names = [
        "depth10",
        # "unconstrained",
    ]
    # filter by non existing files
    partitioner_names = [
        p
        for p in partitioner_names
        if not skip_existing
        or not Path(save_path(out, "csv", _name="metrics_t", **kwargs, p=p)).exists()
        or not Path(save_path(out, "csv", _name="metrics_one", **kwargs, p=p)).exists()
    ]

    if not partitioner_names:
        pytest.skip("All files already exist")

    partitioners = {n: partitioners_dict[n] for n in partitioner_names}

    dfs_t, dfs_one = compute_regrets(ds, model, partitioners)

    for partitioner_name, df_t in dfs_t.items():
        path = save_path(out, "csv", _name="metrics_t", **kwargs, p=partitioner_name)
        df_t.to_csv(path)

    for partitioner_name, df_one in dfs_one.items():
        path = save_path(out, "csv", _name="metrics_one", **kwargs, p=partitioner_name)
        df_one.to_csv(path, index=None)


@pytest.mark.parametrize(
    "color_on",
    [
        # "auc",
        # "t",
        # "dt",
        # "CL",
        # "GL",
        "diff_RGL_CL",
    ],
)
def test_plot_rcl_rgl(out, inp, color_on):
    dirpath = Path(inp).parent / "regret_table"

    # get all csv filenames within dirpath
    csv_files = list(dirpath.glob("*.csv"))

    # filter thoses that match the regex pattern p
    p = re.compile(r"regrets__ds=(?P<ds_name>.*):m=(?P<model_name>.*)\.csv")
    csv_files = [f for f in csv_files if p.match(f.name)]

    print(csv_files)

    # read all csv files into a list of dataframes
    dfs = [pd.read_csv(f) for f in csv_files]

    # concatenate all dataframes into a single dataframe
    df = pd.concat(dfs, ignore_index=True)

    print(df)

    # get unique threshold values
    t = df["t"].unique()
    # sort the values
    t = np.sort(t)
    t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]

    # Get the index of the threshold values that are closest to the target threshold values
    t_idx = np.searchsorted(t, t_target)
    t_selected = t[t_idx]

    # filter the dataframe to only include the selected threshold values
    df = df[df["t"].isin(t_selected)]

    print(df)

    df["dt"] = np.abs(df["t"] - 0.5)

    # Filter on AUC >= 0.7
    auc_threshold = 0.8
    df = df[df["auc"] >= auc_threshold]

    # df = df[df["RCL"] <= df["RGL_UB"].max()]

    fig, ax = plt.subplots(figsize=(3, 2))

    if color_on == "auc":
        norm = plt.Normalize(vmin=0.5, vmax=1)
        c = df["auc"]
    elif color_on == "t":
        norm = plt.Normalize(vmin=0, vmax=1)
        c = df["t"]
    elif color_on == "dt":
        norm = plt.Normalize(vmin=0, vmax=0.5)
        c = df["dt"]
    elif color_on == "CL":
        norm = plt.Normalize(vmin=0, vmax=None)
        c = df["CL"]
    elif color_on == "GL":
        norm = plt.Normalize(vmin=0, vmax=None)
        c = df["GL"]
    elif color_on == "diff_RGL_CL":
        norm = plt.Normalize(vmin=None, vmax=None)
        c = df["RGL_UB"] - df["CL"]

    cmap = "RdYlGn"
    ax.scatter(
        df["RGL_UB"],
        df["RCL"],
        c=c,
        cmap=cmap,
        norm=norm,
        edgecolors="white",
        # edgecolors="black",
        linewidths=0.5,
        s=20,
    )

    # add colorbar
    cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax)
    # get colorbar axis
    cax = cbar.ax
    title_dict = {
        "auc": "AUC",
        "dt": r"$|t^{\star} - 0.5|$",
        "t": r"$t^{\star}$",
        "diff_RGL_CL": r"$R_{GL}^{UB} - \mathrm{CL}$",
    }
    cax.set_title(title_dict.get(color_on, color_on))
    if color_on == "auc":
        cax.axhline(auc_threshold, color="black", lw=0.5)

    # sns.scatterplot(data=df, x="RGL_UB", y="RCL", hue="auc", ax=ax)

    ax.set(xlabel=r"$R_{GL}^{UB}$", ylabel=r"$R_{CL}$")

    # xmin, xmax = ax.get_xlim()
    # ax.set_ylim(xmin, xmax)

    with KeepLim(ax):
        ax.plot([0, 1], [0, 1], color="black", lw=0.5, ls="--")

    # ax.set_aspect("equal")

    save_fig(fig, out, c=color_on)


def compute_post_trained(
    ds: ForwardedMixin,
    model: PretrainedMixin,
    post_training: PostTraining,
    rules: Dict[str, glest.Partitioner],
    n_utilities: int = 100,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    rs = 0
    finetuned = isinstance(post_training, FineTuning)

    print(f"ds_name: {ds.ds_name}, post training {post_training}, rules {rules}")

    (X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

    idx_val = np.concatenate([idx_val1, idx_val2])

    X_val = X[idx_val]
    S_val = S[idx_val]
    y_val = y[idx_val]
    X_test = X[idx_test]
    S_test = S[idx_test]
    y_test = y[idx_test]

    # We add some extra values of t we would like in the plots
    t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
    U = get_constant_utilty(n_utilities, t_target)  # (n_utilities, 2, 2)
    t = get_threshold_from_utility(U)  # (n_utilities,)

    # Apply post-training methods
    if not finetuned:  # S is already post trained when finetuned is True
        with Timer("post_training_fit") as timer_post_training_fit:
            post_training.fit(S=S_val, y=y_val, X=X_val)
        with Timer("post_training_predict") as timer_post_training_predict:
            # print(S_test)
            # print("=======================")
            # print(X_test)
            Sp_test = post_training.predict_proba(S_test, X_test)

    # Compute the utility associated with the post trained scores
    u_test_emp = u_emp_from_score(Sp_test, y_test, t, U, return_action=False)
    u_test_emp = u_test_emp.mean(axis=0)  # (n_utilities,)

    df = pd.DataFrame({"u_test_emp_mean": u_test_emp}, index=t)
    df.index.name = "t"

    # Compute the residual metrics after post training
    dfs_t = {}
    dfs_one = {}
    for rule_name, rule in rules.items():
        gle, metrics_t, metrics_one = compute_metrics_residuals(
            rule, Sp_test[:, 0] if np.ndim(Sp_test) == 2 else Sp_test, y_test, X_test, t
        )
        metrics_one.update(timer_post_training_fit.to_dict())
        metrics_one.update(timer_post_training_predict.to_dict())
        df_t, df_one = metrics_to_df(metrics_t, metrics_one, t)
        dfs_t[rule_name] = df_t
        dfs_one[rule_name] = df_one

    # concat the dfs on the columns while prepending each set of cols by the rule name
    # df_residual_t = pd.concat(dfs_t, axis=1, keys=rules.keys())
    # df_residual_one = pd.concat(dfs_one, axis=1, keys=rules.keys())

    # df_t = pd.concat([df, df_residual_t], axis=1)

    return df, dfs_t, dfs_one


@pytest.mark.parametrize(
    "model_name",
    [
        "cnerg1",
        "cnerg2",
        "cnerg3",
        "cnerg4",
        "cnerg5",
        "fb_roberta2",
    ],
)
@pytest.mark.parametrize(
    "ds_name",
    [
        "hate_en_tweets",
        "hate_en_speech18",
        "hate_en_speech_off",
        "hate_en_davidson",
        "hate_en_gender",
        "hate_en_frenk",
        "hate_en_check",
        "hate_en_twitter",
        "hate_en_open",
        "hate",
        "hate_merged_en",
        # "hate_merged_no_en",
        # "merged_hate_check",
        "hate_merged_en2",
        "hate_dyn_gen",
        # "hate_merged_no_en2",
        "hate_merged_large_en",
        # "hate_merged_large_no_en",
        # "hate_merged_large",
    ],
)
@pytest.mark.parametrize(
    "post_training_name",
    [
        # "none",
        # "recal_isotonic",
        # "recal_sigmoid",
        # # "recal_hist10",
        # "recal_hist15",
        # # "recal_scalbin10",
        # "recal_scalbin15",
        # "recal_metacal_mis05",
        # "stack_logistic",
        # "stack_rf",
        # "stack_hgb",
        # "stack_naive_bayes",
        # "stack_gaussian_process",
        # "finetuning_sigmoid",
        # "glar_leaves5",
        # "glar_unconstrained",
        "glart_both2_leaves5",
        "glart_bw2_leaves5",
    ],
)
def test_utility_table(
    out,
    model_name,
    ds_name,
    post_training_name,
    skip_existing=True,
):
    ds = ds_registry[ds_name]()
    model = model_registry[model_name]()
    ds_name2 = ds_rename[ds_name]
    model_name2 = model_rename[model_name]

    post_training = PostTrainingDict()[post_training_name]
    partitioners_dict = PartitionerDict()
    partitioner_names = [
        "depth10",
        # "unconstrained",
    ]
    partitioners = {n: partitioners_dict[n] for n in partitioner_names}

    kwargs = dict(ds=ds_name, m=model_name, t=post_training_name)

    path1 = save_path(out, "csv", _name="utility", **kwargs)
    path1 = Path(path1)
    if skip_existing and path1.exists():
        pytest.skip(f"File {path1} already exists")

    print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})")
    df, dfs_t, dfs_one = compute_post_trained(ds, model, post_training, partitioners)
    df.to_csv(path1)

    for partitioner_name, df_t in dfs_t.items():
        path = save_path(out, "csv", _name="metrics_t", **kwargs, p=partitioner_name)
        df_t.to_csv(path)

    for partitioner_name, df_one in dfs_one.items():
        path = save_path(out, "csv", _name="metrics_one", **kwargs, p=partitioner_name)
        df_one.to_csv(path, index=None)


def merge_table_utility(
    filepaths: List[str],
    keys: List[tuple],
    key_names: List[str],
    ref_post_training: str = "none",
) -> pd.DataFrame:
    dfs = [pd.read_csv(f, index_col=0) for f in filepaths]
    df = pd.concat(dfs, keys=keys)
    df.index.names = key_names + ["t"]

    df_ref = df.xs(ref_post_training, level="post_training")
    diff = df - df_ref

    t = diff.index.get_level_values("t")
    diff_norm = diff.mul(t, axis=0)

    rel_diff = df / df_ref

    return pd.concat(
        [df, diff, diff_norm, rel_diff],
        keys=["abs", "diff", "diff_norm", "rel_diff"],
        axis=1,
    )


def concat_df_t(filepaths: List[str], keys: List[tuple], key_names: List[str]):
    dfs = [pd.read_csv(f, index_col=0) for f in filepaths]
    df = pd.concat(dfs, keys=keys)
    df.index.names = key_names + ["t"]
    return df


def merge_table_metrics_t(
    filepaths: List[str],
    keys: List[tuple],
    key_names: List[str],
    ref_post_training: str = "none",
) -> pd.DataFrame:
    df = concat_df_t(filepaths, keys, key_names)

    df_ref = df.xs(ref_post_training, level="post_training")
    diff = df - df_ref

    t = diff.index.get_level_values("t")
    diff_norm = diff.mul(t, axis=0)

    rel_diff = df / df_ref

    return pd.concat(
        [df, diff, diff_norm, rel_diff],
        keys=["abs", "diff", "diff_norm", "rel_diff"],
        axis=1,
    )


def concat_df_one(filepaths: List[str], keys: List[tuple], key_names: List[str]):
    dfs = [pd.read_csv(f) for f in filepaths]
    df = pd.concat(dfs, keys=keys)
    # artifact of index column in some csv files
    df = df.drop("Unnamed: 0", axis=1, errors="ignore")
    df.index = df.index.droplevel(len(key_names))
    df.index.names = key_names
    return df


def merge_table_metrics_one(
    filepaths: List[str],
    keys: List[tuple],
    key_names: List[str],
    ref_post_training: str = "none",
) -> pd.DataFrame:
    df = concat_df_one(filepaths, keys, key_names)

    df["post_training_total_cpu_time"] = (
        df["post_training_fit_cpu_time"] + df["post_training_predict_cpu_time"]
    )
    df["post_training_total_wall_time"] = (
        df["post_training_fit_wall_time"] + df["post_training_predict_wall_time"]
    )

    df_ref = df.xs(ref_post_training, level="post_training")

    # perform df - df_ref only on numerical columns
    df_num = df.select_dtypes(include=[np.number])
    df_ref_num = df_ref.select_dtypes(include=[np.number])
    diff = df_num - df_ref_num
    rel_diff = df_num / df_ref_num

    return pd.concat([df, diff, rel_diff], axis=1, keys=["abs", "diff", "rel_diff"])


def merge_tables_utility(
    dirpath,
    ref_post_training: str = "none",
    ref_partitioner: str = None,
    merge_both: bool = True,
):
    # get all csv filenames within dirpath
    print(dirpath)
    csv_files = list(dirpath.glob("*.csv"))
    print(csv_files)
    # filter thoses that match the regex pattern p
    p = re.compile(
        r"utility__ds=(?P<ds_name>.*):m=(?P<model_name>.*):t=(?P<post_training_name>.*)\.csv"
    )
    print(p)
    filepaths, keys = zip(
        *[(f, r.groupdict()) for f in csv_files if (r := p.match(f.name))]
    )
    keys = [tuple(k.values()) for k in keys]

    df_utility = merge_table_utility(
        filepaths, keys, ["ds", "model", "post_training"], ref_post_training
    )

    p = re.compile(
        r"metrics_t__ds=(?P<ds_name>.*):m=(?P<model_name>.*):t=(?P<post_training_name>.*):p=(?P<partitioner_name>.*)\.csv"
    )
    filepaths, keys = zip(
        *[(f, r.groupdict()) for f in csv_files if (r := p.match(f.name))]
    )
    keys = [tuple(k.values()) for k in keys]

    df_metrics_t = merge_table_metrics_t(
        filepaths,
        keys,
        ["ds", "model", "post_training", "partitioner"],
        ref_post_training,
    )

    p = re.compile(
        r"metrics_one__ds=(?P<ds_name>.*):m=(?P<model_name>.*):t=(?P<post_training_name>.*):p=(?P<partitioner_name>.*)\.csv"
    )
    filepaths, keys = zip(
        *[(f, r.groupdict()) for f in csv_files if (r := p.match(f.name))]
    )
    keys = [tuple(k.values()) for k in keys]

    df_metrics_one = merge_table_metrics_one(
        filepaths,
        keys,
        ["ds", "model", "post_training", "partitioner"],
        ref_post_training,
    )

    # print(df_metrics_one.columns)

    if ref_partitioner is not None:
        # Filter on the reference partitioner
        df_metrics_t = df_metrics_t.xs(
            ref_partitioner, level="partitioner", drop_level=True
        )
        df_metrics_one = df_metrics_one.xs(
            ref_partitioner, level="partitioner", drop_level=True
        )
        on = ["ds", "model", "post_training"]

    else:
        on = ["ds", "model", "post_training", "partitioner"]

    if not merge_both:
        return df_utility, df_metrics_t, df_metrics_one

    df_metrics = pd.merge(
        df_metrics_t.reset_index(),
        df_metrics_one.reset_index(),
        on=on,
        how="outer",
    )
    df_metrics = df_metrics.set_index(df_metrics_t.index.names)

    # print(df_metrics.columns)
    df_utility.to_csv("/decision_suboptimal_classifiers/aggregated_utility.csv")
    return df_utility, df_metrics


def merge_tables_init(
    dirpath: str, ref_partitioner: str | None = None, merge_both: bool = True
):
    # get all csv filenames within dirpath
    csv_files = list(dirpath.glob("*.csv"))

    # filter thoses that match the regex pattern p
    p = re.compile(
        r"metrics_t__ds=(?P<ds_name>.*):m=(?P<model_name>.*):p=(?P<partitioner_name>.*)\.csv"
    )
    filepaths, keys = zip(
        *[(f, r.groupdict()) for f in csv_files if (r := p.match(f.name))]
    )
    keys = [tuple(k.values()) for k in keys]

    df_metrics_t = concat_df_t(filepaths, keys, ["ds", "model", "partitioner"])

    p = re.compile(
        r"metrics_one__ds=(?P<ds_name>.*):m=(?P<model_name>.*):p=(?P<partitioner_name>.*)\.csv"
    )
    filepaths, keys = zip(
        *[(f, r.groupdict()) for f in csv_files if (r := p.match(f.name))]
    )
    keys = [tuple(k.values()) for k in keys]

    df_metrics_one = concat_df_one(filepaths, keys, ["ds", "model", "partitioner"])

    if ref_partitioner is not None:
        df_metrics_t = df_metrics_t.xs(
            ref_partitioner, level="partitioner", drop_level=True
        )
        df_metrics_one = df_metrics_one.xs(
            ref_partitioner, level="partitioner", drop_level=True
        )
        on = ["ds", "model"]
    else:
        on = ["ds", "model", "partitioner"]

    if not merge_both:
        return df_metrics_t, df_metrics_one

    df_metrics = pd.merge(
        df_metrics_t.reset_index(),
        df_metrics_one.reset_index(),
        on=on,
        how="outer",
    )
    df_metrics = df_metrics.set_index(df_metrics_t.index.names)

    return df_metrics


def merge_tables(
    dirpath_utility: str,
    dirpath_init: str,
    ref_post_training: str,
    ref_partitioner: str | None = None,
    t_target: List[float] | None = None,
    collapse_cols_index: bool = True,
    auc_threshold: float | None = None,
):
    df_utility, df_residual = merge_tables_utility(
        dirpath_utility,
        ref_post_training=ref_post_training,
        ref_partitioner=ref_partitioner,
        merge_both=True,
    )
    df_metrics = merge_tables_init(dirpath_init, ref_partitioner, merge_both=True)

    # add outer level to df_metrics columns index
    df_metrics.columns = pd.MultiIndex.from_product([["abs"], df_metrics.columns])

    # in df_utility, reset the "post_training" level from the index and put it as a column
    df_utility = df_utility.reset_index()
    df_utility = df_utility.set_index(["ds", "model", "t"])

    print(df_residual.columns)

    df = pd.merge(
        df_metrics,
        df_utility,
        left_index=True,
        right_index=True,
        how="outer",
    )

    # print(df.columns)

    df = df.set_index(("post_training", ""), append=True)
    df.index = df.index.rename("post_training", level=-1)

    if t_target is not None:
        t_unique = np.sort(df.index.get_level_values("t").unique())

        # Get the index of the threshold values that are closest to the target threshold values
        t_selected = t_unique[np.searchsorted(t_unique, t_target)]

        # filter the dataframe to only include the selected threshold values in the index level "t"
        print(df)
        print(t_selected)
        print(df.shape)
        print(np.mean(df.index.get_level_values("t").isin(t_selected)))
        df = df[df.index.get_level_values("t").isin(t_selected)]
        print(df.shape)
        print(df)

    if auc_threshold is not None:
        # filter on auc >= auc_threshold
        df = df[df[("abs", "auc")] >= auc_threshold]

    df[("abs", "dt")] = np.abs(df.index.get_level_values("t") - 0.5)
    df[("abs", "t")] = df.index.get_level_values("t")
    df[("abs", "RGL_mean")] = (df[("abs", "RGL_residuals")])
    # df[("abs", "RCL_RGL_UB")] = df[("abs", "RCL")] + df[("abs", "RGL_UB")]
    # df[("abs", "RCL_RGL_LB")] = df[("abs", "RCL")] + df[("abs", "RGL_LB")]
    df[("abs", "RCL_RGL_mean")] = df[("abs", "RCL")] + df[("abs", "RGL_mean")]
    df[("abs", "RCL_RGL_mean_ratio")] = (
        df[("abs", "RGL_mean")] / df[("abs", "RCL_RGL_mean")]
    )
    df[("abs", "event_rate_dist")] = (df[("abs", "E[Y]")] - df[("abs", "t")]).abs()
    df[("abs", "event_rate2")] = df[("abs", "E[C^2]")] / df[("abs", "E[Y]")]
    df[("abs", "event_rate2_dist")] = (
        df[("abs", "event_rate2")] - df[("abs", "t")]
    ).abs()

    # df_residual[("abs", "post_training_total_cpu_time")] = (
    #     df_residual[("abs", "post_training_fit_cpu_time")]
    #     + df_residual[("abs", "post_training_predict_cpu_time")]
    # )
    # df_residual[("abs", "post_training_total_wall_time")] = (
    #     df_residual[("abs", "post_training_fit_wall_time")]
    #     + df_residual[("abs", "post_training_predict_wall_time")]
    # )

    if collapse_cols_index:
        # collapse the two levels of the columns by joining them with "_"
        df.columns = ["_".join(col) for col in df.columns]
        df_residual.columns = ["_".join(col) for col in df_residual.columns]


    df.to_csv(Path(dirpath_utility) / "full_dataframe.csv")
    return df, df_residual


method_rename = {
    "none": "No post-training",
    # "recal_isotonic": "Recalibration (isotonic)",
    # "finetuning_sigmoid": "Fine-tuning (sklearn)",
    "recal_isotonic": "Isotonic",
    # "recal_sigmoid": r"Recalibration ($\sigma$)",
    "recal_sigmoid": "Platt",
    "recal_hist15": "Histogram",
    "recal_hist10": "Histogram",
    "recal_scalbin15": "Scal.-Bin.",
    "recal_scalbin10": "Scal.-Bin.",
    # "recal_scalbin15": "Scaling-Binning",
    # "recal_scalbin10": "Scaling-Binning",
    "recal_metacal_mis05": "Meta-Cal",
    "finetuning_sigmoid": "Fine-tuning",
    "stack_logistic": "Stack Logistic",
    "stack_rf": "Stack RF",
    "stack_hgb": "Stack HGB",
    "stack_naive_bayes": "Stack NaiveBayes",
    "stack_gaussian_process": "Stack GaussianProcess",
    "glar_leaves5": "GLAR",
    "glar_unconstrained": "GLAR unconstrained",
    "glart_both2_leaves5": "GLAR",  # "GLART",
}

methods_selected = [
    "recal_isotonic",
    "recal_sigmoid",
    "recal_hist15",
    "recal_scalbin15",
    "recal_metacal_mis05",
    "finetuning_sigmoid",
    "stack_hgb",
    "stack_rf",
    "glart_both2_leaves5",
    "glar_leaves5",
    # "glar_unconstrained",
]

_cycler = itertools.cycle(plt.cm.tab10.colors)
# Add black color as first color of the cycle
_cycler = itertools.chain(["black"], _cycler)
# colors = {m: next(cycler) for m in methods}
method_colors = {m: next(_cycler) for m in methods_selected}
method_rename_colors = {method_rename.get(m, m): c for m, c in method_colors.items()}


@pytest.mark.parametrize(
    "color_on",
    [
        # "auc",
        # "t",
        # "dt",
        # "CL",
        # "GL",
        # "diff_RGL_RCL",
        "RGL_UB",
        # "RCL",
    ],
)
def test_merge_tables(out, inp, color_on):
    # ref_partitioner = "unconstrained"
    ref_partitioner = "leaves5"
    ref_post_training = "none"
    # ref_post_training = "recal_isotonic"
    method_x = "recal_isotonic"
    method_y = "finetuning_sigmoid"
    # method_y = "stack_hgb"
    # method_y = "glar_leaves5"
    auc_threshold = 0.75

    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = None

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
    )

    print(df)

    t = df.index.get_level_values("t")

    # collapse the two levels of the columns by joining them with "_"
    df.columns = ["_".join(col) for col in df.columns]
    print(df)

    df["abs_dt"] = np.abs(t - 0.5)
    df["abs_t"] = t

    print(df)

    # filter on abs_auc >= 0.7
    df = df[df["abs_auc"] >= auc_threshold]

    # count unique values in level "post_training" of the index
    print(df.index.get_level_values("post_training").value_counts())

    df1 = df.xs(method_x, axis=0, level="post_training")
    # df1 = df.xs("glar_leaves5", axis=0, level="post_training")
    df2 = df.xs(method_y, axis=0, level="post_training")

    print(df1)
    print(df2)

    # return

    inter = df1.index.intersection(df2.index)
    df1 = df1.loc[inter]
    df2 = df2.loc[inter]

    set_latex_font()
    fig, ax = plt.subplots(figsize=(4, 3))
    # fig.patch.set_facecolor("lightgray")
    # ax.set_facecolor("red")

    cmap = "Reds_r"
    norm = plt.Normalize()
    # c = df1["abs_RGL_UB"] - df1["abs_RCL"]
    # c = df1["abs_RCL"]
    c = df1["abs_RGL_UB"]

    if color_on == "auc":
        norm = plt.Normalize(vmin=0.5, vmax=1)
        c = df1["abs_auc"]
    elif color_on == "t":
        norm = plt.Normalize(vmin=0, vmax=1)
        c = df1["abs_t"]
    elif color_on == "dt":
        norm = plt.Normalize(vmin=0, vmax=0.5)
        c = df1["abs_dt"]
    elif color_on == "CL":
        norm = plt.Normalize(vmin=0, vmax=None)
        c = df1["abs_CL"]
    elif color_on == "GL":
        norm = plt.Normalize(vmin=0, vmax=None)
        c = df1["abs_GL"]
    elif color_on == "diff_RGL_RCL":
        norm = plt.Normalize(vmin=None, vmax=None)
        c = df1["abs_RGL_UB"] - df1["abs_RCL"]
    elif color_on == "RGL_UB":
        norm = plt.Normalize(vmin=None, vmax=None)
        c = df1["abs_RGL_UB"]
    elif color_on == "RCL":
        norm = plt.Normalize(vmin=None, vmax=None)
        c = df1["abs_RCL"]

    title_dict = {
        "auc": "AUC",
        "dt": r"$|t^{\star} - 0.5|$",
        "t": r"$t^{\star}$",
        "diff_RGL_RCL": r"$R_{GL}^{UB} - R_{CL}$",
        "RGL_UB": r"$R_{GL}^{UB}$",
        "RCL": r"$R_{CL}$",
    }

    # ax.scatter(
    #     df1["diff_norm_u_test_emp_mean"],
    #     df2["diff_norm_u_test_emp_mean"],
    #     c=c,
    #     cmap=cmap,
    #     norm=norm,
    #     edgecolors="black",
    #     linewidths=0.2,
    #     s=30,
    # )
    ax.scatter(
        c,
        df2["diff_norm_u_test_emp_mean"] - df1["diff_norm_u_test_emp_mean"],
        c=df1["abs_dt"],
        cmap=cmap,
        norm=norm,
        edgecolors="white",
        linewidths=0.5,
        s=30,
    )
    # with KeepLim(ax):
    #     ax.plot([0, 1], [0, 1], ls="--", color="black", lw=0.5)

    # ax.set(
    #     xlabel=f"Gain using {method_rename.get(method_x, method_x)}",
    #     ylabel=f"Gain using {method_rename.get(method_y, method_y)}",
    # )
    ax.set(
        xlabel=title_dict.get(color_on, color_on),
        ylabel=f"{method_rename.get(method_y, method_y)} - {method_rename.get(method_x, method_x)}",
    )
    cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax)
    cax = cbar.ax
    # cax.set_title(title_dict.get(color_on, color_on))

    # ref_rename = {
    #     "none": "No post-training",
    #     "recal_isotonic": "Recalibration (isotonic)",
    # }
    ax.set_title(
        f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
    )
    # ax.set_aspect("equal")

    # ax.set_xlim(None, 0.05)
    # ax.set_ylim(None, 0.05)

    save_fig(fig, out, c=color_on)


def test_plot_gain_vs_regret(out, inp):
    ref_partitioner = "leaves5"
    ref_post_training = "none"
    # auc_threshold = 0.75
    auc_threshold = 0.5
    t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    _df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    m = "recal_isotonic"
    df = _df.xs(m, axis=0, level="post_training")

    set_latex_font()

    fig, ax = plt.subplots(figsize=(3, 2))
    ax.scatter(
        df["abs_RCL"],
        df["diff_norm_u_test_emp_mean"],
        edgecolors="white",
        linewidths=0.5,
    )
    ax.set(xlabel=r"$R_{CL}$", ylabel=f"{method_rename.get(m, m)} gain")
    with KeepLim(ax):
        ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
    save_fig(fig, out, "recal_rcl")

    m = "finetuning_sigmoid"
    df = _df.xs(m, axis=0, level="post_training")
    fig, ax = plt.subplots(figsize=(3, 2))
    cmap = "Reds_r"
    norm = plt.Normalize()
    ax.scatter(
        df["abs_RCL"] + df["abs_RGL_UB"],
        df["diff_norm_u_test_emp_mean"],
        edgecolors="white",
        c=df["abs_auc"],
        cmap=cmap,
        norm=norm,
        linewidths=0.1,
    )
    cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax)
    cax = cbar.ax
    cax.set_title("$\mathrm{AUC}$")
    ax.set(xlabel=r"$R_{CL} + R_{GL}^{UB}$", ylabel=f"{method_rename.get(m, m)} gain")
    with KeepLim(ax):
        ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
    save_fig(fig, out, "finetuning_rcl")

    fig, ax = plt.subplots(figsize=(3, 2))
    ax.scatter(
        df["abs_RCL"],
        df["diff_norm_u_test_emp_mean"],
        edgecolors="white",
        linewidths=0.5,
    )
    ax.set(xlabel=r"$R_{CL}$", ylabel=f"{method_rename.get(m, m)} gain")
    with KeepLim(ax):
        ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
    save_fig(fig, out, "finetuning_rcl_control")

    fig, ax = plt.subplots(figsize=(3, 2))
    ax.scatter(
        df["abs_auc"],
        df["diff_norm_u_test_emp_mean"],
        edgecolors="white",
        linewidths=0.1,
    )
    ax.set(xlabel=r"$\mathrm{AUC}$", ylabel=f"{method_rename.get(m, m)} gain")
    ax.set_xlim(None, 1)
    save_fig(fig, out, "finetuning_auc")

    fig, ax = plt.subplots(figsize=(3, 2))
    ax.scatter(
        df["abs_acc"],
        df["diff_norm_u_test_emp_mean"],
        edgecolors="white",
        linewidths=0.1,
    )
    ax.set(xlabel=r"$\mathrm{Acc}$", ylabel=f"{method_rename.get(m, m)} gain")
    ax.set_xlim(None, 1)
    save_fig(fig, out, "finetuning_acc")


# @pytest.mark.parametrize(
#     "post_training",
#     [
#         "recal_isotonic",
#         "recal_sigmoid",
#         "stack_logistic",
#         "stack_rf",
#         "stack_hgb",
#         "stack_naive_bayes",
#         "stack_gaussian_process",
#         "finetuning_sigmoid",
#         "glar_leaves5",
#         "glar_unconstrained",
#     ],
# )
# @pytest.mark.parametrize(
#     "against",
#     [
#         ("abs_RCL", r"$R_{CL}$"),
#     ],
# )
@pytest.mark.parametrize(
    "ref_post_training",
    [
        "none",
        # "recal_isotonic",
    ],
)
def test_plot_gain_vs_post_training(out, inp, ref_post_training):
    ref_partitioner = "leaves5"
    # ref_post_training = "none"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    _df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    def plot_one(m, x):
        x_label = metrics_rename.get(x, x)
        df = _df.xs(m, axis=0, level="post_training")

        set_latex_font()
        fig, ax = plt.subplots(figsize=(3, 2))

        if x == "abs_RCL_RGL":
            df_x = df["abs_RCL"] + df["abs_RGL_UB"]
        else:
            df_x = df[x]

        ax.scatter(
            df_x,
            df["diff_norm_u_test_emp_mean"],
            edgecolors="white",
            linewidths=0.5,
            color=method_colors.get(m, "tab:blue"),
        )
        ax.set(xlabel=x_label, ylabel=f"Gain of {method_rename.get(m, m)}")
        ax.set_title(
            f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        )
        if "abs_R" in x:
            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
        save_fig(fig, out, r=ref_post_training, y=m, x=x, pad_inches=0.01)

    post_trainings = [
        "recal_isotonic",
        # "recal_sigmoid",
        # "stack_logistic",
        # "stack_rf",
        # "stack_hgb",
        # "stack_naive_bayes",
        # "stack_gaussian_process",
        # "finetuning_sigmoid",
        # "glar_leaves5",
        # "glar_unconstrained",
    ]
    # against = [
    #     ("abs_auc", r"$\mathrm{AUC}$"),
    #     ("abs_acc", r"$\mathrm{Acc}$"),
    #     ("abs_t", r"$t^{\star}$"),
    #     ("abs_dt", r"$|t^{\star} - 0.5|$"),
    #     ("abs_CL", r"$\mathrm{CL}$"),
    #     ("abs_GL", r"$\mathrm{GL}$"),
    #     ("abs_RGL_UB", r"$R_{GL}^{UB}$"),
    #     ("abs_RCL", r"$R_{CL}$"),
    #     ("abs_RCL_RGL", r"$R_{CL} + R_{GL}^{UB}$"),
    # ]
    against = [
        "abs_auc",
        # "abs_acc",
        # "abs_t",
        # "abs_dt",
        # "abs_CL",
        # "abs_GL",
        # "abs_RGL_UB",
        # "abs_RCL",
        # "abs_RCL_RGL",
    ]

    for post_training, x in itertools.product(post_trainings, against):
        plot_one(post_training, x)


metrics_rename = {
    "abs_auc": r"$\mathrm{AUC}$",
    "abs_acc": r"$\mathrm{Acc}$",
    "abs_t": r"$t^{\star}$",
    "abs_dt": r"$|t^{\star} - 0.5|$",
    "abs_CL": r"$\mathrm{CL}$",
    "abs_ECE": r"$\mathrm{ECE}$",
    "abs_MCE": r"$\mathrm{MCE}$",
    "abs_RMSCE": r"$\mathrm{RMSCE}$",
    "abs_GL": r"$\mathrm{GL}$",
    "abs_RCL": r"$\smash{\hat{R}}_{f\!,t^{\star}}^{\mathrm{CL}}$",
    "abs_RCL_RGL": r"$\smash{\hat{R}}_{f\!,t^{\star}}^{\mathrm{CL}} + R_{GL}^{UB}$",
    "abs_L": r"Brier",
    "abs_RGL_UB": r"$R_{\mathrm{GL}}^{\mathrm{UB}}$",
    "abs_RGL_LB": r"$R_{\mathrm{GL}}^{\mathrm{LB}}$",
    "abs_RGL_mean": r"$\hat{R}^{\mathrm{GL}}_{f}$",
    "abs_RCL_RGL_mean": r"$\smash{\hat{R}}_{f\!,t^{\star}}^{\mathrm{CL}} + \smash{\hat{R}}^{\mathrm{GL}}_{f}$",
    "abs_RCL_RGL_LB": r"$\smash{\hat{R}}_{f\!,t^{\star}}^{\mathrm{CL}} + L_h^{\mathrm{GL}}$",
    "abs_RCL_RGL_UB": r"$\smash{\hat{R}}_{f\!,t^{\star}}^{\mathrm{CL}} + U_h^{\mathrm{GL}}$",
    "abs_event_rate_dist": r"$|\mathbb{E}[Y] - t^{\star}|$",
}


def test_fig_gain_recal_vs_cal_metrics(out, inp):
    """FIG1"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "none"
    post_training = "recal_isotonic"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    # t_target1 = np.linspace(0, 1, 102)[2:-2]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # print(df)
    # return

    # Select a post-training method
    df = df.xs(post_training, axis=0, level="post_training")

    print(df)
    # print(df.columns)

    # print(
    #     df[["abs_u_test_emp_mean", "diff_u_test_emp_mean", "diff_norm_u_test_emp_mean"]]
    # )

    x = "abs_ECE"
    # m = post_training

    def plot_one(m, x):
        x_label = f"a. {metrics_rename.get(x, x)}"
        y = "diff_norm_u_test_emp_mean"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(2.2, 1.5))

        # if x == "abs_RCL_RGL":
        #     df_x = df["abs_RCL"] + df["abs_RGL_UB"]
        # else:

        if x in ["abs_t", "abs_dt"]:
            # group _df by t and average the values of x and y
            _df = df[[x, y]].groupby("t").mean()
        else:
            _df = df

        _df_x = _df[x]
        _df_y = _df[y]

        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.5,
            color=method_colors.get(m, "tab:blue"),
            label=f"{method_rename.get(m, m)}",
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x:
            # inset Axes....
            x1, x2, y1, y2 = (
                -0.005,
                0.05,
                -0.005,
                0.05,
            )  # subregion of the original image
            axins = ax.inset_axes(
                [0.55, 0.05, 0.4, 0.4],
                xlim=(x1, x2),
                ylim=(y1, y2),
                xticklabels=[],
                yticklabels=[],
            )
            axins.scatter(
                _df_x,
                _df["diff_norm_u_test_emp_mean"],
                edgecolors="white",
                linewidths=0.5,
                color=method_colors.get(m, "tab:blue"),
            )
            axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(xlabel=x_label, ylabel=f"Recalibration gain")
        # ax.set(xlabel=x_label, ylabel=f"Recalibration gain ({method_rename.get(m, m)})")
        # Add spearman correlation r on the upper center of the axis
        ha = "left"
        va = "top"
        xpos = 0.04
        ypos = 0.96
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            f"Correlation: $r^2 = {np.square(r.correlation):.2f}$",
            ha=ha,
            va=va,
            transform=ax.transAxes,
        )

        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.0),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )
        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(1, 0),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(0.97, -0.03),
            textcoords="axes fraction",
            ha="left",
            va="top",
        )

        # ax.legend(loc="upper left", title="Recalibration")
        # apply a yshift to the legend
        # ax.legend(loc="upper left", title="Recalibration", bbox_to_anchor=(0, 0.9))
        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.01)

    for x in [
        # "abs_auc",
        # "abs_acc",
        # "abs_ECE",
        # "abs_MCE",
        # "abs_RMSCE",
        # "abs_CL",
        "abs_RCL",
        # "abs_L",
        # "abs_event_rate_dist",
    ]:
        plot_one(post_training, x)


def test_fig_gain_recal_vs_cal_metrics2(out, inp):
    """FIG1"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "none"
    post_training = "recal_isotonic"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    # t_target1 = np.linspace(0, 1, 102)[2:-2]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # print(df)
    # return

    # Select a post-training method
    df = df.xs(post_training, axis=0, level="post_training")

    print(df)
    # print(df.columns)

    # print(
    #     df[["abs_u_test_emp_mean", "diff_u_test_emp_mean", "diff_norm_u_test_emp_mean"]]
    # )

    x = "abs_ECE"
    # m = post_training

    def plot_one(m, x):
        x_label = f"a. {metrics_rename.get(x, x)}"
        y = "diff_norm_u_test_emp_mean"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(1.7, 1.5))
        # fig, ax = plt.subplots(figsize=(1.8, 1.5))

        # if x == "abs_RCL_RGL":
        #     df_x = df["abs_RCL"] + df["abs_RGL_UB"]
        # else:

        if x in ["abs_t", "abs_dt"]:
            # group _df by t and average the values of x and y
            _df = df[[x, y]].groupby("t").mean()
        else:
            _df = df

        _df_x = _df[x]
        _df_y = _df[y]

        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.3,
            color=method_colors.get(m, "tab:blue"),
            label=f"{method_rename.get(m, m)}",
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x:
            # inset Axes....
            x1, x2, y1, y2 = (
                -0.005,
                0.05,
                -0.005,
                0.05,
            )  # subregion of the original image
            axins = ax.inset_axes(
                [0.55, 0.05, 0.4, 0.4],
                xlim=(x1, x2),
                ylim=(y1, y2),
                xticklabels=[],
                yticklabels=[],
            )
            axins.scatter(
                _df_x,
                _df["diff_norm_u_test_emp_mean"],
                edgecolors="white",
                linewidths=0.3,
                color=method_colors.get(m, "tab:blue"),
            )
            axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(xlabel=x_label, ylabel="Recalibration gain")
        # ax.set(xlabel=x_label, ylabel=f"Recalibration gain ({method_rename.get(m, m)})")
        # Add spearman correlation r on the upper center of the axis
        ha = "left"
        va = "top"
        xpos = 0.04
        ypos = 0.96
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            # rf"Correlation:\\$r^2 = {np.square(r.correlation):.2f}$",
            rf"$\begin{{array}}{{l}}\text{{Correlation:}}\\[-1mm] r^2 = {np.square(r.correlation):.2f}\end{{array}}$",
            ha=ha,
            va=va,
            transform=ax.transAxes,
        )

        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.0),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )
        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(1, 0),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            # xytext=(0.97, -0.03),
            # xytext=(0.83, -0.15),
            xytext=(1, -0.15),
            textcoords="axes fraction",
            ha="right",
            va="top",
        )

        # ax.legend(loc="upper left", title="Recalibration")
        # apply a yshift to the legend
        # ax.legend(loc="upper left", title="Recalibration", bbox_to_anchor=(0, 0.9))
        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.01)

    for x in [
        # "abs_auc",
        # "abs_acc",
        # "abs_ECE",
        # "abs_MCE",
        # "abs_RMSCE",
        # "abs_CL",
        "abs_RCL",
        # "abs_L",
        # "abs_event_rate_dist",
    ]:
        plot_one(post_training, x)


def test_fig_correlation(out, inp):
    """FIG1b"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "none"
    m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method

    def get_corr(m, x):
        _df = df.xs(m, axis=0, level="post_training")
        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        r = scipy.stats.spearmanr(_df_x, _df_y)
        return np.square(r.correlation)

    methods = [
        "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
    ]

    corrs = []

    for m in methods:
        corr = {}
        for x in [
            "abs_auc",
            "abs_acc",
            "abs_ECE",
            "abs_MCE",
            "abs_RMSCE",
            "abs_CL",
            "abs_RCL",
            "abs_L",
        ]:
            corr[x] = get_corr(m, x)

        corrs.append(pd.DataFrame(corr, index=[m]))

    # merge the list of series into a dataframe by repeating the series name
    # corrs = pd.concat(corrs.values(), keys=corrs.keys(), axis=0)
    corrs = pd.concat(corrs, axis=0)

    # set index col name to "post_training"
    corrs.index.name = "post_training"

    # transpose
    corrs = corrs.T

    # SOrt the rows by ascending average values of the columns
    corrs = corrs.reindex(corrs.mean(axis=1).sort_values().index)

    print(corrs)

    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]

    # melt the dataframe
    corrs = corrs.reset_index().melt(
        id_vars="index", var_name="post_training", value_name="r2"
    )

    print(corrs)

    # return

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.7, 1.5))

    # Group by "post_training" and plot the correlation values using matplotlib
    for m, group in corrs.groupby("post_training", sort=False):
        ax.scatter(
            group["r2"],
            group["index"],
            edgecolors="white",
            linewidths=0.5,
            label=method_rename.get(m, m),
            color=method_colors.get(m, "tab:blue"),
            # marker="+",
            # marker=".",
        )

    n_metrics = corrs["index"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics - 1, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    ax.set_ylim(-0.5, n_metrics - 0.5)

    ax.set(xlabel=r"b. Correlation $r^2$ to recalibration gain")
    # ax.set(
    #     xlabel=r"$\begin{array}{c}\text{b. Correlation } r^2 \\[-1mm] \text{to recalibration gain}\end{array}$"
    # )
    ax.legend(loc="lower right", title="Recalibration")

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # sns.scatterplot(data=corrs, x="r2", y="index", hue="post_training", ax=ax)
    save_fig(fig, out, "recal_correlation", pad_inches=0.01)
    return

    # return

    # # convert corss to datagrame
    # corrs = pd.Series(corrs)

    # # sort by values
    # corrs = corrs.sort_values(ascending=True)

    set_latex_font()
    fig, ax = plt.subplots(figsize=(3, 2))

    # create a horizontal bar plot with the values of corrs and the keys of corrs as ylabel
    # rename the keys to the corresponding metric names
    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]
    ax.barh(corrs.index, corrs.values, color="tab:green")
    ax.set(xlabel=r"$r^2$")  # , ylabel="Metrics")
    ax.set_title(f"Correlation to recalibration gain ({method_rename.get(m, m)})")

    # use seaborn to do a horizontal bar plot with "post_training" as hue
    sns.barplot(data=corrs, x=corrs.values, y=corrs.index, hue="post_training")

    save_fig(fig, out, r=ref_post_training, y=m)


def test_fig_correlation2(out, inp):
    """FIG1b"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "none"
    m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method

    def get_corr(m, x):
        _df = df.xs(m, axis=0, level="post_training")
        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        r = scipy.stats.spearmanr(_df_x, _df_y)
        return np.square(r.correlation)

    methods = [
        "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
    ]

    corrs = []

    for m in methods:
        corr = {}
        for x in [
            "abs_auc",
            "abs_acc",
            "abs_ECE",
            "abs_MCE",
            "abs_RMSCE",
            "abs_CL",
            "abs_RCL",
            "abs_L",
        ]:
            corr[x] = get_corr(m, x)

        corrs.append(pd.DataFrame(corr, index=[m]))

    # merge the list of series into a dataframe by repeating the series name
    # corrs = pd.concat(corrs.values(), keys=corrs.keys(), axis=0)
    corrs = pd.concat(corrs, axis=0)

    # set index col name to "post_training"
    corrs.index.name = "post_training"

    # transpose
    corrs = corrs.T

    # SOrt the rows by ascending average values of the columns
    corrs = corrs.reindex(corrs.mean(axis=1).sort_values().index)

    print(corrs)

    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]

    # melt the dataframe
    corrs = corrs.reset_index().melt(
        id_vars="index", var_name="post_training", value_name="r2"
    )

    print(corrs)

    # return

    set_latex_font()
    fig, ax = plt.subplots(figsize=(1.8, 1.5))

    # Group by "post_training" and plot the correlation values using matplotlib
    for m, group in corrs.groupby("post_training", sort=False):
        ax.scatter(
            group["r2"],
            group["index"],
            edgecolors="white",
            linewidths=0.5,
            label=method_rename.get(m, m),
            color=method_colors.get(m, "tab:blue"),
            # marker="+",
            # marker=".",
        )

    n_metrics = corrs["index"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics - 1, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    ax.set_ylim(-0.5, n_metrics - 0.5)

    plt.rc("legend", title_fontsize=7)
    plt.rc("legend", fontsize=7)
    # plt.rc("legend", borderaxespad=0.01)
    # plt.rc("legend", handlelength=1.6)
    # plt.rc("legend", labelspacing=0.2)
    plt.rc("legend", handletextpad=0.1)
    # plt.rc("legend", columnspacing=1.5)

    # ax.set(xlabel=r"Correlation $r^2$ to\\recalibration gain")
    ax.set(
        xlabel=r"$\begin{array}{c}\text{b. Correlation } r^2 \\[-1mm] \text{to recalibration gain}\end{array}$"
    )
    legend = ax.legend(loc="lower right", title="Recalibration", frameon=True)
    # legend.get_frame().set_edgecolor("none")  # No border
    # legend.get_frame().set_facecolor("white")  # White background
    # legend.get_frame().set_alpha(0.7)  # Set transparency to 0.3
    legend.get_frame().set_facecolor("none")  # Transparent background
    # legend.get_frame().set_facecolor("white")  # White background
    # legend.get_frame().set_alpha(0.7)  # Set transparency to 0.3
    legend.get_frame().set_edgecolor("black")  # Black border
    legend.get_frame().set_alpha(0.3)  # Set border transparency to 0.3
    legend.get_frame().set_linewidth(0.5)  # Set border transparency to 0.3
    # legend.get_frame().set_boxstyle("Square", pad=0)
    # set legend frame border
    # ax.legend(frameon=True, title="Recalibration")
    ax.xaxis.set_label_coords(0.5, -0.15)

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # sns.scatterplot(data=corrs, x="r2", y="index", hue="post_training", ax=ax)
    save_fig(fig, out, "recal_correlation", pad_inches=0.01)


@pytest.mark.parametrize(
    "post_training",
    [
        "finetuning_sigmoid",
        # "stack_hgb",
        # "stack_rf",
        # "glar_leaves5",
    ],
)
def test_fig_gain_post_training_vs_metrics(out, inp, post_training):
    """FIG3"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "none"
    # post_training = "finetuning_sigmoid"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # print(df)

    # Select a post-training method
    df = df.xs(post_training, axis=0, level="post_training")

    print(df)
    print(df.columns)

    print(
        df[["abs_u_test_emp_mean", "diff_u_test_emp_mean", "diff_norm_u_test_emp_mean"]]
    )

    # x = "abs_ECE"
    # m = post_training

    # color = "tab:green"
    # color = "#96BB7C"
    # color = "#C64756"
    color = "darkgreen"
    # color = [0.91262605, 0.52893336, 0.40749715]
    # color = [0.87661509, 0.37429964, 0.35939659]
    # color = [0.75861834, 0.25356035, 0.40663694]
    # color = [0.64795375, 0.22217149, 0.43330852]
    # color = [0.48308419, 0.18079663, 0.43787408]
    # color = [0.92907237, 0.68878959, 0.50411509]
    # color = [0.90666382, 0.49008006, 0.38813773]

    alpha = 0.9

    def plot_one(m, x):
        x_label = f"a. {metrics_rename.get(x, x)}"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(2.3, 1.5))

        y = "diff_norm_u_test_emp_mean"

        # if x == "abs_RCL_RGL":
        #     df_x = df["abs_RCL"] + df["abs_RGL_UB"]
        # else:

        if x in ["abs_t", "abs_dt"]:
            # group _df by t and average the values of x and y
            _df = df[[x, y, "abs_RCL_RGL_mean_ratio"]].groupby("t").mean()
        else:
            _df = df
        # else:
        v = _df["abs_RCL_RGL_mean_ratio"]

        _df_x = _df[x]
        _df_y = _df[y]

        # get cmap
        # cmap = "Reds_r"

        # v = _df["abs_RCL"] / _df["abs_RCL_RGL_mean"]
        # v = _df["abs_RGL_mean"] / _df["abs_RCL_RGL_mean"]

        # Define the original colormap
        original_cmap = plt.cm.get_cmap("Reds")

        # Define the range for the new colormap (from 0% to 90%)
        # new_cmap_range = np.linspace(0.0, 0.9, 256)
        new_cmap_range = np.linspace(0.1, 1.0, 256)

        # Extract the colors from the original colormap within the defined range
        new_colors = original_cmap(new_cmap_range)

        # Create a new LinearSegmentedColormap with the extracted colors
        new_cmap = colors.LinearSegmentedColormap.from_list(
            "truncated_reds", new_colors
        )

        # get norm
        norm = plt.Normalize(0, 1)

        # get scalar mappable
        sm = plt.cm.ScalarMappable(norm=norm, cmap=new_cmap)

        # get color based on sm and value of _df_x
        c = sm.to_rgba(v)
        cax = add_colorbar(sm, ax)
        # cax.set_ylim(0, 1)

        # add colorbar based on sm and of the same time as the ax
        # cbar = plt.colorbar(sm, ax=ax)
        cax.set_ylabel(
            rf"$\frac{{{metrics_rename['abs_RGL_mean'][1:-1]}}}{{{metrics_rename['abs_RCL_RGL_mean'][1:-1]}}}$"
            # r"$\frac{R_{\mathrm{GL}}^{\mathrm{mean}}}{(R_{\mathrm{CL}} + R_{\mathrm{GL}}^{\mathrm{mean}})}$"
            # r"$\frac{R_{\mathrm{CL}}}{(R_{\mathrm{CL}} + R_{\mathrm{GL}}^{\mathrm{mean}})}$"
        )

        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.5,
            # color=[0.29408557, 0.13721193, 0.38442775],
            color=c,
            alpha=alpha,
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x:
            # inset Axes....
            x1, x2, y1, y2 = (
                -0.005,
                0.05,
                -0.005,
                0.05,
            )  # subregion of the original image
            axins = ax.inset_axes(
                [0.55, 0.05, 0.4, 0.4],
                xlim=(x1, x2),
                ylim=(y1, y2),
                xticklabels=[],
                yticklabels=[],
            )
            axins.scatter(
                _df_x,
                _df["diff_norm_u_test_emp_mean"],
                edgecolors="white",
                linewidths=0.5,
                color=c,
                alpha=alpha,
            )
            axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(xlabel=x_label, ylabel=f"{method_rename.get(m, m)} gain")
        # ax.set(xlabel=x_label, ylabel=f"Post-training gain ({method_rename.get(m, m)})")
        # Add spearman correlation r on the upper center of the axis
        ha = "left"
        va = "top"
        xpos = 0.04
        ypos = 0.96
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            f"Correlation: $r^2 = {np.square(r.correlation):.2f}$",
            ha=ha,
            va=va,
            transform=ax.transAxes,
        )

        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.0),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )
        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(1, 0),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(0.97, -0.03),
            textcoords="axes fraction",
            ha="left",
            va="top",
        )

        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.01)

    for x in [
        # "abs_auc",
        # "abs_acc",
        # "abs_ECE",
        # "abs_MCE",
        # "abs_RMSCE",
        # "abs_CL",
        # "abs_RCL",
        # "abs_L",
        # "abs_RGL_UB",
        # "abs_RGL_LB",
        # "abs_RGL_mean",
        # "abs_RCL_RGL_UB",
        # "abs_RCL_RGL_LB",
        "abs_RCL_RGL_mean",
        # "abs_event_rate_dist",
    ]:
        plot_one(post_training, x)


@pytest.mark.parametrize(
    "post_training",
    [
        "finetuning_sigmoid",
        # "stack_hgb",
        # "stack_rf",
        # "glar_leaves5",
    ],
)
def test_fig_gain_post_training_vs_metrics2(out, inp, post_training):
    """FIG3"""
    # ref_partitioner = "depth10"  # doesn't affect the result
    # ref_post_training = "recal_isotonic"
    # # post_training = "finetuning_sigmoid"
    # auc_threshold = 0.75
    # # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    # t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # # t_target = np.linspace(0.01, 0.99, 11)
    # # t_target = None
    # # x, x_label = against

    # dirpath_utility = Path(inp).parent / "utility_table"
    # dirpath_init = Path(inp).parent / "regret_table"

    # df, df_residual = merge_tables(
    #     dirpath_utility,
    #     dirpath_init,
    #     ref_post_training,
    #     ref_partitioner,
    #     t_target=t_target,
    #     auc_threshold=auc_threshold,
    # )


    ref_partitioner = "depth10"
    # ref_partitioner = "unconstrained"
    ref_post_training = "none"
    # post_training = "finetuning_sigmoid"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp) / "utility_table"
    dirpath_init = Path(inp) / "regret_table"

    # print(dirpath_utility)
    # print(dirpath_init)
    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # print(df)

    # Select a post-training method
    df = df.xs(post_training, axis=0, level="post_training")

    print(df)
    print(df.columns)

    print(
        df[["abs_u_test_emp_mean", "diff_u_test_emp_mean", "diff_norm_u_test_emp_mean"]]
    )

    # x = "abs_ECE"
    # m = post_training

    # color = "tab:green"
    # color = "#96BB7C"
    # color = "#C64756"
    color = "darkgreen"
    # color = [0.91262605, 0.52893336, 0.40749715]
    # color = [0.87661509, 0.37429964, 0.35939659]
    # color = [0.75861834, 0.25356035, 0.40663694]
    # color = [0.64795375, 0.22217149, 0.43330852]
    # color = [0.48308419, 0.18079663, 0.43787408]
    # color = [0.92907237, 0.68878959, 0.50411509]
    # color = [0.90666382, 0.49008006, 0.38813773]

    alpha = 0.9

    def plot_one(m, x):
        x_label = "$\widehat{\mathcal{R}}^{CL}_{f,t^*}+\widehat{\mathcal{R}}^{GL}_f$"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(2.3, 1.5))

        y = "diff_norm_u_test_emp_mean"

        # if x == "abs_RCL_RGL":
        #     df_x = df["abs_RCL"] + df["abs_RGL_UB"]
        # else:

        if x in ["abs_t", "abs_dt"]:
            # group _df by t and average the values of x and y
            _df = df[[x, y, "abs_RCL_RGL_mean_ratio"]].groupby("t").mean()
        else:
            _df = df
        # else:
        v = _df["abs_RCL_RGL_mean_ratio"]

        _df_x = -_df[x]
        _df_y = -_df[y]

        # get cmap
        # cmap = "Reds_r"

        # v = _df["abs_RCL"] / _df["abs_RCL_RGL_mean"]
        # v = _df["abs_RGL_mean"] / _df["abs_RCL_RGL_mean"]

        # Define the original colormap
        original_cmap = plt.cm.get_cmap("Reds")

        # Define the range for the new colormap (from 0% to 90%)
        # new_cmap_range = np.linspace(0.0, 0.9, 256)
        new_cmap_range = np.linspace(0.1, 1.0, 256)

        # Extract the colors from the original colormap within the defined range
        new_colors = original_cmap(new_cmap_range)

        # Create a new LinearSegmentedColormap with the extracted colors
        new_cmap = colors.LinearSegmentedColormap.from_list(
            "truncated_reds", new_colors
        )

        # get norm
        norm = plt.Normalize(0, 1)

        # get scalar mappable
        sm = plt.cm.ScalarMappable(norm=norm, cmap=new_cmap)

        # get color based on sm and value of _df_x
        c = sm.to_rgba(v)
        cax = add_colorbar(sm, ax)
        # cax.set_ylim(0, 1)

        # add colorbar based on sm and of the same time as the ax
        # cbar = plt.colorbar(sm, ax=ax)
        cax.set_ylabel(
            r"$\frac{\widehat{\mathcal{R}}^{GL}_{f}}{\widehat{\mathcal{R}}^{CL}_{f,t^*}+\widehat{\mathcal{R}}^{GL}_{f}}$"
            # r"$\frac{R_{\mathrm{GL}}^{\mathrm{mean}}}{(R_{\mathrm{CL}} + R_{\mathrm{GL}}^{\mathrm{mean}})}$"
            # r"$\frac{R_{\mathrm{CL}}}{(R_{\mathrm{CL}} + R_{\mathrm{GL}}^{\mathrm{mean}})}$"
        )

        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.5,
            # color=[0.29408557, 0.13721193, 0.38442775],
            color=c,
            alpha=alpha,
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x:
            # inset Axes....
            x1, x2, y1, y2 = (
                -0.05,
                0.005,
                -0.05,
                0.005,
            )  # subregion of the original image
            axins = ax.inset_axes(
                [0.55, 0.05, 0.4, 0.4],
                xlim=(x1, x2),
                ylim=(y1, y2),
                xticklabels=[],
                yticklabels=[],
            )
            axins.scatter(
                _df_x,
                -_df["diff_norm_u_test_emp_mean"],
                edgecolors="white",
                linewidths=0.5,
                color=c,
                alpha=alpha,
            )
            axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(xlabel=x_label, ylabel=f"{method_rename.get(m, m)} gain")
        # ax.set(xlabel=x_label, ylabel=f"Post-training gain ({method_rename.get(m, m)})")
        # Add spearman correlation r on the upper center of the axis
        ha = "left"
        va = "top"
        xpos = 0.04
        ypos = 0.96
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            f"$r^2 = {np.square(r.correlation):.2f}$",
            ha=ha,
            va=va,
            transform=ax.transAxes,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="black", alpha=0.8)
        )

        ax.annotate(
            r"$\times\!\Lambda_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.04),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )
        ax.annotate(
            r"$\times\!\Lambda_{\!\!\Delta}$",
            xy=(1, 0),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(0.99, -0.03),
            textcoords="axes fraction",
            ha="left",
            va="top",
        )

        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.01)

    for x in [
        # "abs_auc",
        # "abs_acc",
        # "abs_ECE",
        # "abs_MCE",
        # "abs_RMSCE",
        # "abs_CL",
        # "abs_RCL",
        # "abs_L",
        # "abs_RGL_UB",
        # "abs_RGL_LB",
        # "abs_RGL_mean",
        # "abs_RCL_RGL_UB",
        # "abs_RCL_RGL_LB",
        "abs_RCL_RGL_mean",
        # "abs_event_rate_dist",
    ]:
        plot_one(post_training, x)


@pytest.mark.parametrize(
    "post_training",
    [
        # "recal_isotonic",
        "finetuning_sigmoid",
        # "stack_hgb",
        # "stack_rf",
        # "glar_leaves5",
    ],
)
@pytest.mark.parametrize(
    "ref_post_training",
    [
        "none",
        "recal_isotonic",
    ],
)
def test_fig_gain_post_training_vs_t(out, inp, post_training, ref_post_training):
    """"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    # ref_post_training = "none"
    auc_threshold = 0.75
    t_target = None

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method
    df = df.xs(post_training, axis=0, level="post_training")

    def plot_one(m, x):
        x_label = metrics_rename.get(x, x)
        y = "diff_norm_u_test_emp_mean"
        # y = "abs_RGL_mean"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(3, 2))

        _df = df[[x, y]].groupby("t").mean()

        _df_x = _df[x]
        _df_y = _df[y]

        ax.plot(
            _df_x,
            _df_y,
            # edgecolors="white",
            # linewidths=0.5,
            # alpha=alpha,
            color="black",
            marker=".",
            markeredgecolor="white",
            markeredgewidth=0.5,
        )

        # ax.axvline(df["abs_E[Y]"].mean(), color="black", ls="--", lw=0.5)

        if ref_post_training == "none":
            ax.set(
                xlabel=x_label, ylabel=f"Post-training gain ({method_rename.get(m, m)})"
            )

        elif ref_post_training == "recal_isotonic":
            ax.set(
                xlabel=x_label,
                ylabel="Excess gain over recalibration",
            )

        save_fig(fig, out, r=ref_post_training, y=post_training, x=x)

    for x in [
        "abs_t",
        "abs_dt",
    ]:
        plot_one(post_training, x)


@pytest.mark.parametrize(
    "ref_post_training",
    [
        # "none",
        "recal_isotonic",
    ],
)
@pytest.mark.parametrize(
    "method_type",
    [
        # "both",
        "recal_only",
        "nonrecal_only",
    ],
)
def test_fig_gain_post_training_vs_t2(out, inp, ref_post_training, method_type):
    """"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    # ref_post_training = "none"
    auc_threshold = 0.75
    t_target = None

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    def plot_one(m, x, ax, c):
        x_label = metrics_rename.get(x, x)
        y = "diff_norm_u_test_emp_mean"
        # y = "abs_RGL_mean"

        # Select a post-training method
        _df = df.xs(m, axis=0, level="post_training")
        print(m, _df[x].groupby("t").count().iloc[0])
        # get subdf where t == 0.009901
        # _df = _df[_df["t"] == 0.009901]
        print(_df)
        # get first unique value of index level t
        t0 = _df.index.get_level_values("t").unique()[0]
        subdf = _df.xs(t0, level="t")
        print(subdf)
        _df = _df[[x, y]].groupby("t").mean()

        _df_x = _df[x]
        _df_y = _df[y]

        ax.plot(
            _df_x,
            _df_y,
            # edgecolors="white",
            # linewidths=0.5,
            # alpha=alpha,
            # color="black",
            # marker=".",
            # markeredgecolor="white",
            # markeredgewidth=0.1,
            # markersize=2,
            label=method_rename.get(m, m),
            color=c,
        )

        # ax.axvline(df["abs_E[Y]"].mean(), color="black", ls="--", lw=0.5)

        if ref_post_training == "none":
            ax.set(xlabel=x_label, ylabel="Post-training gain")

        elif ref_post_training == "recal_isotonic":
            ax.set(
                xlabel=x_label,
                ylabel="Excess gain over isotonic recalibration",
            )

        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.02),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )

    methods = [
        "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        "glar_leaves5",
        "glart_both2_leaves5",
    ]

    cycler = itertools.cycle(plt.cm.tab10.colors)
    # Add black color as first color of the cycle
    cycler = itertools.chain(["black"], cycler)
    colors = {m: next(cycler) for m in methods}

    title = "Post-training"
    if method_type == "recal_only":
        methods = [
            "recal_isotonic",
            "recal_sigmoid",
            "recal_hist15",
            "recal_scalbin15",
            "recal_metacal_mis05",
        ]
        title = "Recalibration"
    elif method_type == "nonrecal_only":
        methods = [
            "finetuning_sigmoid",
            "stack_hgb",
            "stack_rf",
            # "glar_leaves5",
            "glart_both2_leaves5",
        ]

    # Remove ref_post_training from the list of methods if it is in the list
    if ref_post_training in methods:
        methods.remove(ref_post_training)

    set_latex_font()
    for x in [
        "abs_t",
        # "abs_dt",
    ]:
        fig, ax = plt.subplots(figsize=(2.5, 2.5))

        for m in methods:
            plot_one(m, x, ax, colors[m])

        add_legend(ax, title=title)

        # if ref_post_training != "none":
        ax.axhline(0, color="black", ls="--", lw=0.5)

        save_fig(fig, out, r=ref_post_training, x=x, type=method_type)


@pytest.mark.parametrize(
    "ref_post_training",
    [
        "none",
        # "recal_isotonic",
    ],
)
@pytest.mark.parametrize(
    "method_type",
    [
        # "both",
        "recal_only",
        # "nonrecal_only",
    ],
)
def test_fig_gain_post_training_vs_event_rate(out, inp, ref_post_training, method_type):
    """"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    auc_threshold = 0.75
    t_target = None

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    def plot_one(m, x, ax, c):
        x_label = metrics_rename.get(x, x)
        y = "diff_norm_u_test_emp_mean"

        # Select a post-training method
        _df = df.xs(m, axis=0, level="post_training")

        _df_x = _df[x]
        _df_y = _df[y]

        # Compute LOWESS curve
        lowess = sm.nonparametric.lowess(
            _df_y, _df_x, frac=0.2
        )  # frac controls the smoothing parameter

        # Plot LOWESS curve
        ax.plot(lowess[:, 0], lowess[:, 1], color=c, label=method_rename.get(m, m))

        # ax.scatter(
        #     _df_x,
        #     _df_y,
        #     # edgecolors="white",
        #     # linewidths=0.5,
        #     # alpha=alpha,
        #     # color="black",
        #     marker=".",
        #     markeredgecolor="white",
        #     markeredgewidth=0.1,
        #     markersize=2,
        #     label=method_rename.get(m, m),
        #     color=c,
        # )

        # ax.axvline(df["abs_E[Y]"].mean(), color="black", ls="--", lw=0.5)

        if ref_post_training == "none":
            ylabel = (
                "Recalibration gain"
                if method_type == "recal_only"
                else "Post-training gain"
            )
            ax.set(xlabel=x_label, ylabel=ylabel)

        elif ref_post_training == "recal_isotonic":
            ax.set(
                xlabel=x_label,
                ylabel="Excess gain over isotonic recalibration",
            )

    methods = [
        "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        "glar_leaves5",
    ]

    cycler = itertools.cycle(plt.cm.tab10.colors)
    # Add black color as first color of the cycle
    cycler = itertools.chain(["black"], cycler)
    colors = {m: next(cycler) for m in methods}

    title = "Post-training"
    if method_type == "recal_only":
        methods = [
            "recal_isotonic",
            "recal_sigmoid",
            "recal_hist15",
            "recal_scalbin15",
            "recal_metacal_mis05",
        ]
        title = "Recalibration"
    elif method_type == "nonrecal_only":
        methods = [
            "finetuning_sigmoid",
            "stack_hgb",
            "stack_rf",
            "glar_leaves5",
        ]

    # Remove ref_post_training from the list of methods if it is in the list
    if ref_post_training in methods:
        methods.remove(ref_post_training)

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.5, 2.5))

    x = "abs_event_rate_dist"
    # x = "abs_event_rate2_dist"

    for m in methods:
        plot_one(m, x, ax, colors[m])

    add_legend(ax, title=title)

    # if ref_post_training != "none":
    ax.axhline(0, color="black", ls="--", lw=0.5)

    save_fig(fig, out, r=ref_post_training, x=x, type=method_type)


@pytest.mark.parametrize(
    "ref_post_training",
    [
        "none",
        # "recal_isotonic",
    ],
)
def test_check_result_count(out, inp, ref_post_training):
    """"""
    ref_partitioner = "leaves5"  # doesn't affect the result
    # ref_post_training = "none"
    m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = None

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=None,
    )

    # print(df)
    # return

    _df = df.xs(m, axis=0, level="post_training")
    print(m, _df.reset_index()[["t"]].groupby("t").count().iloc[0])
    # get subdf where t == 0.009901
    # _df = _df[_df["t"] == 0.009901]
    # print(_df)
    # get first unique value of index level t
    t0 = _df.index.get_level_values("t").unique()[0]
    subdf = _df.xs(t0, level="t")
    # temp set num max row printed to 1000
    with pd.option_context("display.max_rows", 1000):
        print(subdf)

    # Count the number of non nans values in the columns "post_training_fit_cpu_time" per level "m" of df
    print(df_residual.columns)
    print(df_residual)
    count = (
        df_residual[["abs_post_training_fit_cpu_time"]]
        .reset_index()
        .groupby("post_training")
        .count()
    )
    print(count)


def test_fig_correlation_post(out, inp):
    """FIG3b"""
    ref_partitioner = "unconstrained"
    ref_post_training = "none"
    m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method

    def get_corr(m, x):
        _df = df.xs(m, axis=0, level="post_training")
        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        r = scipy.stats.spearmanr(_df_x, _df_y)
        return np.square(r.correlation)

    methods = [
        # "recal_isotonic",
        # "recal_sigmoid",
        # "recal_hist15",
        # "recal_scalbin15",
        # "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        "glart_both2_leaves5",
        # "glar_leaves5",
        # "glar_unconstrained",
    ]

    corrs = []

    for m in methods:
        corr = {}
        for x in [
            "abs_auc",
            "abs_acc",
            "abs_ECE",
            "abs_MCE",
            "abs_RMSCE",
            "abs_CL",
            "abs_RCL",
            "abs_L",
            # "abs_RGL_UB",
            # "abs_RGL_LB",
            # "abs_RGL_mean",
            # "abs_RCL_RGL_UB",
            # "abs_RCL_RGL_LB",
            "abs_RCL_RGL_mean",
        ]:
            corr[x] = get_corr(m, x)

        corrs.append(pd.DataFrame(corr, index=[m]))

    # merge the list of series into a dataframe by repeating the series name
    # corrs = pd.concat(corrs.values(), keys=corrs.keys(), axis=0)
    corrs = pd.concat(corrs, axis=0)

    # set index col name to "post_training"
    corrs.index.name = "post_training"

    # transpose
    corrs = corrs.T

    # SOrt the rows by ascending average values of the columns
    corrs = corrs.reindex(corrs.mean(axis=1).sort_values().index)

    print(corrs)

    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]

    # melt the dataframe
    corrs = corrs.reset_index().melt(
        id_vars="index", var_name="post_training", value_name="r2"
    )

    print(corrs)

    # return

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.7, 1.5))

    # Group by "post_training" and plot the correlation values using matplotlib
    for m, group in corrs.groupby("post_training", sort=False):
        ax.scatter(
            group["r2"],
            group["index"],
            edgecolors="white",
            linewidths=0.5,
            label=method_rename.get(m, m),
            color=method_colors.get(m, "tab:blue"),
            # marker="+",
            # marker=".",
        )

    n_metrics = corrs["index"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    ax.set_ylim(-0.5, n_metrics - 0.5)

    # ax.set(xlabel=r"Correlation $r^2$ to gain of post-training")
    xlabel = ax.set_xlabel(r"b. Correlation $r^2$ to gain of post-training")
    ax.xaxis.set_label_coords(0.4, -0.15)
    xlabel.set_x(0.44)
    ax.legend(loc="lower right", title="Post-training")

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # sns.scatterplot(data=corrs, x="r2", y="index", hue="post_training", ax=ax)
    save_fig(fig, out, "recal_correlation", pad_inches=0.01, ref_gl=ref_partitioner)


def test_fig_correlation_post2(out, inp):
    """FIG3b"""
    ref_partitioner = "leaves5"
    # ref_partitioner = "unconstrained"
    ref_post_training = "none"
    m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method

    def get_corr(m, x):
        _df = df.xs(m, axis=0, level="post_training")
        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        r = scipy.stats.spearmanr(_df_x, _df_y)
        return np.square(r.correlation)

    methods = [
        # "recal_isotonic",
        # "recal_sigmoid",
        # "recal_hist15",
        # "recal_scalbin15",
        # "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        # "glar_leaves5",
        "glart_both2_leaves5",
        # "glar_unconstrained",
    ]

    corrs = []

    for m in methods:
        corr = {}
        for x in [
            "abs_auc",
            "abs_acc",
            "abs_ECE",
            "abs_MCE",
            "abs_RMSCE",
            "abs_CL",
            "abs_RCL",
            "abs_L",
            # "abs_RGL_UB",
            # "abs_RGL_LB",
            # "abs_RGL_mean",
            # "abs_RCL_RGL_UB",
            # "abs_RCL_RGL_LB",
            "abs_RCL_RGL_mean",
        ]:
            corr[x] = get_corr(m, x)

        corrs.append(pd.DataFrame(corr, index=[m]))

    # merge the list of series into a dataframe by repeating the series name
    # corrs = pd.concat(corrs.values(), keys=corrs.keys(), axis=0)
    corrs = pd.concat(corrs, axis=0)

    # set index col name to "post_training"
    corrs.index.name = "post_training"

    # transpose
    corrs = corrs.T

    # SOrt the rows by ascending average values of the columns
    corrs = corrs.reindex(corrs.mean(axis=1).sort_values().index)

    print(corrs)

    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]

    # melt the dataframe
    corrs = corrs.reset_index().melt(
        id_vars="index", var_name="post_training", value_name="r2"
    )

    print(corrs)

    # return

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.7, 1.5))

    # Group by "post_training" and plot the correlation values using matplotlib
    for m, group in corrs.groupby("post_training", sort=False):
        ax.scatter(
            group["r2"],
            group["index"],
            edgecolors="white",
            linewidths=0.5,
            label=method_rename.get(m, m),
            color=method_colors.get(m, "tab:blue"),
            # marker="+",
            # marker=".",
        )

    n_metrics = corrs["index"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    ax.set_ylim(-0.5, n_metrics - 0.5)

    # ax.set(xlabel=r"Correlation $r^2$ to gain of post-training")
    xlabel = ax.set_xlabel(r"b. Correlation $r^2$ to gain of post-training")
    xlabel.set_x(0.44)
    legend = ax.legend(loc="lower right", title="Post-training", frameon=True)
    legend.get_frame().set_facecolor("none")  # Transparent background
    legend.get_frame().set_edgecolor("black")  # Black border
    legend.get_frame().set_alpha(0.3)  # Set border transparency to 0.3
    legend.get_frame().set_linewidth(0.5)  # Set border transparency to 0.3
    ax.xaxis.set_label_coords(0.4, -0.15)

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # sns.scatterplot(data=corrs, x="r2", y="index", hue="post_training", ax=ax)
    save_fig(fig, out, "recal_correlation", pad_inches=0.01, ref_gl=ref_partitioner)


@pytest.mark.parametrize(
    "post_training",
    [
        "finetuning_sigmoid",
        # "stack_hgb",
        # "stack_rf",
        # "glar_leaves5",
    ],
)
def test_fig_gain_post_training_vs_metrics_excess(out, inp, post_training):
    """FIG2"""
    ref_partitioner = "leaves5"
    # ref_partitioner = "unconstrained"
    ref_post_training = "recal_isotonic"
    # post_training = "finetuning_sigmoid"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    _df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method
    _df = _df.xs(post_training, axis=0, level="post_training")

    color = "darkgreen"

    alpha = 0.9

    def plot_one(m, x):
        x_label = f"a. {metrics_rename.get(x, x)}"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(2.2, 1.5))

        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.5,
            color=method_colors.get(m, "tab:blue"),
            alpha=alpha,
            label=f"{method_rename.get(m, m)}",
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x or "abs_RGL" in x:
            # # inset Axes....
            # x1, x2, y1, y2 = (
            #     -0.005,
            #     0.05,
            #     -0.005,
            #     0.05,
            # )  # subregion of the original image
            # axins = ax.inset_axes(
            #     [0.55, 0.05, 0.4, 0.4],
            #     xlim=(x1, x2),
            #     ylim=(y1, y2),
            #     xticklabels=[],
            #     yticklabels=[],
            # )
            # axins.scatter(
            #     _df_x,
            #     _df["diff_norm_u_test_emp_mean"],
            #     edgecolors="white",
            #     linewidths=0.5,
            #     color=color,
            #     alpha=alpha,
            # )
            # axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            # ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(
            xlabel=x_label,
            # ylabel=
        )
        # label = ax.set_ylabel('Y-axis Label')
        label = ax.set_ylabel(r"Gain over recalibration")
        label.set_y(0.4)
        # ax.set_title(f"{method_rename.get(m, m)}")
        # Add spearman correlation r on the upper center of the axis
        ha = "left"
        va = "top"
        xpos = 0.04
        ypos = 0.96
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            f"Correlation: $r^2 = {np.square(r.correlation):.2f}$",
            ha=ha,
            va=va,
            transform=ax.transAxes,
        )
        # ax.legend(loc="lower right", title="Post-training", framealpha=1)

        # set legend and frame alpha =1 with square box and no shadow
        # ax.legend(
        #     loc="lower right",
        #     title="Post-training",
        #     framealpha=1,
        #     fancybox=False,
        #     shadow=False,
        # )

        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)

        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.0),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )
        ax.annotate(
            r"$\times\!U_{\!\!\Delta}$",
            xy=(1, 0),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(0.97, -0.03),
            textcoords="axes fraction",
            ha="left",
            va="top",
        )

        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.01)

    for x in [
        # "abs_auc",
        # "abs_acc",
        # "abs_ECE",
        # "abs_MCE",
        # "abs_RMSCE",
        # "abs_CL",
        # "abs_RCL",
        # "abs_L",
        # "abs_RGL_UB",
        # "abs_RGL_LB",
        "abs_RGL_mean",
        # "abs_RCL_RGL_UB",
        # "abs_RCL_RGL_LB",
        # "abs_RCL_RGL_mean",
    ]:
        plot_one(post_training, x)


@pytest.mark.parametrize(
    "post_training",
    [
        "finetuning_sigmoid",
        # "stack_hgb",
        # "stack_rf",
        # "glar_leaves5",
    ],
)
def test_fig_gain_post_training_vs_metrics_excess2(out, inp, post_training):
    """FIG2"""
    ref_partitioner = "depth10"
    # ref_partitioner = "unconstrained"
    ref_post_training = "recal_sigmoid"
    # post_training = "finetuning_sigmoid"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp) / "utility_table"
    dirpath_init = Path(inp) / "regret_table"

    # print(dirpath_utility)
    # print(dirpath_init)
    _df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method
    _df = _df.xs(post_training, axis=0, level="post_training")

    color = "darkgreen"

    alpha = 0.9
    _df.to_csv(Path(inp)/"before_plot.csv")
    def plot_one(m, x):
        x_label = "$\widehat{\mathcal{R}}^{GL}$"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(1.7, 1.5))

        _df_x = - _df[x]
        _df_y =  - _df["diff_norm_u_test_emp_mean"]
        print("y axis", _df_y)
        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.3,
            color=method_colors.get(m, "tab:blue"),
            alpha=alpha,
            label=f"{method_rename.get(m, m)}",
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x or "abs_RGL" in x:
            # # inset Axes....
            # x1, x2, y1, y2 = (
            #     -0.005,
            #     0.05,
            #     -0.005,
            #     0.05,
            # )  # subregion of the original image
            # axins = ax.inset_axes(
            #     [0.55, 0.05, 0.4, 0.4],
            #     xlim=(x1, x2),
            #     ylim=(y1, y2),
            #     xticklabels=[],
            #     yticklabels=[],
            # )
            # axins.scatter(
            #     _df_x,
            #     _df["diff_norm_u_test_emp_mean"],
            #     edgecolors="white",
            #     linewidths=0.5,
            #     color=color,
            #     alpha=alpha,
            # )
            # axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            # ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(-0.035, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(
            xlabel=x_label,
            # ylabel=
        )
        # label = ax.set_ylabel('Y-axis Label')
        label = ax.set_ylabel("Cost reduction (FT vs recal.)", fontsize=6.8)
        label.set_y(0.4)
        # ax.set_title(f"{method_rename.get(m, m)}")
        # Add spearman correlation r on the upper center of the axis
        ha = "right"
        va = "bottom"
        xpos = 0.96
        ypos = 0.06
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            f"$r^2\!=\!{np.square(r.correlation):.2f}$",
            # # rf"$\begin{{array}}{{r}}\text{{Correlation:}}\\[-1mm] r^2 = {np.square(r.correlation):.2f}\end{{array}}$",
            # f"Correlation: $r^2 = {np.square(r.correlation):.2f}$",
            ha=ha,
            va=va,
            fontsize=7,
            transform=ax.transAxes,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="black", alpha=0.8)
        )
        # ax.legend(loc="lower right", title="Post-training", framealpha=1)

        # set legend and frame alpha =1 with square box and no shadow
        # ax.legend(
        #     loc="lower right",
        #     title="Post-training",
        #     framealpha=1,
        #     fancybox=False,
        #     shadow=False,
        # )

        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)

        ax.annotate(
            r"$\times\!\Lambda_{\!\!\Delta}$",
            xy=(0, 1),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(-0.02, 1.0),
            textcoords="axes fraction",
            ha="right",
            va="center",
        )
        ax.annotate(
            r"$\times\!\Lambda_{\!\!\Delta}$",
            xy=(1, 0),
            xycoords="axes fraction",
            # xytext=(-0.07, 1.07),
            xytext=(1, -0.15),
            textcoords="axes fraction",
            ha="right",
            va="top",
        )

        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.01)

    for x in [
        # "abs_auc",
        # "abs_acc",
        # "abs_ECE",
        # "abs_MCE",
        # "abs_RMSCE",
        # "abs_CL",
        # "abs_RCL",
        # "abs_L",
        # "abs_RGL_UB",
        # "abs_RGL_LB",
        "abs_RGL_mean",
        # "abs_RCL_RGL_UB",
        # "abs_RCL_RGL_LB",
        # "abs_RCL_RGL_mean",
    ]:
        plot_one(post_training, x)


def test_fig_correlation_post_excess(out, inp):
    """FIG2b"""
    ref_partitioner = "leaves5"
    ref_post_training = "recal_isotonic"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method

    def get_corr(m, x):
        _df = df.xs(m, axis=0, level="post_training")
        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        r = scipy.stats.spearmanr(_df_x, _df_y)
        return np.square(r.correlation)

    methods = [
        # "recal_isotonic",
        # "recal_sigmoid",
        # "recal_hist15",
        # "recal_scalbin15",
        # "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        "glart_both2_leaves5",
        # "glar_leaves5",
        # "glar_unconstrained",
    ]

    corrs = []

    for m in methods:
        corr = {}
        for x in [
            "abs_auc",
            "abs_acc",
            "abs_ECE",
            "abs_MCE",
            "abs_RMSCE",
            "abs_CL",
            "abs_RCL",
            "abs_L",
            # "abs_RGL_UB",
            # "abs_RGL_LB",
            "abs_RGL_mean",
            # "abs_RCL_RGL_UB",
            # "abs_RCL_RGL_LB",
            # "abs_RCL_RGL_mean",
        ]:
            corr[x] = get_corr(m, x)

        corrs.append(pd.DataFrame(corr, index=[m]))

    # merge the list of series into a dataframe by repeating the series name
    # corrs = pd.concat(corrs.values(), keys=corrs.keys(), axis=0)
    corrs = pd.concat(corrs, axis=0)

    # set index col name to "post_training"
    corrs.index.name = "post_training"

    # transpose
    corrs = corrs.T

    # SOrt the rows by ascending average values of the columns
    corrs = corrs.reindex(corrs.mean(axis=1).sort_values().index)

    print(corrs)

    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]

    # melt the dataframe
    corrs = corrs.reset_index().melt(
        id_vars="index", var_name="post_training", value_name="r2"
    )

    print(corrs)

    # return

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.7, 1.5))

    # Group by "post_training" and plot the correlation values using matplotlib
    for m, group in corrs.groupby("post_training", sort=False):
        ax.scatter(
            group["r2"],
            group["index"],
            edgecolors="white",
            linewidths=0.5,
            label=method_rename.get(m, m),
            # marker="+",
            # marker=".",
            color=method_colors.get(m, "tab:blue"),
        )

    n_metrics = corrs["index"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    ax.set_ylim(-0.5, n_metrics - 0.5)

    ax.set(xlabel=r"b. Correlation $r^2$ to gain over recalibration")
    ax.xaxis.set_label_coords(0.45, -0.15)
    ax.legend(loc="lower right", title="Post-training")

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # sns.scatterplot(data=corrs, x="r2", y="index", hue="post_training", ax=ax)
    save_fig(fig, out, "correlation", pad_inches=0.01, ref_gl=ref_partitioner)


def test_fig_correlation_post_excess2(out, inp):
    """FIG2b"""
    ref_partitioner = "leaves5"
    ref_post_training = "recal_isotonic"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    # Select a post-training method

    def get_corr(m, x):
        _df = df.xs(m, axis=0, level="post_training")
        _df_x = _df[x]
        _df_y = _df["diff_norm_u_test_emp_mean"]

        r = scipy.stats.spearmanr(_df_x, _df_y)
        return np.square(r.correlation)

    methods = [
        # "recal_isotonic",
        # "recal_sigmoid",
        # "recal_hist15",
        # "recal_scalbin15",
        # "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        # "glar_leaves5",
        "glart_both2_leaves5",
        # "glar_unconstrained",
    ]

    corrs = []

    for m in methods:
        corr = {}
        for x in [
            "abs_auc",
            "abs_acc",
            "abs_ECE",
            "abs_MCE",
            "abs_RMSCE",
            "abs_CL",
            "abs_RCL",
            "abs_L",
            # "abs_RGL_UB",
            # "abs_RGL_LB",
            "abs_RGL_mean",
            # "abs_RCL_RGL_UB",
            # "abs_RCL_RGL_LB",
            # "abs_RCL_RGL_mean",
        ]:
            corr[x] = get_corr(m, x)

        corrs.append(pd.DataFrame(corr, index=[m]))

    # merge the list of series into a dataframe by repeating the series name
    # corrs = pd.concat(corrs.values(), keys=corrs.keys(), axis=0)
    corrs = pd.concat(corrs, axis=0)

    # set index col name to "post_training"
    corrs.index.name = "post_training"

    # transpose
    corrs = corrs.T

    # SOrt the rows by ascending average values of the columns
    corrs = corrs.reindex(corrs.mean(axis=1).sort_values().index)

    print(corrs)

    corrs.index = [metrics_rename.get(k, k) for k in corrs.index]

    # melt the dataframe
    corrs = corrs.reset_index().melt(
        id_vars="index", var_name="post_training", value_name="r2"
    )

    print(corrs)

    # return

    set_latex_font()
    fig, ax = plt.subplots(figsize=(1.8, 1.6))

    # Group by "post_training" and plot the correlation values using matplotlib
    for m, group in corrs.groupby("post_training", sort=False):
        ax.scatter(
            group["r2"],
            group["index"],
            edgecolors="white",
            linewidths=0.5,
            label=method_rename.get(m, m),
            # marker="+",
            # marker=".",
            color=method_colors.get(m, "tab:blue"),
        )

    n_metrics = corrs["index"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    ax.set_ylim(-0.5, n_metrics - 0.5)

    plt.rc("legend", title_fontsize=7)
    plt.rc("legend", fontsize=7)
    # plt.rc("legend", borderaxespad=0.01)
    plt.rc("legend", handlelength=1.6)
    # plt.rc("legend", labelspacing=0.2)
    plt.rc("legend", handletextpad=0.1)
    # plt.rc("legend", columnspacing=1.5)

    ax.set(
        xlabel=r"$\begin{array}{c}\text{b. Correlation } r^2 \text{ to}\\[-1mm] \text{gain over recalibration}\end{array}$"
    )
    # ax.set(xlabel=r"Correlation $r^2$ to gain over recalibration")
    legend = ax.legend(loc="lower right", title="Post-training", frameon=True)
    legend.get_frame().set_facecolor("none")  # Transparent background
    legend.get_frame().set_edgecolor("black")  # Black border
    legend.get_frame().set_alpha(0.3)  # Set border transparency to 0.3
    legend.get_frame().set_linewidth(0.5)  # Set border transparency to 0.3
    ax.xaxis.set_label_coords(0.5, -0.15)

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # sns.scatterplot(data=corrs, x="r2", y="index", hue="post_training", ax=ax)
    save_fig(fig, out, "correlation", pad_inches=0.01, ref_gl=ref_partitioner)


def test_fig_gain_magnitudes(out, inp):
    """FIG4

    Compare the magnitude of a better method of recalibration over a simple recalibration,
    to a good method of post training.
    """
    # ref_partitioner = "unconstrained"
    ref_partitioner = "leaves5"
    ref_post_training = "recal_isotonic"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    methods = [
        # "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        # "finetuning_sigmoid",
        # "stack_hgb",
        # "stack_rf",
        # "glar_leaves5",
        # "glar_unconstrained",
    ]

    fig, ax = plt.subplots(figsize=(3, 2))

    x = "abs_RGL_mean"

    m_ref = "stack_hgb"

    for m in methods:
        _df1 = df.xs(m_ref, axis=0, level="post_training")
        _df2 = df.xs(m, axis=0, level="post_training")

        print(_df1.shape, _df2.shape)

        # print(_df1)
        # print(_df2)

        # concatenate on the axis 1 and align the index so that unmatched get dropped
        _df = pd.concat([_df1, _df2], keys=[m_ref, m], axis=1, join="inner")

        # How the get the indices that were remove on the above concat
        idx = _df1.index.difference(_df2.index)

        print(idx)
        # continue

        print(_df)
        # continue

        ax.scatter(
            _df[(m_ref, "diff_norm_u_test_emp_mean")],
            _df[(m, "diff_norm_u_test_emp_mean")],
            label=method_rename.get(m, m),
            edgecolors="white",
            linewidths=0.5,
            alpha=0.8,
        )

    ax.axhline(0, color="black", ls="--", lw=0.5, zorder=10)

    ax.set(
        xlabel=f"Gain of {method_rename.get(m_ref, m_ref)} over isotonic recalibration",
        ylabel="Gain over isotonic recalibration",
    )
    ax.legend(loc="lower right", title="Recalibration method")

    save_fig(fig, out, "magnitudes", ref_gl=ref_partitioner)


def test_fig_gain_magnitudes2(out, inp):
    """FIG4

    Compare the magnitude of a better method of recalibration over a simple recalibration,
    to a good method of post training.
    """
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "recal_isotonic"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    methods = [
        # "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        "glar_leaves5",
        # "glar_unconstrained",
    ]

    fig, ax = plt.subplots(figsize=(3, 3))

    x = "abs_RCL_RGL_mean"

    m_ref = "stack_hgb"

    cycler = itertools.cycle(plt.cm.tab10.colors)
    for m in methods:
        _df = df.xs(m, axis=0, level="post_training")
        # _df1 = df.xs(m_ref, axis=0, level="post_training")
        # _df2 = df.xs(m, axis=0, level="post_training")

        # print(_df1.shape, _df2.shape)

        # print(_df1)
        # print(_df2)

        # concatenate on the axis 1 and align the index so that unmatched get dropped
        # _df = pd.concat([_df1, _df2], keys=[m_ref, m], axis=1, join="inner")

        # How the get the indices that were remove on the above concat
        # idx = _df1.index.difference(_df2.index)

        # print(idx)
        # continue

        print(_df)
        # continue

        # iterate over the tab10 colors from matplotlib

        color = next(cycler)

        ax.scatter(
            _df[x],
            _df["diff_norm_u_test_emp_mean"],
            label=method_rename.get(m, m),
            edgecolors="white",
            linewidths=0.5,
            alpha=0.8,
            color=color,
        )

    ax.axhline(0, color="black", ls="--", lw=0.5, zorder=10)

    ax.set(
        xlabel=metrics_rename.get(x, x),
        ylabel="Gain over isotonic recalibration",
    )
    # ax.legend(loc="lower right", title="Post-training method")
    add_legend(ax, ncol=2, title="Post-training method")

    save_fig(fig, out, "magnitudes2")


def plot_fig_gain_magnitudes3(inp, ref_partitioner, ax: plt.Axes | None = None):
    rel = True
    filter = None
    # ref_partitioner = "unconstrained"
    # ref_partitioner = "leaves5"
    ref_post_training = "recal_sigmoid"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp) / "utility_table"
    dirpath_init = Path(inp) / "regret_table"
    
    print(dirpath_utility)
    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    methods = [
        # "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        # "glar_leaves5",
        "glart_both2_leaves5",
        # "glar_unconstrained",
    ]

    # select the subdf that has its index level "post_training" in methods
    df = df[df.index.get_level_values("post_training").isin(methods)]

    if ax is None:
        fig, ax = plt.subplots(figsize=(3, 2))
    else:
        fig = ax.get_figure()

    # x = "diff_norm_u_test_emp_mean"
    x = "rel_diff_u_test_emp_mean" if rel else "diff_u_test_emp_mean"

    df = df.reset_index()

    # Get the min of df col "diff_norm_u_test_emp_mean" after removing post_training == "metacal"
    vmin = df.query("post_training == 'recal_sigmoid'")[x].min()

    # Compute the median of "abs_RCL" column
    rcl_med = df.query(f"post_training == '{ref_post_training}'")["abs_RCL"].median()

    if filter == "below":
        # Restrict to indices of df belowe the median of "abs_RCL"
        df = df[df["abs_RCL"] < rcl_med]
        title = "Low miscalibration ($R_{\mathrm{CL}} < \mathrm{median}$)"
    elif filter == "above":
        # Restrict to indices of df above the median of "abs_RCL"
        df = df[df["abs_RCL"] >= rcl_med]
        title = "High miscalibration ($R_{\mathrm{CL}} \geq \mathrm{median}$)"
    elif filter is not None:
        raise ValueError("Invalid filter")
    else:
        title = None

    sns.boxplot(
        data=df,
        x=x,
        y="post_training",
        order=methods,
        hue="post_training",
        hue_order=methods,
        # hue_order=palette_dict.keys(),
        ax=ax,
        palette=method_colors,
        width=0.9,
        gap=0.2,
        formatter=lambda x: method_rename.get(x, x),
        showfliers=False,
        legend=False,
    )
    # sns.violinplot(
    #     data=df,
    #     x=x,
    #     y="post_training",
    #     order=methods,
    #     hue="post_training",
    #     hue_order=methods,
    #     ax=ax,
    #     palette=method_colors,
    #     # split=True,  # Split the violins for each hue level
    #     inner=None,  # Remove inner bars/points
    #     linewidth=1,  # Adjust the linewidth of the edges
    #     scale="width",  # Scale the width of each violin plot
    #     formatter=lambda x: method_rename.get(x, x),
    #     legend=False,
    # )
    sns.stripplot(
        data=df,
        x=x,
        y="post_training",
        order=methods,
        hue="post_training",
        hue_order=methods,
        # hue_order=palette_dict.keys(),
        # color="black",
        palette=method_colors,
        # palette=method_rename_colors,
        jitter=0.15,
        alpha=0.6,
        size=2,
        ax=ax,
        legend=False,
        dodge=False,
        edgecolor="white",
        linewidth=0.25,
        # width=0.5,
        formatter=lambda x: method_rename.get(x, x),
    )

    ax.set_xlim(vmin, None)

    n_metrics = df["post_training"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics - 1, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    if rel:
        # Turn the current xlabels by applying x -> f"$\times${x}""
        # ax.set_xticklabels([f"$\\times${x.get_text()}" for x in ax.get_xticklabels()])
        xline = 1
    else:
        xline = 0

    ref_color = method_colors[ref_post_training]
    ax.axvline(
        xline,
        color=ref_color,
        ls="--",
        lw=1,
        zorder=10,
        label=method_rename.get(ref_post_training, ref_post_training).capitalize()
        + " recalibration",
    )

    if title is not None:
        ax.set_title(title)

    if rel:
        xlabel = "a. Relative gain to Isotonic recalibration"

    else:
        xlabel = "a. Gain over isotonic recalibration"

    ax.set(
        # xlabel=xlabel,
        ylabel=None,  # "Post-training method",
    )
    ax.set_xlabel(xlabel, fontsize=8)

    ax.legend(loc="upper right", title=None)

    n_methods = len(methods)
    l_tail = 0.1
    dh = 1.0 / n_methods
    lw = 0.5
    pos_arrow = -0.35

    # compute widthB of the bracket based on fraction 0.1 of the axis
    # widthB =

    # n_const = 6
    # w_const = 0.1
    # ax.annotate(
    #     "Recalibration",
    #     xy=(pos_arrow, n_const * dh),
    #     xytext=(pos_arrow - l_tail, n_const * dh),
    #     xycoords="axes fraction",
    #     # fontsize=fs,
    #     ha="center",
    #     va="center",
    #     bbox=None,  # dict(boxstyle='square', fc='white'),
    #     arrowprops=dict(arrowstyle=f"-[, widthB={w_const}, lengthB=0.5", lw=lw),
    #     rotation=90,
    # )
    # n_const = 2
    # w_const = 5.1
    # ax.annotate(
    #     "Other post-training",
    #     xy=(pos_arrow, n_const * dh),
    #     xytext=(pos_arrow - l_tail, n_const * dh),
    #     xycoords="axes fraction",
    #     # fontsize=fs,
    #     ha="center",
    #     va="center",
    #     bbox=None,  # dict(boxstyle='square', fc='white'),
    #     arrowprops=dict(arrowstyle=f"-[, widthB={w_const}, lengthB=0.5", lw=lw),
    #     rotation=90,
    # )

    # get fontsize of the yticklabels
    fontsize = ax.yaxis.get_ticklabels()[0].get_fontsize()

    spine_pos = -0.20
    tip_pos = -0.17
    # spine_pos = -0.35
    # tip_pos = -0.32
    add_label_band(
        ax,
        3.45,
        -0.5,
        "Recalibration",
        spine_pos=spine_pos,
        tip_pos=tip_pos,
        fontsize=fontsize,
    )
    add_label_band(
        ax, 7.5, 3.55, "Other", spine_pos=spine_pos, tip_pos=tip_pos, fontsize=fontsize
    )

    # Get current xticklabels
    # xticklabels = ax.get_xticklabels()
    # # Turn the current xlabels by applying x -> f"$\times${x}""
    # print(xticklabels)
    # for x in xticklabels:
    #     print(x.get_text())
    # ax.set_xticklabels([f"$\\times${x.get_text()}" for x in xticklabels])

    # def percent_format(x, pos):
    #     return f"{x}%"

    # # Apply the custom formatter to the x-axis
    # ax.xaxis.set_major_formatter(ticker.FuncFormatter(percent_format))
    # ax.xaxis.set_major_locator(ticker.MultipleLocator(0.01))
    # ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.0025))
    ax.xaxis.set_major_formatter(r"$\times${x:.1f}")

    # ax.set_xticklabels([f"$\\times${x.get_text()}" for x in ax.get_xticklabels()])

    # create ax.text with fontsize = fontsize and text = lol
    # ax.text(0.5, 1.05, "LOL", fontsize=fontsize)

    return fig


@pytest.mark.parametrize(
    "filter",
    [
        None,
        # "below",
        # "above",
    ],
)
def test_fig_gain_magnitudes3(out, inp, filter):
    """FIG4

    Compare the magnitude of a better method of recalibration over a simple recalibration,
    to a good method of post training.
    """
    # ref_partitioner = "leaves5"
    ref_partitioner = "unconstrained"
    set_latex_font()
    fig, ax = plt.subplots(figsize=(3, 2))
    fig = plot_fig_gain_magnitudes3(inp, ref_partitioner=ref_partitioner, ax=ax)
    save_fig(fig, out, "magnitudes3", pad_inches=0.01, ref_gl=ref_partitioner)


def plot_fig_times(inp, ax: plt.Axes | None = None, ylabels: bool = True):
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "recal_isotonic"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    methods = [
        # "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        # "glar_leaves5",
        "glart_both2_leaves5",
        # "glar_unconstrained",
    ]
    # select the subdf that has its index level "post_training" in methods
    df_residual = df_residual[
        df_residual.index.get_level_values("post_training").isin(methods)
    ]

    if ax is None:
        fig, ax = plt.subplots(figsize=(3, 2))
    else:
        fig = ax.get_figure()

    # x = "diff_norm_u_test_emp_mean"
    # x = "rel_diff_u_test_emp_mean" if rel else "diff_u_test_emp_mean"
    # x = "abs_post_training_fit_cpu_time"
    # x = "abs_post_training_total_cpu_time"
    x = "rel_diff_post_training_total_cpu_time"

    print(df_residual)
    print(df_residual.columns)

    # # temporarily print up 10 10 columns
    # with pd.option_context("display.max_columns", 10):
    #     print(
    #         df_residual[
    #             [
    #                 "abs_post_training_total_cpu_time",
    #                 "diff_post_training_total_cpu_time",
    #                 "rel_diff_post_training_total_cpu_time",
    #             ]
    #         ]
    #     )
    # return

    # drop rows with nans in col x
    df_residual = df_residual.dropna(subset=[x])

    print(df_residual[x])
    print(df_residual[x].min(), df_residual[x].max())

    # get the minimum value of level t
    t0 = df_residual.index.get_level_values("t").unique()[0]
    df_residual = df_residual.xs(t0, level="t")

    df = df.reset_index()
    df_residual = df_residual.reset_index()

    # Get the min of df col "diff_norm_u_test_emp_mean" after removing post_training == "metacal"
    # vmin = df.query("post_training == 'recal_sigmoid'")[x].min()

    # Compute the median of "abs_RCL" column
    # rcl_med = df.query(f"post_training == '{ref_post_training}'")["abs_RCL"].median()

    title = None

    sns.boxplot(
        data=df_residual,
        x=x,
        y="post_training",
        order=methods,
        hue="post_training",
        hue_order=methods,
        # hue_order=palette_dict.keys(),
        ax=ax,
        palette=method_colors,
        width=0.9,
        gap=0.2,
        formatter=lambda x: method_rename.get(x, x),
        showfliers=False,
        legend=False,
        log_scale=True,
    )
    sns.stripplot(
        data=df_residual,
        x=x,
        y="post_training",
        order=methods,
        hue="post_training",
        # hue_order=palette_dict.keys(),
        # color="black",
        palette=method_colors,
        # palette=method_rename_colors,
        jitter=0.15,
        alpha=0.6,
        size=2,
        ax=ax,
        legend=False,
        dodge=False,
        edgecolor="white",
        linewidth=0.25,
        # width=0.5,
        formatter=lambda x: method_rename.get(x, x),
    )

    # ax.set_xlim(vmin, None)

    ref_color = method_colors[ref_post_training]

    ax.axvline(1, color=ref_color, ls="--", lw=1, zorder=10)
    # fontsize = ax.yaxis.get_ticklabels()[0].get_fontsize()

    # get default ticklabels fontsize
    fontsize = plt.rcParams["xtick.labelsize"]

    # Data coordinate for x, axis coordinate for y
    # x_data = 1  # x in data coordinates
    # y_axis = 0  # y in axis coordinates (0.5 is the middle of the y-axis)

    # # Annotate with mixed coordinates
    # ax.annotate(
    #     "Isotonic Regression",
    #     xy=(x_data, y_axis),  # Position of the annotation
    #     xycoords=("data", "axes fraction"),  # 'data' for x, 'axes fraction' for y
    #     textcoords="offset points",  # Position of the text relative to the annotation point
    #     xytext=(0, -18),  # Offset of the text from the annotated point
    #     ha="center",  # Horizontal alignment
    #     arrowprops=dict(
    #         arrowstyle="->",
    #         lw=0.5,
    #         shrinkA=0,  # No shrink from the start
    #         shrinkB=0,
    #         color=ref_color,
    #     ),  # No shrink from the end),  # Arrow properties
    #     fontsize=fontsize,
    #     color=ref_color,
    #     # set color of the arrow
    # )

    # enable x ticks minor locator to log scale
    ax.xaxis.set_minor_locator(plt.LogLocator(base=10.0, subs=[1.0]))

    n_metrics = df["post_training"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics - 1, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    # ax.xaxis.set_major_formatter(r"$\times${x}")
    # class CustomTicker(LogFormatterSciNotation):
    #     def __call__(self, x, pos=None):
    #         return LogFormatterSciNotation.__call__(self, x, pos)

    # get xticks
    # xticks = ax.get_xticks()
    # xticklabels = ax.get_xticklabels()

    # # print(xticks)
    # # print(xticklabels)
    # # for each xticklabel in xticklabels, extract its representation
    # for xticklabel in xticklabels:
    #     print(xticklabel.get_text())
    #     # edit text of xticklabel
    #     xticklabel.set_text(f"$\\times${xticklabel.get_text()}")
    #     print(xticklabel)
    # print(xticklabels)
    # # xticklabels = ax.get_xticklabels()
    # ax.set_xticks(xticks)  # Ensure the tick positions remain the same
    # ax.set_xticklabels(xticklabels)
    # # ax.set_xticklabels(new_labels)

    def custom_log_formatter(x, pos):
        # Convert the tick value to the log10 base
        log_val = np.log10(x)
        return rf"$\times 10^{{{int(log_val)}}}$"

    # Set the custom formatter
    ax.xaxis.set_major_formatter(FuncFormatter(custom_log_formatter))

    # ax.xaxis.set_major_formatter(LogFormatter)

    # ax.set_xticklabels([f"$\\times${x.get_text()}" for x in ax.get_xticklabels()])

    # if rel:
    #     # Turn the current xlabels by applying x -> f"$\times${x}""
    #     ax.set_xticklabels([f"$\\times${x.get_text()}" for x in ax.get_xticklabels()])
    #     xline = 1
    # else:
    #     xline = 0

    # ax.axvline(xline, color="black", ls="--", lw=0.5, zorder=10)

    if title is not None:
        ax.set_title(title)

    xlabel = "b. Relative CPU time (s)"

    ax.set(
        # xlabel=xlabel,
        ylabel=None,
        # ylabel="Post-training method",
    )
    ax.set_xlabel(xlabel, fontsize=8)

    if not ylabels:
        # disable yticklabels
        ax.set_yticklabels([])

    n_methods = len(methods)
    l_tail = 0.1
    dh = 1.0 / n_methods
    lw = 0.5
    pos_arrow = -0.35

    # compute widthB of the bracket based on fraction 0.1 of the axis
    # widthB =

    # n_const = 6.5
    # w_const = 5.75
    # ax.annotate(
    #     "Recalibration",
    #     xy=(pos_arrow, n_const * dh),
    #     xytext=(pos_arrow - l_tail, n_const * dh),
    #     xycoords="axes fraction",
    #     # fontsize=fs,
    #     ha="center",
    #     va="center",
    #     bbox=None,  # dict(boxstyle='square', fc='white'),
    #     arrowprops=dict(arrowstyle=f"-[, widthB={w_const}, lengthB=0.5", lw=lw),
    #     rotation=90,
    # )
    # n_const = 2
    # w_const = 4.5
    # ax.annotate(
    #     "?",
    #     xy=(pos_arrow, n_const * dh),
    #     xytext=(pos_arrow - l_tail, n_const * dh),
    #     xycoords="axes fraction",
    #     # fontsize=fs,
    #     ha="center",
    #     va="center",
    #     bbox=None,  # dict(boxstyle='square', fc='white'),
    #     arrowprops=dict(arrowstyle=f"-[, widthB={w_const}, lengthB=0.5", lw=lw),
    #     rotation=90,
    # )

    if ylabels:
        add_label_band(
            ax,
            3.45,
            -0.5,
            "Recalibration",
            spine_pos=-0.35,
            tip_pos=-0.32,
            fontsize=fontsize,
        )
        add_label_band(
            ax, 7.5, 3.55, "Other", spine_pos=-0.35, tip_pos=-0.32, fontsize=fontsize
        )

    # y_positions = [
    #     4,
    #     5,
    #     6,
    # ]  # Assuming these are the indices of the categories to bracket

    # # Define the y coordinates for the bracket
    # y_min = min(y_positions) - 0.2
    # y_max = max(y_positions) + 0.2

    # # # Create a bracket shape using a FancyBboxPatch
    # # bracket = patches.FancyBboxPatch(
    # #     (11, y_min), 0.2, y_max - y_min, boxstyle="bracket", ec="black"
    # # )

    # # # Add the bracket to the plot
    # # ax.add_patch(bracket)

    # # # Add the label for the bracket
    # # ax.text(12, (y_min + y_max) / 2, "Grouped Categories", va="center", ha="left")

    # # ax.set_xscale("log")

    # # Create the bracket path
    # bracket_path = [
    #     (11, y_min),  # Left-bottom
    #     (11.5, y_min),  # Horizontal to the right
    #     (11.5, y_max),  # Vertical to the top
    #     (11, y_max),  # Horizontal to the left
    # ]

    # # Create a custom path patch for the bracket
    # bracket = patches.Polygon(bracket_path, closed=False, edgecolor="black")

    # # Add the bracket to the plot
    # ax.add_patch(bracket)

    return fig


@pytest.mark.parametrize(
    "filter",
    [
        None,
        # "below",
        # "above",
    ],
)
def test_fig_times(out, inp, filter):
    fig, ax = plt.subplots(figsize=(3, 2))
    fig = plot_fig_times(inp, ax=ax, ylabels=False)
    save_fig(fig, out, "times", f=filter, pad_inches=0.01)


def test_fig_gain_times(out, inp):
    # create a figure with 2 axes horiwontally aligned, with hspace = 0.5
    # fig, axs = plt.subplots(1, 2, figsize=(6, 2), gridspec_kw={"wspace": 0.0})
    # fig, ax1 = plt.subplots(figsize=(5, 1.5))
    fig, ax1 = plt.subplots(figsize=(6, 1.5))
    ax2 = add_axis("right", ax1, size="100%", pad=0.03, align_lim=False, transfer=False)
    # get ax1 and ax2
    # ax1, ax2 = axs
    ref_partitioner = "depth10"
    plot_fig_gain_magnitudes3(inp, ref_partitioner, ax=ax1)
    plot_fig_times(inp, ax=ax2, ylabels=False)
    save_fig(fig, out, "gain_times", pad_inches=0.01)


def test_fig_gain_magnitudes4(out, inp):
    """FIG4

    Compare the magnitude of a better method of recalibration over a simple recalibration,
    to a good method of post training.
    """
    rel = True
    ref_partitioner = "leaves5"  # doesn't affect the result
    ref_post_training = "recal_isotonic"
    # m = "recal_isotonic"
    auc_threshold = 0.75
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    methods = [
        # "recal_isotonic",
        "recal_sigmoid",
        "recal_hist15",
        "recal_scalbin15",
        "recal_metacal_mis05",
        "finetuning_sigmoid",
        "stack_hgb",
        "stack_rf",
        "glar_leaves5",
        # "glar_unconstrained",
    ]

    fig, ax = plt.subplots(figsize=(3, 3))

    # x = "diff_norm_u_test_emp_mean"
    x = "rel_diff_u_test_emp_mean" if rel else "diff_u_test_emp_mean"

    df = df.reset_index()

    # Get the min of df col "diff_norm_u_test_emp_mean" after removing post_training == "metacal"
    vmin = df.query("post_training == 'recal_sigmoid'")[x].min()

    df["abs_RGL_RCL_ratio"] = df["abs_RGL_mean"] / df["abs_RCL"]

    # Compute the median of "abs_RCL" column
    rcl_med = df.query(f"post_training == '{ref_post_training}'")[
        "abs_RGL_RCL_ratio"
    ].median()

    df["abs_RGL_RCL_ratio_low"] = df["abs_RGL_RCL_ratio"] < rcl_med

    def formater(x):
        x = method_rename.get(x, x)

        # d = {
        #     "abs_RCL_low": r"$R_{\mathrm{CL}}$",
        #     "False": r"$R_{\mathrm{CL}} < \mathrm{med}$",
        #     "True": r"$R_{\mathrm{CL}} \geq \mathrm{med}$",
        # }

        # print(x)

        # x = d.get(x, x)
        return x

    g = sns.violinplot(
        data=df,
        x=x,
        y="post_training",
        order=methods,
        hue="abs_RGL_RCL_ratio_low",
        ax=ax,
        formatter=formater,
        scale="width",
        split=True,
        # facet_kws={"legend_out": True},
    )

    labels = [
        r"$\geq \mathrm{med}$",
        r"$< \mathrm{med}$",
    ]

    # title = r"$R_{\mathrm{CL}}$"
    title = r"$\frac{R_{\mathrm{GL}^{\mathrm{mean}}}}{R_{\mathrm{CL}}}$"
    sns.move_legend(ax, "upper right", labels=labels, title=title)

    # sns.boxplot(
    #     data=df,
    #     x=x,
    #     y="post_training",
    #     order=methods,
    #     hue="post_training",
    #     hue_order=methods,
    #     # hue_order=palette_dict.keys(),
    #     ax=ax,
    #     palette="tab10",
    #     width=0.9,
    #     gap=0.2,
    #     formatter=lambda x: method_rename.get(x, x),
    #     showfliers=False,
    #     legend=False,
    # )
    # sns.stripplot(
    #     data=df,
    #     x=x,
    #     y="post_training",
    #     order=methods,
    #     # hue_order=palette_dict.keys(),
    #     color="black",
    #     jitter=0.15,
    #     alpha=0.6,
    #     size=2,
    #     ax=ax,
    #     legend=False,
    #     dodge=True,
    #     edgecolor="white",
    #     linewidth=0.25,
    #     # width=0.5,
    #     formatter=lambda x: method_rename.get(x, x),
    # )

    ax.set_xlim(vmin, None)

    n_metrics = df["post_training"].nunique()
    with KeepLim(ax):
        for i in range(0, n_metrics - 1, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)

    # disable minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())

    if rel:
        # Turn the current xlabels by applying x -> f"$\times${x}""
        ax.set_xticklabels([f"$\\times${x.get_text()}" for x in ax.get_xticklabels()])
        xline = 1
    else:
        xline = 0

    ax.axvline(xline, color="black", ls="--", lw=0.5, zorder=10)

    # if title is not None:
    #     ax.set_title(title)

    if rel:
        xlabel = "Relative gain to isotonic recalibration"

    else:
        xlabel = "Gain over isotonic recalibration"

    ax.set(
        xlabel=xlabel,
        ylabel="Post-training method",
    )

    save_fig(fig, out, "magnitudes4", rel=rel)


def test_fig_gain_vs_metrics(out, inp):
    """FIG Appendix"""
    ref_partitioner = "leaves5"
    ref_post_training = "none"
    auc_threshold = 0.75
    # t_target = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    # t_target1 = np.linspace(0, 1, 102)[2:-2]
    t_target = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975]
    # t_target = np.linspace(0.01, 0.99, 11)
    # t_target = None
    # x, x_label = against

    dirpath_utility = Path(inp).parent / "utility_table"
    dirpath_init = Path(inp).parent / "regret_table"

    df, df_residual = merge_tables(
        dirpath_utility,
        dirpath_init,
        ref_post_training,
        ref_partitioner,
        t_target=t_target,
        auc_threshold=auc_threshold,
    )

    def plot_one(m, x):
        _df = df.xs(post_training, axis=0, level="post_training")
        x_label = metrics_rename.get(x, x)
        y = "diff_norm_u_test_emp_mean"

        set_latex_font()
        fig, ax = plt.subplots(figsize=(2.1, 1.35))

        if x in ["abs_t", "abs_dt"]:
            # group _df by t and average the values of x and y
            _df = _df[[x, y]].groupby("t").mean()

        _df_x = _df[x]
        _df_y = _df[y]

        color = method_colors.get(m, "tab:blue")

        ax.scatter(
            _df_x,
            _df_y,
            edgecolors="white",
            linewidths=0.5,
            color=color,
            label=f"{method_rename.get(m, m)}",
        )

        # Compute LOWESS curve
        lowess = sm.nonparametric.lowess(
            _df_y, _df_x, frac=0.3
        )  # frac controls the smoothing parameter

        # Plot LOWESS curve
        ax.plot(
            lowess[:, 0], lowess[:, 1], color="black", label=method_rename.get(m, m)
        )

        print(_df_x.isna().sum(), _df_y.isna().sum())

        r = scipy.stats.spearmanr(_df_x, _df_y)

        if "abs_RCL" in x:
            # inset Axes....
            x1, x2, y1, y2 = (
                -0.005,
                0.05,
                -0.005,
                0.05,
            )  # subregion of the original image
            axins = ax.inset_axes(
                [0.55, 0.05, 0.4, 0.4],
                xlim=(x1, x2),
                ylim=(y1, y2),
                xticklabels=[],
                yticklabels=[],
            )
            axins.scatter(
                _df_x,
                _df["diff_norm_u_test_emp_mean"],
                edgecolors="white",
                linewidths=0.5,
                color=method_colors.get(m, "tab:blue"),
            )
            axins.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.75, zorder=1)
            ax.indicate_inset_zoom(axins, edgecolor="black")

            with KeepLim(ax):
                ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=1)
            # ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            xmin = min(xlim[0], ylim[0])
            xmax = max(xlim[1], ylim[1])
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)

        ax.set(xlabel=x_label)
        ylabel = ax.set_ylabel(f"Gain of {method_rename.get(m, m)}")
        if m == "recal_scalbin15":
            ylabel.set_y(0.35)
        if m in ["stack_hgb", "finetuning_sigmoid"]:
            ylabel.set_y(0.45)
        # ax.set(xlabel=x_label, ylabel=f"Recalibration gain ({method_rename.get(m, m)})")
        # Add spearman correlation r on the upper center of the axis
        ha = "left"
        va = "top"
        xpos = 0.04
        ypos = 0.96
        if "abs_auc" in x:
            ha = "right"
            xpos = 0.96

        ax.text(
            xpos,
            ypos,
            f"Correlation: $r^2 = {np.square(r.correlation):.2f}$",
            ha=ha,
            va=va,
            transform=ax.transAxes,
        )
        # ax.legend(loc="upper left", title="Recalibration")
        # apply a yshift to the legend
        # ax.legend(loc="upper left", title="Recalibration", bbox_to_anchor=(0, 0.9))
        # ax.set_title(
        #     f"Reference utility: {method_rename.get(ref_post_training, ref_post_training)}"
        # )
        # if "abs_R" in x:
        #     with KeepLim(ax):
        #         ax.plot([-1, 1], [-1, 1], color="black", ls="--", lw=0.5, zorder=-1)
        save_fig(fig, out, r=ref_post_training, y=post_training, x=x, pad_inches=0.05)

    post_trainings = [
        # "recal_isotonic",
        # "recal_sigmoid",
        # "recal_hist15",
        # "recal_scalbin15",
        # "recal_metacal_mis05",
        # "stack_rf",
        # "stack_hgb",
        # "finetuning_sigmoid",
        # "glar_leaves5",
        "glart_both2_leaves5",
        #
        # "stack_logistic",
        # "stack_naive_bayes",
        # "stack_gaussian_process",
        # "glar_unconstrained",
    ]

    against = [
        "abs_acc",
        "abs_auc",
        "abs_CL",
        "abs_ECE",
        "abs_MCE",
        "abs_RMSCE",
        "abs_RCL",
        "abs_L",
        "abs_RCL_RGL_mean",
        #
        # "abs_t",
        # "abs_dt",
        # "abs_GL",
        # "abs_RCL_RGL",
        # "abs_RGL_UB",
        # "abs_RGL_LB",
        # "abs_RGL_mean",
        # "abs_RCL_RGL_LB",
        # "abs_RCL_RGL_UB",
        # "abs_event_rate_dist",
    ]

    for post_training, x in itertools.product(post_trainings, against):
        plot_one(post_training, x)

    # for x in [
    #     # "abs_auc",
    #     # "abs_acc",
    #     # "abs_ECE",
    #     # "abs_MCE",
    #     # "abs_RMSCE",
    #     # "abs_CL",
    #     "abs_RCL",
    #     # "abs_L",
    #     # "abs_event_rate_dist",
    # ]:
    #     plot_one(post_training, x)


model_names = [
    "cnerg1",
    "cnerg2",
    "cnerg3",
    "cnerg4",
    "cnerg5",
    # "fb_roberta1",
    "fb_roberta2",
    # "mistral_instruct",
]
ds_names = [
    # "hate",
    # "hate_merged_en",
    # "hate_merged_en2",
    # "hate_merged_no_en",
    # "hate_merged_large_en",
    # "hate_dyn_gen",
    # "merged_hate_check",
    # "hate_merged_no_en2",
    # "hate_merged_large_no_en",
    # "hate_merged_large",
    "hate_en_tweets",
    "hate_en_speech18",
    "hate_en_speech_off",
    "hate_en_davidson",
    "hate_en_gender",
    "hate_en_frenk",
    "hate_en_check",
    "hate_en_twitter",
    "hate_en_open",
]


@pytest.mark.parametrize("model_name", model_names)
@pytest.mark.parametrize("ds_name", ds_names)
def test_plot_gain_vs_threshold(out, model_name, ds_name):
    ds = ds_registry[ds_name]()
    model = model_registry[model_name]()
    fig = plot_gain_vs_threshold(ds, model)
    save_fig(fig, out, "EU_emp", ds=ds_name, model=model_name)


def plot_gain_vs_threshold(ds: ForwardedMixin, model: PretrainedMixin) -> plt.Figure:
    """Variation of xp_ours_decision_real with the new version of GLAR
    to investigate the differences with the previous version."""
    rs = 0
    finetuned = False

    plt.style.use(["science"])
    plt.rcParams.update(bundles.icml2024())

    # Xp, yp = ds.get_real_X_y()
    (X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

    idx_val = np.concatenate([idx_val1, idx_val2])

    X_val = X[idx_val]
    S_val = S[idx_val]
    y_val = y[idx_val]
    X_val1 = X[idx_val1]
    S_val1 = S[idx_val1]
    y_val1 = y[idx_val1]
    X_val2 = X[idx_val2]
    S_val2 = S[idx_val2]
    y_val2 = y[idx_val2]
    X_test = X[idx_test]
    S_test = S[idx_test]
    y_test = y[idx_test]

    # binwise_fit = False
    binwise_fit = True

    partitioner_est = DecisionTreeRegressor(
        random_state=0,
        # min_samples_leaf=10,
        max_leaf_nodes=5,
    )
    partitioner = glest.Partitioner(
        estimator=partitioner_est,
        predict_method="apply",
        n_bins=10,
        strategy="quantile",
        binwise_fit=binwise_fit,
        verbose=0,
        raise_on_fit_error=True,
    )

    S_test = fit_predict_clf(
        X_train=None,
        y_train=None,
        X_val1=X_val1,
        y_val1=y_val1,
        X_val2=X_val2,
        y_val2=y_val2,
        X_test=X_test,
        classifier=(S_val1, S_val2, S_test),
        rule=None,
    )
    # est2 = DecisionTreeClassifier(random_state=0)
    # S2p_test = fit_predict_clf(
    #     X_train=None,
    #     y_train=None,
    #     X_val1=X_val1,
    #     y_val1=y_val1,
    #     X_val2=X_val2,
    #     y_val2=y_val2,
    #     X_test=X_test,
    #     classifier=est2,
    #     rule=None,
    # )
    # Rp_test = fit_predict_clf(
    #     X_train=None,
    #     y_train=None,
    #     X_val1=X_val1,
    #     y_val1=y_val1,
    #     X_val2=X_val2,
    #     y_val2=y_val2,
    #     X_test=X_test,
    #     classifier=(S_val1, S_val2, S_test),
    #     rule=partitioner,
    # )
    glar = GLAR(partitioner)
    glar.fit(S_val1, y_val1, X_val1, val_data=(S_val2, y_val2, X_val2))
    Rp_test = glar.predict_proba(S_test, X_test)

    glart = GLARThresholded(partitioner, GL_min=0.01, binwise=False)
    glart.fit(S_val1, y_val1, X_val1, val_data=(S_val2, y_val2, X_val2))
    R2p_test = glart.predict_proba(S_test, X_test)

    [SCp_test] = recalibrate_scores(S_val, y_val, [S_test])

    # # Compute metrics for each version
    # def get_metrics(S, X, y, name):
    #     partitioner_est = DecisionTreeClassifier(random_state=0, max_leaf_nodes=5)
    #     partitioner = glest.Partitioner(
    #         partitioner_est, "apply", strategy="quantile", n_bins=15
    #     )
    #     gle = glest.GLEstimator(S, partitioner=partitioner, random_state=0).fit(X, y)
    #     metrics = gle.metrics()
    #     metrics["method"] = name
    #     return metrics

    # L_metrics = [
    #     get_metrics(S_test, X_test, y_test, "raw"),
    #     get_metrics(Rp_test, X_test, y_test, "glar"),
    #     get_metrics(SCp_test, X_test, y_test, "recalibration"),
    # ]

    # df_metrics = pd.DataFrame(L_metrics)

    # rename_cols = {
    #     "method": "Method",
    #     "hf_name": "HuggingFace",
    #     "latent_layer": "Latent Layer",
    # }

    # cols = [
    #     "method",
    #     "GL",
    #     "CL",
    #     "L",
    # ]

    # method_dict = {
    #     "glar": "GLAR",
    #     "recalibration": "Recalibration",
    #     "raw": "Raw",
    # }
    # df_metrics["method"] = df_metrics["method"].replace(method_dict)
    # df_metrics["GL"] = np.clip(df_metrics["GL"], 0, None)

    # df_metrics = df_metrics[cols]

    # df_metrics.rename(
    #     columns=rename_cols,
    #     inplace=True,
    # )

    # df_metrics.to_latex(Path(out) / "metrics.tex", index=False)
    # df_metrics.to_csv(Path(out) / "metrics.csv", index=False)
    # print(out)
    # print(df_metrics)

    U = get_constant_utilty(100)
    t = get_threshold_from_utility(U)

    EU_s = u_emp_from_score(S_test, y_test, t, U, return_action=False)
    EU_r = u_emp_from_score(Rp_test, y_test, t, U, return_action=False)
    EU_r2 = u_emp_from_score(R2p_test, y_test, t, U, return_action=False)
    EU_c = u_emp_from_score(SCp_test, y_test, t, U, return_action=False)

    metrics = pd.DataFrame(
        {
            "t": t,
            "EU_s_emp": EU_s.mean(axis=0),
            "EU_r_emp": EU_r.mean(axis=0),
            "EU_r2_emp": EU_r2.mean(axis=0),
            "EU_c_emp": EU_c.mean(axis=0),
        }
    )

    set_latex_font()

    def plot():
        fig, ax = plt.subplots(figsize=(3, 2))

        EU_s = metrics.EU_s_emp
        EU_r = metrics.EU_r_emp
        EU_r2 = metrics.EU_r2_emp
        EU_c = metrics.EU_c_emp

        def rel(EU):
            # return EU
            return 100 * np.divide(EU - EU_s, EU_s)

        ax.axhline(0, color="black", lw=0.5)

        # ax.plot(
        #     metrics.t,
        #     rel(EU_s),
        #     label=r"$a = 1_{S \geq t}$ (naive)",
        # )
        ax.plot(
            metrics.t,
            rel(EU_r),
            color="tab:red",
            label=r"$\mathrm{GLAR}$",
            # label=r"$a = 1_{R \geq t}$ (ours)",
        )
        ax.axhline(np.mean(rel(EU_r)), color="tab:red", ls="--", lw=0.5)
        ax.plot(
            metrics.t,
            rel(EU_r2),
            color="tab:green",
            label=r"$\mathrm{GLAR} T$",
            # label=r"$a = 1_{R \geq t}$ (ours)",
        )
        ax.axhline(np.mean(rel(EU_r2)), color="tab:green", ls="--", lw=0.5)
        # ax.plot(
        #     metrics.t,
        #     rel(EU_d),
        #     # label=r"$a = 1_{D \geq t}$ (tree)",
        #     # label=r"$a = 1_{D \geq t}$ (tree)",
        # )
        ax.plot(
            metrics.t,
            rel(EU_c),
            color="tab:blue",
            label=r"Recalibration",
            # label=r"$\mathrm{GLAR}$",
            # label=r"$a = 1_{R \geq t}$ (ours)",
        )
        ax.axhline(np.mean(rel(EU_c)), color="tab:blue", ls="--", lw=0.5)
        # ax.plot(metrics.t, metrics.EU_d, label="EU_d")
        # ax.plot(metrics.t, metrics.EU_s_emp, label="EU_s_emp")
        # ax.plot(metrics.t, metrics.EU_r_emp, label="EU_r_emp")
        # ax.plot(metrics.t, metrics.EU_q_emp, label="EU_q_emp")
        # ax.plot(metrics.t, metrics.EU_d_emp, label="EU_d_emp")

        # add_legend(ax, pos="topdown")
        ax.legend()
        ylabel = r"Utility gain vs raw (\%)"
        # ylabel = r"Empirical utility gap to naive (\%)"
        ax.set(
            xlabel=r"Threshold $t^{\star}$ with varying utilities $U$",
            # xlabel="Optimal threshold $t$ derived from the costs",
            ylabel=ylabel,
            # yscale="log",
        )
        # ax.set_ylim(bottom=0)
        return fig

    fig = plot()
    return fig


def evaluate_glar(
    ds: ForwardedMixin,
    model: PretrainedMixin,
    glar_methods: List[Tuple[str, PostTraining, dict]],
) -> plt.Figure:
    """Variation of xp_ours_decision_real with the new version of GLAR
    to investigate the differences with the previous version."""
    rs = 0
    finetuned = False

    plt.style.use(["science"])
    plt.rcParams.update(bundles.icml2024())

    # Xp, yp = ds.get_real_X_y()
    (X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

    idx_val = np.concatenate([idx_val1, idx_val2])

    X_val = X[idx_val]
    S_val = S[idx_val]
    y_val = y[idx_val]
    X_val1 = X[idx_val1]
    S_val1 = S[idx_val1]
    y_val1 = y[idx_val1]
    X_val2 = X[idx_val2]
    S_val2 = S[idx_val2]
    y_val2 = y[idx_val2]
    X_test = X[idx_test]
    S_test = S[idx_test]
    y_test = y[idx_test]

    # binwise_fit = False
    binwise_fit = True

    partitioner_est = DecisionTreeRegressor(
        random_state=0,
        # min_samples_leaf=10,
        max_leaf_nodes=5,
    )
    partitioner = glest.Partitioner(
        estimator=partitioner_est,
        predict_method="apply",
        n_bins=10,
        strategy="quantile",
        binwise_fit=binwise_fit,
        verbose=0,
        raise_on_fit_error=True,
    )

    gle = glest.GLEstimator(S_val1, partitioner=partitioner, random_state=0)
    gle.fit(X_val1, y_val1, test_data=(X_val2, y_val2, S_val2))

    glar = GLAR(partitioner)
    glar.fit(S_val1, y_val1, X_val1, val_data=(S_val2, y_val2, X_val2))
    R_test = glar.predict_proba(S_test, X_test)

    # glart = GLARThresholded(partitioner, GL_min=0.01, binwise=False)
    # glart.fit(S_val1, y_val1, X_val1, val_data=(S_val2, y_val2, X_val2))
    # R2p_test = glart.predict_proba(S_test, X_test)

    [C_test] = recalibrate_scores(S_val, y_val, [S_test])

    U = get_constant_utilty(100)
    t = get_threshold_from_utility(U)

    EU_s = u_emp_from_score(S_test, y_test, t, U, return_action=False)
    EU_c = u_emp_from_score(C_test, y_test, t, U, return_action=False)
    EU_r = u_emp_from_score(R_test, y_test, t, U, return_action=False)

    # metrics = {
    #     "t": t,
    #     "EU_s": EU_s.mean(axis=0),
    #     "EU_r": EU_r.mean(axis=0),
    #     "EU_c": EU_c.mean(axis=0),
    # }

    metrics = {
        "EU_s": EU_s.mean(),
        "EU_r": EU_r.mean(),
        "EU_c": EU_c.mean(),
    }
    metrics.update(gle.metrics())

    for glar_method in glar_methods:
        name, glar_class, glar_params = glar_method

        glar = glar_class(partitioner=partitioner, **glar_params)
        glar.fit(S_val1, y_val1, X_val1, val_data=(S_val2, y_val2, X_val2))
        G_test = glar.predict_proba(S_test, X_test, t=t)

        EU_g = u_emp_from_score(G_test, y_test, t, U, return_action=False)

        metrics.update(
            {
                f"EU_g_{name}": EU_g.mean(),
            }
        )

    return metrics


def test_evaluate_glar_all(out, inp, recompute=False):
    path = Path(save_path(inp, "csv", "table"))
    if not recompute and path.exists():
        df = pd.read_csv(path)
    else:
        df = evaluate_glar_all()
        df.to_csv(path, index=False)

    fig = plot_benchmark_glar_all(df)
    save_fig(fig, out)


def evaluate_glar_all():
    glar_methods = [
        # ("glar_t_all", GLARThresholded, {"GL_min": 0.01, "mode": "all"}),
        # ("glart_rgl_bw", GLARThresholded, {"GL_min": 0.01, "mode": "binwise"}),
        ("glart_rgl_bw2", GLARThresholded, {"GL_min": 0.02, "mode": "binwise"}),
        # ("glart_rgl_both", GLARThresholded, {"GL_min": 0.01, "mode": "both"}),
        ("glart_rgl_both2", GLARThresholded, {"GL_min": 0.02, "mode": "both"}),
        # ("glart_rgl_bw2", GLARThresholded, {"GL_min": 0.02, "mode": "binwise"}),
        # ("glar_t_both", GLARThresholded, {"GL_min": 0.01, "mode": "both"}),
    ]
    rows = []

    def run_one(ds_name, model_name):
        ds = ds_registry[ds_name]()
        model = model_registry[model_name]()
        metrics = evaluate_glar(ds, model, glar_methods)
        metrics.update(
            {
                "ds": ds_name,
                "model": model_name,
            }
        )
        return metrics

    params = list(itertools.product(ds_names, model_names))
    rows = Parallel(n_jobs=-1)(
        delayed(run_one)(ds_name, model_name) for ds_name, model_name in tqdm(params)
    )

    return pd.DataFrame(rows)


def plot_benchmark_glar_all(df: pd.DataFrame) -> plt.Figure:
    n_results = df.shape[0]

    df = df.sort_values("GL")
    df.reset_index(inplace=True)

    # concatenate strs from df["ds"] and df["model"] bu joining with a _
    labels = (
        df["ds"] + "_" + df["model"] + " (GL=" + df["GL"].round(4).astype(str) + ")"
    )

    set_latex_font()
    fig, ax = plt.subplots(1, 1, figsize=(5, 4))

    def _rel(EU):
        # return EU
        return 100 * (np.divide(EU - df["EU_s"], df["EU_s"]))

    def rel(EU):
        return _rel(EU) - _rel(df["EU_c"])

    idx = np.arange(n_results)

    ax.scatter(
        rel(df["EU_c"]),
        idx,
        color="black",
        label="Recalibration",
        marker="o",
        edgecolors="white",
        linewidths=0.5,
    )
    ax.scatter(
        rel(df["EU_r"]),
        idx,
        color="tab:blue",
        label="GLAR",
        marker="o",
        edgecolors="white",
        linewidths=0.5,
    )

    # match all the columns starting with "EU_g_"
    cols = df.columns[df.columns.str.startswith("EU_g_")]

    rename = {
        "EU_g_glart_rgl_bw2": "GLART binwise",
        "EU_g_glart_rgl_both2": "GLART both",
    }

    # start matploltib color cycler after blue and orange
    ax.set_prop_cycle(cycler(color=plt.cm.tab10.colors[1:]))
    for col in cols:
        ax.scatter(
            rel(df[col]),
            idx,
            label=rename.get(col, col),
            marker="o",
            edgecolors="white",
            linewidths=0.5,
        )

    add_legend(ax)
    add_horizontal_bands(ax, n_results)

    ax.set(xlabel="Gain vs isotonic (%)")

    # set yticks to [0, n_results - 1] and set corresponding labels to labels
    ax.set_yticks(range(n_results))
    ax.set_yticklabels(labels, fontsize=4)

    # for i, (name, row) in enumerate(df.iterrows()):
    #     x = np.arange(1, 4)
    #     y = row[1:]
    #     ax.scatter()

    return fig
