"""Code that generates the figures of the paper."""
import itertools
import os
import re
import shutil
from os.path import join
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import torch
from joblib import Memory, Parallel, delayed
from sklearn.utils import check_random_state
from tqdm import tqdm

from .._linalg import create_orthonormal_vector
from .._plot import (barplot_ece_gl, barplot_ece_gl_brier, barplots_ece_gl_cal,
                     plot_brier_acc, plot_brier_ece, plot_cost_vs_gl,
                     plot_ffstar_1d, plot_ffstar_2d_v2,
                     plot_fig_counter_example, plot_fig_renditions,
                     plot_fig_theorem, plot_fig_theorem_v2, plot_frac_pos_vs_scores,
                     plot_lower_bound_vs_acc, plot_lower_bound_vs_brier,
                     plot_lower_bound_vs_ece, plot_renditions_calibration,
                     plot_score_vs_probas2)
from .._utils import compute_classif_metrics, save_fig
from ..CalibrationExample import (CustomUnconstrained, CustomUniform,
                                  SigmoidExample)
from ._utils import (bin_train_test_split, binarize_multiclass_marginal,
                     binarize_multiclass_max, calibrate_scores,
                     calibration_curve, compute_calib_metrics,
                     compute_multi_classif_metrics)
from .main import cluster_evaluate_marginals, cluster_evaluate_max
from .networks import (ALL_IMAGENET_NETWORKS, IMAGENET_VGG, Amazon_BertBase,
                       Emotion_DistilBert, IMAGENET_AlexNet, IMAGENET_ConvNeXt,
                       IMAGENET_DenseNet, IMAGENET_EfficientNet,
                       IMAGENET_GoogLeNet, IMAGENET_Inception,
                       IMAGENET_MLPMixer, IMAGENET_MNASNet, IMAGENET_MobileNet,
                       IMAGENET_RegNet, IMAGENET_ResNet, IMAGENET_ResNext,
                       IMAGENET_ShuffleNet, IMAGENET_VisionTransformer,
                       IMAGENET_WideResNet, ZeroShotBartYahoo)
from .results import _get_out_kwargs
from .test_results import read_los_prediction_results2

memory = Memory('joblib_cache')


best_versions = {
    IMAGENET_VGG: '19_bn',
    IMAGENET_ResNet: '152',
    IMAGENET_DenseNet: '161',
    IMAGENET_ShuffleNet: '1_0',
    IMAGENET_MobileNet: 'v3L',
    IMAGENET_ResNext: '101',
    IMAGENET_WideResNet: '101',
    IMAGENET_MNASNet: '1_0',
    IMAGENET_EfficientNet: 'b7',
    IMAGENET_RegNet: 'y_32gf',
    IMAGENET_VisionTransformer: 'l_16',
    IMAGENET_ConvNeXt: 'large',
}

rename_versions = {
    '19_bn': '-19 BN',
    '50': '-50',
    '11': '-11',
    '18': '-18',
    '152': '-152',
    '161': '-161',
    '121': '-121',
    '1_0': ' 1.0',
    '0_5': ' 0.5',
    'v3L': ' V3L',
    # 'v3L': ' V3 Large',
    'v2': ' V2',
    '101': '-101',
    'b7': '-B7',
    'b0': '-B0',
    'l_16': ' L-16',
    'b_16': ' B-16',
    'large': ' Large',
    'y_400mf': ' y_400mf',
    'y_32gf': ' y_32gf',
    '': '',
}


