import os
import re
import shutil
from os.path import join
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import seaborn as sns
import torch
import torchmetrics
from cycler import cycler
from joblib import Memory

from .._linalg import gram_schmidt_orthonormalization
from .._plot import (heterogeneity_plot_wx, plot_2D_softmax, plot_fig_renditions,
                     plot_frac_pos_vs_scores, plot_K_softmax, set_latex_font, plot_grouping_bounds)
from .._utils import save_fig
from ._utils import (binarize_multiclass_max, calibrate_scores, binarize_multiclass_marginal,
                     calibration_curve, compute_calib_metrics, bin_train_test_split,
                     compute_multi_classif_metrics, merge_dataset_sublabels)
from .main import cluster_evaluate_marginals, cluster_evaluate_max, evaluate
from .networks import (ALL_IMAGENET_NETWORKS, IMAGENET_VGG,
                       IMAGENET_AlexNet, IMAGENET_MLPMixer)
from .results import cluster_marginals, cluster_max


memory = Memory('joblib_cache')


def test_vgg():

    net = IMAGENET_VGG()
    print(net.model)
    Xt, y_scores, y_labels = net.get_Xt_y_scores_labels(batch_size=100, n_batch_max=50)
    print(net.get_w().shape)
    print(net.get_intercept().shape)
    print(Xt.shape)


# def test_vggcifar10():

#     net = CIFAR10_VGG()
#     print(net.model)
#     Xt, y_scores, y_labels = net.get_Xt_y_scores_labels(batch_size=100, n_batch_max=50)
#     print(net.get_w().shape)
#     print(net.get_intercept().shape)
#     print(Xt.shape)
#     print(y_scores.shape)
#     print(y_labels.shape)
#     print(y_labels)
#     print(torch.unique(y_labels))


# def test_heterogeneity_vggcifar10():
#     net = CIFAR10_VGG()
#     Xt, y_scores, y_labels = net.get_Xt_y_scores_labels(batch_size=100)

#     # print(Xt.shape)
#     # exit()

#     print(torch.unique(y_labels))

#     positive_class = 1  # Class #1 (automobile is chosen positive)
#     print(y_labels)
#     y_labels = (y_labels == positive_class).int()
#     print(y_labels)

#     w = net.get_w()
#     w = w[positive_class, :]
#     print(w.shape)

#     P = np.eye(w.shape[0])
#     P[:, 0] = w
#     P = gram_schmidt_orthonormalization(P)

#     w_perp_space = np.linspace(-10, 10, 10)
#     X_test, y_test = evaluate(Xt, y_labels, P, w_space=0.5, n_neighbors=100,
#                               alpha=10, w_perp_space=w_perp_space)

#     print(X_test.shape)
#     print(y_test.shape)

#     y_scores = net.last_layer(torch.tensor(X_test.reshape(-1, X_test.shape[-1])).float()).detach()
#     print(y_scores.shape)

#     # y_scores_softmax = torch.nn.functional.softmax(y_scores, dim=1)
#     # print(y_scores_softmax.shape)
#     # y_scores = y_scores_softmax[:, positive_class]

#     # y_scores = torch.special.expit(y_scores[:, positive_class])/10
#     # y_scores = y_scores.reshape(X_test.shape[:-1])

#     y_scores = None
#     # exit()

#     # y_test *= 10

#     fig = heterogeneity_plot_wx(X_test, y_test, P, y_scores=y_scores, n_var_max=10)
#     fig.savefig('test6.pdf', bbox_inches='tight')


def test_2D_softmax(out):
    # print('lol')
    # exit()

    d = 4
    rs = np.random.RandomState(0)
    W = rs.uniform(-1, 1, size=(d, 2))
    # B = rs.uniform(-0.5, 0.5, size=(d,))
    B = rs.uniform(-0.1, 0.1, size=(d,))
    # B = np.zeros(W.shape[0])

    # W /= np.linalg.norm(W, axis=1)[:, None]
    # exit()
    fig = plot_2D_softmax(W, B)
    # fig.savefig('test14.pdf', bbox_inches='tight')
    save_fig(fig, out)


def test_K_softmax(out):
    fig = plot_K_softmax()
    # fig.savefig('test110.pdf', bbox_inches='tight')
    save_fig(fig, out)


# @pytest.mark.parametrize('positive_class', range(10))
# def test_kmeans_on_vgg_cifar10(positive_class, out):
#     net = CIFAR10_VGG()
#     Xt, y_scores, y_labels = net.get_Xt_y_scores_labels(batch_size=100)

#     (frac_pos,
#      counts,
#      mean_scores,
#      ) = cluster_evaluate_marginals(Xt, y_labels, y_scores, positive_class,
#                                     bins=15, n_clusters=3)

#     class_name = net.get_class_names()[positive_class]
#     fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
#                                   legend_loc='upper left',
#                                 #   bbox_to_anchor=(0.5, 1.01),
#                                   title=class_name.capitalize(),
#                                   xlim_margin=0.05,
#                                   ylim_margin=0.05,
#                                   )

#     save_fig(fig, out, positive_class=positive_class)


# def test_kmeans_on_vgg_cifar10_max(out):
#     net = CIFAR10_VGG()
#     ci = 'binomtest'
#     Xt, y_scores, y_labels = net.get_Xt_y_scores_labels(batch_size=100)

#     (frac_pos,
#      counts,
#      mean_scores,
#      ) = cluster_evaluate_max(Xt, y_labels, y_scores,
#                               bins=15, n_clusters=3)

#     frac_pos = frac_pos.reshape(frac_pos.shape[0], -1)
#     counts = counts.reshape(counts.shape[0], -1)
#     mean_scores = mean_scores.reshape(mean_scores.shape[0], -1)

#     fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
#                                   legend_loc='upper left',
#                                   xlim_margin=0.05,
#                                   ylim_margin=0.05,
#                                   min_cluster_size=10,
#                                   title='All classes',
#                                   hist=True,
#                                   ci=ci,
#                                   )

#     save_fig(fig, out, ci=ci)


# def test_kmeans_on_vgg_imagenet(out, n_jobs):
#     net = IMAGENET_VGG()
#     dirpath = net.get_default_dirpath()
#     Xt = torch.load(join(dirpath, 'Xt.pt'))
#     y_scores = torch.load(join(dirpath, 'y_scores.pt'))
#     y_labels = torch.load(join(dirpath, 'y_labels.pt'))

#     (frac_pos,
#      counts,
#      mean_scores,
#      ) = cluster_evaluate_marginals(Xt, y_labels, y_scores, positive_class=None,
#                                     bins=15, n_clusters=3, verbose=1, n_jobs=n_jobs)

#     torch.save(frac_pos, join(dirpath, 'frac_pos.pt'))
#     torch.save(counts, join(dirpath, 'counts.pt'))
#     torch.save(mean_scores, join(dirpath, 'mean_scores.pt'))

#     for k in range(frac_pos.shape[2]):
#         class_name = net.get_class_names()[k]
#         fig = plot_frac_pos_vs_scores(frac_pos[..., k],
#                                       counts[..., k],
#                                       mean_scores[..., k],
#                                       ncol=1,
#                                       legend_loc='upper left',
#                                       #   bbox_to_anchor=(0.5, 1.01),
#                                       title=class_name.capitalize(),
#                                       xlim_margin=0.05,
#                                       ylim_margin=0.05,
#                                       )

#         save_fig(fig, out, positive_class=class_name)


# @pytest.mark.parametrize('breakout', [True, False])
# @pytest.mark.parametrize('n_clusters', [2, 3, 10, 50, 100])
# @pytest.mark.parametrize('n_bins', [15])
# def test_kmeans_on_vgg_imagenet_max(breakout, n_clusters, n_bins, out, n_jobs):
#     net = IMAGENET_VGG()
#     ci = 'binomtest'
#     dirpath = net.get_default_dirpath()
#     Xt = torch.load(join(dirpath, 'Xt.pt'))
#     y_scores = torch.load(join(dirpath, 'y_scores.pt'))
#     y_labels = torch.load(join(dirpath, 'y_labels.pt'))

#     (frac_pos,
#      counts,
#      mean_scores,
#      ) = cluster_evaluate_max(Xt, y_labels, y_scores, breakout=breakout,
#                               bins=n_bins, n_clusters=n_clusters, verbose=1,
#                               n_jobs=n_jobs, clustering='kmeans')

#     fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
#                                   legend_loc='upper left',
#                                   xlim_margin=0.05,
#                                   ylim_margin=0.05,
#                                   min_cluster_size=10,
#                                   title='All classes',
#                                   hist=True,
#                                   ci=ci,
#                                   )

#     save_fig(fig, out, breakout=breakout, n_clusters=n_clusters, n_bins=n_bins,
#              ci=ci)


# @pytest.mark.parametrize('breakout', [False])
# @pytest.mark.parametrize('n_bins', [15])
# # @pytest.mark.parametrize('min_samples_leaf', [500, 250, 750])
# @pytest.mark.parametrize('min_samples_leaf', [100])
# # @pytest.mark.parametrize('min_samples_leaf', [750])
# # @pytest.mark.parametrize('min_samples_leaf', [10, 20, 100, 1000])
# def test_tree_on_vgg_imagenet_max(breakout, n_bins, out, n_jobs,
#                                   min_samples_leaf):
#     net = IMAGENET_VGG()
#     ci = 'clopper'
#     dirpath = net.get_default_dirpath()
#     Xt = torch.load(join(dirpath, 'Xt.pt'))
#     y_scores = torch.load(join(dirpath, 'y_scores.pt'))
#     y_labels = torch.load(join(dirpath, 'y_labels.pt'))

#     (frac_pos,
#      counts,
#      mean_scores,
#      ) = cluster_evaluate_max(Xt, y_labels, y_scores, breakout=breakout,
#                               bins=n_bins, verbose=2,
#                               n_jobs=8,
#                               min_samples_leaf=min_samples_leaf,
#                               clustering='decision_tree',
#                               )

#     fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
#                                   legend_loc='upper left',
#                                   xlim_margin=0.05,
#                                   ylim_margin=0.05,
#                                   min_cluster_size=10,
#                                   title='All classes',
#                                   hist=True,
#                                   ci=ci,
#                                   )

#     save_fig(fig, out, breakout=breakout, n_bins=n_bins,
#              min_samples_leaf=min_samples_leaf, ci=ci,)


def test_forward_dump_network(net, n_jobs, batch_size=1):
    """Forward and dump dataset for a network."""
    net.forward_dataset(batch_size=batch_size, n_batch_max=None, verbose=1,
                        n_jobs=n_jobs, skip_exist=True)
    net.load_forwarded_dataset(batch_size=batch_size, n_batch_max=None,
                               verbose=1, dump=True, n_jobs=n_jobs,
                               skip_exist=True)


