import itertools
import os
from itertools import product
from os.path import join

import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.lines as lines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import seaborn as sns
import torch
from calibration._linalg import create_orthonormal_vector
from calibration._plot import (plot_ffstar_1d, plot_frac_pos_vs_scores,
                               plot_score_vs_probas2, set_latex_font)
from calibration._utils import compute_classif_metrics, save_fig, save_path
from calibration.CalibrationExample import CustomUniform, SigmoidExample, Steps
from calibration.xp_nn_calibration._utils import (bin_train_test_split,
                                                  binarize_multiclass_max,
                                                  calibrate_scores,
                                                  compute_calib_metrics,
                                                  grouping_loss_lower_bound)
from calibration.xp_nn_calibration.main import (cluster_evaluate,
                                                cluster_evaluate_max)
from calibration.xp_nn_calibration.networks import (IMAGENET_ConvNeXt,
                                                    IMAGENET_VisionTransformer)
from grouping.xp_grouping._plot import plot_simu
from grouping.xp_grouping.ClassificationExample import (BaseExample,
                                                        Link1DExample,
                                                        SigmoidExample)
from joblib import Parallel, delayed
from matplotlib.colors import LogNorm, Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm import tqdm

from ._plot import plot_metrics
from .gl_induced import CEstimator, estimate_GL_induced


def _compute_metrics(X, y_labels, S, Q):

    def cost_oracle(x, y_labels):
        C_pos = np.full_like(x, 100)
        C_pos[x < 0] = 50
        C_neg = np.full_like(x, 10)
        C_neg[x < 0] = 5
        return y_labels*C_pos + (1-y_labels)*C_neg

    def cost_estimate(x, y_scores):
        C_pos = np.full_like(x, 100)
        C_pos[x < 0] = 50
        C_neg = np.full_like(x, 10)
        C_neg[x < 0] = 5
        return y_scores*C_pos + (1-y_scores)*C_neg

    frac_pos, counts, mean_scores, *_ = cluster_evaluate(X.reshape(-1, 1), y_labels, S, bins=15, clustering='decision_tree', max_clusters_bin=2)

    # Compute metrics
    metrics = compute_classif_metrics(S, y_labels, Q)
    metrics.update(compute_calib_metrics(frac_pos, counts, S, y_labels, bins=15))
    metrics.update({
        'cost_error': abs(np.mean(cost_estimate(X, S)) - np.mean(cost_oracle(X, y_labels))),
    })
    return metrics


@pytest.mark.parametrize('alpha', np.linspace(0, 1, 11))
@pytest.mark.parametrize('bayes_opt', [False, True])
def test_1d_example(alpha, bayes_opt, out):
    n = 100000
    name = 'sin2' if bayes_opt else 'sin'
    ex = CustomUniform(name=name, alpha=alpha)
    X, y_labels = ex.generate_X_y(n)
    S = ex.f(X)
    Q = ex.f_star(X)

    metrics = _compute_metrics(X, y_labels, S, Q)
    metrics.update({
        'name': name,
        'alpha': alpha,
        'bayes_opt': bayes_opt,
        'GL': ex.analytical_gl()[0],
    })
    metrics = pd.DataFrame(metrics, index=[0])

    pytest.buffer_df = pd.concat([pytest.buffer_df, metrics])
    print(pytest.buffer_df)
    pytest.buffer_df.to_csv('1d_example.csv')

    fig = plot_score_vs_probas2(S, y_labels, Q, max_samples=1000,
                                height=2.5, grid_space=0, lim_margin=0.03, ncol=1,
                                plot_first_last_bins=False)
    save_fig(fig, out, n=name, a=alpha, order=['n', 'a'])


@pytest.mark.parametrize('bayes_opt', [
    False,
    True,
])
# @pytest.mark.parametrize('n_steps', list(range(1, 2)) + [None])
@pytest.mark.parametrize('n_steps', list(range(1, 11)) + [None])
def test_steps_example(bayes_opt, n_steps, out):
    n = 100000
    ex = Steps(n_steps=n_steps, dist='uniform', bayes_opt=bayes_opt)
    X, y_labels = ex.generate_X_y(n)
    S = ex.f(X)
    Q = ex.f_star(X)

    metrics = _compute_metrics(X, y_labels, S, Q)
    metrics.update({
        'n_steps': n_steps,
        'bayes_opt': bayes_opt,
        'GL': ex.analytical_gl()[0],
    })
    metrics = pd.DataFrame(metrics, index=[0])

    pytest.buffer_df = pd.concat([pytest.buffer_df, metrics])
    print(pytest.buffer_df)
    pytest.buffer_df.to_csv('steps_example.csv')

    fig = plot_score_vs_probas2(S, y_labels, Q, max_samples=1000,
                                height=2.5, grid_space=0, lim_margin=0.03, ncol=1,
                                plot_first_last_bins=False)
    save_fig(fig, out, nsteps=n_steps, bayes_opt=bayes_opt)


@pytest.mark.parametrize('delta_width', [None, 0] + list(np.logspace(-3, 2, 11)))
@pytest.mark.parametrize('bayes_opt', [False, True])
def test_sigmoid_example(delta_width, out, bayes_opt):
    n = 10000
    d = 2

    w = np.ones(d)
    w_perp = create_orthonormal_vector(w)

    ex = SigmoidExample(w, w_perp, bayes_opt=bayes_opt, delta_width=delta_width, lbd=2)
    X, y_labels = ex.generate_X_y(n)

    S = ex.f(X)
    Q = ex.f_star(X)

    # Apply decision threshold
    y_bayes = Q >= 0.5
    y_pred = S >= 0.5

    # Compute metrics
    metrics = compute_classif_metrics(S, y_labels, Q)
    metrics.update({
        'delta_width': delta_width,
        'bayes_opt': bayes_opt,
    })
    metrics = pd.DataFrame(metrics, index=[0])
    # print(metrics)

    pytest.buffer_df = pd.concat([pytest.buffer_df, metrics])

    print(pytest.buffer_df)

    pytest.buffer_df.to_csv('sigmoid_example.csv')

    fig = plot_score_vs_probas2(S, y_labels, Q, max_samples=1000,
                                height=2.5, grid_space=0, lim_margin=0.03, ncol=1,
                                plot_first_last_bins=False)
    save_fig(fig, out, dw=delta_width)