@pytest.mark.parametrize('calibrate', [
    None,
    'isotonic'
])
@pytest.mark.parametrize('net', [
    IMAGENET_AlexNet,
    IMAGENET_VGG,
    IMAGENET_ResNet,
    IMAGENET_DenseNet,
    IMAGENET_Inception,
    IMAGENET_GoogLeNet,
    IMAGENET_ShuffleNet,
    IMAGENET_MobileNet,
    IMAGENET_ResNext,
    IMAGENET_WideResNet,
    IMAGENET_MNASNet,
    IMAGENET_EfficientNet,
    IMAGENET_RegNet,
    IMAGENET_VisionTransformer,
    IMAGENET_ConvNeXt,
])
@pytest.mark.parametrize('split', [
    'test_r',
    'test_c:_merged_no_rep5',
    'test_c:snow5',
    'val',
])
@pytest.mark.parametrize('best', [
    False,
    True,
])
@pytest.mark.parametrize('clustering', [
    'kmeans',
    'decision_tree',
])
def test_fig_imagenet_calibration(calibrate, net, split, out, n_jobs, best, clustering):
    """Generate the grouping diagrams of all vision networks on all datasets."""

    if clustering != 'decision_tree' and split != 'test_r':
        pytest.skip(f'We use kmeans only on imagenet-r: {clustering} on {split}')

    if clustering == 'decision_tree':
        clustering_name = 'dt'
    elif clustering == 'kmeans':
        clustering_name = 'km'
    else:
        raise ValueError(f'Unknown {clustering}')

    n_bins = 15
    test_size = 0.5
    max_clusters_bin = 2  # for decision_tree only
    n_clusters = 2  # for kmeans only
    hist = True
    ci = 'clopper'
    min_cluster_size = 14
    capsize = 3.5
    breakout = False
    cluster_size = 30
    vary_cluster_size = False
    absolute_size_scale = (10, 1500)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 1
    _renditions = False
    # legend_size = 10
    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=10.5)#10)#13)
    annotation_size = 15

    if best:
        version = best_versions.get(net, None)
        if version is None:
            pytest.skip(f'No best version "{version}" found for net "{net}"')

        positions = {
            IMAGENET_VGG: (0, 1),
            IMAGENET_ResNet: (0, 0),
            IMAGENET_DenseNet: (0, 0),
            IMAGENET_ShuffleNet: (0, 0),
            IMAGENET_MobileNet: (0, 1),
            IMAGENET_ResNext: (0, 0),
            IMAGENET_WideResNet: (0, 0),
            IMAGENET_MNASNet: (0, 0),
            IMAGENET_EfficientNet: (1, 1),
            IMAGENET_RegNet: (1, 0),
            IMAGENET_VisionTransformer: (1, 0),
            IMAGENET_ConvNeXt: (1, 0),
        }
        plot_xlabel, plot_ylabel = positions[net]
        net = net(split=split, type=version)

    else:
        positions = {
            IMAGENET_AlexNet: (0, 1),
            IMAGENET_VGG: (0, 0),
            IMAGENET_ResNet: (0, 0),
            IMAGENET_DenseNet: (0, 0),
            IMAGENET_Inception: (0, 1),
            IMAGENET_GoogLeNet: (0, 0),
            IMAGENET_ShuffleNet: (0, 0),
            IMAGENET_MobileNet: (0, 0),
            IMAGENET_ResNext: (0, 1),
            IMAGENET_WideResNet: (0, 0),
            IMAGENET_MNASNet: (0, 0),
            IMAGENET_EfficientNet: (0, 0),
            IMAGENET_RegNet: (1, 1),
            IMAGENET_VisionTransformer: (1, 0),
            IMAGENET_ConvNeXt: (1, 0),
        }
        plot_xlabel, plot_ylabel = positions[net]
        net = net(split=split)

    # if net == IMAGENET_ConvNeXt:
    #     type = 'large'
    #     net = net(split='test_r', type=type)
    # elif net == IMAGENET_VisionTransformer:
    #     type = 'l_16'
    #     net = net(split='test_r', type=type)
    # else:
    #     net = net(split='test_r')


    # net = net(split='test_r', type=type)
    # net = IMAGENET_ConvNeXt(split='test_r')
    # net = IMAGENET_VisionTransformer(split='test_r')
    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
    y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
    y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
    splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
    train_idx, test_idx = next(splitter)

    ds = net.get_dataset()
    # for s, _ in ds.imgs:
    #     match = re.match(f'.*/([a-z]*)_[0-9]*.jpg', s)
    #     if not match:
    #         print(s, match)

    # return
    if _renditions:
        renditions = [re.match(f'.*/([a-z]*)_[0-9]*.jpg', s).group(1) for s, _ in ds.imgs]
        renditions = np.array(renditions)
        paths = np.array([s for s, _ in ds.imgs])

        print(renditions)
        assert len(renditions) == len(y_labels)

    # return

    if calibrate is not None:
        y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                       test_size=test_idx, max_calibration=True)
        y_labels = y_well_guess

    (frac_pos,
     counts,
     mean_scores,
     labels,
     *_
     ) = cluster_evaluate_max(Xt, y_labels, y_scores,
                                breakout=breakout,
                                bins=n_bins,
                                clustering=clustering,
                                test_size=test_idx,
                                min_samples_leaf=None,
                                max_clusters_bin=max_clusters_bin,
                                n_clusters=n_clusters,
                                verbose=1,
                                n_jobs=n_jobs,
                                binary_scores=calibrate is not None,
                                return_clustering=True,
                                )

    if _renditions:
        print(labels)
        labels_test = labels[test_idx]
        print(labels_test)
        renditions_test = renditions[test_idx]
        print(renditions_test)
        paths_test = paths[test_idx]
        y_labels_test = y_labels[test_idx]

        df = pd.DataFrame({
            'rendition': renditions_test,
            'clusters': labels_test,
            'count': 1,
            'paths': paths_test,
            'labels': y_labels_test,
        })
        print(np.unique(y_labels_test))
        print(df)
        # df_count = df.query('rendition == "origami"')
        # df_count = df
        df_count = df.pivot_table(index='clusters', columns='labels', values='count', aggfunc=np.sum)
        # df_count = df.pivot_table(index='clusters', columns='rendition', values='count', aggfunc=np.sum)

        pd.set_option('display.max_columns', 20)
        pd.set_option('display.max_rows', 200)
        print(df_count.transpose())

        df_percent = 100*df_count/df_count.sum(axis=0)
        df_percent = df_percent.round(0)

        print(df_percent.transpose())

        df_count.transpose().to_csv(join(out, f'rendition_counts_{net}.csv'))
        df_percent.transpose().to_csv(join(out, f'rendition_percent_{net}.csv'))

        for index, subdf in df.groupby(['rendition', 'clusters']):
            r, c = index
            print(index)
            # print(subdf)
            dirpath = join(out, f'renditions_{net}', r, str(int(c)))
            for p in subdf['paths']:
                os.makedirs(dirpath, exist_ok=True)
                dest_path = join(dirpath, p.replace('/', '_'))
                # shutil.copyfile(p, dest_path)

    # y_pred_scores, y_well_guess = binarize_multiclass_marginal(np.array(y_scores), np.array(y_labels), k)

    xlabel = 'Confidence score'
    ylabel = 'Correct predictions'
    fig = plot_frac_pos_vs_scores(frac_pos,
                                    counts,
                                    mean_scores,
                                    y_scores=None,#y_pred_scores,
                                    y_labels=None,#y_well_guess,
                                    ncol=1,
                                    # ncol=3,
                                    legend_loc='upper left',
                                    bbox_to_anchor=(0, 1),
                                    # legend_loc='lower center',
                                    # bbox_to_anchor=(0.5, 1.1),
                                    # legend_loc='upper right' if calibrate else 'upper left',
                                    title=None,#class_name.capitalize(),
                                    xlim_margin=0.05,
                                    ylim_margin=0.05,
                                    min_cluster_size=min_cluster_size,
                                    hist=hist,
                                    ci=ci,
                                    legend_cluster_sizes=True,
                                    vary_cluster_size=vary_cluster_size,
                                    capsize=capsize,
                                    xlabel='Confidence score',# of predicted class ($\max_k~S_k$)',
                                    # xlabel='Maximum confidence score',
                                    ylabel='Correct predictions (%)',
                                    # ylabel='Fraction of correct predictions',
                                    # ylabel='Fraction of correct predictions',
                                    cluster_size=cluster_size,
                                    absolute_size_scale=absolute_size_scale,
                                    plot_cal_hist=plot_cal_hist,
                                    figsize=figsize,
                                    legend_n_sizes=legend_n_sizes,
                                    # legend_size=legend_size,
                                    legend_sizes_only=True,
                                    legend_min_max=False,
                                    plot_first_last_bins=False,
                                    grid_space=0,
                                    legend_title='Sizes',
                                    )

    ax = fig.axes[0]
    ax.set_xlim((0, 1))
    ax.set_ylim((0, 1))

    # ax.set_xticks([0, 0.5, 1])
    # ax.set_xticklabels(['0', '0.5', '1'])
    # ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    # ax.set_yticklabels(['0', '25', '50', '75', '100'])

    # fig.axes[1].get_xaxis().set_visible(False)
    # fig.axes[2].get_yaxis().set_visible(False)
    # # fig.axes[1].set_xticks([])
    # # fig.axes[2].set_yticks([])

    # # if calibrate is not None or not isinstance(net, IMAGENET_ConvNeXt):
    # # if not plot_ylabel:
    # #     ax.set_ylabel(None)
    # #     ax.yaxis.set_ticklabels([])
    # # if not plot_xlabel:
    # #     ax.set_xlabel(None)
    # #     ax.xaxis.set_ticklabels([])

    ax.set_xlabel(None)
    ax.set_ylabel(None)
    # ax.set_xticks([0, 0.5, 1])
    # ax.set_xticklabels(['0', '', '1'])
    # # ax.set_xticklabels(['0', '0.5', '1'])
    # ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    # ax.set_yticklabels(['0', '', '', '', '100\n(%)'])
    # # ax.set_yticklabels(['0', '25', '50', '75', '100'])

    fig.axes[1].get_xaxis().set_visible(False)
    fig.axes[2].get_yaxis().set_visible(False)


    delta = 0.06
    if True:#plot_xlabel:
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(['0', '', '1'])
        ax.annotate(xlabel, xy=(0.5, -delta), xytext=(0.5, -delta),
                    xycoords='axes fraction', ha='center', va='top',
                    fontsize=plt.rcParams['axes.labelsize'],
                    )
    else:
        ax.xaxis.set_ticklabels([])

    if plot_ylabel:
        ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
        ax.set_yticklabels(['0', '', '', '', '100\n(%)'])
        # ax.set_yticklabels(['0', '25', '50', '75', '100'])
        ax.annotate(ylabel, xy=(-delta, 0.5), xytext=(-delta, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )
    else:
        ax.yaxis.set_ticklabels([])

    out_kwargs = {
        'net': str(net),
        'cal': str(calibrate),
        'clu': clustering_name,
        # 'bins': n_bins,
        # 'ci': ci,
        # 'hist': hist,
        # 'test_size': test_size,
        # 'vsize': vary_cluster_size,
        # # 'capsize': capsize,
        # # 'csize': cluster_size,
        # 'scale': absolute_size_scale,
        # # 'chist': plot_cal_hist,
        # # 'new': True,
    }

    order = ['net', 'cal']
    # order = ['net', 'clustering', 'test_size']

    # out_kwargs['max_clusters_bin'] = max_clusters_bin
    # order.append('max_clusters_bin')

    save_fig(fig, out, order=order, **out_kwargs, pad_inches=0.0)


# @pytest.mark.parametrize('calibrate', [None, 'isotonic'])
@pytest.mark.parametrize('which', [
    'small',
    'best',
    'both',
])
def test_fig_network_comparison(which, inp, out, n_jobs, nocache, append_versions=True):
    """"Generate the comparison figure (ECE and grouping loss) for all vision
    networks on all datasets (figures 7 and 14)."""
    n_bins = 15
    clustering = 'decision_tree'
    test_size = 0.5
    hist = True
    ci = 'clopper'
    min_cluster_size = 10
    capsize = 4
    breakout = False
    cluster_size = 100
    vary_cluster_size = False
    absolute_size_scale = (10, 1000)
    plot_cal_hist = True
    min_samples_leaf = None
    max_clusters_bin = None
    n_clusters = None
    n_samples_per_cluster_per_bin = 30

    bins = np.linspace(0, 1, n_bins+1)

    def compute_one(split):
        # split = 'test_r'
        nets = [
            IMAGENET_VGG,
            IMAGENET_ResNet,
            IMAGENET_DenseNet,
            IMAGENET_Inception,
            IMAGENET_GoogLeNet,
            IMAGENET_ShuffleNet,
            IMAGENET_MobileNet,
            IMAGENET_ResNext,
            IMAGENET_WideResNet,
            IMAGENET_MNASNet,
            IMAGENET_EfficientNet,
            IMAGENET_RegNet,
            IMAGENET_VisionTransformer,
            IMAGENET_MLPMixer,
            IMAGENET_AlexNet,
            IMAGENET_ConvNeXt,
        ]

        def compute_one_net(net, calibrate, best):
            print(split, net, calibrate, best)
            if best:
                version = best_versions.get(net, None)
                if version is None:
                    print(f'No best version "{version}" found for net "{net}"')
                    return pd.DataFrame()
                net = net(split=split, type=version)  # best version
            else:
                net = net(split=split)  # default version (ie smallest)
            dirpath = net.get_default_dirpath()
            Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
            y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
            y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

            y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
            splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
            train_idx, test_idx = next(splitter)

            _y_scores, _y_labels = y_scores, y_labels
            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=True)
                y_labels = y_well_guess
                # y_scores, y_labels = calibrate_scores(y_scores, y_labels, method=calibrate, max_calibration=True)

            (frac_pos,
            counts,
            mean_scores,
            *_
            ) = cluster_evaluate_max(Xt, y_labels, y_scores, breakout=breakout,
                                    bins=n_bins, verbose=2,
                                    n_jobs=4,
                                    min_samples_leaf=min_samples_leaf,
                                    max_clusters_bin=max_clusters_bin,
                                    n_samples_per_cluster_per_bin=n_samples_per_cluster_per_bin,
                                    clustering=clustering,
                                    n_clusters=n_clusters,
                                    test_size=test_idx,
                                    binary_scores=calibrate is not None,
                                    )

            if calibrate is None:
                y_pred_scores, y_well_guess = binarize_multiclass_max(_y_scores, _y_labels)

            else:
                y_well_guess, y_pred_scores = y_labels[test_idx], y_scores[test_idx, 1]

            extra_out_kwargs = {
                'split': split,
                'dataset': net.get_dataset_name(),
                'network': net.get_class_name(False),
                'network+version': net.get_class_name(True),
                'calibrate': str(calibrate),
                'best': best,
                'n_samples_per_cluster_per_bin': n_samples_per_cluster_per_bin,
            }
            name = str(net)
            order = []
            out_kwargs, order = _get_out_kwargs(
                clustering, n_bins, ci, name, hist, test_size,
                calibrate, max_clusters_bin, min_samples_leaf, n_clusters,
                min_cluster_size, extra_out_kwargs, order)
            out_kwargs['breakout'] = breakout

            metrics = {}
            metrics.update(out_kwargs)
            metrics.update(compute_multi_classif_metrics(_y_scores, _y_labels))
            metrics_binarized = compute_classif_metrics(y_pred_scores, y_well_guess)
            metrics_binarized.update(compute_calib_metrics(frac_pos, counts, y_pred_scores, y_well_guess, bins))
            metrics_binarized = {f'binarized_{k}': v for k, v in metrics_binarized.items()}
            metrics.update(metrics_binarized)
            # metrics['fig_path'] = fig_path

            df = pd.DataFrame([metrics])

            return df

        dfs = Parallel(n_jobs=n_jobs)(
            delayed(compute_one_net)(net, calibrate, best)
            for net, calibrate, best in tqdm(list(itertools.product(nets, [None, 'isotonic'], [False, True])))
        )

        df = pd.concat(dfs, axis=0, ignore_index=True)
        return df

    splits = [
        'test_c:_merged_no_rep5',
        'test_r',
        'test_c:snow5',
        'val',
    ]

    os.makedirs(inp, exist_ok=True)
    df_paths = {split: join(inp, f'metrics_{split}_nspcb{n_samples_per_cluster_per_bin}.csv') for split in splits}
    for split, path in df_paths.items():
        print(split)
        if nocache or not os.path.exists(path):
            df = compute_one(split)
            df.to_csv(path)

        df = pd.read_csv(path, index_col=0)

        # Drop mlpmixer
        df = df.query('network != "mlpmixer"')

        if which == 'small':
            df = df.query('best == False')
        elif which == 'best':
            df = df.query('best == True')
        elif which == 'both':
            pass
        else:
            raise ValueError(f'"{which}" not known.')

        # print(df)

        df_cal = df.query(f'calibrate == "isotonic"')
        df_ncal = df.query(f'calibrate == "None"')

    # dfs = [pd.read_csv(path, index_col=0) for path in df_paths.values()]
    # df = pd.concat(dfs, axis=0, ignore_index=True)
        # print(df)

        ece1 = df_ncal['binarized_msce']
        ece2 = df_cal['binarized_msce']
        glexp1 = df_ncal['binarized_lower_bound_debiased']
        glexp2 = df_cal['binarized_lower_bound_debiased']
        glexpbias1 = df_ncal['binarized_lower_bound_bias']
        glexpbias2 = df_cal['binarized_lower_bound_bias']
        glind1 = df_ncal['binarized_GL_ind']
        glind2 = df_cal['binarized_GL_ind']
        clind1 = df_ncal['binarized_CL_ind']
        clind2 = df_cal['binarized_CL_ind']
        acc1 = df_ncal['acc']
        acc2 = df_cal['acc']
        net_names1 = df_ncal['network+version']
        net_names2 = df_cal['network+version']

        # print(df_ncal['binarized_n_size_one_clusters'])
        # print(df_ncal['binarized_n_nonzero_clusters'])

        print(clind1)
        print(clind2)
        # print(glind1)

        assert list(net_names1) == list(net_names2)
        assert list(acc1) == list(acc2)

        names, versions = zip(*[(n.split(':')+[''])[:2] for n in net_names1])

        # print(names)
        # print(versions)

        networks_by_name = {c.__name__.lower(): c for c in ALL_IMAGENET_NETWORKS}

        # print(networks_by_name)

        rename_names = {n: networks_by_name[n].__name__ for n in names}

        rename_names.update({
            'wideresnet': 'Wide ResNet',
            'resnext': 'ResNeXt',
            'visiontransformer': 'ViT',
        })
        # print(rename_names)

        # rename_versions = {
        #     '19_bn': '-19 BN',
        #     '50': '-50',
        #     '152': '-152',
        #     '161': '-161',
        #     '1_0': ' 1.0',
        #     '0_5': ' 0.5',
        #     'v3L': ' V3L',
        #     # 'v3L': ' V3 Large',
        #     'v2': ' V2',
        #     '101': '-101',
        #     # '101': '',
        #     # '1_0': '',
        #     'b7': '-B7',
        #     'b0': '-B0',
        #     # 'y_32gf': '',
        #     'l_16': ' L-16',
        #     'large': ' Large',
        #     '': '',
        # }

        if append_versions:
            names = [f'{rename_names[n]}{rename_versions.get(v, " "+v.capitalize())}' for n, v in zip(names, versions)]
        else:
            names = [f'{rename_names[n]}' for n, v in zip(names, versions)]

        # rename_names = {
        #     'vgg': 'VGG',
        #     'resnet': 'ResNet',
        #     'densenet': 'DenseNet',
        #     'inception': '',
        #     'googlenet': '',
        #     'shufflenet': '',
        #     'mobilenet': '',
        #     'resnext': '',
        #     'wideresnet': '',
        #     'mnasnet': '',
        #     'efficientnet': '',
        #     'regnet': '',
        #     'visiontransformer': '',
        #     'mlpmixer': '',
        #     'alexnet': '',
        #     'convnext': '',
        # }

        # gl_lower_bound_debiased = np.maximum(df['binarized_lower_bound_debiased'], 0)
        # gl_lower_bound = df['binarized_lower_bound']
        # brier = df['binarized_brier']
        # names = df['name']
        # splits = df['split']
        splits = None

        # kwargs = {'cal': calibrate}
        figsize = (3.5, 6) if which == 'both' else (3.6, 3.4)
        val1 = (ece1, glexp1, glexpbias1, glind1, clind1)
        val2 = (ece2, glexp2, glexpbias2, glind2, clind2)
        fig = barplots_ece_gl_cal(names, val1, val2, acc1,
                                  figsize=figsize,
                                #   loc='upper right', bbox_to_anchor=(1, 1)
                                  loc='lower right', bbox_to_anchor=(1, 0)
                                  )
        # fig.axes[0].set_title(f'ImageNet {split}')
        save_fig(fig, out, n='ece_gl', split=split, which=which,
                 nspcb=n_samples_per_cluster_per_bin, pad_inches=0)