def _cluster_marginals(net, out, clustering='decision_tree',
                       n_bins=15, n_jobs=1, min_samples_leaf=None, n_clusters=3,
                       ci='clopper', hist=True, min_cluster_size=10,
                       test_size=None, max_clusters_bin=2):
    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt'))
    y_scores = torch.load(join(dirpath, 'y_scores.pt'))
    y_labels = torch.load(join(dirpath, 'y_labels.pt'))

    (frac_pos,
     counts,
     mean_scores,
     *_
     ) = cluster_evaluate_marginals(Xt, y_labels, y_scores,
                                    positive_class=None,
                                    bins=n_bins,
                                    clustering=clustering,
                                    n_clusters=n_clusters,
                                    test_size=test_size,
                                    min_samples_leaf=min_samples_leaf,
                                    max_clusters_bin=max_clusters_bin,
                                    verbose=1,
                                    n_jobs=n_jobs)

    torch.save(frac_pos, join(dirpath, 'frac_pos.pt'))
    torch.save(counts, join(dirpath, 'counts.pt'))
    torch.save(mean_scores, join(dirpath, 'mean_scores.pt'))

    for k in range(frac_pos.shape[2]):
        class_name = net.get_class_names()[k]

        y_pred_scores, y_well_guess = binarize_multiclass_marginal(np.array(y_scores), np.array(y_labels), k)

        fig = plot_frac_pos_vs_scores(frac_pos[..., k],
                                      counts[..., k],
                                      mean_scores[..., k],
                                      y_scores=y_pred_scores,
                                      y_labels=y_well_guess,
                                      ncol=1,
                                      legend_loc='upper left',
                                      #   bbox_to_anchor=(0.5, 1.01),
                                      title=class_name.capitalize(),
                                      xlim_margin=0.05,
                                      ylim_margin=0.05,
                                      min_cluster_size=min_cluster_size,
                                      hist=hist,
                                      ci=ci,
                                      )

        out_kwargs = {
            'clustering': clustering,
            'n_bins': n_bins,
            'ci': ci,
            'net': str(net),
            'hist': hist,
            'test_size': test_size,
            'positive_class': class_name,
        }

        order = ['net', 'clustering', 'test_size']

        if clustering == 'decision_tree':
            out_kwargs['min_samples_leaf'] = min_samples_leaf
            out_kwargs['max_clusters_bin'] = max_clusters_bin
            order.append('min_samples_leaf')

        elif clustering == 'kmeans':
            out_kwargs['n_clusters'] = n_clusters
            out_kwargs['min_cluster_size'] = min_cluster_size
            order.append('n_clusters')

        save_fig(fig, out, order=order, **out_kwargs)


def __cluster_max(net, clustering='decision_tree', breakout=False,
                  n_bins=15, n_jobs=1, min_samples_leaf=100,
                  max_clusters_bin=None, n_clusters=3,
                  test_size=None, calibrate=False, verbose=2, return_Xy=False):
    dirpath = net.get_default_dirpath()
    Xt = np.array(torch.load(join(dirpath, 'Xt.pt')))
    y_scores = np.array(torch.load(join(dirpath, 'y_scores.pt')))
    y_labels = np.array(torch.load(join(dirpath, 'y_labels.pt')))

    # Xt = Xt[:100, :]
    # y_scores = y_scores[:100]
    # y_labels = y_labels[:100]

    if clustering is None:
        _clustering = clustering

    elif clustering == 'class':
        y_labels_pred = np.argmax(np.array(y_scores), axis=1)
        _clustering = (y_labels_pred).astype(int)

    elif 'spectral' in clustering or 'dbscan' in clustering:
        r = re.match('(([a-z]*):)?(.*)', clustering)
        similarity = r.group(2) if r.group(2) is not None else 'path'
        y_labels_pred = np.argmax(np.array(y_scores), axis=1)
        meta_classes = net.get_meta_classes(similarity=similarity,
                                            clustering=r.group(3))
        _clustering = np.array([meta_classes[k] for k in y_labels_pred])

    else:
        _clustering = clustering

    # Calibrate the classifier's confidence
    if calibrate:
        y_scores = calibrate_scores(y_scores, y_labels, method=calibrate)

    (frac_pos,
     counts,
     mean_scores,
     *_
     ) = cluster_evaluate_max(Xt, y_labels, y_scores, breakout=breakout,
                              bins=n_bins, verbose=verbose,
                              n_jobs=n_jobs,
                              min_samples_leaf=min_samples_leaf,
                              max_clusters_bin=max_clusters_bin,
                              clustering=_clustering,
                              n_clusters=n_clusters,
                              test_size=test_size,
                              )

    if return_Xy:
        return frac_pos, counts, mean_scores, Xt, y_scores, y_labels

    return frac_pos, counts, mean_scores


def _cluster_max(net, out, clustering='decision_tree', breakout=False,
                 n_bins=15, n_jobs=1, min_samples_leaf=100, n_clusters=3,
                 ci='clopper', hist=True, min_cluster_size=10,
                 max_clusters_bin=None, test_size=None,
                 calibrate=False,
                 **kwargs):
    """Helper function to plot frac_pos in clusters."""
    r = __cluster_max(
        net=net,
        clustering=clustering,
        breakout=breakout,
        n_bins=n_bins,
        n_jobs=n_jobs,
        min_samples_leaf=min_samples_leaf,
        max_clusters_bin=max_clusters_bin,
        n_clusters=n_clusters,
        test_size=test_size,
        calibrate=calibrate,
        return_Xy=True,
        )
    frac_pos, counts, mean_scores, Xt, y_scores, y_labels = r

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)

    fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores,
                                #   y_scores=y_pred_scores,
                                #   y_labels=y_well_guess,
                                  ncol=1,
                                  legend_loc='upper left',
                                  xlim_margin=0.05,
                                  ylim_margin=0.05,
                                  min_cluster_size=min_cluster_size,
                                  title=None,
                                  hist=hist,
                                  ci=ci,
                                  mean_only=False,
                                  xlabel='Maximum confidence score',
                                  ylabel='Fraction of correct predictions',
                                  )

    out_kwargs = {
        'clustering': clustering,
        'breakout': breakout,
        'n_bins': n_bins,
        'ci': ci,
        'net': str(net),
        'hist': hist,
        'test_size': test_size,
        'calibrate': calibrate,
        'max_clusters_bin': max_clusters_bin,
    }

    order = ['net', 'clustering', 'test_size']

    if clustering == 'decision_tree':
        out_kwargs['min_samples_leaf'] = min_samples_leaf
        order.append('min_samples_leaf')

    elif clustering == 'kmeans':
        out_kwargs['n_clusters'] = n_clusters
        out_kwargs['min_cluster_size'] = min_cluster_size
        order.append('n_clusters')

    save_fig(fig, out, order=order, **out_kwargs)


@pytest.mark.parametrize('max_clusters_bin', [2, 3, 4, 5, 7, 10, 15, 20, 25, 50])
@pytest.mark.parametrize('test_size', [0.1, 0.25, 0.5, 0.75, 0.9])
def test_cluster_tree_max(net, out, n_jobs, max_clusters_bin, test_size):
    _cluster_max(
        net=net,
        clustering='decision_tree',
        min_samples_leaf=None,
        max_clusters_bin=max_clusters_bin,
        n_jobs=n_jobs,
        out=out,
        n_bins=15,
        test_size=test_size,
    )


# @pytest.mark.parametrize('min_samples_leaf', [1, 5, 10, 25, 50, 75, 100, 250, 500, 750])
@pytest.mark.parametrize('max_clusters_bin', [2, 3, 4, 5, 7, 10, 15, 20, 25, 50])
@pytest.mark.parametrize('test_size', [0.1, 0.25, 0.5, 0.75, 0.9])
def test_cluster_tree_marginals(net, out, n_jobs, max_clusters_bin, test_size):
    _cluster_marginals(
        net=net,
        clustering='decision_tree',
        max_clusters_bin=max_clusters_bin,
        n_jobs=n_jobs,
        out=out,
        n_bins=15,
        test_size=test_size,
    )


@pytest.mark.parametrize('min_samples_leaf', [100])
def test_cluster_tree_max_train(net, out, n_jobs, min_samples_leaf):
    return test_cluster_tree_max(net, out, n_jobs, min_samples_leaf)


@pytest.mark.parametrize('calibrate', [None])
@pytest.mark.parametrize('n_clusters', [2, 3, 4, 5, 10, 20])
def test_cluster_kmeans_max(net, out, n_jobs, n_clusters, calibrate):
    _cluster_max(
        net=net,
        clustering='kmeans',
        n_clusters=n_clusters,
        n_jobs=n_jobs,
        out=out,
        n_bins=15,
        calibrate=calibrate,
    )


@pytest.mark.parametrize('n_clusters', [2, 3, 4, 5, 10, 100, 250, 500])
def test_cluster_kmeans_marginals(net, out, n_jobs, n_clusters):
    _cluster_marginals(
        net=net,
        clustering='kmeans',
        n_clusters=n_clusters,
        n_jobs=n_jobs,
        out=out,
        n_bins=15,
    )


@pytest.mark.parametrize('n_clusters', [100])
def test_cluster_kmeans_max_train(net, out, n_jobs, n_clusters):
    return test_cluster_kmeans_max(net, out, n_jobs, n_clusters)


@pytest.mark.parametrize('clustering', ['spectral2', 'spectral3', 'spectral5', 'spectral10', 'spectral25', 'spectral50', 'spectral100'])
@pytest.mark.parametrize('similarity', ['path', 'lch', 'wup'])
def test_cluster_spectral_max(net, out, n_jobs, clustering, similarity):
    _cluster_max(
        net=net,
        clustering=f'{similarity}:{clustering}',
        n_jobs=n_jobs,
        out=out,
        n_bins=15,
        test_size=None,
    )


def test_cluster_class_max(net, out, n_jobs):
    _cluster_max(
        net=net,
        clustering='class',
        n_jobs=n_jobs,
        out=out,
        n_bins=15,
        test_size=None,
    )


def test_samples():
    sublabels = [
        'art',
        'cartoon',
        'graffiti',
        'embroidery',
        'graphic',
        'origami',
        'painting',
        'sculpture',
        'tattoo',
        'toy',
        'deviantart',
        'misc',
        'videogame',
        'sketch',
        'sticker',
    ]
    c = 0
    for sublabel in sublabels:
        net = IMAGENET_AlexNet(split=f'test_r_{sublabel}')
        try:
            ds = net.get_dataset()
            print(f'\nSublabel {sublabel} done')
            print(f'Dataset:\t{net.get_dataset_name()}')
            print(f'N samples:\t{len(ds)}')
            c += len(ds)
        except ValueError:
            print(f'\nSublabel {sublabel} not found.')

    net = IMAGENET_AlexNet(split='test_r')
    ds = net.get_dataset()
    assert c == len(ds)