@pytest.mark.parametrize('bayes_opt', [False, True])
@pytest.mark.parametrize('params', [
    ('steps_example.csv', 'n_steps', True),
    # 'sigmoid_example.csv',
    # ('1d_example.csv', 'alpha'),
])
def test_plot(out, bayes_opt, params):
    set_latex_font()
    filename, var, rev = params
    plt.rc('axes', labelsize=12)
    df = pd.read_csv(filename, index_col=0)
    fig = plot_metrics(df, var=var, bayes_opt=bayes_opt, rev=rev)
    save_fig(fig, out, opt=bayes_opt, f=filename)


@pytest.mark.parametrize('bayes_opt', [
    False,
    True,
])
# @pytest.mark.parametrize('n_steps', list(range(1, 2)) + [None])
@pytest.mark.parametrize('n_steps', list(range(1, 21)) + [None])
def test_fig_examples(n_steps, bayes_opt, out):
    n = 100000
    figsize = (2, 2.2)
    max_samples = 1000
    # ex = Steps(10, dist='gaussian')
    ex = Steps(n_steps, dist='uniform', bayes_opt=bayes_opt)

    plt.rc('legend', fontsize=10)
    plt.rc('legend', title_fontsize=12)
    plt.rc('legend', handletextpad=0.5)
    plt.rc('legend', columnspacing=1.3)
    # plt.rc('legend', borderpad=0)
    plt.rc('legend', borderpad=0.2)
    plt.rc('legend', borderaxespad=0.2)
    # plt.rc('legend', borderaxespad=0.1)
    plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=9)
    plt.rc('ytick', labelsize=9)
    plt.rc('axes', labelsize=12)

    m = 100
    x_min = -2
    x_max = 2
    XX = np.linspace(x_min, x_max, m)
    Q = ex.f_star(XX)
    S = ex.f(XX)
    P = ex.p(XX)

    fig = plot_ffstar_1d(S, Q, P, x_min=x_min, x_max=x_max, figsize=figsize,
                         lw=1.5, disc_gap=0)#, bbox_to_anchor=(0, 0), loc='lower left')

    ax = fig.axes[1]
    ax_up = fig.axes[0]
    ax.set_xticks([-2, 2])
    # ax.set_xticks([-2, 0, 2])
    ax.set_xticklabels(['$-2$', '2'])
    # ax.set_xticklabels(['$-2$', '', '2'])
    delta = 0.04
    ax.set_xlabel(None)
    ax.annotate('$X$', xy=(0.5, -delta), xytext=(0.5, -delta),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['axes.labelsize'],
                )

    save_fig(fig, out, n='X', bayes_opt=bayes_opt, nsteps=n_steps, order=['n', 'nsteps'], pad_inches=0.02)

    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.2)
    # plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    # plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=0.8)
    # plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=10.5)#10)#13)

    X, y_labels = ex.generate_X_y(n=n)

    y_scores = ex.f(X)
    y_true_probas = ex.f_star(X)
    fig = plot_score_vs_probas2(y_scores, y_labels, y_true_probas, max_samples=max_samples,
                                height=2.5, grid_space=0, lim_margin=0.03, ncol=1,
                                plot_first_last_bins=False)

    ax = fig.axes[0]

    xlabel = 'Confidence score $S$'
    ylabel = 'True probability $Q$'
    plot_xlabel = True
    plot_ylabel = True

    ax.set_xlabel(None)
    ax.set_ylabel(None)

    delta = 0.04
    if plot_xlabel:
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(['0', '', '1'])
        ax.annotate(xlabel, xy=(0.5, -delta), xytext=(0.5, -delta),
                    xycoords='axes fraction', ha='center', va='top',
                    fontsize=plt.rcParams['axes.labelsize'],
                    )
    else:
        ax.xaxis.set_ticklabels([])

    delta = 0.02
    if plot_ylabel:
        ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
        ax.set_yticklabels(['0', '', '', '', '1'])
        # ax.set_yticklabels(['0', '25', '50', '75', '100'])
        ax.annotate(ylabel, xy=(-delta, 0.5), xytext=(-delta, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )
    else:
        ax.yaxis.set_ticklabels([])
    save_fig(fig, out, n='QS', bayes_opt=bayes_opt, nsteps=n_steps, order=['n', 'nsteps'], pad_inches=0.02)



@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic'
])
@pytest.mark.parametrize('net', [
    # IMAGENET_AlexNet,
    # IMAGENET_VGG,
    # IMAGENET_ResNet,
    # IMAGENET_DenseNet,
    # IMAGENET_Inception,
    # IMAGENET_GoogLeNet,
    # IMAGENET_ShuffleNet,
    # IMAGENET_MobileNet,
    # IMAGENET_ResNext,
    # IMAGENET_WideResNet,
    # IMAGENET_MNASNet,
    # IMAGENET_EfficientNet,
    # IMAGENET_RegNet,
    # IMAGENET_VisionTransformer,
    IMAGENET_ConvNeXt,
])
@pytest.mark.parametrize('split', [
    'test_r',
    # 'test_c:_merged_no_rep5',
    # 'test_c:snow5',
    # 'val',
])
# @pytest.mark.parametrize('best', [
#     False,
#     # True,
# ])
@pytest.mark.parametrize('clustering', [
    # 'kmeans',
    'decision_tree',
])
def test_upperbound(calibrate, net, split, out, n_jobs, clustering):
    """Generate the grouping diagrams of all vision networks on all datasets."""

    if clustering != 'decision_tree' and split != 'test_r':
        pytest.skip(f'We use kmeans only on imagenet-r: {clustering} on {split}')

    if clustering == 'decision_tree':
        clustering_name = 'dt'
    elif clustering == 'kmeans':
        clustering_name = 'km'
    else:
        raise ValueError(f'Unknown {clustering}')

    n_bins = 15
    test_size = 0.5
    max_clusters_bin = 20  # for decision_tree only
    n_clusters = 2  # for kmeans only
    hist = True
    ci = 'clopper'
    min_cluster_size = 14
    capsize = 3.5
    breakout = False
    cluster_size = 30
    vary_cluster_size = False
    absolute_size_scale = (10, 1500)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 1
    _renditions = False
    # legend_size = 10
    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=10.5)#10)#13)
    annotation_size = 15

    # if best:
    #     version = best_versions.get(net, None)
    #     if version is None:
    #         pytest.skip(f'No best version "{version}" found for net "{net}"')

    #     # positions = {
    #     #     IMAGENET_VGG: (0, 1),
    #     #     IMAGENET_ResNet: (0, 0),
    #     #     IMAGENET_DenseNet: (0, 0),
    #     #     IMAGENET_ShuffleNet: (0, 0),
    #     #     IMAGENET_MobileNet: (0, 1),
    #     #     IMAGENET_ResNext: (0, 0),
    #     #     IMAGENET_WideResNet: (0, 0),
    #     #     IMAGENET_MNASNet: (0, 0),
    #     #     IMAGENET_EfficientNet: (1, 1),
    #     #     IMAGENET_RegNet: (1, 0),
    #     #     IMAGENET_VisionTransformer: (1, 0),
    #     #     IMAGENET_ConvNeXt: (1, 0),
    #     # }
    #     # plot_xlabel, plot_ylabel = positions[net]
    #     net = net(split=split, type=version)

    # else:
        # positions = {
        #     IMAGENET_AlexNet: (0, 1),
        #     IMAGENET_VGG: (0, 0),
        #     IMAGENET_ResNet: (0, 0),
        #     IMAGENET_DenseNet: (0, 0),
        #     IMAGENET_Inception: (0, 1),
        #     IMAGENET_GoogLeNet: (0, 0),
        #     IMAGENET_ShuffleNet: (0, 0),
        #     IMAGENET_MobileNet: (0, 0),
        #     IMAGENET_ResNext: (0, 1),
        #     IMAGENET_WideResNet: (0, 0),
        #     IMAGENET_MNASNet: (0, 0),
        #     IMAGENET_EfficientNet: (0, 0),
        #     IMAGENET_RegNet: (1, 1),
        #     IMAGENET_VisionTransformer: (1, 0),
        #     IMAGENET_ConvNeXt: (1, 0),
        # }
        # plot_xlabel, plot_ylabel = positions[net]
    net = net(split=split)

    # if net == IMAGENET_ConvNeXt:
    #     type = 'large'
    #     net = net(split='test_r', type=type)
    # elif net == IMAGENET_VisionTransformer:
    #     type = 'l_16'
    #     net = net(split='test_r', type=type)
    # else:
    #     net = net(split='test_r')


    # net = net(split='test_r', type=type)
    # net = IMAGENET_ConvNeXt(split='test_r')
    # net = IMAGENET_VisionTransformer(split='test_r')
    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
    y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
    y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
    splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
    train_idx, test_idx = next(splitter)

    ds = net.get_dataset()
    # for s, _ in ds.imgs:
    #     match = re.match(f'.*/([a-z]*)_[0-9]*.jpg', s)
    #     if not match:
    #         print(s, match)

    # return
    # if _renditions:
    #     renditions = [re.match(f'.*/([a-z]*)_[0-9]*.jpg', s).group(1) for s, _ in ds.imgs]
    #     renditions = np.array(renditions)
    #     paths = np.array([s for s, _ in ds.imgs])

    #     print(renditions)
    #     assert len(renditions) == len(y_labels)

    # return

    if calibrate is not None:
        y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                       test_size=test_idx, max_calibration=True)
        y_labels = y_well_guess

    (frac_pos,
     counts,
     mean_scores,
     labels,
     *_
     ) = cluster_evaluate_max(Xt, y_labels, y_scores,
                                breakout=breakout,
                                bins=n_bins,
                                clustering=clustering,
                                test_size=test_idx,
                                min_samples_leaf=None,
                                max_clusters_bin=max_clusters_bin,
                                n_clusters=n_clusters,
                                verbose=1,
                                n_jobs=n_jobs,
                                binary_scores=calibrate is not None,
                                return_clustering=True,
                                )


    from calibration.xp_nn_calibration._utils import (
        calibration_curve, grouping_loss_lower_bound,
        grouping_loss_upper_bound_c)

    LB = grouping_loss_lower_bound(frac_pos, counts, debiased=True)
    LB_bins = grouping_loss_lower_bound(frac_pos, counts, debiased=True, reduce_bin=False)
    UB = grouping_loss_upper_bound_c(frac_pos, counts)

    prob_bins, mean_bins = calibration_curve(frac_pos, counts, mean_scores, remove_empty=False)

    C = prob_bins
    S = mean_bins

    I = (0.5 - C)*C*(S < 0.5) + (C - 0.5)*(1 - C)*(S >= 0.5)
    I2 = (C - 0.5)*(1 - C)*(S < 0.5) + (0.5 - C)*C*(S >= 0.5)
    bin_density = np.nansum(counts, axis=1)/np.nansum(counts)
    # I = np.vdot(bin_density, I)
    UB_acc_bins = 1 - 2*(LB_bins - I)
    UB_acc_bins[UB_acc_bins > 1] = 1

    I = np.nansum(bin_density*I)
    I2 = np.nansum(bin_density*I2)
    UB_acc = 1 - 2*(LB - I)
    LB_acc = 2*(LB - I2)
    LB_acc0 = 2*(0 - I2)

    UB_acc2 = np.nansum(bin_density*UB_acc_bins)

    print(net)
    print('LB', LB)
    print('UB', UB)
    print('I', I)
    print('LB_acc0', LB_acc0)
    print('LB_acc', LB_acc)
    print('UB_acc', UB_acc)
    print('UB_acc2', UB_acc2)
    print('Er 2(UB-LB)', 2*(UB-LB))
    print('Er 2UB', 2*UB)
    print('Acc(Q) 1 - 2(UB-LB)', 1 - 2*(UB-LB))
    print('Acc(Q) 1 - 2UB', 1 - 2*UB)
    return


    # if _renditions:
    #     print(labels)
    #     labels_test = labels[test_idx]
    #     print(labels_test)
    #     renditions_test = renditions[test_idx]
    #     print(renditions_test)
    #     paths_test = paths[test_idx]
    #     y_labels_test = y_labels[test_idx]

    #     df = pd.DataFrame({
    #         'rendition': renditions_test,
    #         'clusters': labels_test,
    #         'count': 1,
    #         'paths': paths_test,
    #         'labels': y_labels_test,
    #     })
    #     print(np.unique(y_labels_test))
    #     print(df)
    #     # df_count = df.query('rendition == "origami"')
    #     # df_count = df
    #     df_count = df.pivot_table(index='clusters', columns='labels', values='count', aggfunc=np.sum)
    #     # df_count = df.pivot_table(index='clusters', columns='rendition', values='count', aggfunc=np.sum)

    #     pd.set_option('display.max_columns', 20)
    #     pd.set_option('display.max_rows', 200)
    #     print(df_count.transpose())

    #     df_percent = 100*df_count/df_count.sum(axis=0)
    #     df_percent = df_percent.round(0)

    #     print(df_percent.transpose())

    #     df_count.transpose().to_csv(join(out, f'rendition_counts_{net}.csv'))
    #     df_percent.transpose().to_csv(join(out, f'rendition_percent_{net}.csv'))

    #     for index, subdf in df.groupby(['rendition', 'clusters']):
    #         r, c = index
    #         print(index)
    #         # print(subdf)
    #         dirpath = join(out, f'renditions_{net}', r, str(int(c)))
    #         for p in subdf['paths']:
    #             os.makedirs(dirpath, exist_ok=True)
    #             dest_path = join(dirpath, p.replace('/', '_'))
    #             # shutil.copyfile(p, dest_path)

    # y_pred_scores, y_well_guess = binarize_multiclass_marginal(np.array(y_scores), np.array(y_labels), k)

    xlabel = 'Confidence score'
    ylabel = 'Correct predictions'
    fig = plot_frac_pos_vs_scores(frac_pos,
                                    counts,
                                    mean_scores,
                                    y_scores=None,#y_pred_scores,
                                    y_labels=None,#y_well_guess,
                                    ncol=1,
                                    # ncol=3,
                                    legend_loc='upper left',
                                    bbox_to_anchor=(0, 1),
                                    # legend_loc='lower center',
                                    # bbox_to_anchor=(0.5, 1.1),
                                    # legend_loc='upper right' if calibrate else 'upper left',
                                    title=None,#class_name.capitalize(),
                                    xlim_margin=0.05,
                                    ylim_margin=0.05,
                                    min_cluster_size=min_cluster_size,
                                    hist=hist,
                                    ci=ci,
                                    legend_cluster_sizes=True,
                                    vary_cluster_size=vary_cluster_size,
                                    capsize=capsize,
                                    xlabel='Confidence score',# of predicted class ($\max_k~S_k$)',
                                    # xlabel='Maximum confidence score',
                                    ylabel='Correct predictions (%)',
                                    # ylabel='Fraction of correct predictions',
                                    # ylabel='Fraction of correct predictions',
                                    cluster_size=cluster_size,
                                    absolute_size_scale=absolute_size_scale,
                                    plot_cal_hist=plot_cal_hist,
                                    figsize=figsize,
                                    legend_n_sizes=legend_n_sizes,
                                    # legend_size=legend_size,
                                    legend_sizes_only=True,
                                    legend_min_max=False,
                                    plot_first_last_bins=False,
                                    grid_space=0,
                                    legend_title='Sizes',
                                    )

    ax = fig.axes[0]
    ax.set_xlim((0, 1))
    ax.set_ylim((0, 1))

    # ax.set_xticks([0, 0.5, 1])
    # ax.set_xticklabels(['0', '0.5', '1'])
    # ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    # ax.set_yticklabels(['0', '25', '50', '75', '100'])

    # fig.axes[1].get_xaxis().set_visible(False)
    # fig.axes[2].get_yaxis().set_visible(False)
    # # fig.axes[1].set_xticks([])
    # # fig.axes[2].set_yticks([])

    # # if calibrate is not None or not isinstance(net, IMAGENET_ConvNeXt):
    # # if not plot_ylabel:
    # #     ax.set_ylabel(None)
    # #     ax.yaxis.set_ticklabels([])
    # # if not plot_xlabel:
    # #     ax.set_xlabel(None)
    # #     ax.xaxis.set_ticklabels([])

    ax.set_xlabel(None)
    ax.set_ylabel(None)
    # ax.set_xticks([0, 0.5, 1])
    # ax.set_xticklabels(['0', '', '1'])
    # # ax.set_xticklabels(['0', '0.5', '1'])
    # ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    # ax.set_yticklabels(['0', '', '', '', '100\n(%)'])
    # # ax.set_yticklabels(['0', '25', '50', '75', '100'])

    fig.axes[1].get_xaxis().set_visible(False)
    fig.axes[2].get_yaxis().set_visible(False)


    delta = 0.06
    if True:#plot_xlabel:
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(['0', '', '1'])
        ax.annotate(xlabel, xy=(0.5, -delta), xytext=(0.5, -delta),
                    xycoords='axes fraction', ha='center', va='top',
                    fontsize=plt.rcParams['axes.labelsize'],
                    )
    else:
        ax.xaxis.set_ticklabels([])

    if True:
        ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
        ax.set_yticklabels(['0', '', '', '', '100\n(%)'])
        # ax.set_yticklabels(['0', '25', '50', '75', '100'])
        ax.annotate(ylabel, xy=(-delta, 0.5), xytext=(-delta, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )
    else:
        ax.yaxis.set_ticklabels([])

    out_kwargs = {
        'net': str(net),
        'cal': str(calibrate),
        'clu': clustering_name,
        # 'bins': n_bins,
        # 'ci': ci,
        # 'hist': hist,
        # 'test_size': test_size,
        # 'vsize': vary_cluster_size,
        # # 'capsize': capsize,
        # # 'csize': cluster_size,
        # 'scale': absolute_size_scale,
        # # 'chist': plot_cal_hist,
        # # 'new': True,
    }

    order = ['net', 'cal']
    # order = ['net', 'clustering', 'test_size']

    # out_kwargs['max_clusters_bin'] = max_clusters_bin
    # order.append('max_clusters_bin')

    save_fig(fig, out, order=order, **out_kwargs, pad_inches=0.0)


@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic'
])
@pytest.mark.parametrize('net', [
    IMAGENET_ConvNeXt,
])
@pytest.mark.parametrize('split', [
    'test_r',
])
@pytest.mark.parametrize('clustering', [
    # 'kmeans',
    'decision_tree',
])
def test_gl_induced(calibrate, net, split, out, n_jobs, clustering):
    """Generate the grouping diagrams of all vision networks on all datasets."""
    n_bins = 15
    test_size = 0.5
    max_clusters_bin = 20  # for decision_tree only
    n_clusters = 2  # for kmeans only
    breakout = False
    net = net(split=split)

    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
    y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
    y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
    splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
    train_idx, test_idx = next(splitter)

    from .gl_induced import CEstimator, estimate_GL_induced

    est = CEstimator(y_pred_scores, y_well_guess)
    SS = np.linspace(0, 1, 2000)
    _c_hat = est._c_hat(SS)
    c_hat = est.c_hat()
    # c_hat = est.predict(SS.reshape(-1, 1))

    bins = np.linspace(0, 1, 3)
    GL_ind = estimate_GL_induced(c_hat, y_pred_scores, bins)

    print(GL_ind)

    fig, ax = plt.subplots(1, 1)
    ax.scatter(y_pred_scores, y_well_guess)
    ax.plot(SS, _c_hat, color='black')
    ax.set_aspect('equal')
    save_fig(fig, out)