@pytest.mark.parametrize('split', [
    'test_unseen',
    # 'test_seen',
])
@pytest.mark.parametrize('calibrate', [
    None,
    # 'isotonic',
])
def test_fig_nlp(calibrate, out, n_jobs, split):
    """Generate the grouping diagrams of the NLP networks on YahooAnswers (figure 8)."""
    n_bins = 15
    clustering = 'decision_tree'
    test_size = 0.5
    max_clusters_bin = 2
    hist = True
    ci = 'clopper'
    min_cluster_size = 10
    capsize = 3.5
    breakout = False
    cluster_size = 30
    vary_cluster_size = False
    absolute_size_scale = (10, None)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 1
    # legend_size = 10
    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=13)

    net = ZeroShotBartYahoo(split=split)
    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
    y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
    y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

    k = 1  # Plot the positive class

    # y_pred_scores, y_well_guess = binarize_multiclass_marginal(y_scores, y_labels, k)
    splitter = bin_train_test_split(y_scores[:, k], test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
    train_idx, test_idx = next(splitter)

    if calibrate is not None:
        y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                    test_size=test_idx, max_calibration=False)
        # y_labels = y_well_guess

    # if calibrate:
    #     y_scores = calibrate_scores(y_scores, y_labels)

    # k = 1  # Plot the positive class
    class_name = net.get_class_names()[k]
    (frac_pos,
     counts,
     mean_scores,
     *_
     ) = cluster_evaluate_marginals(Xt, y_labels, y_scores,
                                    positive_class=k,
                                    bins=n_bins,
                                    clustering=clustering,
                                    test_size=test_idx,
                                    min_samples_leaf=None,
                                    max_clusters_bin=max_clusters_bin,
                                    verbose=2,
                                    n_jobs=n_jobs)

    # y_pred_scores, y_well_guess = binarize_multiclass_marginal(np.array(y_scores), np.array(y_labels), k)

    fig = plot_frac_pos_vs_scores(frac_pos,
                                    counts,
                                    mean_scores,
                                    y_scores=None,#y_pred_scores,
                                    y_labels=None,#y_well_guess,
                                    ncol=1,
                                    legend_loc='upper left',
                                    bbox_to_anchor=(0, 1),
                                    title=None,#class_name.capitalize(),
                                    xlim_margin=0.05,
                                    ylim_margin=0.05,
                                    min_cluster_size=min_cluster_size,
                                    hist=hist,
                                    ci=ci,
                                    legend_cluster_sizes=True,
                                    vary_cluster_size=vary_cluster_size,
                                    capsize=capsize,
                                    xlabel='Confidence score',
                                    # xlabel='Confidence score of positive class',
                                    # xlabel='Maximum confidence score',
                                    ylabel='Fraction of positives (%)',# if split == 'test_unseen' else None,
                                    # ylabel='Fraction of correct predictions',
                                    cluster_size=cluster_size,
                                    absolute_size_scale=absolute_size_scale,
                                    plot_cal_hist=plot_cal_hist,
                                    figsize=figsize,
                                    legend_n_sizes=legend_n_sizes,
                                    # legend_size=legend_size,
                                    legend_sizes_only=True,
                                    legend_min_max=False,
                                    plot_first_last_bins=False,
                                    grid_space=0,
                                    legend_title='Sizes',
                                    )

    ax = fig.axes[0]
    ax.set_xlim((0, 1))
    ax.set_ylim((0, 1))

    ax.set_xticks([0, 0.5, 1])
    ax.set_xticklabels(['0', '0.5', '1'])
    ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    ax.set_yticklabels(['0', '25', '50', '75', '100'])

    fig.axes[1].get_xaxis().set_visible(False)
    fig.axes[2].get_yaxis().set_visible(False)
    # fig.axes[1].set_xticks([])
    # fig.axes[2].set_yticks([])

    if calibrate is not None or not net.split == 'test_seen':
        ax.set_ylabel(None)
        ax.yaxis.set_ticklabels([])

    out_kwargs = {
        'net': str(net),
        'cal': calibrate,
        'pos': class_name,
        # 'clustering': clustering,
        # 'bins': n_bins,
        # 'ci': ci,
        # 'hist': hist,
        # 'test_size': test_size,
        # 'vsize': vary_cluster_size,
        # # 'capsize': capsize,
        # # 'csize': cluster_size,
        # 'scale': absolute_size_scale,
        # # 'chist': plot_cal_hist,
        # 'split': split,
    }

    order = ['net', 'cal']
    # order = ['net', 'clustering', 'test_size']

    # out_kwargs['max_clusters_bin'] = max_clusters_bin
    # order.append('max_clusters_bin')

    save_fig(fig, out, order=order, **out_kwargs)