def test_imagenet_r_sublabels(net, out, n_jobs):

    sublabels = [
        'art',
        'cartoon',
        'graffiti',
        'embroidery',
        'graphic',
        'origami',
        'painting',
        'sculpture',
        'tattoo',
        'toy',
        'deviantart',
        'misc',
        'videogame',
        'sketch',
        'sticker',
    ]
    sublabels = sorted(sublabels)

    clustering = None
    breakout = False
    n_bins = 15
    n_jobs = n_jobs
    calibrate = False
    hist = False
    ci = 'clopper'

    split = net.split
    fig = None

    frac_pos, counts, mean_scores = __cluster_max(
        net=net,
        clustering=clustering,
        breakout=breakout,
        n_bins=n_bins,
        n_jobs=n_jobs,
        calibrate=calibrate,
    )

    custom_cycler = cycler(color=['black'])
    fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
                                  legend_loc='upper left',
                                  xlim_margin=0.05,
                                  ylim_margin=0.05,
                                  min_cluster_size=None,
                                  title=None,
                                  hist=False,
                                  ci=ci,
                                  mean_only=True,
                                  ax=None,
                                  mean_label=None,
                                  color_cycler=custom_cycler,
                                  )

    custom_cycler = cycler(color=sns.color_palette('husl', len(sublabels)).as_hex())
    ax = fig.axes[0]

    for i, sublabel in enumerate(sublabels):
        net.split = f'{split}:{sublabel}'

        frac_pos, counts, mean_scores = __cluster_max(
            net=net,
            clustering=clustering,
            breakout=breakout,
            n_bins=n_bins,
            n_jobs=n_jobs,
            calibrate=calibrate,
        )

        if i > 0:
            custom_cycler = None

        fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
                                      legend_loc='upper left',
                                      xlim_margin=0.05,
                                      ylim_margin=0.05,
                                      min_cluster_size=None,
                                      title=None,
                                      hist=False,
                                      ci=ci,
                                      mean_only=True,
                                      ax=ax,
                                      mean_label=sublabel.capitalize(),
                                      color_cycler=custom_cycler,
                                      )
        ax = fig.axes[0]

    ax.legend(loc='upper left', fontsize='small')

    # For file naming
    net.split = split

    out_kwargs = {
        'n_bins': n_bins,
        'ci': ci,
        'net': str(net),
        'hist': hist,
        'calibrate': calibrate,
    }

    order = ['net']

    save_fig(fig, out, order=order, **out_kwargs)


def test_imagenet_r_sublabels_diff(out, n_jobs, with_type=False):
    set_latex_font()
    sublabels = [
        'art',
        'cartoon',
        'graffiti',
        'embroidery',
        'graphic',
        'origami',
        'painting',
        'sculpture',
        'tattoo',
        'toy',
        'deviantart',
        'misc',
        'videogame',
        'sketch',
        'sticker',
    ]
    sublabels = sorted(sublabels)
    n_sublabels = len(sublabels)

    clustering = None
    breakout = False
    n_bins = 15
    n_jobs = n_jobs
    calibrate = False

    def compute_diff_one(net):
        net = net()
        res = []
        for i, sublabel in enumerate(sublabels):
            print(sublabel)
            net.split = f'test_r:{sublabel}'

            r = __cluster_max(
                net=net,
                clustering=clustering,
                breakout=breakout,
                n_bins=n_bins,
                n_jobs=n_jobs,
                calibrate=calibrate,
                verbose=0
            )
            res.append(r)

        # Zip results from list of tuples to tuple of lists
        L_frac_pos, L_counts, L_mean_scores = list(zip(*res))

        frac_pos = np.concatenate(L_frac_pos, axis=1)
        counts = np.concatenate(L_counts, axis=1)
        mean_scores = np.concatenate(L_mean_scores, axis=1)

        prob_bins, mean_bins = calibration_curve(frac_pos, counts, mean_scores)

        diff = np.mean(frac_pos - prob_bins[:, None], axis=0)

        df = pd.DataFrame({
            'diff': diff,
            'sublabels': [s.capitalize() for s in sublabels],
            'net': net.get_class_name(with_type=with_type),
        })

        print(df)

        return df

    dfs = []
    for net in ALL_IMAGENET_NETWORKS:
        df = compute_diff_one(net)
        dfs.append(df)

    df = pd.concat(dfs, axis=0)

    g = sns.catplot(data=df, x='diff', y='sublabels', hue='net')
    fig = g.figure
    ax = fig.axes[0]

    ax.axvline(0, color='darkgray', lw=1, zorder=0)
    xmin, xmax = ax.get_xlim()
    xabs = max(abs(xmin), abs(xmax))
    ax.set_xlim((-xabs, xabs))
    ax.set_ylim((n_sublabels-0.5, 0.5))

    # Add gray layouts in the background every other rows
    for k in range(1, n_sublabels, 2):
        ax.axhspan(k-0.5, k+0.5, color='.93', zorder=-1)

    ax.set_xlabel('Sublabel calibration $-$ Mean calibration')
    ax.set_ylabel('Sublabels')
    g.legend.set_title('Network')
    save_fig(fig, out, with_type=with_type)


@pytest.mark.parametrize('net', ALL_IMAGENET_NETWORKS)
def test_imagenet_r_sublabels_clustering(net, out, n_jobs):
    net = net()
    set_latex_font()
    sublabels = [
        'art',
        'cartoon',
        'graffiti',
        'embroidery',
        'graphic',
        'origami',
        'painting',
        'sculpture',
        'tattoo',
        'toy',
        'deviantart',
        'misc',
        'videogame',
        'sketch',
        'sticker',
    ]
    sublabels = sorted(sublabels)
    n_sublabels = len(sublabels)

    clustering = None
    breakout = False
    n_bins = 15
    n_jobs = n_jobs
    calibrate = False
    hist = True
    min_cluster_size = 10
    ci = 'clopper'

    res = []
    for i, sublabel in enumerate(sublabels):
        print(sublabel)
        net.split = f'test_r:{sublabel}'

        r = __cluster_max(
            net=net,
            clustering=clustering,
            breakout=breakout,
            n_bins=n_bins,
            n_jobs=n_jobs,
            calibrate=calibrate,
            verbose=0
        )
        res.append(r)

    # Zip results from list of tuples to tuple of lists
    L_frac_pos, L_counts, L_mean_scores = list(zip(*res))

    frac_pos = np.concatenate(L_frac_pos, axis=1)
    counts = np.concatenate(L_counts, axis=1)
    mean_scores = np.concatenate(L_mean_scores, axis=1)

    fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, ncol=1,
                                  legend_loc='upper left',
                                  xlim_margin=0.05,
                                  ylim_margin=0.05,
                                  min_cluster_size=min_cluster_size,
                                  title=None,
                                  hist=hist,
                                  ci=ci,
                                  mean_only=False,
                                  )

    out_kwargs = {
        'clustering': clustering,
        'breakout': breakout,
        'n_bins': n_bins,
        'ci': ci,
        'net': str(net),
        'hist': hist,
        'calibrate': calibrate,
    }

    order = ['net', 'clustering']

    save_fig(fig, out, order=order, **out_kwargs)


def test_accuracy(net):
    dirpath = net.get_default_dirpath()
    y_scores = torch.load(join(dirpath, 'y_scores.pt'))
    y_labels = torch.load(join(dirpath, 'y_labels.pt'))

    metric = torchmetrics.Accuracy()
    acc = metric(torch.argmax(y_scores, dim=1), y_labels).item()
    print(f'\n{net}\tAcc={100*acc:.1f}%')


def test_merge_imagenet_c():
    sublabels = [
        'snow',
        'spatter',
        'shot_noise',
        'speckle_noise',
        'zoom_blur',
        'saturate',
        'pixelate',
        'motion_blur',
        'jpeg_compression',
        'impulse_noise',
        'gaussian_blur',
        'gaussian_noise',
        'glass_blur',
        'frost',
        'fog',
        'elastic_transform',
        'defocus_blur',
        'brightness',
        'contrast',
    ]

    ds_path = 'datasets/imagenet-c/'
    merge_dataset_sublabels(ds_path, sublabels, verbose=1)


@pytest.mark.parametrize('severity', range(1, 6))
def test_merge_imagenet_c_no_rep(severity):
    corruptions = [
        'snow',
        'spatter',
        'shot_noise',
        'speckle_noise',
        'zoom_blur',
        'saturate',
        'pixelate',
        'motion_blur',
        'jpeg_compression',
        'impulse_noise',
        'gaussian_blur',
        'gaussian_noise',
        'glass_blur',
        'frost',
        'fog',
        'elastic_transform',
        'defocus_blur',
        'brightness',
        'contrast',
    ]

    n_corruptions = len(corruptions)

    net = IMAGENET_AlexNet(split='val')
    ds = net.get_dataset()
    n_samples = len(ds)

    q = n_samples//n_corruptions
    r = n_samples % n_corruptions
    n_samples_per_corruption = np.full(n_corruptions, q, dtype=int)
    n_samples_per_corruption[:r] += 1

    assert np.all(n_samples_per_corruption > 0)
    assert np.sum(n_samples_per_corruption) == n_samples

    corruption_idx = np.repeat(np.arange(n_corruptions), n_samples_per_corruption)

    assert len(corruption_idx) == n_samples

    rs = np.random.RandomState(0)
    corruption_idx = rs.permutation(corruption_idx)

    uniques, counts = np.unique(corruption_idx, return_counts=True)

    assert np.sum(counts) == n_samples
    assert np.array_equal(uniques, np.arange(n_corruptions))

    root = join(ds.root, 'val')

    out_rootpath = f'datasets/imagenet-c/_merged_no_rep/{severity}/'

    for i, (imgpath, _) in enumerate(ds.imgs):
        relpath = os.path.relpath(imgpath, root)
        corruption_id = corruption_idx[i]
        corruption_name = corruptions[corruption_id]

        in_imgpath = join(f'datasets/imagenet-c/{corruption_name}/{severity}/', relpath)
        subfolders, imgname = os.path.split(relpath)
        out_dirpath = join(out_rootpath, subfolders)
        os.makedirs(out_dirpath, exist_ok=True)
        out_imgpath = join(out_dirpath, f'{corruption_name}_{imgname}')

        print(f'{i+1}/{n_samples}: Copying {in_imgpath} to {out_imgpath}...')

        shutil.copyfile(in_imgpath, out_imgpath)


MIMIC_XP_PATHS = [
    'mimic/2022-04-15-17-24-15_los_prediction_horizon7',
    'mimic/2022-04-21-16-10-54_los_prediction_quantile10_discharged',
    'mimic/2022-04-21-16-10-54_los_prediction_horizon3_discharged',
    'mimic/2022-04-21-16-10-54_los_prediction_quantile4_discharged',
    'mimic/2022-04-21-16-10-54_los_prediction_horizon3_discharged_or_deceased'  ,
    'mimic/2022-04-21-16-10-54_los_prediction_quantile4_discharged_or_deceased',
]


def read_los_prediction_results(path2experience: Path):
    experience_lists = []
    for ff in path2experience.iterdir():
        if str(ff).endswith('log.csv'):
            run_id = ff.stem.split("_")[0]
            experience_lists.append({
                "log":
                pd.read_csv(path2experience / f"{run_id}_log.csv"),
                "X":
                np.load(path2experience / f"{run_id}_X.npy"),
                "X_imputed":
                np.load(path2experience / f"{run_id}_X_imputed.npy"),
                "y":
                np.load(path2experience / f"{run_id}_y.npy"),
                "y_pred":
                np.load(path2experience / f"{run_id}_y_pred.npy"),
                "stratification_features":
                pd.read_csv(path2experience /
                            f"{run_id}_stratification_features.csv"),
            })
    return experience_lists