@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic'
])
@pytest.mark.parametrize('net', [
    IMAGENET_ConvNeXt,
])
@pytest.mark.parametrize('split', [
    'test_r',
])
@pytest.mark.parametrize('clustering', [
    # 'kmeans',
    'decision_tree',
])
def test_n_bins(calibrate, net, split, out, n_jobs, clustering):
    """Generate the grouping diagrams of all vision networks on all datasets."""
    test_size = 0.5
    max_clusters_bin = 2  # for decision_tree only
    n_clusters = 2  # for kmeans only
    breakout = False
    net = net(split=split)


    path = os.path.join(out, 'data.csv')
    if not os.path.exists(path):
        dirpath = net.get_default_dirpath()
        Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
        y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
        y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

        y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)

        rows = []
        for n_bins, strategy in itertools.product(range(1, 51), ['quantile', 'uniform']):
        # for n_bins, strategy in itertools.product([48], ['quantile', 'uniform']):
            print(n_bins, strategy)
            if strategy == 'quantile':
                quantiles = np.linspace(0, 1, n_bins + 1)
                bins = np.percentile(y_pred_scores, quantiles * 100)
            elif strategy == 'uniform':
                bins = np.linspace(0, 1, n_bins + 1)
            else:
                raise ValueError(f'Unknown strategy {strategy}.')

            splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=bins, random_state=0)
            train_idx, test_idx = next(splitter)

            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=True)
                y_labels = y_well_guess

            (frac_pos,
            counts,
            mean_scores,
            labels,
            *_
            ) = cluster_evaluate_max(Xt, y_labels, y_scores,
                                        breakout=breakout,
                                        bins=bins,
                                        clustering=clustering,
                                        test_size=test_idx,
                                        min_samples_leaf=None,
                                        max_clusters_bin=max_clusters_bin,
                                        n_clusters=n_clusters,
                                        verbose=1,
                                        n_jobs=n_jobs,
                                        binary_scores=calibrate is not None,
                                        return_clustering=True,
                                        )

            from calibration.xp_nn_calibration._utils import \
                grouping_loss_lower_bound

            from .gl_induced import CEstimator, estimate_GL_induced

            est = CEstimator(y_pred_scores, y_well_guess)
            c_hat = est.c_hat()

            LB_debiased, bias = grouping_loss_lower_bound(frac_pos, counts, debiased=True, return_bias=True)
            LB_biased = LB_debiased + bias
            GL_ind = estimate_GL_induced(c_hat, y_pred_scores, bins)

            y_bins = np.digitize(y_pred_scores, bins=bins) - 1
            y_bins = np.clip(y_bins, a_min=None, a_max=n_bins-1)
            uniques, counts = np.unique(y_bins[test_idx], return_counts=True)

            counts = np.pad(counts, pad_width=(0, n_bins - len(counts)))
            assert len(counts) == n_bins
            n_samples_per_bin = np.mean(counts)

            rows.append({
                'n_bins': n_bins,
                'LB_debiased': LB_debiased,
                'LB_biased': LB_biased,
                'bias': bias,
                'GL_ind': GL_ind,
                'n_samples_per_bin': n_samples_per_bin,
                'strategy': strategy,
            })
        df = pd.DataFrame(rows)
        df.to_csv(path)

    df = pd.read_csv(path)

    print(df)

    set_latex_font()
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    # sns.lineplot(data=df, x='n_bins', y=''
    df['LB_est'] = df['LB_biased'] - df['bias'] - df['GL_ind']

    # ax.plot(df['n_bins'], df['LB_biased'], label=r'$\mathrm{GL}_{lb}(S_B)$')
    # ax.plot(df['n_bins'], df['bias'], label='bias')
    # ax.plot(df['n_bins'], df['GL_ind'], label=r'$\mathrm{GL}_{ind}(S, B)$')
    # ax.plot(df['n_bins'], df['LB_est'], label=r'$\hat{\mathrm{GL}}_{lb}(S, B)$')
    # twinx = ax.twinx()
    # twinx.plot(df['n_bins'], df['n_samples_per_bin'], label='# samples per bin')

    # ax.scatter(y_pred_scores, y_well_guess)
    # ax.set_aspect('equal')

    _df = df.melt(id_vars=['n_bins', 'strategy'], value_vars=['LB_biased', 'bias', 'GL_ind', 'LB_est'])
    print(_df)
    _df['strategy'] = _df['strategy'].str.capitalize()
    _df['variable'] = _df['variable'].replace({
        'LB_biased': r'$\mathrm{GL}_{lb}(S_B)$',
        'bias': 'bias',
        'GL_ind': r'$\mathrm{GL}_{ind}(S, B)$',
        'LB_est': r'$\hat{\mathrm{GL}}_{lb}(S, B)$',
    })

    _df.rename({'strategy': 'Strategy'}, axis=1, inplace=True)

    sns.lineplot(data=_df, x='n_bins', y='value', hue='variable', style='Strategy',
                 ax=ax, style_order=['Uniform', 'Quantile'])

    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size="30%", pad=0.05)

    ax_top.plot(df['n_bins'], df['n_samples_per_bin'], label='# samples per bin', color='purple')
    ax_top.get_xaxis().set_visible(False)
    # ax_top.set_yticks([0, 15000, 30000])
    ax_top.set_yscale('log')
    ax_top.set(ylabel='Samples per bin', title='ConvNeXt ImageNet-R')

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:], ncol=2)
    ax.set(xlabel='Number of bins', ylabel='GL')
    save_fig(fig, out)