# def test_fig_theorem(out):
#     """Generate the intuition figure (fig 2)."""
#     # plt.rc('legend', title_fontsize=10)
#     plt.rc('legend', fontsize=13)
#     # plt.rc('legend', handletextpad=0.01)
#     plt.rc('legend', columnspacing=0.4)
#     plt.rc('legend', borderpad=0.3)
#     plt.rc('legend', borderaxespad=0.2)
#     # plt.rc('legend', handlelength=1.2)
#     plt.rc('legend', labelspacing=0.3)
#     legend_right = True
#     fig = plot_fig_theorem(isoline_right=False, squared=False, legend_right=legend_right)
#     save_fig(fig, out, legend_right=legend_right)


def test_fig_theorem_v2(out):
    """Generate the intuition figure (figure 2)."""
    # plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=13)
    # plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.4)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    # plt.rc('legend', handlelength=1.2)
    plt.rc('legend', labelspacing=0.3)
    legend_right = True
    fig = plot_fig_theorem_v2(isoline_right=False, squared=False, legend_right=legend_right)
    save_fig(fig, out, legend_right=legend_right, pad_inches=0)


def test_fig_counter_example(out):
    """Generate the 1D example accurate + calibrated (figure 1)."""
    fig = plot_fig_counter_example()
    save_fig(fig, out, pad_inches=0)