def read_los_prediction_results2(path2experience: Path):
    experience_dict = {}
    for ff in path2experience.iterdir():
        if re.search("log.csv", str(ff)) is not None:
            model_log = pd.read_csv(path2experience / f"{ff.stem}.csv").to_dict(
                orient="index"
            )[0]
            model_name = model_log.get("estimator_name")
            run_id = model_log.get("run_id")

            split_rs = model_log.get("split_rs", np.nan)
            if not np.isnan(split_rs):
                run_str = (
                    f"estimator={model_name}__split_rs={split_rs}__run_id={run_id}"
                )
            else:
                run_str = run_id
            key_run = f"estimator={model_name}__split_rs={split_rs}"
            experience_dict[key_run] = {
                "log": model_log,
                "X": np.load(path2experience / f"{run_str}_X.npy"),
                "X_imputed": np.load(path2experience / f"{run_str}_X_imputed.npy"),
                "y": np.load(path2experience / f"{run_str}_y.npy"),
                "y_pred": np.load(path2experience / f"{run_str}_y_pred.npy"),
                "stratification_features": pd.read_csv(
                    path2experience / f"{run_str}_stratification_features.csv"
                ),
            }

    return experience_dict


def test_mimic(out, n_jobs):
    dirpath = 'mimic_dist'

    calibrate = None #'isotonic'
    test_size = 0.5 if calibrate is not None else None
    n_bins = 15
    max_clusters_bin = 2
    clustering = 'na_2'
    hist = True
    ci = 'clopper'
    min_cluster_size = 10
    capsize = 3.5
    breakout = False
    cluster_size = 30
    vary_cluster_size = False
    absolute_size_scale = (10, None)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 1
    plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.01)
    plt.rc('legend', columnspacing=0.02)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.2)
    # plt.rc('legend', handlelength=1)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=10)
    plt.rc('ytick', labelsize=10)
    plt.rc('axes', labelsize=13)

    root, folders, _ = next(os.walk(dirpath))

    # folders = [s for s in folders if re.match('^_los_prediction', s)]

    from ._utils import bin_train_test_split

    # print(folders)
    # return

    folders = [
        # '_los_prediction_horizon7_discharged_shift_is_medicare_or_medicaid',
        # '_los_prediction_horizon7_discharged_shift_is_first_unit_surgical',
        # '_los_prediction_quantile9_discharged',
        # '_los_prediction_quantile4_discharged',
        '_los_prediction_horizon7_discharged',
        # '_los_prediction_horizon3_discharged',
        # '_los_prediction_quantile2_discharged',
    ]

    estimators = [
        # 'estimator=hist_gradient_boosting_classifier__split_rs=1',
        # 'estimator=random_forest_classifier__split_rs=0',
        # 'estimator=lstm__split_rs=2',
        # 'estimator=logistic_regression__split_rs=0',
        # 'estimator=random_forest_classifier__split_rs=1',
        # 'estimator=random_forest_classifier__split_rs=4',
        # 'estimator=random_forest_classifier__split_rs=2',
        # 'estimator=logistic_regression__split_rs=2',
        # 'estimator=hist_gradient_boosting_classifier__split_rs=4',
        # 'estimator=hist_gradient_boosting_classifier__split_rs=0',
        # 'estimator=lstm__split_rs=1',
        # 'estimator=lstm__split_rs=3',
        # 'estimator=hist_gradient_boosting_classifier__split_rs=2',
        # 'estimator=logistic_regression__split_rs=1',
        # 'estimator=random_forest_classifier__split_rs=3',
        # 'estimator=lstm__split_rs=4',
        # 'estimator=hist_gradient_boosting_classifier__split_rs=3',
        # 'estimator=logistic_regression__split_rs=4',
        'estimator=lstm__split_rs=0',
        # 'estimator=logistic_regression__split_rs=3',
    ]


    for folder in folders:#[:1]:
        # print(folder)
        xp_path = join(root, folder)
        print(xp_path)
        xps = read_los_prediction_results2(Path(xp_path))
        # estimators = list(xps.keys())

        for estimator in estimators:#[:1]:

            xp = xps[estimator]
            print(xp.keys())
            X_raw = xp['X']
            X = xp['X_imputed']
            y_labels = xp['y']
            y_scores = xp['y_pred']
            strat = xp['stratification_features']

            print(X_raw.shape)
            print(X.shape)
            print(y_labels.shape)
            print(y_scores.shape)

            # print(y_labels)
            # print(strat['n_codes_missing'])
            # return

            r = re.match(r'na_([A-z,0-9])*', clustering)
            n = int(r.group(1))
            _clustering = strat['n_codes_missing']
            age_min = np.min(_clustering)
            age_max = np.max(_clustering)
            qs = np.linspace(0, 1, n + 1)[1:-1]
            bins = [age_min - 1] + list(np.quantile(_clustering, qs)) + [age_max + 1]
            _clustering = np.digitize(_clustering, bins=bins)

            if y_scores.shape[1] == 2:
                classes = [1]
            else:
                classes = np.arange(y_scores.shape[1])

            for k in classes:
            # k = 1  # Plot the positive class

                # y_pred_scores, y_well_guess = binarize_multiclass_marginal(y_scores, y_labels, k)
                if test_size is not None:
                    splitter = bin_train_test_split(y_scores[:, k], test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
                    train_idx, test_idx = next(splitter)
                else:
                    test_idx = None

                if calibrate is not None:
                    y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                                test_size=test_idx, max_calibration=False)
                    # y_labels = y_well_guess

                (frac_pos,
                counts,
                mean_scores,
                *_
                ) = cluster_evaluate_marginals(X, y_labels, y_scores,
                                                positive_class=k,
                                                bins=n_bins,
                                                clustering=_clustering,
                                                test_size=test_idx,
                                                min_samples_leaf=None,
                                                max_clusters_bin=max_clusters_bin,
                                                verbose=2,
                                                n_jobs=n_jobs)

                fig = plot_frac_pos_vs_scores(frac_pos,
                                                counts,
                                                mean_scores,
                                                y_scores=None,#y_pred_scores,
                                                y_labels=None,#y_well_guess,
                                                ncol=1,
                                                legend_loc='upper left',
                                                bbox_to_anchor=(0, 1),
                                                title=None,#class_name.capitalize(),
                                                xlim_margin=0.05,
                                                ylim_margin=0.05,
                                                min_cluster_size=min_cluster_size,
                                                hist=hist,
                                                ci=ci,
                                                legend_cluster_sizes=True,
                                                vary_cluster_size=vary_cluster_size,
                                                capsize=capsize,
                                                xlabel=f'Confidence score class {k}',
                                                # xlabel='Confidence score of positive class',
                                                # xlabel='Maximum confidence score',
                                                ylabel='Fraction of positives (%)',# if split == 'test_unseen' else None,
                                                # ylabel='Fraction of correct predictions',
                                                cluster_size=cluster_size,
                                                absolute_size_scale=absolute_size_scale,
                                                plot_cal_hist=plot_cal_hist,
                                                figsize=figsize,
                                                legend_n_sizes=legend_n_sizes,
                                                # legend_size=legend_size,
                                                legend_sizes_only=True,
                                                legend_min_max=False,
                                                plot_first_last_bins=False,
                                                grid_space=0,
                                                legend_title='Sizes',
                                                plot_cluster_id=True,
                                                )

                ax = fig.axes[0]
                ax.set_xlim((0, 1))
                ax.set_ylim((0, 1))

                ax.set_xticks([0, 0.5, 1])
                ax.set_xticklabels(['0', '0.5', '1'])
                ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
                ax.set_yticklabels(['0', '25', '50', '75', '100'])

                fig.axes[1].get_xaxis().set_visible(False)
                fig.axes[2].get_yaxis().set_visible(False)
                # fig.axes[1].set_xticks([])
                # fig.axes[2].set_yticks([])

                # if calibrate is not None or not net.split == 'test_seen':
                #     ax.set_ylabel(None)
                #     ax.yaxis.set_ticklabels([])

                out_kwargs = {
                    'xp': folder,
                    'est': estimator,
                    'cal': calibrate,
                    'pos': k,
                    'clu': clustering,
                    # 'clustering': clustering,
                    # 'bins': n_bins,
                    # 'ci': ci,
                    # 'hist': hist,
                    # 'test_size': test_size,
                    # 'vsize': vary_cluster_size,
                    # # 'capsize': capsize,
                    # # 'csize': cluster_size,
                    # 'scale': absolute_size_scale,
                    # # 'chist': plot_cal_hist,
                    # 'split': split,
                }

                order = ['xp', 'est', 'cal']
                # order = ['net', 'clustering', 'test_size']

                # out_kwargs['max_clusters_bin'] = max_clusters_bin
                # order.append('max_clusters_bin')

                # print(out)
                save_fig(fig, out, order=order, **out_kwargs)



folders = [
        '_los_prediction_horizon7_discharged_shift_is_medicare_or_medicaid',
        '_los_prediction_horizon7_discharged_shift_is_first_unit_surgical',
        '_los_prediction_horizon7_discharged',
        '_los_prediction_horizon3_discharged',
    ]
estimators = [
    'estimator=lstm__split_rs=0',
    'estimator=lstm__split_rs=1',
    'estimator=lstm__split_rs=2',
    'estimator=lstm__split_rs=3',
    'estimator=lstm__split_rs=4',
]
# estimators = [
#     'estimator=hist_gradient_boosting_classifier__split_rs=1',
#     'estimator=random_forest_classifier__split_rs=0',
#     'estimator=lstm__split_rs=2',
#     'estimator=logistic_regression__split_rs=0',
#     'estimator=random_forest_classifier__split_rs=1',
#     'estimator=random_forest_classifier__split_rs=4',
#     'estimator=random_forest_classifier__split_rs=2',
#     'estimator=logistic_regression__split_rs=2',
#     'estimator=hist_gradient_boosting_classifier__split_rs=4',
#     'estimator=hist_gradient_boosting_classifier__split_rs=0',
#     'estimator=lstm__split_rs=1',
#     'estimator=lstm__split_rs=3',
#     'estimator=hist_gradient_boosting_classifier__split_rs=2',
#     'estimator=logistic_regression__split_rs=1',
#     'estimator=random_forest_classifier__split_rs=3',
#     'estimator=lstm__split_rs=4',
#     'estimator=hist_gradient_boosting_classifier__split_rs=3',
#     'estimator=logistic_regression__split_rs=4',
#     'estimator=lstm__split_rs=0',
#     'estimator=logistic_regression__split_rs=3',
# ]
import itertools