@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic'
])
@pytest.mark.parametrize('net', [
    IMAGENET_ConvNeXt,
])
@pytest.mark.parametrize('split', [
    'test_r',
])
@pytest.mark.parametrize('clustering', [
    # 'kmeans',
    'decision_tree',
])
def test_n_clusters(calibrate, net, split, out, n_jobs, clustering):
    """Generate the grouping diagrams of all vision networks on all datasets."""
    n_bins = 10
    test_size = 0.5
    # max_clusters_bin = 20  # for decision_tree only
    # n_clusters = 2  # for kmeans only
    breakout = False
    net = net(split=split)
    # strategy = 'uniform'


    path = os.path.join(out, 'data.csv')
    if not os.path.exists(path):
        dirpath = net.get_default_dirpath()
        Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
        y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
        y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

        y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)

        rows = []
        for n_clusters, strategy in itertools.product(range(2, 11), ['quantile', 'uniform']):
        # for n_bins, strategy in itertools.product([48], ['quantile', 'uniform']):
            print(n_clusters, strategy)
            max_clusters_bin = n_clusters

            if strategy == 'quantile':
                quantiles = np.linspace(0, 1, n_bins + 1)
                bins = np.percentile(y_pred_scores, quantiles * 100)
            elif strategy == 'uniform':
                bins = np.linspace(0, 1, n_bins + 1)
            else:
                raise ValueError(f'Unknown strategy {strategy}.')

            splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=bins, random_state=0)
            train_idx, test_idx = next(splitter)

            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=True)
                y_labels = y_well_guess

            (frac_pos,
            counts,
            mean_scores,
            labels,
            *_
            ) = cluster_evaluate_max(Xt, y_labels, y_scores,
                                        breakout=breakout,
                                        bins=bins,
                                        clustering=clustering,
                                        test_size=test_idx,
                                        min_samples_leaf=None,
                                        max_clusters_bin=max_clusters_bin,
                                        n_clusters=n_clusters,
                                        verbose=1,
                                        n_jobs=n_jobs,
                                        binary_scores=calibrate is not None,
                                        return_clustering=True,
                                        )

            from calibration.xp_nn_calibration._utils import \
                grouping_loss_lower_bound

            from .gl_induced import CEstimator, estimate_GL_induced

            est = CEstimator(y_pred_scores, y_well_guess)
            c_hat = est.c_hat()

            LB_debiased, bias = grouping_loss_lower_bound(frac_pos, counts, debiased=True, return_bias=True)
            LB_biased = LB_debiased + bias
            GL_ind = estimate_GL_induced(c_hat, y_pred_scores, bins)

            rows.append({
                'n_clusters': n_clusters,
                'LB_debiased': LB_debiased,
                'LB_biased': LB_biased,
                'bias': bias,
                'GL_ind': GL_ind,
                'n_samples_per_cluster': np.mean(counts, where=counts > 0),
                'strategy': strategy,
            })
        df = pd.DataFrame(rows)
        os.makedirs(out, exist_ok=True)
        df.to_csv(path)

    df = pd.read_csv(path)

    print(df)

    set_latex_font()
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    df['LB_est'] = df['LB_biased'] - df['bias'] - df['GL_ind']

    _df = df.melt(id_vars=['n_clusters', 'strategy'], value_vars=['LB_biased', 'bias', 'GL_ind', 'LB_est'])
    print(_df)
    _df['strategy'] = _df['strategy'].str.capitalize()
    _df['variable'] = _df['variable'].replace({
        'LB_biased': r'$\mathrm{GL}_{lb}(S_B)$',
        'bias': 'bias',
        'GL_ind': r'$\mathrm{GL}_{ind}(S, B)$',
        'LB_est': r'$\hat{\mathrm{GL}}_{lb}(S, B)$',
    })

    _df.rename({'strategy': 'Strategy'}, axis=1, inplace=True)

    sns.lineplot(data=_df, x='n_clusters', y='value', hue='variable', style='Strategy',
                 ax=ax, style_order=['Uniform', 'Quantile'])

    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size="30%", pad=0.05)

    ax_top.plot(df['n_clusters'], df['n_samples_per_cluster'], label='# samples per cluster', color='purple')
    ax_top.get_xaxis().set_visible(False)
    # ax_top.set_yticks([0, 15000, 30000])
    ax_top.set_yscale('log')
    ax_top.set(ylabel='Samples per cluster', title='ConvNeXt ImageNet-R')

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:], ncol=2)
    ax.set(xlabel='Number of clusters', ylabel='GL')
    save_fig(fig, out)