def test_fig_visualization(out):
    """Generate the explanatory figure on grouping diagrams (figure 5)."""
    n_bins = 10
    n_clusters = 2
    mean_scores = (np.arange(n_bins) + 0.5)/n_bins
    mean_scores = np.tile(mean_scores[:, None], (1, n_clusters))
    # print(mean_scores.shape)
    # return
    counts = np.full((n_bins, n_clusters), 50, dtype=float)


    def f(x, alpha):
        return alpha*x*(x-1) + x

    x_up = mean_scores[:, 0]
    x_down = mean_scores[:, 0]

    mu_up = f(x_up, alpha=-0.3)
    mu_down = f(x_down, alpha=0.9)

    # mu_up = -np.power(x_up, 2) + 2* x_up
    # # mu_up = np.power(x_up, 1/2)
    # mu_down = np.power(x_down, 4)
    frac_pos = np.stack([mu_up, mu_down], axis=1)

    prob_bins = calibration_curve(frac_pos, counts, mean_scores, return_mean_bins=False)

    size_gradient = np.power(np.linspace(0, 100, n_bins), 1/2)
    size_gradient[0] = 1
    counts *= size_gradient[:, None]

    clustering = 'decision_tree'
    test_size = 0.5
    max_clusters_bin = 2
    hist = True
    # ci = None
    ci = 'clopper'
    min_cluster_size = 14
    capsize = 3.5
    breakout = False
    cluster_size = 50#30
    vary_cluster_size = False
    absolute_size_scale = (10, None)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 0
    # legend_size = 10
    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=10.5)#10)#13)
    annotation_size = 15


    xlabel = 'Confidence score'
    ylabel = 'Fraction of positives  '

    fig = plot_frac_pos_vs_scores(frac_pos,
                                    counts,
                                    mean_scores,
                                    y_scores=None,#y_pred_scores,
                                    y_labels=None,#y_well_guess,
                                    ncol=1,
                                    # ncol=3,
                                    legend_loc='upper left',
                                    bbox_to_anchor=(0, 1),
                                    # legend_loc='lower center',
                                    # bbox_to_anchor=(0.5, 1.1),
                                    # legend_loc='upper right' if calibrate else 'upper left',
                                    title=None,#class_name.capitalize(),
                                    xlim_margin=0.05,
                                    ylim_margin=0.05,
                                    min_cluster_size=min_cluster_size,
                                    hist=hist,
                                    ci=ci,
                                    legend_cluster_sizes=False,
                                    vary_cluster_size=vary_cluster_size,
                                    capsize=capsize,
                                    xlabel=None,# of predicted class ($\max_k~S_k$)',
                                    # xlabel='Maximum confidence score',
                                    ylabel=None,
                                    # ylabel='Fraction of correct predictions',
                                    # ylabel='Fraction of correct predictions',
                                    cluster_size=cluster_size,
                                    absolute_size_scale=absolute_size_scale,
                                    plot_cal_hist=plot_cal_hist,
                                    figsize=figsize,
                                    legend_n_sizes=legend_n_sizes,
                                    # legend_size=legend_size,
                                    legend_sizes_only=True,
                                    legend_min_max=False,
                                    plot_first_last_bins=False,
                                    grid_space=0,
                                    legend_title='Sizes',
                                    )
    ax = fig.axes[0]
    ax.set_xlim((0, 1))
    ax.set_ylim((0, 1))

    ax.set_xticks([0, 0.5, 1])
    ax.set_xticklabels(['0', '', '1'])
    # ax.set_xticklabels(['0', '0.5', '1'])
    ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    ax.set_yticklabels(['0', '', '', '', '100\n(%)'])
    # ax.set_yticklabels(['0', '25', '50', '75', '100'])

    fig.axes[1].get_xaxis().set_visible(False)
    fig.axes[2].get_yaxis().set_visible(False)


    delta = 0.06
    ax.annotate(xlabel, xy=(0.5, -delta), xytext=(0.5, -delta),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['axes.labelsize'],
                )
    ax.annotate(ylabel, xy=(-delta, 0.5), xytext=(-delta, 0.5),
                xycoords='axes fraction', ha='right', va='center',
                fontsize=plt.rcParams['axes.labelsize'],
                rotation=90,
                )

    # Position for mu
    i = 4
    j = 0
    mu_x = mean_scores[i, j]
    mu_y = frac_pos[i, j]

    # Position for c
    i = 4
    j = 0
    c_x = mean_scores[i, j]
    c_y = prob_bins[i]

    # Position for p
    i = 4
    j = 0
    p_x = mean_scores[i, j]
    p_y = 1

    ax.annotate(r"$\hat{\mu}^{(s)}_j$",
                xy=(mu_x, mu_y), xycoords='data',
                xytext=(0.29, 0.55), textcoords='data',
                # xytext=(0.215, 0.7), textcoords='data',
                va='bottom',
                ha='right',
                fontsize=annotation_size,
                arrowprops=dict(arrowstyle="->",
                                shrinkB=4,
                                patchA=None,
                                shrinkA=7,
                                connectionstyle="arc3,rad=.22"),
                )
    ax.annotate(r"$\hat{c}^{(s)}$",
                xy=(c_x, c_y), xycoords='data',
                xytext=(0.6, 0.10), textcoords='data',
                va='bottom',
                ha='left',
                fontsize=annotation_size,
                arrowprops=dict(arrowstyle="->",
                                # shrinkA=10,
                                # patchA=None,
                                shrinkB=2,
                                connectionstyle="arc3,rad=-.22"),
                )
    ax.annotate(r"$n^{(s)}$",
    # ax.annotate(r"$\frac{\hat{n}^{(s)}}{n}$",
    # ax.annotate(r"$\hat{p}^{(s)}$",
                xy=(p_x, p_y), xycoords='data',
                xytext=(0.33, 0.8), textcoords='data',
                va='bottom',
                ha='right',
                fontsize=annotation_size,#+2,
                arrowprops=dict(arrowstyle="->",
                                shrinkB=1,
                                patchA=None,
                                shrinkA=10,
                                connectionstyle="arc3,rad=.22"),
                )
    ax.annotate(r"s",
                xy=(p_x, 0), xycoords='data',
                xytext=(p_x, 0), textcoords='data',
                va='bottom',
                ha='center',
                fontsize=annotation_size,
                )

    save_fig(fig, out, pad_inches=0)