@pytest.mark.parametrize('metric', [
    # 'ind_mse',
    'tot_abs',
    # 'tot_abs_norm',
    # 'tot_mape',
])
# @pytest.mark.parametrize('folder', [
#         '_los_prediction_horizon7_discharged_shift_is_medicare_or_medicaid',
#         '_los_prediction_horizon7_discharged_shift_is_first_unit_surgical',
#         # '_los_prediction_quantile9_discharged',
#         # '_los_prediction_quantile4_discharged',
#         '_los_prediction_horizon7_discharged',
#         '_los_prediction_horizon3_discharged',
#         # '_los_prediction_quantile2_discharged',
#     ])
@pytest.mark.parametrize('fold_est', itertools.product(folders, estimators))
def test_mimic_one(metric, fold_est, out, n_jobs):
    from ._utils import grouping_loss_lower_bound
    dirpath = 'mimic_dist'
    folder, estimator = fold_est
    # folder = '_los_prediction_horizon7_discharged'

    # or: xp__los_prediction_quantile2_discharged_est_estimator_lstm__split_rs_0_cal_N_clu_na_2_pos_2
    # estimator = 'estimator=lstm__split_rs=0'


    n_bins = 15
    max_clusters_bin = 2
    clustering = 'na_2'
    test_size = None
    calibrate = 'isotonic'
    n_bins = 15
    max_clusters_bin = 2
    clustering = 'na_2'
    hist = True
    ci = 'clopper'
    min_cluster_size = 1
    capsize = 3.5
    breakout = False
    cluster_size = 30
    vary_cluster_size = False
    absolute_size_scale = (10, None)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 1

    xp_path = join(dirpath, folder)
    xps = read_los_prediction_results2(Path(xp_path))
    if estimator not in xps:
        pytest.skip(f'estimator {estimator} not in xps {xps.keys()} for folder {folder}.')

    xp = xps[estimator]
    X_raw = xp['X']
    X = xp['X_imputed']
    y_labels = xp['y']
    y_scores = xp['y_pred']
    strat = xp['stratification_features']
    n_missing_measures = strat["n_codes_missing"]
    log = xp['log']

    r = re.match(r'na_([A-z,0-9])*', clustering)
    n = int(r.group(1))
    _clustering = strat['n_codes_missing']
    age_min = np.min(_clustering)
    age_max = np.max(_clustering)
    qs = np.linspace(0, 1, n + 1)[1:-1]
    bins = [age_min - 1] + list(np.quantile(_clustering, qs)) + [age_max + 1]
    _clustering = np.digitize(_clustering, bins=bins)

    # print(_clustering)
    # print(np.unique(_clustering))
    # return

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
    _y_scores, _y_labels = y_scores, y_labels
    # if calibrate is not None:
    #     y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
    #                                 test_size=test_idx, max_calibration=True)
    #     y_labels = y_well_guess

    # Common clustering
    # (frac_pos,
    # counts,
    # mean_scores,
    # labels,
    # _,
    # _,
    # _,
    # labels_train,
    # ) = cluster_evaluate_max(X, y_labels, y_scores,
    #                             breakout=False,
    #                             bins=n_bins,
    #                             clustering=clustering,
    #                             test_size=test_size,
    #                             min_samples_leaf=None,
    #                             max_clusters_bin=max_clusters_bin,
    #                             verbose=1,
    #                             n_jobs=n_jobs,
    #                             binary_scores=calibrate is not None,
    #                             return_clustering=True,
    #                             )

    from .._plot import set_latex_font
    set_latex_font()

    if y_scores.shape[1] == 2:
        classes = [1]
    else:
        classes = np.arange(y_scores.shape[1])

    print(classes)

    # for k in classes:
    # for k in [1]:

    # Calibrate without train/test split

    def one(calibrate, debiased=True, min_cluster_size=5):
        if calibrate is not None:
            y_scores, _ = calibrate_scores(_y_scores, y_labels, method='isotonic',
                                        test_size=None, max_calibration=False)
        else:
            y_scores = _y_scores

        k = 1

        (frac_pos,
        counts,
        mean_scores,
        *_
        ) = cluster_evaluate_marginals(X, y_labels, y_scores,
                                        positive_class=k,
                                        bins=n_bins,
                                        clustering=_clustering,
                                        test_size=test_size,
                                        min_samples_leaf=None,
                                        max_clusters_bin=max_clusters_bin,
                                        verbose=2,
                                        n_jobs=n_jobs)

        plt.rc('legend', title_fontsize=10)
        plt.rc('legend', fontsize=10)
        plt.rc('legend', handletextpad=0.01)
        plt.rc('legend', columnspacing=0.02)
        plt.rc('legend', borderpad=0.3)
        plt.rc('legend', borderaxespad=0.2)
        plt.rc('legend', handlelength=1.2)
        # plt.rc('legend', handlelength=1)
        plt.rc('legend', labelspacing=0.1)
        plt.rc('xtick', labelsize=10)
        plt.rc('ytick', labelsize=10)
        plt.rc('axes', labelsize=13)
        fig = plot_frac_pos_vs_scores(frac_pos,
                                        counts,
                                        mean_scores,
                                        y_scores=None,#y_pred_scores,
                                        y_labels=None,#y_well_guess,
                                        ncol=1,
                                        legend_loc='upper left',
                                        bbox_to_anchor=(0, 1),
                                        title=None,#class_name.capitalize(),
                                        xlim_margin=0,
                                        ylim_margin=0,
                                        min_cluster_size=14,
                                        hist=hist,
                                        ci=ci,
                                        legend_cluster_sizes=True,
                                        vary_cluster_size=vary_cluster_size,
                                        capsize=capsize,
                                        xlabel=f'Confidence score class {k}',
                                        # xlabel='Confidence score of positive class',
                                        # xlabel='Maximum confidence score',
                                        ylabel='Fraction of positives (%)',# if split == 'test_unseen' else None,
                                        # ylabel='Fraction of correct predictions',
                                        cluster_size=cluster_size,
                                        absolute_size_scale=absolute_size_scale,
                                        plot_cal_hist=plot_cal_hist,
                                        figsize=figsize,
                                        legend_n_sizes=legend_n_sizes,
                                        # legend_size=legend_size,
                                        legend_sizes_only=True,
                                        legend_min_max=False,
                                        plot_first_last_bins=False,
                                        grid_space=0,
                                        legend_title='Sizes',
                                        plot_cluster_id=True,
                                        )
        save_fig(fig, out, n='curve', cal=calibrate, folder=folder, est=estimator, order=['n', 'cal'])

        import matplotlib as mpl
        mpl.rcParams.update(mpl.rcParamsDefault)
        set_latex_font()

        prob_bins, mean_bins = calibration_curve(frac_pos, counts, mean_scores, remove_empty=False)

        # gl_bin = np.zeros(n_bins)
        # gl_bin_debiased = np.zeros(n_bins)

        bin_gl_lower_debiased, bias = grouping_loss_lower_bound(
            frac_pos, counts, reduce_bin=False, debiased=True, return_bias=True
        )
        bin_gl_lower = grouping_loss_lower_bound(
            frac_pos, counts, reduce_bin=False, debiased=False
        )
        gl_bin_debiased = bin_gl_lower_debiased
        gl_bin = bin_gl_lower

        print('biased', gl_bin)
        print('debiased', gl_bin_debiased)

        def cost_binary(y_scores_pos, y_labels, is_na_pos, cost_pos=1000, cost_neg=10, na_mult=2, pos_label=1):
            assert y_scores_pos.ndim == 1
            assert is_na_pos.ndim == 1
            y_labels = np.array(y_labels == pos_label, dtype=int)

            predicted_costs = y_scores_pos*cost_pos*(1 + (na_mult-1)*is_na_pos) \
                + (1 - y_scores_pos)*cost_neg*(1 + (na_mult-1)*is_na_pos)

            oracle_costs = y_labels*cost_pos*(1 + (na_mult-1)*is_na_pos) \
                + (1 - y_labels)*cost_neg*(1 + (na_mult-1)*is_na_pos)

            return predicted_costs, oracle_costs

        def cost_binary2(y_scores_pos, y_labels, is_na_pos,
                        cost_pos_na_pos=1,
                        cost_pos_na_neg=1,
                        cost_neg_na_pos=1,
                        cost_neg_na_neg=1,
                        pos_label=1):
            assert y_scores_pos.ndim == 1
            assert is_na_pos.ndim == 1
            y_labels = np.array(y_labels == pos_label, dtype=int)

            predicted_costs = y_scores_pos*(cost_pos_na_pos*is_na_pos + cost_pos_na_neg*(1 - is_na_pos)) \
                + (1 - y_scores_pos)*(cost_neg_na_pos*is_na_pos + cost_neg_na_neg*(1 - is_na_pos))

            oracle_costs = y_labels*(cost_pos_na_pos*is_na_pos + cost_pos_na_neg*(1 - is_na_pos)) \
                + (1 - y_labels)*(cost_neg_na_pos*is_na_pos + cost_neg_na_neg*(1 - is_na_pos))

            return predicted_costs, oracle_costs

        y_scores_pos = y_scores[:, 1]
        is_na_pos = _clustering - 1
        assert np.array_equal(np.unique(y_labels), [0, 1])
        assert np.array_equal(np.unique(is_na_pos), [0, 1])
        # is_na
        predicted_costs, oracle_costs = cost_binary(y_scores_pos, y_labels, is_na_pos,
        cost_pos=10, cost_neg=20, na_mult=100)

        print(predicted_costs)
        print(oracle_costs)

        print(np.mean(predicted_costs))
        print(np.mean(oracle_costs))

        # Cut per bin
        # n_bins = 15
        bins = np.linspace(0, 1, n_bins + 1)
        y_bins = np.digitize(y_scores_pos, bins=bins) - 1
        y_bins = np.clip(y_bins, a_min=None, a_max=n_bins-1)

        predicted_costs_bin = np.full(n_bins, np.nan)
        oracle_costs_bin = np.full(n_bins, np.nan)
        mse_costs_bin = np.full(n_bins, np.nan)
        for i in range(n_bins):
            y_labels_bin = y_labels[y_bins == i]
            y_scores_pos_bin = y_scores_pos[y_bins == i]
            is_na_pos_bin = is_na_pos[y_bins == i]
            # predicted_costs, oracle_costs = cost_binary(y_scores_pos_bin, y_labels_bin, is_na_pos_bin,
            # cost_pos=100, cost_neg=10, na_mult=2000)
            predicted_costs, oracle_costs = cost_binary2(y_scores_pos_bin, y_labels_bin, is_na_pos_bin,
                cost_pos_na_pos=100,
                cost_pos_na_neg=100,
                cost_neg_na_pos=10,
                cost_neg_na_neg=1,
                )
            # cost_pos=10, cost_neg=20, na_mult=100)
            predicted_costs_bin[i] = np.mean(predicted_costs)
            oracle_costs_bin[i] = np.mean(oracle_costs)
            mse_costs_bin[i] = np.mean(np.square(predicted_costs - oracle_costs))

        print(predicted_costs_bin)
        print(oracle_costs_bin)

        print(predicted_costs_bin - oracle_costs_bin)

        if debiased:
            gl = gl_bin
        else:
            gl = gl_bin_debiased

        # Unormalize GL
        _gl = gl
        gl = gl*np.divide(np.sum(counts), np.sum(counts, axis=1))

        big_enough = np.any(counts >= min_cluster_size, axis=1)
        idx_pos = gl >= 0
        idx = np.logical_and(idx_pos, big_enough)

        if metric == 'ind_mse':
            error_bin = mse_costs_bin
        elif metric == 'tot_abs':
            error_bin = np.absolute(predicted_costs_bin - oracle_costs_bin)
        elif metric == 'tot_abs_norm':
            error_bin = np.absolute(predicted_costs_bin - oracle_costs_bin)
            error_bin /= np.nanmax(error_bin[idx_pos])
        elif metric == 'tot_mape':
            error_bin = 100*np.absolute(np.divide(oracle_costs_bin - predicted_costs_bin, oracle_costs_bin))
        else:
            raise ValueError(f'Unknown metric {metric}.')

        # error_bin = np.absolute(predicted_costs_bin - oracle_costs_bin)
        print(error_bin)



        # # Save in dataframe
        if pytest.shared_fig is None:
            # shared_fig = plt.figure(figsize=(3, 2))
            shared_fig, _ = plt.subplots(1, 1, figsize=(3, 2))
            pytest.shared_fig = shared_fig
            shared_dfs = []
            pytest.mimic_shared_df = shared_dfs

        else:
            shared_fig = pytest.shared_fig
            shared_dfs = pytest.mimic_shared_df

        fig, _ = plt.subplots(1, 1, figsize=(3, 2))
        ax = fig.axes[0]
        shared_ax = shared_fig.axes[0]
        ece_l2 = np.square(prob_bins - mean_bins)
        # ece_l2 = np.square(prob_bins - mean_bins)*np.divide(np.sum(counts, axis=1), np.sum(counts))

        if calibrate is not None:
            color = None

        else:
            vmin = np.nanmin(ece_l2)
            vmax = np.nanmax(ece_l2)
            print('vmax', vmax)
            from matplotlib.colors import Normalize, LogNorm
            import matplotlib.cm as cm
            norm = LogNorm(vmin, vmax)
            # norm = Normalize(vmin, vmax)

            cmap = 'RdBu_r'
            m = cm.ScalarMappable(norm=norm, cmap=cmap)

            colors = np.array([m.to_rgba(v) for v in ece_l2])
            color = colors[idx]
            print(colors)
        print(ece_l2)

        n_samples = np.sum(idx)
        df = pd.DataFrame({
            'gl': gl[idx],
            'error_bin': error_bin[idx],
            'metric': np.full(n_samples, metric),
            'estimator': np.full(n_samples, estimator),
            'task': np.full(n_samples, folder),
            'task_estimator': np.full(n_samples, f'{folder}_{estimator}'),
        })
        shared_dfs.append(df)

        def plot_one(ax, colorbar=True):
            _color = color if colorbar else None
            ax.scatter(gl[idx], error_bin[idx], color=_color, edgecolor='black', lw=0.5, )
            # plt.scatter(ece_l2[idx], error_bin[idx], color=colors[idx])
            # plt.scatter(gl[idx]+ece_l2[idx], error_bin[idx], color=colors[idx])
            # plt.scatter(gl[idx], error_bin[idx], color=colors[idx])

            if calibrate is None and colorbar:
                cb = plt.colorbar(m, ax=ax)
                cb.ax.set_title(r'$\ell^2$-$ECE$')

            # ax = plt.gca()
            ax.set_xlabel(r'$\widehat \mathcal{L}_{GL,bin}$')
            if metric == 'ind_mse':
                ax.set_ylabel(r'$MSE(C^{\star}, \hat C)_{bin}$')
            elif metric == 'tot_abs':
                ax.set_ylabel(r'$|C^{\star}_{bin} - \hat C_{bin}|$')
            elif metric == 'tot_abs_norm':
                ax.set_ylabel(r'$|C^{\star}_{bin} - \hat C_{bin}|$ normalized')
            elif metric == 'tot_mape':
                ax.set_ylabel(r'$MAPE(C^{\star}_{bin}, \hat C_{bin})$ (%)')
            else:
                raise ValueError(f'Unknown metric {metric}.')

        # plot_one(ax)
        plot_one(shared_ax, colorbar=False)

        # ylim = ax.get_ylim()
        # ax.set_ylim(ylim[0], 5)
        # save_fig(fig, out, n='error_gl', cal=calibrate, deb=debiased, folder=folder, est=estimator, metric=metric, order=['n', 'cal', 'deb'])
        save_fig(shared_fig, out, n='shared_error_gl', cal='')

        # fig = plt.figure(figsize=(3, 2))
        # plt.scatter(gl[idx] + ece_l2[idx], error_bin[idx], color=color, edgecolor='black', lw=0.5)
        # ax = plt.gca()
        # ax.set_xlabel(r'$\widehat \mathcal{L}_{GL,bin} + \ell^2$-$ECE$')
        # ax.set_ylabel(r'$|C^{\star}_{bin} - \hat C_{bin}|$')
        # save_fig(fig, out, n='error_gl_ece', cal=calibrate, deb=debiased, folder=folder, est=estimator, metric=metric, order=['n', 'cal', 'deb'])


        return ece_l2, gl, error_bin

    # ece_l2, gl, error_bin = one(calibrate=None, debiased=True)
    calibrate = 'isotonic'
    ece_l2_cal, gl_cal, error_bin_cal = one(calibrate=calibrate, debiased=True)

    # from .._plot import barplots_ece_gl_cal
    # # print(np.flip(np.arange(n_bins)[idx]))
    # # print(np.flip(np.arange(n_bins)).shape)
    # # print(idx.shape)
    # bin_ids = 1+np.arange(n_bins)
    # # bin_ids = np.array([str(s) for s in np.arange(n_bins)])
    # fig = barplots_ece_gl_cal(bin_ids.astype(str), (ece_l2, gl), (ece_l2_cal, gl_cal),
    #                           np.flip(bin_ids), plot_table=False, keep_scale=False)
    # # fig = barplots_ece_gl_cal(bin_ids[idx].astype(str), (ece_l2[idx], gl[idx]), None, np.flip(bin_ids[idx]))
    # # fig.axes[0].set_title(f'ImageNet {split}')
    # save_fig(fig, out, n='ece_gl', cal=calibrate, folder=folder, est=estimator, metric=metric, order=['n', 'cal'])

    df = pd.concat(pytest.mimic_shared_df, axis=0, ignore_index=True)

    fig, _ = plt.subplots(1, 1, figsize=(3, 2))
    ax = fig.axes[0]
    print(df)
    # return
    # df['task_estimator'] = df['task_estimator'].str.slice(0,10) #pd.factorize(df.task_estimator)[0] + 1
    sns.scatterplot(data=df, x='gl', y='error_bin', ax=ax, hue='task_estimator', legend='full')
    ax.set_xlabel(r'$\widehat \mathcal{L}_{GL,bin}$')
    if metric == 'ind_mse':
        ax.set_ylabel(r'$MSE(C^{\star}, \hat C)_{bin}$')
    elif metric == 'tot_abs':
        ax.set_ylabel(r'$|C^{\star}_{bin} - \hat C_{bin}|$')
    elif metric == 'tot_abs_norm':
        ax.set_ylabel(r'$|C^{\star}_{bin} - \hat C_{bin}|$ normalized')
    elif metric == 'tot_mape':
        ax.set_ylabel(r'$MAPE(C^{\star}_{bin}, \hat C_{bin})$ (%)')
    else:
        raise ValueError(f'Unknown metric {metric}.')
    import statsmodels.api as sm
    # ax.legend(title='task/estimator', bbox_to_anchor=(1, 1), loc='upper left')
    z = sm.nonparametric.lowess(df['error_bin'], df['gl'], frac=0.75)
    ax.plot(z[:, 0], z[:, 1], lw=1.5, color='black')#lw=0.8)
    save_fig(fig, out, n='shared_error_gl2', cal=calibrate)