@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic'
])
@pytest.mark.parametrize('clustering', [
    # 'kmeans',
    'decision_tree',
])
def test_n_bins_simu(calibrate, out, n_jobs, clustering):
    """Generate the grouping diagrams of all vision networks on all datasets."""
    test_size = 0.5
    max_clusters_bin = 2  # for decision_tree only
    n_clusters = 2  # for kmeans only
    breakout = False

    ex = Link1DExample(link='poly')
    assert isinstance(ex, BaseExample)

    Xt, y_labels = ex.generate_X_y(n=10000, random_state=0)
    y_scores = ex.S(Xt)

    path = os.path.join(out, 'data.csv')
    if not os.path.exists(path):

        if y_scores.ndim > 1:
            y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
        else:
            y_pred_scores = y_scores
            y_well_guess = y_labels

        GL = ex.GL()

        rows = []
        for n_bins, strategy in itertools.product(range(1, 51), ['quantile', 'uniform']):
            print(n_bins, strategy)
            if strategy == 'quantile':
                quantiles = np.linspace(0, 1, n_bins + 1)
                bins = np.percentile(y_pred_scores, quantiles * 100)
            elif strategy == 'uniform':
                bins = np.linspace(0, 1, n_bins + 1)
            else:
                raise ValueError(f'Unknown strategy {strategy}.')

            splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=bins, random_state=0)
            train_idx, test_idx = next(splitter)

            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=True)
                y_labels = y_well_guess

            (frac_pos,
            counts,
            mean_scores,
             *_
             ) = cluster_evaluate(Xt, y_labels, y_scores,
                                  bins=bins,
                                  clustering=clustering,
                                  test_size=test_idx,
                                  min_samples_leaf=None,
                                  max_clusters_bin=max_clusters_bin,
                                  n_clusters=n_clusters,
                                  verbose=1,
                                  n_jobs=n_jobs,
                                  return_clustering=True,
                                  )

            from calibration.xp_nn_calibration._utils import \
                grouping_loss_lower_bound

            from .gl_induced import CEstimator, estimate_GL_induced

            est = CEstimator(y_pred_scores, y_well_guess)
            c_hat = est.c_hat()

            LB_debiased, bias = grouping_loss_lower_bound(frac_pos, counts, debiased=True, return_bias=True)
            LB_biased = LB_debiased + bias
            GL_ind = estimate_GL_induced(c_hat, y_pred_scores, bins)

            y_bins = np.digitize(y_pred_scores, bins=bins) - 1
            y_bins = np.clip(y_bins, a_min=None, a_max=n_bins-1)
            uniques, counts = np.unique(y_bins[test_idx], return_counts=True)

            counts = np.pad(counts, pad_width=(0, n_bins - len(counts)))
            assert len(counts) == n_bins
            n_samples_per_bin = np.mean(counts)

            rows.append({
                'n_bins': n_bins,
                'LB_debiased': LB_debiased,
                'LB_biased': LB_biased,
                'bias': bias,
                'GL_ind': GL_ind,
                'GL': GL,
                'n_samples_per_bin': n_samples_per_bin,
                'strategy': strategy,
            })
        df = pd.DataFrame(rows)
        os.makedirs(out, exist_ok=True)
        df.to_csv(path)

    df = pd.read_csv(path)

    print(df)

    set_latex_font()
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    # sns.lineplot(data=df, x='n_bins', y=''
    df['LB_est'] = df['LB_biased'] - df['bias'] - df['GL_ind']

    _df = df.melt(id_vars=['n_bins', 'strategy'], value_vars=['LB_biased', 'bias', 'GL_ind', 'LB_est', 'GL'])
    print(_df)
    _df['variable'] = _df['variable'].replace({
        'LB_biased': r'$\widehat{\mathrm{LB}}(S_B)$',
        'bias': 'bias',
        'GL_ind': r'$\mathrm{GL}_{ind}(S, B)$',
        'LB_est': r'$\widehat{\mathrm{LB}}_{corrected}(S, B)$',
        'GL': r'$\mathrm{GL}$',
    })
    _df['strategy'] = _df['strategy'].replace({
        'uniform': 'Equal-width',
        'quantile': 'Equal-mass',
    })

    _df.rename({'strategy': 'Strategy'}, axis=1, inplace=True)

    sns.lineplot(data=_df, x='n_bins', y='value', hue='variable', style='Strategy',
                 ax=ax, style_order=['Equal-width', 'Equal-mass'])

    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size="30%", pad=0.05)

    ax_top.plot(df['n_bins'], df['n_samples_per_bin'], label='# samples per bin', color='purple')
    ax_top.get_xaxis().set_visible(False)
    # ax_top.set_yticks([0, 15000, 30000])
    ax_top.set_yscale('log')
    ax_top.set(ylabel='Samples per bin', title='Simulation')

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:], ncol=2)
    ax.set(xlabel='Number of bins', ylabel='GL')
    save_fig(fig, out)