@pytest.mark.parametrize('agg', [
    # 'bin',
    # 'mean',
    'weighted_common',
    # 'weighted_each',
])
def test_fig_imagenet_r_renditions_calibration(agg, out, n_jobs):
    """Generate the figures on renditions on imagenet-r (fig 13)"""
    n_bins = 15
    max_clusters_bin = 2
    renditions = [
        'art',
        'cartoon',
        'graffiti',
        'embroidery',
        'graphic',
        'origami',
        'painting',
        'sculpture',
        'tattoo',
        'toy',
        'deviantart',
        'misc',
        'videogame',
        'sketch',
        'sticker',
    ]
    renditions = sorted(renditions)
    n_renditions = len(renditions)

    @memory.cache()
    def compute_one(net, with_type=False, agg='mean'):
        net = net(split='test_r')
        dirpath = net.get_default_dirpath()
        ds = net.get_dataset()

        Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
        y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()
        y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()


        # (_frac_pos,
        # _counts,
        # _mean_scores,
        # *_,
        # ) = cluster_evaluate_max(Xt, y_labels, y_scores,
        #                         breakout=False,
        #                         bins=n_bins,
        #                         clustering=None,
        #                         test_size=None,
        #                         min_samples_leaf=None,
        #                         max_clusters_bin=max_clusters_bin,
        #                         verbose=1,
        #                         n_jobs=n_jobs,
        #                         binary_scores=False,
        #                         return_clustering=False,
        #                         )
        # _prob_bins, _mean_bins = calibration_curve(_frac_pos, _counts, _mean_scores)

        res = []
        for i, rendition in enumerate(renditions):
            print(rendition)
            split = f'test_r:{rendition}'
            # ds = net.get_dataset()
            selection = np.array([bool(re.match(f'.*/{rendition}_.*', s)) for s, _ in ds.imgs])
            net.split = split

            Xt_r = Xt[selection, :]
            y_scores_r = y_scores[selection, :]
            y_labels_r = y_labels[selection]

            (frac_pos,
            counts,
            mean_scores,
            *_,
            ) = cluster_evaluate_max(Xt_r, y_labels_r, y_scores_r,
                                        breakout=False,
                                        bins=n_bins,
                                        clustering=None,
                                        test_size=None,
                                        min_samples_leaf=None,
                                        max_clusters_bin=max_clusters_bin,
                                        verbose=1,
                                        n_jobs=n_jobs,
                                        binary_scores=False,
                                        return_clustering=False,
                                        )

            res.append((frac_pos, counts, mean_scores))

        # Zip results from list of tuples to tuple of lists
        L_frac_pos, L_counts, L_mean_scores = list(zip(*res))

        frac_pos = np.concatenate(L_frac_pos, axis=1)
        counts = np.concatenate(L_counts, axis=1)
        mean_scores = np.concatenate(L_mean_scores, axis=1)

        prob_bins, mean_bins = calibration_curve(frac_pos, counts, mean_scores)

        # assert np.allclose(prob_bins, _prob_bins)
        # assert np.allclose(mean_bins, _mean_bins)

        if agg == 'weighted_each':
            # diff = np.sum(counts*(frac_pos - prob_bins[:, None]), axis=0)/np.sum(counts)
            diff = np.sum(counts*frac_pos, axis=0)/np.sum(counts, axis=0) - np.sum(np.sum(counts, axis=1)*prob_bins/np.sum(counts))

        elif agg == 'weighted_common':
            diff = np.sum(counts*(frac_pos - prob_bins[:, None]), axis=0)/np.sum(counts, axis=0)

        elif agg == 'mean':
            diff = np.mean(frac_pos - prob_bins[:, None], axis=0)

        elif agg == 'bin':
            diff = frac_pos - prob_bins[:, None]

            df = pd.DataFrame({
                'diff': diff.flatten(),
                'rendition': np.tile([s.capitalize() for s in renditions], (n_bins, 1)).flatten(),
                'bin': np.tile(np.arange(n_bins), (1, n_renditions)).flatten(),
                'net': net.get_class_name(with_type=with_type),
            })

            return df

        else:
            raise ValueError(f'Unknown {agg}')

        df = pd.DataFrame({
            'diff': diff,
            'rendition': [s.capitalize() for s in renditions],
            'net': net.get_class_name(with_type=with_type),
        })

        return df

    with_type = True
    # weighted = True
    dfs = []
    for net in ALL_IMAGENET_NETWORKS:
        if net == IMAGENET_MLPMixer:
            continue
    # for net in [IMAGENET_ConvNeXt, IMAGENET_AlexNet, IMAGENET_VGG, IMAGENET_Inception]: #ALL_IMAGENET_NETWORKS[:2]:
        df = compute_one(net, with_type=with_type, agg=agg)
        dfs.append(df)
        # break

    df = pd.concat(dfs, axis=0)

    print(df)

    # Rename networks and versions
    names, versions = zip(*[(n.split(':')+[''])[:2] for n in df['net']])
    networks_by_name = {c.__name__.lower(): c for c in ALL_IMAGENET_NETWORKS}
    rename_names = {n: networks_by_name[n].__name__ for n in names}
    rename_names.update({
        'wideresnet': 'Wide ResNet',
        'resnext': 'ResNeXt',
        'visiontransformer': 'ViT',
    })
    df['net'] = [f'{rename_names[n]}{rename_versions.get(v, " "+v.capitalize())}' for n, v in zip(names, versions)]


    fig = plot_renditions_calibration(df, x='diff', y='rendition', hue='net')
    save_fig(fig, out, with_type=with_type, agg=agg, pad_inches=0.01)