@pytest.fixture
def mimic_xps(request):
    path = Path(request.param)
    xps = read_los_prediction_results(path)
    return xps


@pytest.mark.parametrize('mimic_xps', MIMIC_XP_PATHS, indirect=True)
@pytest.mark.parametrize('clustering', [
    # 'kmeans', 'decision_tree', 'gender', 'age_2', 'age_4', 'age70',
    'gender', 'age_2', 'age_4', 'age70',
    'ethnicity', 'meta_ethnicity_BW', 'meta_ethnicity_B',
    'meta_ethnicity_BWAH', 'admission_type', 'first_careunit', 'last_careunit',
    'first_wardid', 'last_wardid', 'admission_location', 'discharge_location',
    'insurance', 'icd9_chapter_title', 'main_diagnosis_icd9_chapter_title',
    'na_2', 'na_3', 'na_4', 'na_5', 'deceased_status'])
@pytest.mark.parametrize('calibration_type', ['max', 'marginals'])
@pytest.mark.parametrize('max_clusters_bin', [2])#, 6, 12])
@pytest.mark.parametrize('test_size', [0.5, 0.75, 0.9])
# @pytest.mark.parametrize('test_size', [None, 0.5, 0.75, 0.9])
def test_mimic_hetero_plot(mimic_xps, clustering, calibration_type, max_clusters_bin, test_size, out, n_jobs):
    if clustering != 'decision_tree' and max_clusters_bin != 2:
        pytest.skip(f'Test only one max_clusters_bin value when clustering is '
                    f'{clustering}!="decision_tree"')

    if clustering != 'kmeans' and test_size is None:
        pytest.skip(f'Test_size=None only for kmeans')

    for xp in mimic_xps:
        log = xp['log']
        X = xp['X_imputed']
        y_scores = xp['y_pred']
        y_labels = xp['y']
        estimator_name = log.iloc[0]['estimator_name']
        label_type = log.iloc[0]['label_type']
        strat = xp['stratification_features']
        horizon = str(log.iloc[0]['horizon'])
        horizon = np.array(eval(horizon))
        if horizon.ndim == 1:
            horizon = f'quantile{len(horizon)}'
        elif horizon.ndim == 0:
            horizon = f'horizon{horizon}'
        else:
            raise ValueError(f'Unknown horizon {horizon}.')

        if len(X) != len(strat):
            print(f'Skip {estimator_name} because shape mismatch.')
            continue

        if clustering in strat:
            _clustering = strat[clustering]

        elif re.match(r'age_[A-z,0-9]*', clustering):
            r = re.match(r'age_([A-z,0-9])*', clustering)
            n = int(r.group(1))
            _clustering = strat['age']
            age_min = np.min(_clustering)
            age_max = np.max(_clustering)
            qs = np.linspace(0, 1, n + 1)[1:-1]
            bins = [age_min - 1] + list(np.quantile(_clustering, qs)) + [age_max + 1]
            _clustering = np.digitize(_clustering, bins=bins)

        elif re.match(r'na_[A-z,0-9]*', clustering):
            r = re.match(r'na_([A-z,0-9])*', clustering)
            n = int(r.group(1))
            _clustering = strat['n_codes_missing']
            age_min = np.min(_clustering)
            age_max = np.max(_clustering)
            qs = np.linspace(0, 1, n + 1)[1:-1]
            bins = [age_min - 1] + list(np.quantile(_clustering, qs)) + [age_max + 1]
            _clustering = np.digitize(_clustering, bins=bins)

        elif re.match(r'age[0-9]*', clustering):
            r = re.match(r'age([0-9]*)', clustering)
            n = int(r.group(1))
            _clustering = strat['age']
            _clustering = (_clustering >= n).astype(int)

        elif clustering == 'ethnicity':
            _clustering = strat['ethnicity']

        elif re.match(r'meta_ethnicity_[A-z,0-9]*', clustering):
            def ethnicity_to_meta_ethnicity(eths):
                meta_eths = np.zeros_like(eths, dtype=int)

                r = re.match(r'meta_ethnicity_([A-z,0-9]*)', clustering)

                n = r.group(1)

                if n == 'BW':
                    map = {
                        'BLACK': 1,
                        'ASIAN': 0,
                        'WHITE': 2,
                        'HISPANIC': 0,
                        'AMERICAN': 0,
                    }

                elif n == 'B':
                    map = {
                        'WHITE': 0,
                        'ASIAN': 0,
                        'BLACK': 1,
                        'HISPANIC': 0,
                        'AMERICAN': 0,
                    }

                elif n == 'BWAH':
                    map = {
                        'BLACK': 1,
                        'WHITE': 2,
                        'ASIAN': 3,
                        'HISPANIC': 4,
                    }

                else:
                    raise ValueError(f'Unkown meta_ethnicity {clustering}.')

                for i, eth in enumerate(eths):

                    for key, value in map.items():
                        if key in eth:
                            meta_eths[i] = value

                return {eths[i]: eth for i, eth in enumerate(meta_eths)}

            _clustering = strat['ethnicity']
            replace = ethnicity_to_meta_ethnicity(np.unique(_clustering))
            _clustering = _clustering.replace(replace)

        elif clustering in ['kmeans', 'decision_tree']:
            if np.isnan(X).any():
                continue

            _clustering = clustering

        elif isinstance(clustering, str) and clustering not in strat:
            print(f'{clustering} not in strat.')
            continue

        else:
            raise ValueError(f'Unknown clustering method {clustering}.')

        plot_cluster_id = True

        if len(estimator_name) > 10:
            estimator_name = estimator_name[:10]

        extra_out_kwargs = {
            'est': estimator_name,
            'lab': label_type,
            'hor': horizon,
            'clu': clustering,
            'cid': plot_cluster_id,
        }

        order = ['hor', 'lab', 'est', 'clu', 'cal']

        if calibration_type == 'max':
            extra_out_kwargs['cal'] = 'max'
            df = cluster_max(X, y_labels, y_scores, name=None, out_dir=out, n_jobs=n_jobs,
                        clustering=_clustering, test_size=test_size,
                        max_clusters_bin=max_clusters_bin,
                        extra_out_kwargs=extra_out_kwargs, order=order,
                        plot_cluster_id=plot_cluster_id,)

        elif calibration_type == 'marginals':
            extra_out_kwargs['cal'] = 'mar'
            order.append('class')
            df = cluster_marginals(X, y_labels, y_scores, name=None, out_dir=out, n_jobs=n_jobs,
                              clustering=_clustering, test_size=test_size,
                              max_clusters_bin=max_clusters_bin,
                              extra_out_kwargs=extra_out_kwargs, order=order,
                              plot_cluster_id=plot_cluster_id,)

        # Save in dataframe
        if pytest.mimic_shared_df is None:
            pytest.mimic_shared_df = [df]

        else:
            pytest.mimic_shared_df.append(df)

    df = pd.concat(pytest.mimic_shared_df, axis=0, ignore_index=True)
    df.to_csv(join(out, 'metrics.csv'))