@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic'
])
@pytest.mark.parametrize('clustering', [
    # 'kmeans',
    'decision_tree',
])
@pytest.mark.parametrize('grid', [
    list(product([15], np.linspace(1, 150, 50))),  # Clusters
    list(product(range(1, 31), [30])),  # Bins
])
@pytest.mark.parametrize('n_trials', [
    # 1,
    # 10,
    100,
])
@pytest.mark.parametrize('n', [
    10000,
    # 100000,
])
@pytest.mark.parametrize('s', [
    # ('1d', 'bowl'),
    # ('1d', 'sin'),
    ('2d', 'expit'),
    # ('2d', 'expit', 10),
    # ('2d', 'sin'),
    # ('2d', 'signsin'),
    # ('2d', 'sin', 0.5),
])
def test_n_clusters_simu(calibrate, inp, out, n_jobs, clustering, grid, n_trials, s,
                         nocache, n):
    """Generate the grouping diagrams of all vision networks on all datasets."""
    test_size = 0.5
    # n = 10000
    # k = 150  # number max of clusters

    grid = list(grid)

    if len(np.unique([x[0] for x in grid])) == 1:
        which = 'cluster'
    elif len(np.unique([x[1] for x in grid])) == 1:
        which = 'bin'
    else:
        raise ValueError(f'grid must have an unique value for the first or '
                         f'second axis. Got {grid}.')

    if len(s) == 2:
        s = (s[0], s[1], 1)

    name, version, alpha = s
    if name == '2d':
        w = np.ones(2)
        w_perp = create_orthonormal_vector(w)
        ex = SigmoidExample(w, w_perp, bayes_opt=False, alpha=alpha, psi=version)

    elif name == '1d':
        ex = Link1DExample(link='poly', s=version, alpha=alpha)

    else:
        raise ValueError(f'Unknown name {name}.')

    assert isinstance(ex, BaseExample)

    GL = ex.GL_emp(N=10000000)
    print('GL', GL)

    kwargs = dict(
        clustering=clustering,
        trials=n_trials,
        which=which,
        ex=repr(ex),
        n=n,
    )

    print('started test with:\n')
    print('kwargs', kwargs)
    print('grid', grid)
    print()

    path = save_path(inp, ext='csv', **kwargs)
    if nocache or not os.path.exists(path):
        rows = []

        def compute_one(trial, n_bins, n_samples_per_cluster_per_bin, strategy):
        # for trial in range(n_trials):
            n_samples_per_cluster_per_bin = int(n_samples_per_cluster_per_bin)
            n_bins = int(n_bins)
            Xt, y_labels = ex.generate_X_y(n=n, random_state=trial)
            y_scores = ex.S(Xt)

            if y_scores.ndim > 1:
                y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
            else:
                y_pred_scores = y_scores
                y_well_guess = y_labels

            if strategy == 'quantile':
                quantiles = np.linspace(0, 1, n_bins + 1)
                bins = np.percentile(y_pred_scores, quantiles * 100)
            elif strategy == 'uniform':
                bins = np.linspace(0, 1, n_bins + 1)
            else:
                raise ValueError(f'Unknown strategy {strategy}.')

            splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=bins, random_state=0)
            train_idx, test_idx = next(splitter)

            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=True)
                y_labels = y_well_guess

            (frac_pos,
            counts,
            mean_scores,
            *_
            ) = cluster_evaluate(Xt, y_labels, y_scores,
                                        bins=bins,
                                        clustering=clustering,
                                        test_size=test_idx,
                                        n_samples_per_cluster_per_bin=n_samples_per_cluster_per_bin,
                                        verbose=0,
                                        n_jobs=1,
                                        return_clustering=True,
                                        )

            est = CEstimator(y_pred_scores, y_well_guess)
            c_hat = est.c_hat()

            LB_debiased, bias = grouping_loss_lower_bound(frac_pos, counts, debiased=True, return_bias=True)
            LB_biased = LB_debiased + bias
            GL_ind = estimate_GL_induced(c_hat, y_pred_scores, bins)

            return {
                'n_samples_per_cluster_per_bin': n_samples_per_cluster_per_bin,
                'LB_debiased': LB_debiased,
                'LB_biased': LB_biased,
                'bias': bias,
                'GL_ind': GL_ind,
                'GL': GL,
                'n_samples_per_cluster': np.mean(counts, where=counts > 0),
                'n_size_one_clusters': np.sum(counts == 1),
                'n_nonzero_clusters': np.sum(counts > 0),
                'strategy': strategy,
                'trial': trial,
                'n_bins': n_bins,
            }

        rows = Parallel(n_jobs=n_jobs)(delayed(compute_one)(
            trial, n_bins, n_clusters, strategy
        ) for trial, (n_bins, n_clusters), strategy in tqdm(list(product(
            range(n_trials), grid, ['quantile', 'uniform'])))
        )

        print('Computation ended')


        df = pd.DataFrame(rows)
        print('Merged df')
        print(df)
        os.makedirs(inp, exist_ok=True)
        df.to_csv(path)

        print('Post compute one')
        print(df)

    df = pd.read_csv(path)

    # print('Post read')
    # print(df)

    # df = df.drop_duplicates(subset=['n_bins', 'n_samples_per_cluster_per_bin', 'strategy'])

    # print('Post dropping duplicates')
    # print(df)

    # df = df.query('n_bins == 15')
    # print(df)
    # return

    if which == 'cluster':
        df = df.query('n_samples_per_cluster_per_bin <= 100')
        fig = plot_simu(df, x='n_samples_per_cluster_per_bin', legend=True)
        ax = fig.axes[0]
        ax.set(xlabel=r'Ratio $\frac{\mathrm{number~of~samples}}{\mathrm{number~of~clusters}}$ per bin')
        ax.set(ylabel=None)
        ax.set_yticklabels([])
        xmin = df['n_samples_per_cluster_per_bin'].min()
        xmax = df['n_samples_per_cluster_per_bin'].max()
        ax.set_xlim((xmin, xmax))

    elif which == 'bin':
        fig = plot_simu(df, x='n_bins', legend=False)
        ax = fig.axes[0]
        ax.set(xlabel='Number of bins')
        xmin = df['n_bins'].min()
        xmax = df['n_bins'].max()
        ax.set_xlim((xmin, xmax))
        # fig.axes[0].set_yscale('log')


    ax.set_title(' ', fontsize=1)  # For both figures to have same height
    ax = fig.axes[0]

    # ax.set_ylim((0, 0.015))
    # ax.set_ylim((0, 0.03))
    ymax = float(f"{2*df['GL'][0]:.1g}")
    print('ymax', ymax)
    ax.set_ylim((0, ymax))

    save_fig(fig, out, **kwargs)


    kwargs['which'] = 'bin'
    path1 = save_path(inp, ext='csv', **kwargs)
    kwargs['which'] = 'cluster'
    path2 = save_path(inp, ext='csv', **kwargs)
    kwargs['which'] = 'both'

    if os.path.exists(path1) and os.path.exists(path2):

        plt.rc('xtick', labelsize=8)
        plt.rc('ytick', labelsize=8)
        # plt.rc('legend', labelspacing=0.1)
        plt.rc('legend', handlelength=1.3)
        plt.rc('legend', handletextpad=0.4)
        plt.rc('axes', titlepad=4)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(3, 1.5), gridspec_kw=dict(wspace=0.03))
        # ax1, ax2 = ax

        df1 = pd.read_csv(path1)
        df2 = pd.read_csv(path2)

        df2 = df2.query('n_samples_per_cluster_per_bin <= 100')
        plot_simu(df2, x='n_samples_per_cluster_per_bin', legend=False, ax=ax1)
        # ax2 = fig.axes[0]
        ax1.set(xlabel=r'$\frac{\mathrm{\#~samples}}{\mathrm{\#~regions}}$ per bin')
        # ax2.set(xlabel=r'Ratio $\frac{\mathrm{number~of~samples}}{\mathrm{number~of~clusters}}$ per bin')
        ax2.set(ylabel=None)
        ax2.set_yticklabels([])
        xmin = df2['n_samples_per_cluster_per_bin'].min()
        xmax = df2['n_samples_per_cluster_per_bin'].max()
        ax1.set_xlim((xmin, xmax))

        from matplotlib.ticker import MultipleLocator

        plot_simu(df1, x='n_bins', legend=True, ax=ax2)
        # ax1 = fig.axes[0]
        ax2.set(xlabel=r'\# bins')
        xmin = df1['n_bins'].min()
        xmax = df1['n_bins'].max()
        ax2.set_xlim((xmin, xmax))
        ax1.set_yticks([0, 0.005, 0.01])
        ax2.set_xticks([5, 15, 25])
        # ax1.set_xticks([1, 5, 10, 15, 20, 25, 30])
        # ax1.set_xticklabels(['', '5', '', '15', '', '25', ''])
        # ax2.set_yticks(ax1.get_yticks())
        ax2.set_yticks([])
        # fig.axes[0].set_yscale('log')

        # ax1.xaxis.set_major_locator(MultipleLocator())
        # ax1.xaxis.set_major_formatter('{x:.0f}')

        # For the minor ticks, use no labels; default NullFormatter.
        ax2.xaxis.set_minor_locator(MultipleLocator(5))


        # ax.set_title(' ', fontsize=1)  # For both figures to have same height
        # ax = fig.axes[0]

        # ax.set_ylim((0, 0.015))
        # ax.set_ylim((0, 0.03))
        ymax = float(f"{2*df1['GL'][0]:.1g}")
        print('ymax', ymax)
        ax1.set_ylim((0, ymax))
        ax2.set_ylim((0, ymax))

        ax1.set(title='a.')
        ax2.set(title='b.')

        save_fig(fig, out, pad_inches=0, **kwargs)