@pytest.mark.parametrize('name', [
    'poly',
    '2x',
    'step4',
    'constant',
])
def test_fig_examples(name, out):
    """Generate the examples of calibrated classifiers (figures 9 and 10)."""
    # n = 10000
    n = 1000000
    figsize = (2, 2.2)

    if name in ['poly', '2x', 'step4']:
        ex = CustomUniform(name=name, dist='gaussian')
        max_samples = 1000
    else:
        ex = CustomUnconstrained(name=name, x_min=-3, x_max=3)
        max_samples = 5000

    # plt.rc('legend', fontsize=10)
    # plt.rc('legend', title_fontsize=12)
    # plt.rc('legend', handletextpad=0.5)
    # plt.rc('legend', columnspacing=1.3)
    # # plt.rc('legend', borderpad=0)
    # plt.rc('legend', borderpad=0.2)
    # plt.rc('legend', borderaxespad=0.2)
    # # plt.rc('legend', borderaxespad=0.1)
    # plt.rc('legend', handlelength=1.5)
    # plt.rc('legend', labelspacing=0.1)
    # plt.rc('xtick', labelsize=9)
    # plt.rc('ytick', labelsize=9)
    # plt.rc('axes', labelsize=12)

    plt.rc('legend', fontsize=10)
    plt.rc('legend', title_fontsize=12)
    plt.rc('legend', handletextpad=0.5)
    plt.rc('legend', columnspacing=1.3)
    # plt.rc('legend', borderpad=0)
    plt.rc('legend', borderpad=0.2)
    plt.rc('legend', borderaxespad=0.2)
    # plt.rc('legend', borderaxespad=0.1)
    plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=9)
    plt.rc('ytick', labelsize=9)
    plt.rc('axes', labelsize=12)

    m = 100
    x_min = -2
    x_max = 2
    XX = np.linspace(x_min, x_max, m)
    Q = ex.f_star(XX)
    S = ex.f(XX)
    P = ex.p(XX)
    fig = plot_ffstar_1d(S, Q, P, x_min=x_min, x_max=x_max, figsize=figsize,
                         lw=1.5)#, bbox_to_anchor=(0, 0), loc='lower left')

    ax = fig.axes[1]
    ax_up = fig.axes[0]
    ax.set_xticks([-2, 2])
    # ax.set_xticks([-2, 0, 2])
    ax.set_xticklabels(['$-2$', '2'])
    # ax.set_xticklabels(['$-2$', '', '2'])
    delta = 0.04
    ax.set_xlabel(None)
    ax.annotate('$X$', xy=(0.5, -delta), xytext=(0.5, -delta),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['axes.labelsize'],
                )
    # ax.set_aspect(1./ax.get_data_ratio())
    # ax_up.set_aspect(ax.get_aspect())
    # ax.set_aspect('square')
    # ax.set_xlim([-2, 2])
    # ax.set_ylim([0, 1])

    save_fig(fig, out, link=name, n='X', order=['link', 'n'], pad_inches=0.02)

    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.2)
    # plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    # plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=0.8)
    # plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=10.5)#10)#13)

    X, y_labels = ex.generate_X_y(n=n)
    # print(y_scores)
    # return
    # y_labels = y_scores >= 0.5
    y_scores = ex.f(X)
    y_true_probas = ex.f_star(X)
    fig = plot_score_vs_probas2(y_scores, y_labels, y_true_probas, max_samples=max_samples,
                                height=2.5, grid_space=0, lim_margin=0.03, ncol=1,
                                plot_first_last_bins=False)

    ax = fig.axes[0]

    xlabel = 'Confidence score $S$'
    ylabel = 'True probability $Q$'
    plot_xlabel = True
    plot_ylabel = True

    ax.set_xlabel(None)
    ax.set_ylabel(None)

    delta = 0.04
    if plot_xlabel:
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(['0', '', '1'])
        ax.annotate(xlabel, xy=(0.5, -delta), xytext=(0.5, -delta),
                    xycoords='axes fraction', ha='center', va='top',
                    fontsize=plt.rcParams['axes.labelsize'],
                    )
    else:
        ax.xaxis.set_ticklabels([])

    delta = 0.02
    if plot_ylabel:
        ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
        ax.set_yticklabels(['0', '', '', '', '1'])
        # ax.set_yticklabels(['0', '25', '50', '75', '100'])
        ax.annotate(ylabel, xy=(-delta, 0.5), xytext=(-delta, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )
    else:
        ax.yaxis.set_ticklabels([])
    save_fig(fig, out, link=name, n='QS', order=['link', 'n'], pad_inches=0.02)


@pytest.mark.parametrize('bayes_opt', [
    False,
    True,
])
@pytest.mark.parametrize('delta_width', [
    3,
    None,
])
def test_fig_examples_realistic(bayes_opt, delta_width, out):
    """Generate the examples of calibrated classifiers based on output
    of neural network (figures 11 and 12)."""
    plt.rc('legend', fontsize=10)
    plt.rc('legend', title_fontsize=12)
    plt.rc('legend', handletextpad=0.5)
    plt.rc('legend', columnspacing=1)
    # plt.rc('legend', borderpad=0)
    plt.rc('legend', borderpad=0.2)
    plt.rc('legend', borderaxespad=0.1)
    # plt.rc('legend', borderaxespad=0.1)
    plt.rc('legend', handlelength=1.6)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=9)
    plt.rc('ytick', labelsize=9)
    plt.rc('axes', labelsize=12)
    plt.rc('axes', titlesize=10)

    random_state = 0
    d = 2
    figsize = (2.2, 2.2)
    max_samples = 1000
    n = 1000000

    rng = check_random_state(random_state)

    w = rng.uniform(size=d)
    w /= np.linalg.norm(w)

    w_perp = create_orthonormal_vector(w)

    ex = SigmoidExample(w, w_perp, bayes_opt=bayes_opt, delta_width=delta_width)

    (fig1,
    fig2,
    fig3,
    fig4,
    ) = plot_ffstar_2d_v2(ex.f, ex.f_1d, ex.psi, ex.delta, ex.delta_max, w,
                          w_perp, ex.mean, ex.cov, trim=True, figsize=figsize)

    ax2 = fig2.axes[1]
    if bayes_opt:
        ax2.set_yticklabels([r'$-\frac{1}{4}$', '0', r'$\frac{1}{4}$'])
    else:
        ax2.set_yticklabels([r'$-\frac{1}{2}$', '0', r'$\frac{1}{2}$'])

    save_fig(fig1, out, f=1, bayes_opt=bayes_opt, delta_width=delta_width, order=['delta_width'], pad_inches=0.01)
    save_fig(fig2, out, f=2, bayes_opt=bayes_opt, delta_width=delta_width, order=['delta_width'], pad_inches=0.01)
    save_fig(fig3, out, f=3, bayes_opt=bayes_opt, delta_width=delta_width, order=['delta_width'], pad_inches=0.01)
    # save_fig(fig4, out, f=4, bayes_opt=bayes_opt, pad_inches=0.01)


    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.2)
    # plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    # plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=0.8)
    # plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=10.5)#10)#13)

    X, y_labels = ex.generate_X_y(n=n)
    # print(y_scores)
    # return
    # y_labels = y_scores >= 0.5
    y_scores = ex.f(X)
    y_true_probas = ex.f_star(X)
    fig = plot_score_vs_probas2(y_scores, y_labels, y_true_probas, max_samples=max_samples,
                                height=2.5, grid_space=0, lim_margin=0.03, ncol=1,
                                plot_first_last_bins=False)

    ax = fig.axes[0]

    xlabel = 'Confidence score $S$'
    ylabel = 'True probability $Q$'
    plot_xlabel = True
    plot_ylabel = True

    ax.set_xlabel(None)
    ax.set_ylabel(None)

    delta = 0.04
    if plot_xlabel:
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(['0', '', '1'])
        ax.annotate(xlabel, xy=(0.5, -delta), xytext=(0.5, -delta),
                    xycoords='axes fraction', ha='center', va='top',
                    fontsize=plt.rcParams['axes.labelsize'],
                    )
    else:
        ax.xaxis.set_ticklabels([])

    delta = 0.02
    if plot_ylabel:
        ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
        ax.set_yticklabels(['0', '', '', '', '1'])
        # ax.set_yticklabels(['0', '25', '50', '75', '100'])
        ax.annotate(ylabel, xy=(-delta, 0.5), xytext=(-delta, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )
    else:
        ax.yaxis.set_ticklabels([])
    save_fig(fig, out, link='', bayes_opt=bayes_opt, n='QS', delta_width=delta_width, order=['delta_width', 'bayes_opt', 'link', 'n'], pad_inches=0.02)