def test_mimic_lower_bound(out):
    dir_name = os.path.dirname(out)
    in_path = os.path.join(dir_name, 'mimic_hetero_plot', 'metrics.csv')
    df = pd.read_csv(in_path, index_col=0)
    print(df)

    print(df.columns)
    # return

    group_col = ['hor', 'est', 'lab', 'class']

    if 'class' not in df.columns:
        df['class'] = 'max'
        # group_col.append('class')

    else:
        df['class'] = df['class'].fillna('max')

    # print(df)
    # return

    dfgb = df.groupby(group_col, dropna=False)

    for index, subdf in dfgb:
        horizon, estimator, label, class_ = index
        # class_ = None
        print(subdf)

        fig = plot_grouping_bounds(subdf['lower_bound'], subdf['upper_bound'], subdf['clu'])

        ax = fig.axes[0]
        ax.set_title(f'{horizon} {label} {estimator} class={class_}')
        save_fig(fig, out, horizon=horizon, estimator=estimator, label=label, class_=class_,
        order=['horizon', 'label', 'estimator', 'class_'])
        # break


def test_mean_only(out, n_jobs):
    n_bins = 10
    clustering = 'kmeans'
    test_size = 0.9
    max_clusters_bin = 1
    hist = False
    ci = None #'clopper'
    min_cluster_size = 10
    vary_cluster_size = False
    capsize = 4
    cluster_size = 80
    absolute_size_scale = (10, 5000)
    plot_cal_hist = False
    figsize = (2.5, 2.5)

    from ..CalibrationExample import CustomUniform, OriginalExample

    # ex = OriginalExample(1)
    ex = CustomUniform(name='poly', dist='uniform', half=True)

    X, y_labels = ex.generate_X_y(n=1000000)
    y_scores = ex.f(X)
    X = X[:, None]
    y_scores = np.stack([1-y_scores, y_scores], axis=1)

    k = 1  # Plot the positive class

    (frac_pos,
     counts,
     mean_scores,
     *_
     ) = cluster_evaluate_marginals(X, y_labels, y_scores,
                                    positive_class=k,
                                    bins=n_bins,
                                    clustering='kmeans',
                                    test_size=None,
                                    n_clusters=2,
                                    min_samples_leaf=None,
                                    max_clusters_bin=max_clusters_bin,
                                    verbose=2,
                                    n_jobs=n_jobs)

    fig = plot_frac_pos_vs_scores(frac_pos,
                                    counts,
                                    mean_scores,
                                    y_scores=None,#y_pred_scores,
                                    y_labels=None,#y_well_guess,
                                    ncol=1,
                                    legend_loc='upper left',
                                    title=None,#class_name.capitalize(),
                                    xlim_margin=0.05,
                                    ylim_margin=0.05,
                                    min_cluster_size=min_cluster_size,
                                    hist=hist,
                                    ci=ci,
                                    legend_cluster_sizes=False,
                                    vary_cluster_size=vary_cluster_size,
                                    capsize=capsize,
                                    xlabel='Confidence score of positive class',
                                    ylabel='Fraction of positives',
                                    cluster_size=cluster_size,
                                    absolute_size_scale=absolute_size_scale,
                                    plot_cal_hist=plot_cal_hist,
                                    figsize=figsize,
                                    mean_only=True,
                                    )

    out_kwargs = {
        'name': ex.__class__.__name__
    }

    order = []

    save_fig(fig, out, order=order, **out_kwargs, ext='png')


def test_GL_debiased():
    frac_pos = np.array([
        [0., 1.],
        [0.5, 0.5],
    ])
    counts = np.array([
        [10000, 10000],
        [2, 2],
    ])
    from ._utils import grouping_loss_lower_bound

    lb, bias = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=False, debiased=True, return_bias=True)
    print()
    print(lb)
    print(bias)

    lb, bias = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=True, debiased=True, return_bias=True)
    print()
    print('Debiased', lb)
    print('Bias', bias)

    lb = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=True, debiased=False)
    print()
    print('Biased:', lb)


def test_calibration():
    from sklearn.calibration import CalibratedClassifierCV

    n = 10
    K = 3
    rs = np.random.RandomState(0)
    X_train = rs.multivariate_normal(np.zeros(K), np.eye(K), size=n)
    y_train = rs.uniform(size=(n, K))
    X_val = rs.multivariate_normal(np.zeros(K), np.eye(K), size=n)
    y_val = rs.uniform(size=(n, K))

    y_label_train = np.array(torch.softmax(torch.Tensor(y_train), dim=1).argmax(dim=1), dtype=int)

    from ._utils import WrapperClassifier
    clf = WrapperClassifier(X_train, y_train, X_val, y_val)

    clf.fit(X_train, y_label_train)
    cal_clf = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit')
    print(X_train.shape)
    print(y_label_train.shape, y_label_train)
    cal_clf.fit(X_train, y_label_train)
    y_val_cal = cal_clf.predict_proba(X_val)

    print(y_val_cal - y_val)
    # print(y_val)


@pytest.mark.parametrize('stratss', [False, True])
# @pytest.mark.parametrize('n', [10, 1000])
def test_startifiedss(stratss, n=1000):
    from sklearn.model_selection import StratifiedShuffleSplit
    from ._utils import scores_to_id_bins
    d = 1
    test_size = 0.5
    n_bins = 15
    cv = StratifiedShuffleSplit(n_splits=10, test_size=0.5)
    X = np.ones((n, d))
    y = np.ones((n,))
    # groups = np.array([0, 0, 0, 0, 1, 1, 2, 2, 2, 2])
    rs = np.random.RandomState(0)
    # groups = rs.randint(0, 2, size=n)
    y_scores = rs.uniform(size=n)
    bins = np.linspace(0, 1, n_bins+1)
    groups = scores_to_id_bins(y_scores, bins)
    from ._utils import bin_train_test_split

    split = bin_train_test_split(y_scores, test_size=test_size, n_splits=10, bins=n_bins, random_state=0, stratss=stratss)
    # train_idx, test_idx = next(split)

    # ss = cv.split(X, groups)
    for train_idx, test_idx in split:
    # for train_idx, test_idx in ss:

        unique_train, counts_train = np.unique(groups[train_idx], return_counts=True)
        unique_test, counts_test = np.unique(groups[test_idx], return_counts=True)
        print('train', unique_train, counts_train)
        print('test', unique_test, counts_test)
        print('max', np.max(np.abs(counts_train - counts_test)))
        print()


        # break


def test_mlpmixer():
    net = IMAGENET_MLPMixer()
    dataloader = net.get_loader()

    dirpath = net.get_default_dirpath()
    # Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
    _y_scores_stored = torch.load(join(dirpath, 'y_scores.pt'))
    _y_labels_stored = torch.load(join(dirpath, 'y_labels.pt'))

    # net.forward_dataset('temp', batch_size=1, n_batch_max=1)

    from .networks.imagenet import PreprocessImages
    import tensorflow as tf

    i, j = 40000, 50000
    _y_pred = torch.argmax(_y_scores_stored, dim=1)
    correct = (_y_labels_stored == _y_pred).float()
    acc_expected = torch.mean(correct[i:j]).item()
    print(f'Accuracy expected: {100*acc_expected:.1f}%')
    return

    acc = 0
    acc_stored = 0

    # def get_pred(y_scores, y_label):


    for i, (img, label) in enumerate(dataloader):

        label = label.item()
        # print(img)
        # print(img.shape)
        # print(torch.min(img), torch.max(img), torch.mean(img))

        y_logits = net.forward(img, truncated=False)

        # print(y_logits)
        # print(y_logits.shape)

        y_logits = y_logits.reshape(1, -1)

        y_scores = net.logits_to_scores(y_logits)

        # print(y_scores)
        # print(y_scores.shape)

        pred = torch.argmax(y_scores).item()
        well_pred = int(label == pred)
        acc = (acc*i + well_pred)/(i+1)

        y_scores_stored = _y_scores_stored[i, :]
        y_labels_stored = _y_labels_stored[i].item()

        pred_stored = torch.argmax(y_scores_stored).item()
        well_pred_stored = int(y_labels_stored == pred_stored)
        acc_stored = (acc_stored*i + well_pred_stored)/(i+1)

        # print(label, pred, acc)

        print(f'{i+1}/{len(dataloader)} Truth: {label}\tTruth stored: {y_labels_stored}\tPredicted: {pred}\tPredicted stored: {pred_stored}\tAcc: {100*acc:.1f}%\tAcc stored: {100*acc_stored:.1f}%')


        # img_in = img.transpose(1, 3)
        # img_in = tf.convert_to_tensor(img_in.numpy(), dtype=tf.uint8)
        # img_out = PreprocessImages(224, True)(img_in)
        # img_out = torch.tensor(img_out)
        # print(img_out)
        # print(img_out.shape)
        # print(torch.min(img_out), torch.max(img_out))
        # break
        # if i >= 1000:
        #     break


def test_splits(out, n_jobs):
    from .networks import IMAGENET_ConvNeXt
    from ._utils import bin_train_test_split
    n_bins = 15
    test_size = 0.5
    net = IMAGENET_ConvNeXt(split='test_r', type='large')
    # net = IMAGENET_ConvNeXt(split='test_r')
    # net = IMAGENET_VisionTransformer(split='test_r')
    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
    y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
    y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
    splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
    train_idx, test_idx = next(splitter)

    df = pd.read_csv(join(out, 'splits.csv'), index_col=0)

    print(df)

    ds = net.get_dataset()
    renditions = [re.match(f'.*/([a-z]*)_[0-9]*.jpg', s).group(1) for s, _ in ds.imgs]
    renditions = np.array(renditions)
    paths = np.array([s for s, _ in ds.imgs])

    n_quantile = 2

    Xt_train = Xt[train_idx, :]
    paths_train = paths[train_idx]

    bins = np.linspace(0, 1, n_bins+1)
    y_bins = np.digitize(y_pred_scores, bins=bins) - 1
    y_bins = np.clip(y_bins, a_min=None, a_max=n_bins-1)
    y_bins_train = y_bins[train_idx]

    for index, row in df.iterrows():
        print('index', index)
        bin_id = index
        # bin_id = row['bin']
        feature_id = row['feature']
        split = row['split']
        ordered = row['ordered']
        print(y_bins_train.shape, bin_id, y_bins_train == bin_id)
        X_bin = Xt_train[y_bins_train == bin_id, :]
        paths_bin = paths_train[y_bins_train == bin_id]
        x = X_bin[:, feature_id]
        print(np.sum(x <= split))
        print(np.sum(x > split))

        xmin = np.min(x)
        xmax = np.max(x)
        qs = [0.01, 0.05] + list(np.linspace(0, 1, n_quantile + 1)[1:-1]) + [0.95, 0.99, 1]
        bins = [xmin - 1] + list(np.quantile(x, qs)) + [xmax + 1]
        if not ordered:
            bins = list(reversed(list(bins)))
        print(bins)
        x_bin_id = np.digitize(x, bins=bins)

        print('unique', np.unique(x_bin_id))

        print('qs', len(qs), qs)
        print('bins', len(bins), bins)


        for i in range(len(bins)-2):
            q = qs[i]
            idx_quantile = np.where(x_bin_id == i+1)[0]

            dirpath = join(out, f'quantiles_{net}', f'bin{bin_id}', str(q))
            for p in paths_bin[idx_quantile]:
                os.makedirs(dirpath, exist_ok=True)
                dest_path = join(dirpath, p.replace('/', '_'))
                shutil.copyfile(p, dest_path)


def test_imagenet_r_renditions(out, n_jobs, calibrate=None):
    n_bins = 15
    clustering = 'decision_tree'
    test_size = 0.5
    max_clusters_bin = 2
    hist = True
    ci = 'clopper'
    min_cluster_size = 14
    capsize = 3.5
    breakout = False
    cluster_size = 30
    vary_cluster_size = False
    absolute_size_scale = (10, 1500)#(10, None)
    plot_cal_hist = False
    figsize = (2.2, 2.2)
    legend_n_sizes = 1
    renditions = [
        'art',
        'cartoon',
        'graffiti',
        'embroidery',
        'graphic',
        'origami',
        'painting',
        'sculpture',
        'tattoo',
        'toy',
        'deviantart',
        'misc',
        'videogame',
        'sketch',
        'sticker',
    ]
    renditions = sorted(renditions)
    n_renditions = len(renditions)

    # net = IMAGENET_ConvNeXt(split='test_r')

    @memory.cache()
    def compute_one(net):
        net = net(split='test_r')
        dirpath = net.get_default_dirpath()
        ds = net.get_dataset()

        Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
        y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()
        y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()

        y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
        splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
        train_idx, test_idx = next(splitter)

        _y_scores, _y_labels = y_scores, y_labels
        if calibrate is not None:
            y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                        test_size=test_idx, max_calibration=True)
            y_labels = y_well_guess

        # Common clustering
        (frac_pos,
        counts,
        mean_scores,
        labels,
        _,
        _,
        _,
        labels_train,
        ) = cluster_evaluate_max(Xt, y_labels, y_scores,
                                    breakout=breakout,
                                    bins=n_bins,
                                    clustering=clustering,
                                    test_size=test_idx,
                                    min_samples_leaf=None,
                                    max_clusters_bin=max_clusters_bin,
                                    verbose=1,
                                    n_jobs=n_jobs,
                                    binary_scores=calibrate is not None,
                                    return_clustering=True,
                                    )

        if calibrate is None:
            y_pred_scores, y_well_guess = binarize_multiclass_max(_y_scores, _y_labels)

        else:
            y_pred_scores, y_well_guess = y_scores[test_idx, 1], y_labels[test_idx]

        common_metrics = compute_calib_metrics(frac_pos, counts, y_pred_scores, y_well_guess, n_bins)

        common_clustering = np.copy(labels_train)
        common_clustering[test_idx] = labels[test_idx]

        prob_bins = calibration_curve(frac_pos, counts, remove_empty=False, return_mean_bins=False)
        common_diff = np.square(frac_pos - prob_bins[:, None])

        test_idx_bool = np.zeros(Xt.shape[0], dtype=bool)
        test_idx_bool[test_idx] = True

        dfs = []
        for i, rendition in enumerate(renditions):
            print(rendition)
            split = f'test_r:{rendition}'
            # ds = net.get_dataset()
            selection = np.array([bool(re.match(f'.*/{rendition}_.*', s)) for s, _ in ds.imgs])
            net.split = split

            Xt_r = Xt[selection, :]
            y_scores_r = y_scores[selection, :]
            y_labels_r = y_labels[selection]
            test_idx_r = np.where(test_idx_bool[selection])[0]
            common_clustering_r = common_clustering[selection]
            y_pred_scores_r = y_pred_scores[selection]
            y_well_guess_r = y_well_guess[selection]

            (frac_pos,
            counts,
            mean_scores,
            *_,
            ) = cluster_evaluate_max(Xt_r, y_labels_r, y_scores_r,
                                        breakout=breakout,
                                        bins=n_bins,
                                        clustering=common_clustering_r,
                                        test_size=test_idx_r,
                                        min_samples_leaf=None,
                                        max_clusters_bin=max_clusters_bin,
                                        verbose=1,
                                        n_jobs=n_jobs,
                                        binary_scores=calibrate is not None,
                                        return_clustering=False,
                                        )

            metrics = compute_calib_metrics(frac_pos, counts, y_pred_scores_r, y_well_guess_r, n_bins)
            metrics.update({
                'net': str(net),
                'split': net.split,
                'dataset': net.get_dataset_name(),
                'network': net.get_class_name(False),
                'network+version': net.get_class_name(True),
                'calibrate': str(calibrate),
                'rendition': rendition,
                'frac_pos': frac_pos,
                'counts': counts,
                'mean_scores': mean_scores,
            })

            prob_bins = calibration_curve(frac_pos, counts, remove_empty=False, return_mean_bins=False)
            diff = np.square(frac_pos - prob_bins[:, None])

            idx_bins, idx_clusters = np.meshgrid(np.arange(frac_pos.shape[0]), np.arange(frac_pos.shape[1]))
            shape = diff.shape[1]
            # df = pd.DataFrame({
            #     # 'diff': diff.flatten(),
            #     # 'counts': counts.flatten(),
            #     # 'common_diff': common_diff.flatten(),
            #     'diff_rel': np.sum(diff - common_diff, axis=0),#.flatten(),
            #     # 'bin_id': idx_bins.flatten(),
            #     # 'cluster_id': idx_clusters.flatten(),
            #     'rendition': np.full(shape, rendition).flatten(),
            #     'network': np.full(shape, metrics['network']).flatten(),
            #     'network+version': np.full(shape, metrics['network+version']).flatten(),
            #     'net': np.full(shape, metrics['net']).flatten(),
            # })

            df = pd.DataFrame([metrics])
            dfs.append(df)

        df = pd.concat(dfs, axis=0)

        print(df)
        df['common_lower_bound'] = common_metrics['lower_bound']
        df['common_lower_bound_debiased'] = common_metrics['lower_bound_debiased']
        df['diff_rel'] = df['lower_bound'] - common_metrics['lower_bound']
        df['diff_rel_debiased'] = df['lower_bound_debiased'] - common_metrics['lower_bound_debiased']

        return df

    dfs = []
    for net in ALL_IMAGENET_NETWORKS:
        if net == IMAGENET_MLPMixer:
            continue
    # for net in [IMAGENET_ConvNeXt, IMAGENET_AlexNet, IMAGENET_VGG, IMAGENET_Inception]: #ALL_IMAGENET_NETWORKS[:2]:
        df = compute_one(net)
        dfs.append(df)

    df = pd.concat(dfs, axis=0)

    df['rendition'] = df['rendition'].str.capitalize()

    # import seaborn as sns

    # g = sns.catplot(data=df, x='diff_rel', y='rendition', hue='network')
    # fig = g.figure
    # ax = fig.axes[0]
    # ax.axvline(0)
    # ax.set_ylabel('Rendition')
    # ax.set_xlabel(r'\hat \mathcal{L}_{GL}')

    # ax.axvline(common_metrics['lower_bound_debiased'])

    figsize = (3, 3)
    plt.rc('figure', figsize=(3, 3))
    fig = plot_fig_renditions(df, x='diff_rel', y='rendition', hue='network', z='lower_bound', figsize=figsize)

    save_fig(fig, out)

    # clustering = None
    # breakout = False
    # n_bins = 15
    # n_jobs = n_jobs
    # calibrate = False

    # def compute_diff_one(net):
    #     net = net()
    #     res = []
    #     for i, sublabel in enumerate(sublabels):
    #         print(sublabel)
    #         net.split = f'test_r:{sublabel}'

    #         r = __cluster_max(
    #             net=net,
    #             clustering=clustering,
    #             breakout=breakout,
    #             n_bins=n_bins,
    #             n_jobs=n_jobs,
    #             calibrate=calibrate,
    #             verbose=0
    #         )
    #         res.append(r)

    #     # Zip results from list of tuples to tuple of lists
    #     L_frac_pos, L_counts, L_mean_scores = list(zip(*res))

    #     frac_pos = np.concatenate(L_frac_pos, axis=1)
    #     counts = np.concatenate(L_counts, axis=1)
    #     mean_scores = np.concatenate(L_mean_scores, axis=1)

    #     prob_bins, mean_bins = calibration_curve(frac_pos, counts, mean_scores)

    #     diff = np.mean(frac_pos - prob_bins[:, None], axis=0)

    #     df = pd.DataFrame({
    #         'diff': diff,
    #         'sublabels': [s.capitalize() for s in sublabels],
    #         'net': net.get_class_name(with_type=with_type),
    #     })

    #     print(df)

    #     return df

    # dfs = []
    # for net in ALL_IMAGENET_NETWORKS:
    #     df = compute_diff_one(net)
    #     dfs.append(df)

    # df = pd.concat(dfs, axis=0)
