import itertools
from copy import copy

import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import statsmodels.api as sm
import torch
from matplotlib.colors import ListedColormap, to_hex, to_rgb
from matplotlib.patches import Ellipse, Patch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sklearn.calibration import calibration_curve as sklearn_calibration_curve
from statsmodels.stats.proportion import proportion_confint

from ._utils import compute_classif_metrics
from .xp_nn_calibration._utils import (calibration_curve,
                                       compute_calib_metrics,
                                       compute_multi_classif_metrics,
                                       grouping_loss_lower_bound,
                                       grouping_loss_upper_bound)


def set_latex_font(math=True, normal=True, extra_preamble=[]):
    if math:
        plt.rcParams['mathtext.fontset'] = 'stix'
    else:
        plt.rcParams['mathtext.fontset'] = plt.rcParamsDefault['mathtext.fontset']

    if normal:
        plt.rcParams['font.family'] = 'STIXGeneral'
        # plt.rcParams['text.usetex'] = True
    else:
        plt.rcParams['font.family'] = plt.rcParamsDefault['font.family']

    # if math or normal:
    #     plt.rcParams['text.latex.preamble'] = r'\usepackage{amsfonts}',
        # [
        #     r'\usepackage{amsfonts}',
        #     # r'\usepackage{amsmath}',
        # ]
        # matplotlib.rc('text',usetex=True)
        # matplotlib.rc('text.latex', preamble=r'\usepackage{color}')

    # plt.rcParams['text.usetex'] = True

    usetex = mpl.checkdep_usetex(True)
    print(f'usetex: {usetex}')
    plt.rc('text', usetex=usetex)
    default_preamble = [
        r'\usepackage{amsfonts}',
    ]
    preamble = ''.join(default_preamble+extra_preamble)
    plt.rc('text.latex', preamble=preamble)
    # mpl.verbose.level = 'debug-annoying'

def separating_line2D(X, beta, beta0):
    assert beta.shape[0] == 2
    return -1/beta[1]*(X*beta[0] + beta0)


def separating_line2D_missing(beta, beta0, mean, cov, which, return_coefs=False):
    B = beta
    B0 = beta0
    mus = mean
    S = cov

    if which == 1:
        id1 = 0
        id2 = 1

    elif which == 2:
        id1 = 1
        id2 = 0

    else:
        raise ValueError('which should be 1 or 2')

    a = (B[id1]*S[id1, id2]/S[id2, id2] + B[id2])
    b = B[id1]*mus[id1] + B0 - mus[id2]*B[id2]*S[id1, id2]/S[id2, id2]

    if return_coefs:
        return -b/a, np.array([a, b])
    else:
        # Solution of ax+b = 0
        return -b/a


def plot_orthogonal_line(ax, w, b=0, xmin=None, xmax=None, **kwargs):
    assert w.shape[0] == 2
    if xmin is None or xmax is None:
        xmin, xmax = ax.get_xlim()

    X = np.linspace(xmin, xmax, 2)
    w0 = w[0]
    w1 = w[1]
    Y = -w0/w1*X - 1/w1*b
    ylim = ax.get_ylim()
    ax.plot(X, Y, **kwargs)
    ax.set_ylim(ylim)

    return ax


def plot_covariance2D(mean, cov, ax, n_std=3.0, facecolor='none',
                      edgecolor='black', linestyle='--', label=r'$\Sigma$',
                      **kwargs):
    """
    Create a plot of the covariance confidence ellipse of *x* and *y*.

    Parameters
    ----------
    x, y : array-like, shape (n, )
        Input data.

    ax : matplotlib.axes.Axes
        The axes object to draw the ellipse into.

    n_std : float
        The number of standard deviations to determine the ellipse's radiuses.

    **kwargs
        Forwarded to `~matplotlib.patches.Ellipse`

    Returns
    -------
    matplotlib.patches.Ellipse
    """
    assert cov.ndim == 2
    assert mean.shape == (2,)
    pearson = cov[0, 1]/np.sqrt(cov[0, 0] * cov[1, 1])
    # Using a special case to obtain the eigenvalues of this
    # two-dimensionl dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, edgecolor=edgecolor,
                      linestyle=linestyle, label=label, **kwargs)

    # Calculating the stdandard deviation of x from
    # the squareroot of the variance and multiplying
    # with the given number of standard deviations.
    scale_x = np.sqrt(cov[0, 0]) * n_std
    mean_x = mean[0]

    # calculating the stdandard deviation of y ...
    scale_y = np.sqrt(cov[1, 1]) * n_std
    mean_y = mean[1]

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean_x, mean_y)

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)


def plot_2D_classif(X, y, beta, mean, cov, X_unmasked, predictor,
                    split_fig=True, figsize=None, xlim=None, ylim=None,
                    legend_loc='best', complete_only=False, beta_learned=None):

    # x_min, x_max = np.nanmin(X[:, 0]) - .5, np.nanmax(X[:, 0]) + .5
    # y_min, y_max = np.nanmin(X[:, 1]) - .5, np.nanmax(X[:, 1]) + .5
    if xlim is None or ylim is None:
        x_min, x_max = np.min(X_unmasked[:, 0]) - .5, np.max(X_unmasked[:, 0]) + .5
        y_min, y_max = np.min(X_unmasked[:, 1]) - .5, np.max(X_unmasked[:, 1]) + .5
    else:
        x_min, x_max = xlim
        y_min, y_max = ylim

    # XX, YY = np.meshgrid(np.linspace(-5, 5, 10),
    #                      np.linspace(-5, 5, 10))

    # print(x_min, x_max)
    # exit()
    h = 100
    # c_min = min(x_min, y_min)
    # c_max = max(x_max, y_max)
    XX0, YY0 = np.meshgrid(np.linspace(x_min, x_max, h),
                           np.linspace(y_min, y_max, h))

    XX1 = np.nan*np.zeros_like(XX0)
    YY1 = YY0.copy()

    XX2 = XX0.copy()
    YY2 = np.nan*np.zeros_like(YY0)

    idx_complete = ~np.isnan(X).any(axis=1)
    idx_missing1 = np.isnan(X[:, 0]) & ~np.isnan(X[:, 1])
    idx_missing2 = np.isnan(X[:, 1]) & ~np.isnan(X[:, 0])

    figs = []
    if not split_fig:
        fig, axes = plt.subplots(1, 3, figsize=figsize)
        figs.append(fig)

    cm = plt.cm.RdBu_r
    cm_bright = ListedColormap(['#0000FF', '#FF0000'])

    if complete_only:
        iters = [(XX0, YY0)]
    else:
        iters = [(XX0, YY0), (XX1, YY1), (XX2, YY2)]

    for i, (XX, YY) in enumerate(iters):
        _, Z = predictor.predict(np.c_[XX.ravel(), YY.ravel()])
        Z = Z.reshape(XX.shape)

        if split_fig:
            fig = plt.figure(figsize=figsize)
            figs.append(fig)
            ax = plt.gca()
        else:
            ax = axes[i]
        # ax.axis('equal')
        # ax.set_aspect('equal')

        crf = ax.contourf(XX0, YY0, Z, levels=np.linspace(0, 1, 11), cmap=cm, alpha=.8, vmin=0, vmax=1)
        cbar = fig.colorbar(crf, ax=ax)
        cbar.ax.set_title(r'$\mathbb{P}(Y=1|\tilde{X})$')
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        if i == 0:
            X_selected = X[idx_complete, :]
            c = y[idx_complete]
            X_selected_unmasked = None#X_unmasked[~idx_complete, :]
            c_unmasked = y[~idx_complete]
        elif i == 1:
            X_selected = np.nan_to_num(X[idx_missing1, :])
            c = y[idx_missing1]
            X_selected_unmasked = X_unmasked[idx_missing1, :]
            c_unmasked = c
        elif i == 2:
            X_selected = np.nan_to_num(X[idx_missing2, :])
            c = y[idx_missing2]
            X_selected_unmasked = X_unmasked[idx_missing2, :]
            c_unmasked = c

        ax.scatter(X_selected[:, 0], X_selected[:, 1], c=c, cmap=cm_bright,
                   edgecolors='k', marker='.', label=None, zorder=10)
        if X_selected_unmasked is not None:
            ax.scatter(X_selected_unmasked[:, 0], X_selected_unmasked[:, 1], c=c_unmasked, cmap=cm_bright,
                       edgecolors='k', marker='.', alpha=0.2, zorder=9)

        ax.axvline(0, color='black', lw=0.5)
        ax.axhline(0, color='black', lw=0.5)

        if i == 0:
            plot_covariance2D(mean, cov, ax, n_std=3, edgecolor='black', linestyle='--', label=r'$\Sigma$')

        label = r'Bayes ($\nu = 0$)'
        if i == 0:  # no missing values

            if beta[2] != 0:
                X1 = np.linspace(x_min, x_max, 100)
                X2 = separating_line2D(X1, beta[1:], beta[0])
                p_line, = ax.plot(X1, X2, color='black', linestyle=':', label=label)  # label=r'$\langle X, \beta^{\star} \rangle + \beta_0^{\star} = 0$')
            else:
                ax.axvline(0, color='black', linestyle=':', label=label)
            # ax.arrow(0, 0, beta[1], beta[2], color='black', length_includes_head=True)
            ax.annotate("", xy=(beta[1], beta[2]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=15))

        elif i == 1:
            c, coef = separating_line2D_missing(beta[1:], beta[0], mean, cov, i, return_coefs=True)
            p_line = ax.axhline(c, ls=':', c='black', label=label)
            # ax.arrow(0, 0, beta[1], beta[2], color='black')
            ax.annotate("", xy=(coef[1], coef[0]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))

        elif i == 2:
            c, coef = separating_line2D_missing(beta[1:], beta[0], mean, cov, i, return_coefs=True)
            p_line = ax.axvline(c, ls=':', c='black', label=label)
            ax.annotate("", xy=(coef[0], coef[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20 ))
            # ax.arrow(0, 0, beta[1], beta[2], color='black')


        if beta_learned is not None:
            label = 'Learned'
            if beta_learned[2] != 0:
                X1 = np.linspace(x_min, x_max, 100)
                X2 = separating_line2D(X1, beta_learned[1:], beta_learned[0])
                p_line, = ax.plot(X1, X2, color='tab:orange', linestyle=':', label=label, zorder=11)  # label=r'$\langle X, \beta^{\star} \rangle + \beta_0^{\star} = 0$')
            else:
                ax.axvline(0, color='tab:orange', linestyle=':', label=label, zorder=11)
            # ax.arrow(0, 0, beta_learned[1], beta_learned[2], color='tab:orange', length_includes_head=True, zorder=11)
            ax.annotate("", xy=(beta_learned[1], beta_learned[2]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='tab:orange', mutation_scale=15), zorder=11)

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        ax.set_xlabel('$X_1$')
        ax.set_ylabel('$X_2$')
        ax.legend(loc=legend_loc).set_zorder(100)

    if not split_fig:
        return figs[0]
    return figs


def plot_score_vs_probas(y_scores, y_labels, y_true_probas, n_bins=15,
                         samples_with_mv=None, legend_loc='best', ncol=3,
                         lim_margin=0.15, max_samples=None):
    """Plot confidence score outputted by a classifier versus the
    true probability of the samples and color points according to their label.
    """
    set_latex_font(normal=False)

    y_scores = np.array(y_scores)
    y_true_probas = np.array(y_true_probas)
    y_labels = np.array(y_labels)

    prob_bins, mean_bins = sklearn_calibration_curve(y_labels, y_scores, n_bins=n_bins)
    metrics = compute_classif_metrics(y_scores, y_labels, y_true_probas)

    if max_samples is not None:
        y_scores = y_scores[:max_samples]
        y_true_probas = y_true_probas[:max_samples]
        y_labels = y_labels[:max_samples]

    df = pd.DataFrame({
        'y_scores': y_scores,
        'y_true_probas': y_true_probas,
        'y_labels': y_labels,
        })
    g = sns.JointGrid(data=df, x='y_scores', y='y_true_probas', hue='y_labels')

    def scatter_with_mv(x, y, hue, missing):
        ax = plt.gca()
        style = pd.Series('Complete', index=x[~missing].index)
        sns.scatterplot(x=x[~missing], y=y[~missing], hue=hue, alpha=1, ax=ax,
                        style=style, style_order=['Complete', 'Missing'])
        sns.scatterplot(x=x[missing], y=y[missing], hue=hue, alpha=1, ax=ax,
                        legend=False, palette='pastel', marker='X')

    if samples_with_mv is not None:
        samples_with_mv = np.array(samples_with_mv)
        g.plot_joint(scatter_with_mv, missing=samples_with_mv)
    else:
        g.plot_joint(sns.scatterplot)
    g.plot_marginals(sns.kdeplot, fill=True, zorder=5)
    # g.plot_marginals(sns.histplot, kde=False, zorder=5)
    g.hue = False
    g.plot_marginals(sns.kdeplot, fill=False, color='black', palette=['gray'], lw=1)
    # g.plot_marginals(sns.histplot, fill=False, color='black', palette=['gray'], stat='density')
    g.set_axis_labels(xlabel='Confidence score', ylabel='True probability')
    ax = g.figure.axes[0]
    ax.legend(title='Label')

    for x in np.linspace(0, 1, n_bins+1):
        ax.axvline(x, lw=1, ls='--', color='grey', zorder=-1)

    j = 0
    for i in range(n_bins):
        if j >= len(prob_bins):
            break
        p = prob_bins[j]
        m = mean_bins[j]
        if i/n_bins <= m <= (i+1)/n_bins:
            ax.annotate(f'{int(100*p)}', xy=((i+.5)/n_bins, -.075), ha='center', va='center', color='grey', size='x-small')
            j += 1

    ax.plot([0, 1], [0, 1], ls='--', lw=1, color='black')
    ax.plot(mean_bins, prob_bins, marker='s', color='black', label='Calibration curve')

    rename = {
        'acc': 'Acc',
        'ece': 'ECE',
        'mce': 'MCE',
        'brier': 'Brier',
        'auroc': 'AUC',
        'mse': 'MSE',
        'rmsce': 'RMSCE',
        'acc_bayes': 'Bayes',
        'brier_bayes': 'Brier(B)',
        'nll': 'NLL',
        'kl': 'KL',
    }

    ax_right = g.figure.axes[2]
    for i, (name, val) in enumerate(metrics.items()):
        ax_right.annotate(f'{rename.get(name, name)}: {val:.3g}',
                          xy=(0.04, 1.2-0.02*i), ha='left', va='center',
                          color='grey', size='x-small', xycoords='axes fraction',
                          )

    ax.legend(loc=legend_loc, ncol=ncol)
    if lim_margin is not None:
        ax.set_xlim((-lim_margin, 1 + lim_margin))
        ax.set_ylim((-lim_margin, 1 + lim_margin))
    return ax.figure


def plot_score_vs_probas2(y_scores, y_labels, y_true_probas, n_bins=15,
                         samples_with_mv=None, legend_loc='best', ncol=3,
                         lim_margin=0.15, max_samples=None, grid_space=0.2,
                         height=6, plot_first_last_bins=True):
    """Plot confidence score outputted by a classifier versus the
    true probability of the samples and color points according to their label.
    """
    set_latex_font()

    y_scores = np.array(y_scores)
    y_true_probas = np.array(y_true_probas)
    y_labels = np.array(y_labels)

    prob_bins, mean_bins = sklearn_calibration_curve(y_labels, y_scores, n_bins=n_bins)
    # metrics = compute_classif_metrics(y_scores, y_labels, y_true_probas)

    if max_samples is not None:
        y_scores = y_scores[:max_samples]
        y_true_probas = y_true_probas[:max_samples]
        y_labels = y_labels[:max_samples]

    _y_labels = np.full(y_labels.shape, 'Negative')
    _y_labels[y_labels == 1] = 'Positive'
    hue_order = ['Negative', 'Positive']
    df = pd.DataFrame({
        'y_scores': y_scores,
        'y_true_probas': y_true_probas,
        'y_labels': y_labels,
        '_y_labels': _y_labels,
        })
    g = sns.JointGrid(data=df, x='y_scores', y='y_true_probas', hue='_y_labels',
                      ratio=10, space=grid_space, height=height, hue_order=hue_order)

    def scatter_with_mv(x, y, hue, missing):
        ax = plt.gca()
        style = pd.Series('Complete', index=x[~missing].index)
        sns.scatterplot(x=x[~missing], y=y[~missing], hue=hue, alpha=1, ax=ax,
                        style=style, style_order=['Complete', 'Missing'])
        sns.scatterplot(x=x[missing], y=y[missing], hue=hue, alpha=1, ax=ax,
                        legend=False, palette='pastel', marker='X')

    if samples_with_mv is not None:
        samples_with_mv = np.array(samples_with_mv)
        g.plot_joint(scatter_with_mv, missing=samples_with_mv)
    else:
        g.plot_joint(sns.scatterplot, s=15)

    bins = np.linspace(0, 1, n_bins + 1)
    def histplot_with_size(x, vertical, hue):
        color_hue = np.ones_like(x)
        if vertical:
            x, y = None, x
        else:
            x, y = x, None
        sns.histplot(
            x=x,
            y=y,
            #weights=hue,
            # hue=color_hue,
            # palette='flare',
            bins=list(bins),
            legend=False,
            color='silver'
        )

    g.plot_marginals(histplot_with_size)

    # g.plot_marginals(sns.kdeplot, fill=True, zorder=5)
    # # g.plot_marginals(sns.histplot, kde=False, zorder=5)
    # g.hue = False
    # g.plot_marginals(sns.kdeplot, fill=False, color='black', palette=['gray'], lw=1)
    # g.plot_marginals(sns.histplot, fill=False, color='black', palette=['gray'], stat='density')
    g.set_axis_labels(xlabel='Confidence score', ylabel='True probability')
    ax = g.figure.axes[0]
    ax.legend(title='Label')

    if not plot_first_last_bins:
        bins = bins[1:-1]
    for x in bins:
        ax.axvline(x, lw=0.5, ls='--', color='grey', zorder=-1)

    ax.plot([0, 1], [0, 1], ls='--', lw=1, color='black')
    ax.plot(mean_bins, prob_bins, marker='.', markersize=5, color='black')#, label='Calibration curve')

    ax.legend(loc=legend_loc, ncol=ncol)
    if lim_margin is not None:
        ax.set_xlim((-lim_margin, 1 + lim_margin))
        ax.set_ylim((-lim_margin, 1 + lim_margin))
    return ax.figure


def plot_calibration_curves(y_scores_list, y_labels_list, y_true_probas=None,
                            methods=None, n_bins=15):
    y_labels = y_labels_list[0]  # should ne equal_
    if y_true_probas is not None:
        g = sns.JointGrid(x=np.zeros_like(y_true_probas), y=y_true_probas, hue=y_labels)
        g.plot_marginals(sns.kdeplot, fill=True, zorder=5)
        g.hue = False
        g.plot_marginals(sns.kdeplot, fill=False, color='black', palette=['gray'], lw=1)
        ax = g.figure.axes[0]
    else:
        plt.figure()
        ax = plt.gca()
    if methods is None:
        methods = [None]*len(y_scores_list)

    ax.plot([0, 1], [0, 1], ls='--', lw=1, color='black')
    for y_scores, y_labels, method in zip(y_scores_list, y_labels_list, methods):
        prob_bins, mean_bins = sklearn_calibration_curve(y_labels, y_scores, n_bins=n_bins)
        metrics = compute_classif_metrics(y_scores, y_labels, y_true_probas)
        ax.plot(mean_bins, prob_bins, marker='s', label=f"{method} (ECE={metrics['ece']:.2g})")

    ax.legend(title='Method')
    ax.set(xlabel='Confidence scores', ylabel='Fraction of positives')


def insert_nan_at_discontinuities(X, Y, min_gap=0.1):
    """Look at discontinuities in Y and add nan at their position for
    discontinuous plotting."""
    X, Y = X.copy(), Y.copy()
    pos = np.where(np.abs(np.diff(Y)) > min_gap)[0]+1
    if len(pos) > 0:
        X = np.insert(X, pos, np.nan)
        Y = np.insert(Y, pos, np.nan)
    return X, Y


def plot_ffstar_1d(f, f_star, p, x_min, x_max, disc_gap=0.1, figsize=(2, 2),
                   bbox_to_anchor=(1, 0), loc='lower right', lw=1, frameon=True):
    set_latex_font()

    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[1, 10], hspace=0)
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    plt.subplots_adjust(hspace=0.075)
    XX = np.linspace(x_min, x_max, len(p))
    ax1.plot(XX, p, label='$\\mathbb{P}_X$', lw=0.5)
    ax1.fill_between(XX, p, alpha=0.2)
    ax1.get_yaxis().set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.set_xticklabels([])
    ylim = ax1.get_ylim()
    ax1.set_ylim((0, ylim[1]))

    ax2.axhline(0.5, color='black', lw=0.5)
    XX = np.linspace(x_min, x_max, len(f))
    XX, YY = insert_nan_at_discontinuities(XX, f, min_gap=disc_gap)
    ax2.plot(XX, YY, label='$S(X)$', color='black', lw=lw)
    XX = np.linspace(x_min, x_max, len(f_star))
    YY = f_star
    # XX, YY = insert_nan_at_discontinuities(XX, f_star, min_gap=disc_gap)
    ax2.plot(XX, YY, label='$Q(X)$', color='tab:red', ls='-', lw=lw)
    ax2.set_xlabel('$X$')
    ax2.set_yticks([0, 0.5, 1])
    ax2.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax2.legend(handles=h2+h1, labels=l2+l1, bbox_to_anchor=bbox_to_anchor,
               loc=loc, frameon=frameon)
    ax2.spines['top'].set_visible(False)

    return fig


def plot_ffstar_2d(f, phi, psi=None, delta=None, delta_max=None, w=None, w_orth=None, mean=None, cov=None,
                   w_learned=None):
    set_latex_font()

    h = 100
    x_min = -2
    x_max = 2
    y_min = -2
    y_max = 2
    XX0, YY0 = np.meshgrid(np.linspace(x_min, x_max, h),
                           np.linspace(y_min, y_max, h))

    cm = plt.cm.RdBu_r
    # cm = plt.cm.coolwarm
    # cm = plt.cm.RdYlGn_r

    fig1 = plt.figure(figsize=(4.5, 6))
    gs = gridspec.GridSpec(2, 1, figure=fig1, height_ratios=[2, 1])
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    # ax = plt.gca()
    Z1 = f(np.c_[XX0.ravel(), YY0.ravel()])
    Z1 = Z1.reshape(XX0.shape)
    crf = ax1.contourf(XX0, YY0, Z1, levels=np.linspace(0, 1, 101), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = fig1.colorbar(crf, ax=ax1)
    cbar.ax.set_title(r'$f(X)$')
    cbar.set_ticks([0, 0.5, 1])
    cbar.ax.set_yticklabels(['0', r'$\frac{1}{2}$', '1'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w[0]/2, w[1]/2, r'$w$', va='top', ha='left')
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w, 0)
        ax1.plot(X1, X2, color='black', linestyle=':', label=r'$f(X) = \frac{1}{2}$')

    # if w_orth is not None:
    #     ax1.annotate("", xy=(w_orth[0], w_orth[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
    #     ax1.text(w_orth[0]/2, w_orth[1]/2, r'$w_{\perp}$', va='top', ha='right')

    if mean is not None and cov is not None:
        plot_covariance2D(mean, cov, ax1, n_std=1, edgecolor='black', linestyle='--', label=r'$\Sigma$')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, 0, x_max])
    ax1.set_yticks([y_min, 0, y_max])
    ax1.set_title(r'$f(X) := \varphi(w^TX)$')
    ax1.set_aspect('equal')
    ax1.legend()

    X1_min = -2
    X1_max = 2
    X1 = np.linspace(X1_min, X1_max, 500)
    Y1 = phi(X1)
    ax2.plot(X1, Y1, label=r'$\varphi(w^TX)$')
    ax2.set_xlabel(r'$w^TX$')
    ax2.set_xticks([X1_min, 0, X1_max])
    ax2.set_yticks([0, 0.5, 1])
    ax2.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
    ax2.axhline(0.5, color='black', lw=0.5)
    ax2.legend()


    if psi is None or delta is None or delta_max is None:
        return fig1


    fig2 = plt.figure(figsize=(4.5, 8))
    gs = gridspec.GridSpec(3, 1, figure=fig2, height_ratios=[2.4, 1, 1], hspace=0.3)
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax3 = plt.subplot(gs[2])
    Z2 = delta(np.c_[XX0.ravel(), YY0.ravel()])
    Z2 = Z2.reshape(XX0.shape)
    Z2_max = np.max(Z2)
    Z2_min = np.min(Z2)
    Z2_lim = max(np.abs(Z2_max), np.abs(Z2_min))
    crf = ax1.contourf(XX0, YY0, Z2, levels=np.linspace(-Z2_lim, Z2_lim, 101), cmap=cm, alpha=.8, vmin=-Z2_lim, vmax=Z2_lim)
    cbar = fig2.colorbar(crf, ax=ax1)
    cbar.ax.set_title(r'$\Delta(X)$')
    cbar.set_ticks([-Z2_lim, 0, Z2_lim])
    cbar.ax.set_yticklabels([f'{-Z2_lim:.2g}', '0', f'{Z2_lim:.2g}'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w[0]/2, w[1]/2, r'$w$', va='top', ha='left')

    if w_orth is not None:
        ax1.annotate("", xy=(w_orth[0], w_orth[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w_orth[0]/2, w_orth[1]/2, r'$w_{\perp}$', va='top', ha='right')
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w_orth, 0)
        ax1.plot(X1, X2, color='black', linestyle=':', label=r'$w_{\perp}^TX = 0$')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, 0, x_max])
    ax1.set_yticks([y_min, 0, y_max])
    ax1.set_title(r'$\Delta(X) := \psi(w_{\perp}^TX)\Delta_{max}(w^TX)$')
    ax1.set_aspect('equal')
    ax1.legend()

    X2_min = -10
    X2_max = 10
    X2 = np.linspace(X2_min, X2_max, 501)
    Y2 = psi(X2)
    ax2.plot(X2, Y2, label=r'$\psi(w_{\perp}^TX)$', color='tab:orange')
    ax2.set_xlabel(r'$w_{\perp}^TX$')
    ax2.set_xticks([X2_min, 0, X2_max])
    ax2.set_yticks([-1, 0, 1])
    ax2.axhline(0, color='black', lw=0.5)
    ax2.legend()

    X2_min = -10
    X2_max = 10
    X2 = np.linspace(X2_min, X2_max, 501)
    Y2 = delta_max(X2)
    ax3.plot(X2, Y2, label=r'$\Delta_{max}(w^TX)$', color='tab:green')
    ax3.set_xlabel(r'$w^TX$')
    ax3.set_xticks([X2_min, 0, X2_max])
    ax3.set_yticks([0, Z2_lim])
    ax3.set_yticklabels(['0', f'{Z2_lim:.2g}'])
    ax3.legend()

    fig3 = plt.figure(figsize=(4.5, 6))
    gs = gridspec.GridSpec(2, 1, figure=fig3, height_ratios=[2, 1])
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    # ax = plt.gca()
    # Z = f(np.c_[XX0.ravel(), YY0.ravel()])
    # Z = Z.reshape(XX0.shape)
    Z3 = Z1 + Z2
    crf = ax1.contourf(XX0, YY0, Z3, levels=np.linspace(0, 1, 101), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = fig3.colorbar(crf, ax=ax1)
    cbar.ax.set_title(r'$f^{\star}(X)$')
    cbar.set_ticks([0, 0.5, 1])
    cbar.ax.set_yticklabels(['0', r'$\frac{1}{2}$', '1'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w[0]/2, w[1]/2, r'$w$', va='top', ha='left')
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w, 0)
        ax1.plot(X1, X2, color='black', linestyle=':', label=r'$f(X) = \frac{1}{2}$')

    if w_orth is not None:
        ax1.annotate("", xy=(w_orth[0], w_orth[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w_orth[0]/2, w_orth[1]/2, r'$w_{\perp}$', va='bottom', ha='left')

    # if mean is not None and cov is not None:
    #     plot_covariance2D(mean, cov, ax, n_std=1, edgecolor='black', linestyle='--', label='$\Sigma$')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, 0, x_max])
    ax1.set_yticks([y_min, 0, y_max])
    ax1.set_title(r'$f^{\star}(X) := f(X) + \Delta(X)$')
    ax1.set_aspect('equal')
    ax1.legend()

    # X1_min = -10
    # X1_max = 10
    # X1 = np.linspace(X1_min, X1_max, 100)
    # Y1 = phi(X1)
    # ax2.plot(X1, Y1, label=r'$\varphi(w^TX) = f(X)$')

    X1_min = -3
    X1_max = 3
    X2_min = -10
    X2_max = 10
    X1 = np.linspace(X1_min, X1_max, 8)
    X2 = np.linspace(X2_min, X2_max, 501)
    for i, x in enumerate(X1):
        y = phi(x)
        color = cm(y)
        # ax2.axhline(y, ls=':', color=color)
        ax2.plot([X2_min, X2_max], [y, y], ls=':', lw=0.75, color=color)
        _delta_max = delta_max(x)
        Y2 = y + np.multiply(psi(X2), _delta_max)
        ax2.plot(X2, Y2, color=color)
        if i == 0:
            ax2.plot([], [], color='black', ls=':', label=r'$f(X)$')
            ax2.plot([], [], color='black', label=r'$f^{\star}(X)$')

    ax2.set_xlabel(r'$w_{\perp}^TX$')
    ax2.set_xticks([X2_min, 0, X2_max])
    ax2.set_yticks([0, 0.5, 1])
    ax2.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
    ax2.axhline(0.5, color='black', lw=0.5)
    ax2.legend()

    fig4 = plt.figure(figsize=(4.5, 6))
    gs = gridspec.GridSpec(2, 1, figure=fig4, height_ratios=[2, 1])
    ax1 = plt.subplot(gs[0])
    Z3 = Z1 + Z2
    crf = ax1.contourf(XX0, YY0, Z3, levels=np.linspace(0, 1, 101), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = fig3.colorbar(crf, ax=ax1)
    cbar.ax.set_title(r'$f^{\star}(X)$')
    cbar.set_ticks([0, 0.5, 1])
    cbar.ax.set_yticklabels(['0', r'$\frac{1}{2}$', '1'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w[0]/2, w[1]/2, r'$w$', va='top', ha='left')
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w, 0)
        ax1.plot(X1, X2, color='black', linestyle=':', label=r'$f(X) = \frac{1}{2}$')

    if w_learned is not None:
        label = 'Learned'
        if w_learned[2] != 0:
            X1 = np.linspace(x_min, x_max, 100)
            X2 = separating_line2D(X1, w_learned[1:], w_learned[0])
            p_line, = ax1.plot(X1, X2, color='tab:orange', linestyle=':', label=label, zorder=11)  # label=r'$\langle X, \beta^{\star} \rangle + \beta_0^{\star} = 0$')
        else:
            ax1.axvline(0, color='tab:orange', linestyle=':', label=label, zorder=11)
        # ax1.arrow(0, 0, w_learned[1], w_learned[2], color='tab:orange', length_includes_head=True, zorder=11)
        ax1.annotate("", xy=(w_learned[1], w_learned[2]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='tab:orange', mutation_scale=15), zorder=11)
        ax1.text(w_learned[1]/2, w_learned[2]/2, r'$\beta$', va='bottom', ha='right', color='tab:orange')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, 0, x_max])
    ax1.set_yticks([y_min, 0, y_max])
    ax1.set_title(r'$f^{\star}(X) := f(X) + \Delta(X)$')
    ax1.set_aspect('equal')
    ax1.legend()

    return fig1, fig2, fig3, fig4


def plot_ffstar_2d_v2(f, phi, psi=None, delta=None, delta_max=None, w=None, w_orth=None, mean=None, cov=None,
                   w_learned=None, trim=False, figsize=(4.5, 6)):
    set_latex_font()

    plot_y_label = True
    fontsize_xcal = 9
    fontsize = 12
    h = 100
    x_min = -2
    x_max = 2
    y_min = -2
    y_max = 2
    XX0, YY0 = np.meshgrid(np.linspace(x_min, x_max, h),
                           np.linspace(y_min, y_max, h))

    cm = plt.cm.RdBu_r
    # cm = plt.cm.coolwarm
    # cm = plt.cm.RdYlGn_r

    if trim:
        fig1, ax1 = plt.subplots(1, 1, figsize=figsize)
        # ax1 = axes[0]
        ax2 = None
    else:
        fig1 = plt.figure(figsize=figsize)
        gs = gridspec.GridSpec(2, 1, figure=fig1, height_ratios=[2, 1])
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1])
    # ax = plt.gca()

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0),
                    # arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20),
                    arrowprops=dict(arrowstyle="-|>",
                                    shrinkB=0,
                                    patchB=None,
                                    patchA=None,
                                    shrinkA=0,
                                    color='black',
                                    # connectionstyle="arc3,rad=.22",
                                    ),
                )
        ax1.text(w[0]/2, w[1]/2, r'$\omega$', va='top', ha='left')
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w, 0)
        ax1.plot(X1, X2, color='black', linestyle='-', lw=0.5, label=r'$S = \frac{1}{2}$')

    # if w_orth is not None:
    #     ax1.annotate("", xy=(w_orth[0], w_orth[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
    #     ax1.text(w_orth[0]/2, w_orth[1]/2, r'$w_{\perp}$', va='top', ha='right')

    if mean is not None and cov is not None:
        plot_covariance2D(mean, cov, ax1, n_std=0.7, edgecolor='black', lw=0.5, linestyle=':', label=r'$\Sigma$', zorder=3)

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, x_max])
    ax1.set_yticks([y_min, y_max])
    ax1.set_title(r'$S(X) = \varphi(\omega^TX)$')
    ax1.set_aspect('equal')
    ax1.legend(bbox_to_anchor=(0, 0), loc='lower left', ncol=2)

    d = 0.04
    ax1.annotate('$X_1$', xy=(0.5, -d), xytext=(0.5, -d),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['axes.labelsize'],
                )
    d = 0.02
    ax1.annotate('$X_2$', xy=(-d, 0.5), xytext=(-d, 0.5),
                xycoords='axes fraction', ha='right', va='center',
                fontsize=plt.rcParams['axes.labelsize'],
                rotation=90,
                )

    # Plot colorbar
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    Z1 = f(np.c_[XX0.ravel(), YY0.ravel()])
    Z1 = Z1.reshape(XX0.shape)
    crf = ax1.contourf(XX0, YY0, Z1, levels=np.linspace(0, 1, 101), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = plt.colorbar(crf, cax=cax)
    # cbar = fig1.colorbar(crf, ax=cax)
    # cbar = fig1.colorbar(crf, ax=ax1)
    cbar.ax.set_title(r'$S(X)$')
    cbar.ax.set_yticks([0, 0.5, 1])
    cbar.ax.set_yticklabels(['0', r'$\frac{1}{2}$', '1'])
    # cbar.ax.set_aspect(ax1.get_aspect())

    # Plot Xcal feature space in corner
    dxcal = 0.025 # 0.028
    x = dxcal
    y = 1 - dxcal
    ax1.annotate('    ', xy=(x, y), xycoords='axes fraction',
                bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize_xcal)
    ax1.annotate(r'$\mathcal{X}~$', xy=(x, y), xycoords='axes fraction',
                # bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize_xcal+3)

    if ax2 is not None:
        X1_min = -2
        X1_max = 2
        X1 = np.linspace(X1_min, X1_max, 500)
        Y1 = phi(X1)
        ax2.plot(X1, Y1, label=r'$\varphi(w^TX)$')
        ax2.set_xlabel(r'$w^TX$')
        ax2.set_xticks([X1_min, 0, X1_max])
        ax2.set_yticks([0, 0.5, 1])
        ax2.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
        ax2.axhline(0.5, color='black', lw=0.5)
        ax2.legend()


    if psi is None or delta is None or delta_max is None:
        return fig1

    if trim:
        fig2, ax1 = plt.subplots(1, 1, figsize=figsize)
        ax2 = None
        ax3 = None
    else:
        fig2 = plt.figure(figsize=(4.5, 8))
        gs = gridspec.GridSpec(3, 1, figure=fig2, height_ratios=[2.4, 1, 1], hspace=0.3)
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1])
        ax3 = plt.subplot(gs[2])

    Z2 = delta(np.c_[XX0.ravel(), YY0.ravel()])
    Z2 = Z2.reshape(XX0.shape)
    Z2_max = np.max(Z2)
    Z2_min = np.min(Z2)
    Z2_lim = max(np.abs(Z2_max), np.abs(Z2_min))
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    crf = ax1.contourf(XX0, YY0, Z2, levels=np.linspace(-Z2_lim, Z2_lim, 101), cmap=cm, alpha=.8, vmin=-Z2_lim, vmax=Z2_lim)
    cbar = fig2.colorbar(crf, cax=cax)
    cbar.ax.set_title(r'$\Delta(X)$')
    # cbar.ax.set_yticks([-0.5, 0, 0.5])
    cbar.set_ticks([-Z2_lim, 0, Z2_lim])
    cbar.ax.set_yticklabels([f'{-Z2_lim:.2g}', '0', f'{Z2_lim:.2g}'])
    # cbar.ax.set_yticklabels([r'$-\frac{1}{2}$', '0', r'$\frac{1}{2}$'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0),
                     arrowprops=dict(arrowstyle="-|>",
                                    shrinkB=0,
                                    patchB=None,
                                    patchA=None,
                                    shrinkA=0,
                                    color='black',
                                    # connectionstyle="arc3,rad=.22",
                        ),#arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
                    )
        ax1.text(w[0]/2, w[1]/2, r'$\omega$', va='top', ha='left', fontsize=fontsize)

    if w_orth is not None:
        ax1.annotate("", xy=(w_orth[0], w_orth[1]), xytext=(0, 0),
                    #  arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20),
                     arrowprops=dict(arrowstyle="-|>",
                                    shrinkB=0,
                                    patchB=None,
                                    patchA=None,
                                    shrinkA=0,
                                    color='black',
                                    # connectionstyle="arc3,rad=.22",
                                    ),
                     )
        ax1.text(w_orth[0]/2, w_orth[1]/2, r'$\omega_{\perp}$', va='top', ha='right', fontsize=fontsize)
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w_orth, 0)
        ax1.plot(X1, X2, color='black', linestyle='-', lw=0.5, label=r'$\omega_{\perp}^TX = 0$')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, x_max])
    ax1.set_yticks([y_min, y_max])
    if not plot_y_label:
        ax1.set_yticklabels(['', ''])

    d = 0.04
    ax1.annotate('$X_1$', xy=(0.5, -d), xytext=(0.5, -d),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['axes.labelsize'],
                )

    if plot_y_label:
        d = 0.02
        ax1.annotate('$X_2$', xy=(-d, 0.5), xytext=(-d, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )

    # ax1.set_title(r'$\Delta(X)$')
    # ax1.set_title(r'$\psi(\omega_{\perp}^TX)\Delta_{max}(\omega^TX)$')
    ax1.set_title(r'$\Delta(X) = \psi(\omega_{\perp}^TX)\Delta_{max}(X)$')
    ax1.set_aspect('equal')
    ax1.legend(bbox_to_anchor=(0, 0), loc='lower left')

    x = dxcal
    y = 1 - dxcal
    ax1.annotate('    ', xy=(x, y), xycoords='axes fraction',
                bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize_xcal)
    ax1.annotate(r'$\mathcal{X}~$', xy=(x, y), xycoords='axes fraction',
                # bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize_xcal+3)

    if ax2 is not None:
        X2_min = -10
        X2_max = 10
        X2 = np.linspace(X2_min, X2_max, 501)
        Y2 = psi(X2)
        ax2.plot(X2, Y2, label=r'$\psi(w_{\perp}^TX)$', color='tab:orange')
        ax2.set_xlabel(r'$w_{\perp}^TX$')
        ax2.set_xticks([X2_min, 0, X2_max])
        ax2.set_yticks([-1, 0, 1])
        ax2.axhline(0, color='black', lw=0.5)
        ax2.legend()

    if ax3 is not None:
        X2_min = -10
        X2_max = 10
        X2 = np.linspace(X2_min, X2_max, 501)
        Y2 = delta_max(X2)
        ax3.plot(X2, Y2, label=r'$\Delta_{max}(w^TX)$', color='tab:green')
        ax3.set_xlabel(r'$w^TX$')
        ax3.set_xticks([X2_min, 0, X2_max])
        ax3.set_yticks([0, Z2_lim])
        ax3.set_yticklabels(['0', f'{Z2_lim:.2g}'])
        ax3.legend()


    plot_y_label = False
    if trim:
        fig3, ax1 = plt.subplots(1, 1, figsize=figsize)
        # ax1 = axes[0]
        ax2 = None
    else:
        fig3 = plt.figure(figsize=(4.5, 6))
        gs = gridspec.GridSpec(2, 1, figure=fig3, height_ratios=[2, 1])
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1])
    # ax = plt.gca()
    # Z = f(np.c_[XX0.ravel(), YY0.ravel()])
    # Z = Z.reshape(XX0.shape)
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    Z3 = Z1 + Z2
    crf = ax1.contourf(XX0, YY0, Z3, levels=np.linspace(0, 1, 101), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = fig3.colorbar(crf, cax=cax)
    cbar.ax.set_title(r'$Q(X)$')
    cbar.set_ticks([0, 0.5, 1])
    cbar.ax.set_yticklabels(['0', r'$\frac{1}{2}$', '1'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0),
        #  arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20)
                     arrowprops=dict(arrowstyle="-|>",
                                    shrinkB=0,
                                    patchB=None,
                                    patchA=None,
                                    shrinkA=0,
                                    color='black',
                                    # connectionstyle="arc3,rad=.22",
                                    ),
         )
        ax1.text(w[0]/2, w[1]/2, r'$\omega$', va='top', ha='left', fontsize=fontsize)
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w, 0)
        ax1.plot(X1, X2, color='black', linestyle='-', lw=0.5, label=r'$S(X) = \frac{1}{2}$')

    if w_orth is not None:
        ax1.annotate("", xy=(w_orth[0], w_orth[1]), xytext=(0, 0),
                     arrowprops=dict(arrowstyle="-|>",
                                    shrinkB=0,
                                    patchB=None,
                                    patchA=None,
                                    shrinkA=0,
                                    color='black',
                                    # connectionstyle="arc3,rad=.22",
                                    ),
        # arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20),
        )
        ax1.text(w_orth[0]/2, w_orth[1]/2, r'$\omega_{\perp}$', va='top', ha='right', fontsize=fontsize)

    # if mean is not None and cov is not None:
    #     plot_covariance2D(mean, cov, ax, n_std=1, edgecolor='black', linestyle='--', label='$\Sigma$')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, x_max])
    ax1.set_yticks([y_min, y_max])
    if not plot_y_label:
        ax1.set_yticklabels(['', ''])

    d = 0.04
    ax1.annotate('$X_1$', xy=(0.5, -d), xytext=(0.5, -d),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['axes.labelsize'],
                )
    if plot_y_label:
        d = 0.02
        ax1.annotate('$X_2$', xy=(-d, 0.5), xytext=(-d, 0.5),
                    xycoords='axes fraction', ha='right', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=90,
                    )
    ax1.set_title(r'$Q(X) = S(X) + \Delta(X)$')
    ax1.set_aspect('equal')
    ax1.legend(bbox_to_anchor=(0, 0), loc='lower left')

    x = dxcal
    y = 1 - dxcal
    ax1.annotate('    ', xy=(x, y), xycoords='axes fraction',
                bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize_xcal)
    ax1.annotate(r'$\mathcal{X}~$', xy=(x, y), xycoords='axes fraction',
                # bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize_xcal+3)

    # X1_min = -10
    # X1_max = 10
    # X1 = np.linspace(X1_min, X1_max, 100)
    # Y1 = phi(X1)
    # ax2.plot(X1, Y1, label=r'$\varphi(w^TX) = f(X)$')

    if ax2 is not None:
        X1_min = -3
        X1_max = 3
        X2_min = -10
        X2_max = 10
        X1 = np.linspace(X1_min, X1_max, 8)
        X2 = np.linspace(X2_min, X2_max, 501)
        for i, x in enumerate(X1):
            y = phi(x)
            color = cm(y)
            # ax2.axhline(y, ls=':', color=color)
            ax2.plot([X2_min, X2_max], [y, y], ls=':', lw=0.75, color=color)
            _delta_max = delta_max(x)
            Y2 = y + np.multiply(psi(X2), _delta_max)
            ax2.plot(X2, Y2, color=color)
            if i == 0:
                ax2.plot([], [], color='black', ls=':', label=r'$f(X)$')
                ax2.plot([], [], color='black', label=r'$f^{\star}(X)$')

        ax2.set_xlabel(r'$w_{\perp}^TX$')
        ax2.set_xticks([X2_min, 0, X2_max])
        ax2.set_yticks([0, 0.5, 1])
        ax2.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
        ax2.axhline(0.5, color='black', lw=0.5)
        ax2.legend()

    fig4 = plt.figure(figsize=(4.5, 6))
    gs = gridspec.GridSpec(2, 1, figure=fig4, height_ratios=[2, 1])
    ax1 = plt.subplot(gs[0])
    Z3 = Z1 + Z2
    crf = ax1.contourf(XX0, YY0, Z3, levels=np.linspace(0, 1, 101), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = fig3.colorbar(crf, ax=ax1)
    cbar.ax.set_title(r'$f^{\star}(X)$')
    cbar.set_ticks([0, 0.5, 1])
    cbar.ax.set_yticklabels(['0', r'$\frac{1}{2}$', '1'])

    if w is not None:
        ax1.annotate("", xy=(w[0], w[1]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='black', mutation_scale=20))
        ax1.text(w[0]/2, w[1]/2, r'$w$', va='top', ha='left')
        X1 = np.linspace(x_min, x_max, 2)
        X2 = separating_line2D(X1, w, 0)
        ax1.plot(X1, X2, color='black', linestyle=':', label=r'$f(X) = \frac{1}{2}$')

    if w_learned is not None:
        label = 'Learned'
        if w_learned[2] != 0:
            X1 = np.linspace(x_min, x_max, 100)
            X2 = separating_line2D(X1, w_learned[1:], w_learned[0])
            p_line, = ax1.plot(X1, X2, color='tab:orange', linestyle=':', label=label, zorder=11)  # label=r'$\langle X, \beta^{\star} \rangle + \beta_0^{\star} = 0$')
        else:
            ax1.axvline(0, color='tab:orange', linestyle=':', label=label, zorder=11)
        # ax1.arrow(0, 0, w_learned[1], w_learned[2], color='tab:orange', length_includes_head=True, zorder=11)
        ax1.annotate("", xy=(w_learned[1], w_learned[2]), xytext=(0, 0), arrowprops=dict(arrowstyle='-|>', color='tab:orange', mutation_scale=15), zorder=11)
        ax1.text(w_learned[1]/2, w_learned[2]/2, r'$\beta$', va='bottom', ha='right', color='tab:orange')

    ax1.set_xlim((x_min, x_max))
    ax1.set_ylim((y_min, y_max))
    ax1.set_xticks([x_min, 0, x_max])
    ax1.set_yticks([y_min, 0, y_max])
    ax1.set_title(r'$f^{\star}(X) := f(X) + \Delta(X)$')
    ax1.set_aspect('equal')
    ax1.legend()

    return fig1, fig2, fig3, fig4


def plot_example_metrics(df, x='mse', y='brier'):
    set_latex_font()
    # palette = ['tab:blue', 'tab:orange']
    # hue_order = ['cal', 'cal+acc']
    plt.figure()
    sns.scatterplot(data=df, x=x, y=y, hue='kind', style='dist')#, palette=palette)
    ax = plt.gca()
    ax.set_yscale('log')
    ax.set_xscale('log')
    for name, row in df.iterrows():
        ax.annotate(name, xy=(row[x], row[y]), xytext=(row[x]*1.01, row[y]*1.01),
        ha='center', va='bottom', fontsize='xx-small', rotation='vertical', color='gray', zorder=-1)


def _love_plot(df, aspect=1):
    df = df.reset_index()
    df = pd.melt(df, id_vars=['kind'], var_name='feature')

    palette = ['#F7756C', '#0ABDC3']
    g = sns.catplot(data=df, x='value', y='feature', hue='kind', aspect=aspect,
                    jitter=False, height=3, legend_out=True, palette=palette)

    ax = plt.gca()
    ax.set_xlabel('$E[X_i|Y=1] - E[X_i|Y=0]$')
    ax.set_ylabel('Feature')
    # x_min, x_max = ax.get_xlim()
    # x_lim = max(abs(x_min), abs(x_max))
    x_lim = df['value'].abs().max()*1.1
    ax.set_xlim((-x_lim, x_lim))
    ax.set_xscale('symlog', linthresh=1e-1)
    ax.axvline(0, lw=0.5, color='gray')

    # Format legend
    legend = g._legend
    legend.set_title('Sample')
    for t in g._legend.texts:
        t.set_text(t.get_text().capitalize())

    return ax


def love_plot(X, y, p, plot=True):
    """Create love plot.

    Parameters
    ----------
        X : (n, d) array
        y : (n,) array
        p : (n,) array

    """
    set_latex_font()

    weights = np.divide(1, p)
    weights[y == 0] = np.divide(1, 1 - p[y == 0])

    def X_to_diff(X, y, weights=None):
        if X.ndim == 1:
            X = X.reshape(-1, 1)
        if weights is not None:
            X = np.multiply(X, weights[:, None])
        df = pd.DataFrame(X)
        df.columns = [f'X{int(c)+1}' for c in df.columns]
        df['y'] = y
        df = df.groupby('y').agg('mean')
        df = df.iloc[1] - df.iloc[0]
        if weights is None:
            df['kind'] = 'unadjusted'
        else:
            df['kind'] = 'adjusted'
        df = df.to_frame().T
        df = df.set_index('kind')
        return df

    diff = X_to_diff(X, y)
    diff_adjusted = X_to_diff(X, y, weights)
    diff = pd.concat([diff, diff_adjusted], axis=0)
    print(diff)

    if plot:
        _love_plot(diff)

    return diff


def heterogeneity_plot_wx(X, y_fraction, P, y_true_probas=None, y_scores=None,
                          n_var_max=None):
    set_latex_font()

    fig = plt.figure()
    ax = plt.gca()

    y_var = np.var(y_fraction, axis=2)
    sorted_idx = np.argsort(-y_var, axis=0)

    def create_color_iter():
        i = 0
        while True:
            yield f'C{i}'
            i += 1

    for j in range(X.shape[1]):
        color_iter = create_color_iter()

        if n_var_max is None:
            w_perp_dims = range(X.shape[0])
        else:
            w_perp_dims = sorted_idx[:n_var_max, j]

        for i in w_perp_dims:

            color = next(color_iter)
            label = f'$\\omega^{{({i+1})}}_{{\\perp}}$' if j == 0 else None
            _x = np.dot(X[i, j, :, :], P[:, i+1])
            _y = y_fraction[i, j, :]

            ax.plot(_x, _y, label=label, color=color, zorder=1)
            ax.axhline(np.mean(_y), ls=':', color=color, zorder=0)
            if y_true_probas is not None:
                ax.plot(_x, y_true_probas[i, j, :], color='black', zorder=0)
            if y_scores is not None:
                ax.plot(_x, y_scores[i, j, :], ls=':', color='black', zorder=0)

    if X.shape[0] < 10 or (n_var_max is not None and n_var_max <= 10):
        ax.legend(ncol=5, loc='upper center', bbox_to_anchor=[0.5, 1.2])

    ax.set_xlabel(r'${\omega_{\perp}^{(i)}}^Tx$')
    ax.set_ylabel('Local fraction of positives')
    ax.locator_params(axis='x', nbins=3)

    return fig


def plot_2D_softmax(W, B):
    """
    Parameters
    ----------
    W : (k, 2) array
    B : (k,) array
    """
    assert W.shape[1] == 2


    fig = plt.figure()
    ax = plt.gca()


    xmin = -100
    xmax = 100

    X = np.linspace(xmin, xmax, 2)
    # exit()
    # XX = np.meshgrid()
    XX1, XX2 = np.meshgrid(np.linspace(xmin, xmax, 300),
                           np.linspace(xmin, xmax, 300))
    XX = np.c_[XX1.ravel(), XX2.ravel()]

    K = W.shape[0]

    # palette = sns.color_palette("hls", K)
    # palette = sns.color_palette(None, K)
    palette = sns.color_palette('muted', K)

    palette_iter = itertools.cycle(palette)

    for i in range(K):
        w = W[i, :]
        w0 = W[i, 0]
        w1 = W[i, 1]
        w_norm = np.square(np.linalg.norm(W[i, :]))
        b = B[i]
        xy_start = -b*w/w_norm
        xy_end = xy_start + w
        color = next(palette_iter) # f'C{i}'
        # ax.annotate("", xy=xy_end, xytext=xy_start,
        #             arrowprops=dict(arrowstyle='-|>', color=color,
        #             mutation_scale=15))
        ax.arrow(xy_start[0], xy_start[1], dx=w[0], dy=w[1], color=color,
                 head_width=0.03, length_includes_head=True, zorder=1)

        # Y = -w0/w1*X - 1/w1*b

        # ax.plot(X, Y, ls='-', alpha=1, lw=1, color=color, zorder=1)#, path_effects=[pe.Stroke(linewidth=1.5, foreground='white'), pe.Normal()])

        plot_orthogonal_line(ax, w, b, xmin=xmin, xmax=xmax, ls='-', alpha=1, lw=1, color=color, zorder=1)

    for i in range(K):
        for j in range(i+1, K):
            # if (i, j) in [(1, 2), (0, 1)]:
            if True:
                plot_orthogonal_line(ax, W[i, :] - W[j, :], B[i] - B[j], xmin=xmin, xmax=xmax, ls='-', alpha=1, lw=1, color='black', zorder=1)

    palette = sns.color_palette('pastel', K)
    # exit()

    # print(W.shape)
    # print(XX.shape)
    # a = np.inner(XX, W) + B[None, :]
    # print(a.shape)
    linear = torch.nn.Linear(2, K)
    print('weight', linear.weight.shape)
    print('bias', linear.bias.shape)
    linear.weight = torch.nn.Parameter(torch.from_numpy(W))
    linear.bias = torch.nn.Parameter(torch.from_numpy(B))

    y_logits = linear(torch.from_numpy(XX)).detach()
    print(y_logits.shape)

    # print(a)
    # print(y_logits)
    # y_scores = torch.nn.functional.softmax(torch.from_numpy(a), dim=1)
    y_scores = torch.nn.functional.softmax(y_logits, dim=1)
    y_labels = torch.argmax(y_scores, dim=1).detach()

    positive_class = 1
    # y_labels[y_labels != positive_class] = 1
    # y_labels = (~y_labels.bool()).int()


    print(y_labels)
    print(torch.unique(y_labels))
    n_labels = len(torch.unique(y_labels))

    # y_scores_max = torch.max(y_scores, dim=1)[0].detach()#.numpy()
    y_scores_max = y_scores[:, positive_class].detach()#.numpy()
    # y_scores_max = torch.special.expit(y_logits[:, positive_class])

    # cmap = (mpl.colors.ListedColormap(['red', 'green', 'blue', 'cyan'])
    # )
        # .with_extremes(over='0.25', under='0.75'))
    L = list(palette.as_hex())
    # L[0], L[1] = L[1], L[0]
    cmap = mpl.colors.ListedColormap(L)
    # cmap = sns.color_palette("hls", K, as_cmap=True)

    ZZ = y_labels.reshape(XX1.shape)
    crf = ax.contourf(XX1, XX2, ZZ, cmap=cmap, alpha=0.8, levels=np.arange(K+1)-0.5, zorder=0)#, levels=np.linspace(0, 1, 11), cmap=cm, alpha=.8, vmin=0, vmax=1)
    cbar = fig.colorbar(crf, ax=ax)
    cbar.ax.set_title('Label')
    cbar.ax.set_yticks(np.arange(n_labels))
    cbar.ax.set_ylim((-0.5, n_labels-0.5))

    # ZZ = y_scores_max.reshape(XX1.shape)
    # crf = ax.contour(XX1, XX2, ZZ, levels=4)#, levels=np.linspace(0, 1, 11), cmap=cm, alpha=.8, vmin=0, vmax=1)
    # ax.clabel(crf, inline=True, fontsize=10)

    # ZZ = y_scores_max.reshape(XX1.shape)
    # crf = ax.contourf(XX1, XX2, ZZ, alpha=0.3, levels=8)#, levels=np.linspace(0, 1, 11), cmap=cm, alpha=.8, vmin=0, vmax=1)
    # cbar = fig.colorbar(crf, ax=ax)
    # cbar.ax.set_title('Label')
    # cbar.ax.set_yticks(np.arange(K))

    # xlim = ax.get_xlim()
    # ylim = ax.get_ylim()


    # W @ X

    ax.set_xlim((xmin, xmax))
    ax.set_ylim((xmin, xmax))

    ax.locator_params(axis='x', nbins=3)
    ax.locator_params(axis='y', nbins=3)

    ax.set_aspect('equal')

    # np.c_[XX.ravel(), YY.ravel()]

    return fig


def plot_K_softmax():
    xmin = -10
    xmax = 10

    XX1, XX2, XX3 = np.meshgrid(np.linspace(xmin, xmax, 300),
                                np.linspace(xmin, xmax, 300), [-1])
    XX = np.c_[XX1.ravel(), XX2.ravel(), XX3.ravel()]

    print(XX.shape)
    ZZ = torch.nn.functional.softmax(torch.from_numpy(XX), dim=1)
    ZZ = ZZ.numpy()
    ZZ = ZZ[:, 0]
    print(ZZ.shape)

    XX1 = XX1[:, :, 0]
    XX2 = XX2[:, :, 0]

    ZZ = ZZ.reshape(XX1.shape)
    fig = plt.figure()
    ax = plt.gca()
    crf = ax.contourf(XX1, XX2, ZZ, levels=9)#, levels=np.linspace(0, 1, 11), cmap=cm, alpha=.8, vmin=0, vmax=1)
    ax.clabel(crf, inline=True, fontsize=10)
    ax.set_aspect('equal')

    return fig


def plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, y_scores=None,
                            y_labels=None, bins=None,
                            legend_loc='best', bbox_to_anchor=None, ncol=3,
                            xlim_margin=0.15, ylim_margin=0.15, title=None,
                            min_cluster_size=1, hist=False,
                            k_largest_variance=None,
                            k_largest_miscalibration=None,
                            ci=None, mean_only=False,
                            ax=None,
                            mean_label=None,
                            color_cycler=None,
                            xlabel='Confidence score',
                            ylabel='Fraction of positives',
                            plot_cluster_id=False,
                            legend_cluster_sizes=True,
                            legend_sizes_only=False,
                            vary_cluster_size=True,
                            capsize=2,
                            cluster_size=None,
                            absolute_size_scale=None,
                            plot_cal_hist=False,
                            figsize=None,
                            legend_n_sizes=None,
                            # legend_size=8,
                            legend_min_max=True,
                            plot_first_last_bins=True,
                            grid_space=0.2,
                            legend_title='Cluster sizes',
                            ):
    """Plot fraction of positives in clusters versus the mean scores assigned
    to the clusters, as well as the calibration curve.
    """
    set_latex_font()

    frac_pos = np.array(frac_pos)
    counts = np.array(counts)
    mean_scores = np.array(mean_scores)

    if k_largest_variance is not None and k_largest_miscalibration is not None:
        raise ValueError('Both k_largest_variance and k_largest_miscalibration'
                         ' should not be passed.')

    if k_largest_variance is not None:
        pass  # select the k classes that has the greatest variance
    # code in utils ?

    if k_largest_miscalibration is not None:
        pass  # select the k classes that has the greatest miscalibration

    if frac_pos.ndim >= 3:
        frac_pos = frac_pos.reshape(frac_pos.shape[0], -1)
    if counts.ndim >= 3:
        counts = counts.reshape(counts.shape[0], -1)
    if mean_scores.ndim >= 3:
        mean_scores = mean_scores.reshape(mean_scores.shape[0], -1)

    if frac_pos.shape != counts.shape:
        raise ValueError(f'Shape mismatch between frac_pos {frac_pos.shape} and counts {counts.shape}')

    if frac_pos.shape != mean_scores.shape:
        raise ValueError(f'Shape mismatch between frac_pos {frac_pos.shape} and mean_scores {mean_scores.shape}')

    available_ci = [None, 'clopper', 'binomtest']
    if ci not in available_ci:
        raise ValueError(f'Unkown CI {ci}. Availables: {available_ci}.')

    if hist and ax is not None:
        raise ValueError("Can't specify ax when hist=True.")

    handles = []
    labels = []

    if legend_cluster_sizes and not legend_sizes_only:
        dummy = mpl.patches.Rectangle((0, 0), 1, 1, fill=False, edgecolor='none',
                                    visible=False)
        handles.append(dummy)
        labels.append('')

    n_bins, n_clusters = frac_pos.shape
    # significant_edgecolor = 'black'
    significant_edgecolor = 'crimson'
    # significant_errcolor = 'black'
    # significant_errcolor = 'lightcoral'
    significant_errcolor = 'crimson'
    alpha_nonsignificant = 0.6

    if bins is None:
        bins = np.linspace(0, 1, n_bins + 1)

    prob_bins_na, mean_bins_na = calibration_curve(frac_pos, counts,
                                                   mean_scores,
                                                   remove_empty=False)

    # Remove empty bins
    non_empty = np.sum(counts, axis=1, dtype=float) > 0
    # Remove bins too small:
    big_enough = np.any(counts >= min_cluster_size, axis=1)
    prob_bins = prob_bins_na[non_empty & big_enough]
    mean_bins = mean_bins_na[non_empty & big_enough]

    if min_cluster_size is not None:
        idx_valid_clusters = counts.flatten() >= min_cluster_size
    else:
        idx_valid_clusters = counts.flatten() == counts.flatten()

    cluster_id = np.tile(np.arange(n_clusters), (n_bins, 1))

    df = pd.DataFrame({
        'y_scores': mean_scores.flatten(),
        'y_frac_pos': frac_pos.flatten(),
        'y_size': counts.flatten(),
        'y_prob_bins': np.tile(prob_bins_na, (frac_pos.shape[1], 1)).T.flatten(),
        'y_valid_clusters': idx_valid_clusters,
        'cluster_id': cluster_id.flatten(),
        })

    if hist:
        extra_kwargs = {}
        if figsize is not None:
            a, b = figsize
            if a != b:
                raise ValueError(f'Jointplot will be squared. Given {figsize}.')
            extra_kwargs['height'] = a

        g = sns.JointGrid(
            data=df,
            x='y_scores',
            y='y_frac_pos',
            hue='y_size',
            palette='flare',
            ratio=10,
            space=grid_space,
            **extra_kwargs
            )
        ax_top = g.figure.axes[1]
        ax = g.figure.axes[0]
        fig = g.fig
        ax_right = g.figure.axes[2]

    elif ax is None:
        fig = plt.figure(figsize=figsize)
        ax_top = plt.gca()
        ax = ax_top

    else:
        fig = ax.figure
        ax_top = ax

    if mean_only:
        if color_cycler is not None:
            ax.set_prop_cycle(color_cycler)
        cal_color = next(ax._get_lines.prop_cycler)['color']

    else:
        cal_color = 'black'

    # Significance
    if ci == 'clopper':
        if not mean_only:
            ci_idx = df['y_valid_clusters']
            ci_count = df['y_frac_pos'][ci_idx]*df['y_size'][ci_idx]
            ci_nobs = df['y_size'][ci_idx]
            ci_low, ci_upp = proportion_confint(
                count=ci_count,
                nobs=ci_nobs,
                alpha=0.05,
                method='beta',
                )
            ci_scores = df['y_scores'][ci_idx]
            ci_frac_pos = df['y_frac_pos'][ci_idx]
            ci_prob_bins = df['y_prob_bins'][ci_idx]

            # Significant cluster are the ones whose CI does not contain
            # the fraction of positive of the bin
            idx_significant = np.logical_or(ci_low > ci_prob_bins, ci_upp < ci_prob_bins)

            # ci_low and ci_upp are not error widths but CI bounds values
            # the errorbar function requires error widths.
            y_err_low = ci_frac_pos - ci_low
            y_err_upp = ci_upp - ci_frac_pos
            y_err = np.stack([y_err_low, y_err_upp], axis=0)

            p1 = ax.errorbar(
                x=ci_scores[~idx_significant],
                y=ci_frac_pos[~idx_significant],
                yerr=y_err[:, ~idx_significant],
                fmt='none',
                elinewidth=0.5,
                capsize=capsize,
                color='lightgray',
                alpha=alpha_nonsignificant,
                zorder=4,
            )
            p2 = ax.errorbar(
                x=ci_scores[idx_significant],
                y=ci_frac_pos[idx_significant],
                yerr=y_err[:, idx_significant],
                fmt='none',
                elinewidth=0.5,
                capsize=capsize,
                color=significant_errcolor,
                # alpha=0.6,
                zorder=4,
            )
            if not legend_sizes_only:
                if np.any(idx_significant):
                    handles.append(p2)
                else:
                    handles.append(p1)
                labels.append(r'$95\%$ conf. interval')

        else:
            count_bins = np.sum(counts.reshape(n_bins, -1), axis=1)[non_empty]
            ci_low, ci_upp = proportion_confint(
                count=prob_bins*count_bins,
                nobs=count_bins,
                alpha=0.05,
                method='beta',
                )
            ci_frac_pos = prob_bins
            y_err_low = ci_frac_pos - ci_low
            y_err_upp = ci_upp - ci_frac_pos
            y_err = np.stack([y_err_low, y_err_upp], axis=0)

            # ax.errorbar(
            #     x=mean_bins,
            #     y=prob_bins,
            #     yerr=y_err,
            #     fmt='none',
            #     elinewidth=0.5,
            #     capsize=2,
            #     color='black',
            #     # alpha=alpha_nonsignificant,
            # )

            ax.plot(mean_bins, ci_upp, color=cal_color, lw=0.5, zorder=1, alpha=0.5)
            ax.plot(mean_bins, ci_low, color=cal_color, lw=0.5, zorder=1, alpha=0.5)
            ax.fill_between(mean_bins, ci_low, ci_upp, color=cal_color,
                            zorder=1, alpha=0.05)

    elif ci == 'binomtest':
        pvalues = []

        ci_idx = df['y_valid_clusters']
        ci_scores = df['y_scores'][ci_idx]
        ci_frac_pos = df['y_frac_pos'][ci_idx]
        ci_prob_bins = df['y_prob_bins'][ci_idx]
        ci_count = df['y_frac_pos'][ci_idx]*df['y_size'][ci_idx]
        ci_nobs = df['y_size'][ci_idx]

        for i in range(len(ci_scores)):
            res = scipy.stats.binomtest(
                k=int(ci_count.iloc[i]),
                n=int(ci_nobs.iloc[i]),
                p=ci_prob_bins.iloc[i],
                )

            pvalues.append(res.pvalue)

        pvalues = pd.Series(pvalues, index=ci_scores.index)

        idx_significant = pvalues < 0.05

    else:
        idx_significant = None

    def scatter_with_size(x, y, hue, valid_clusters, significant_clusters, cluster_id):
        x = x[valid_clusters]
        y = y[valid_clusters]
        hue = hue[valid_clusters]

        min_scatter_size = 25
        max_scatter_size = 80

        cmap = sns.color_palette('flare', as_cmap=True)
        # print(cmap.colors)
        # cmap.colors = [(0.1*np.array(c)).tolist() for c in cmap.colors]
        # cmap = mpl.colors.ListedColormap([np.clip(1.1*np.array(c), 0, 1).tolist() for c in cmap.colors])
        # cmap = mpl.colors.ListedColormap([desaturate(c, 0.75) for c in cmap.colors])

        min_cluster_size = np.min(hue)
        max_cluster_size = np.max(hue)

        if absolute_size_scale is not None:
            vmin, vmax = absolute_size_scale
            if vmin is None:
                vmin = min_cluster_size
            if vmax is None:
                vmax = max_cluster_size
        else:
            vmin = min_cluster_size
            vmax = max_cluster_size

        if vary_cluster_size:
            size = hue
            sizes = (min_scatter_size, max_scatter_size)

            size_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        else:
            size = None
            sizes = None
            size_norm = None

        hue_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

        extra_kwargs = {}
        if cluster_size is not None:
            extra_kwargs['s'] = cluster_size

        g = sns.scatterplot(
            x=x,
            y=y,
            size=size,
            hue=hue,
            hue_norm=hue_norm,
            sizes=sizes,
            size_norm=size_norm,
            palette=cmap,
            zorder=5,
            # edgecolor='lightgray',
            alpha=alpha_nonsignificant if significant_clusters is not None else 1,
            legend="auto" if legend_cluster_sizes else False,
            **extra_kwargs
        )
        if significant_clusters is not None:
            g = sns.scatterplot(
                x=x[significant_clusters],
                y=y[significant_clusters],
                size=size[significant_clusters] if size is not None else None,
                hue=hue[significant_clusters],
                hue_norm=hue_norm,
                sizes=sizes,
                size_norm=size_norm,
                palette=sns.color_palette('flare', as_cmap=True),
                zorder=5,
                edgecolor=significant_edgecolor,
                linewidth=0.75,
                legend=False,
                **extra_kwargs
            )

        if legend_cluster_sizes:
            # Add minimum and maximum cluster sizes in legend handles and labels
            H, L = g.get_legend_handles_labels()
            cmap = sns.color_palette("flare", as_cmap=True)

            _handles = handles[:]
            _labels = labels[:]
            handles.clear()
            labels.clear()

            min_str = str(int(np.min(hue)))
            max_str = str(int(np.max(hue)))

            if min_str != L[0]:
                handles.append(copy(H[0]))
                if vary_cluster_size:
                    handles[0].set_sizes([min_scatter_size])
                handles[0].set_facecolors([cmap(hue_norm(min_cluster_size))])
                handles[0].set_edgecolors([cmap(hue_norm(min_cluster_size))])

                s = ' (min)' if legend_min_max else ''
                labels.append(f'{min_str}{s}')

            if legend_n_sizes is None or legend_n_sizes+2 >= len(H):
                choices = np.arange(len(H))
            else:
                choices = np.linspace(0, len(H), legend_n_sizes+2, dtype=int)
                choices = choices[1:-1]

            for i in choices:
                handles.append(H[i])
                labels.append(L[i])

            if max_str != L[-1]:
                handles.append(copy(H[-1]))
                if vary_cluster_size:
                    handles[-1].set_sizes([max_scatter_size])
                handles[-1].set_facecolors([cmap(hue_norm(max_cluster_size))])
                handles[-1].set_edgecolors([cmap(hue_norm(max_cluster_size))])

                s = ' (max)' if legend_min_max else ''
                labels.append(f'{max_str}{s}')

            handles.extend(_handles)
            labels.extend(_labels)


        if plot_cluster_id:
            cluster_id = cluster_id[valid_clusters]
            for i in range(len(x)):
                ax.annotate(cluster_id.iloc[i], (x.iloc[i], y.iloc[i]), zorder=6, ha='center', va='center', color='white', fontsize='xx-small')

        return g

    def histplot_with_size(x, vertical, hue):
        color_hue = np.ones_like(x)
        if vertical:
            x, y = None, x
        else:
            x, y = x, None
        sns.histplot(
            x=x,
            y=y,
            weights=hue,
            hue=color_hue,
            palette='flare',
            bins=list(bins),
            legend=False,
        )

    if hist:
        if not mean_only:
            g.plot_joint(scatter_with_size, valid_clusters=df['y_valid_clusters'],
                        significant_clusters=idx_significant,
                        cluster_id=df['cluster_id'])
        g.plot_marginals(histplot_with_size)

    elif not mean_only:
        g = scatter_with_size(df['y_scores'], df['y_frac_pos'], df['y_size'],
                              valid_clusters=df['y_valid_clusters'],
                              significant_clusters=idx_significant,
                              cluster_id=df['cluster_id'])

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)


    bins = np.linspace(0, 1, n_bins+1)
    if not plot_first_last_bins:
        bins = bins[1:-1]

    for x in bins:
        ax.axvline(x, lw=0.5, ls='--', color='grey', zorder=-1)

    # if not mean_only:
    #     j = 0
    #     ypos = -ylim_margin/2
    #     for i in range(n_bins):
    #         if j >= len(prob_bins):
    #             break
    #         p = prob_bins[j]
    #         m = mean_bins[j]
    #         if i/n_bins <= m <= (i+1)/n_bins:
    #             ax.annotate(f'{int(100*p)}', xy=((i+.5)/n_bins, ypos), ha='center', va='center', color='grey', size='x-small')
    #             j += 1

    # Plot calibration curve
    p0, = ax.plot([0, 1], [0, 1], ls='--', lw=1, color='black', zorder=0)
    if mean_only:
        marker = 'o'
        markeredgecolor = 'white'
        markeredgewidth = 0.1
    else:
        marker = '.'
        markeredgecolor = None
        markeredgewidth = None

    p1, = ax.plot(mean_bins, prob_bins, marker=marker, markersize=5, color=cal_color,
                  label=mean_label, zorder=2, markeredgecolor=markeredgecolor,
                  markeredgewidth=markeredgewidth)#, linestyle='None')

    if not legend_sizes_only:
        handles.append(p0)
        labels.append('Perfect calibration')
        handles.append(p1)
        labels.append('Calibration curve')

    # Plot background histogram for calibration
    if plot_cal_hist:
        x = ((0.5+np.arange(n_bins))/n_bins)[non_empty]
        y = prob_bins
        #(0, 0, 0, 0.15)
        ax.bar(x, height=y, width=1/n_bins, color=(0.85, 0.85, 0.85, 1), zorder=0, edgecolor=(0.5, 0.5, 0.5, 1))

    metrics = {}

    # Plot upper and lower bounds of grouping loss
    if y_scores is not None and y_labels is not None and bins is not None:
        # metrics = compute_calib_metrics(frac_pos, counts, y_scores, y_labels, bins)
        lower_bound_bins = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=False)
        upper_bound_bins = grouping_loss_upper_bound(frac_pos, counts, y_scores, y_labels, bins, reduce_bin=False)

        lower_bound_bins = lower_bound_bins[non_empty]
        upper_bound_bins = upper_bound_bins[non_empty]

        p2, = ax.plot(mean_bins, prob_bins + lower_bound_bins, ls=':', lw=1, color='red', zorder=2)
        p3, = ax.plot(mean_bins, prob_bins - lower_bound_bins, ls=':', lw=1, color='red', zorder=2)
        p4, = ax.plot(mean_bins, prob_bins + upper_bound_bins, ls=':', lw=1, color='darkgray', zorder=2)
        p5, = ax.plot(mean_bins, prob_bins - upper_bound_bins, ls=':', lw=1, color='darkgray', zorder=2)
        p6 = ax.fill_between(mean_bins, prob_bins + lower_bound_bins, prob_bins + upper_bound_bins, color='black', alpha=0.05, zorder=2)
        p7 = ax.fill_between(mean_bins, prob_bins - lower_bound_bins, prob_bins - upper_bound_bins, color='black', alpha=0.05, zorder=2)

        if not legend_sizes_only:
            handles.append(p2)
            labels.append(r'$C$ $\pm$ $\sigma^-$')
            handles.append(p4)
            labels.append(r'$C \pm~\sigma^+$')

        # Plot metrics
        metrics.update(compute_multi_classif_metrics(y_scores, y_labels))
        metrics.update(compute_calib_metrics(frac_pos, counts, y_scores, y_labels, bins))

    rename = {
        'acc': 'Acc',
        'max_ece': 'ECE',
        'max_mce': 'MCE',
        'max_rmsce': 'RMSCE',
        'auc': 'AUC',
        'lower_bound': r'GL-',
        'upper_bound': r'GL+',
    }
    if hist and metrics:
        for i, (name, val) in enumerate(metrics.items()):
            ax_right.annotate(f'{rename.get(name, name)}: {val:.3g}',
                              xy=(0.05, 1.1-0.02*i), ha='left', va='center',
                              color='grey', size='x-small', xycoords='axes fraction',
                              )

    if not mean_only and len(handles) > 0:
        legend = ax.legend(loc=legend_loc, ncol=ncol, bbox_to_anchor=bbox_to_anchor,
                        handles=handles, labels=labels, fancybox=True, framealpha=1,)
                        # title_fontsize=legend_size, fontsize=legend_size)
        if legend_cluster_sizes:
            legend.set_title(legend_title)
    if xlim_margin is not None:
        ax.set_xlim((-xlim_margin, 1 + xlim_margin))
    if ylim_margin is not None:
        ax.set_ylim((-ylim_margin, 1 + ylim_margin))

    ax.set_aspect('equal')
    if title is not None:
        ax_top.set_title(title)

    return fig


def plot_grouping_bounds(lower_bounds, upper_bounds, labels):
    if len(lower_bounds) != len(upper_bounds):
        raise ValueError(f'Shape mismatch: lower_bounds {len(lower_bounds)} '
                         f'upper_bounds {len(upper_bounds)}')

    if len(lower_bounds) != len(labels):
        raise ValueError(f'Shape mismatch: lower_bounds {len(lower_bounds)} '
                         f'labels {len(labels)}')

    df = pd.DataFrame({
        'labels': labels,
        'lower_bounds': lower_bounds,
        'upper_bounds': upper_bounds,
    })

    df = df.sort_values(by=['lower_bounds'], ascending=False)
    df = df.melt(id_vars=['labels'], value_vars=['lower_bounds', 'upper_bounds'], var_name='bound')

    print(df)

    fig = plt.figure()
    ax = plt.gca()
    df.replace({
        'lower_bounds': 'Lower bound',
        'upper_bounds': 'Upper bound',
    }, inplace=True)
    sns.scatterplot(data=df, x='value', y='labels', hue='bound', ax=ax)

    ax.set_xlabel('Grouping loss')
    ax.set_ylabel('Clustering')

    return fig


def barplot_ece_gl_brier(net, ece, gl_lower_bound, gl_lower_bound_debiased, brier, acc,
                label_brier='Brier binarized',
                label_lb='Lower bound binarized',
                label_lbd='Lower bound binarized debiased',
                label_ece='L2-ECE binarized',):
    fig = plt.figure()
    idx_sort = np.argsort(-acc)
    net = np.array([f'{n} (Acc={100*a:.1f}%)' for n, a in zip(net, acc)])
    order = net[idx_sort]
    ax = sns.barplot(x=brier, y=net, color='gray', order=order, label=label_brier)
    sns.barplot(x=ece+gl_lower_bound, y=net, ax=ax, edgecolor='tab:orange', order=order, label=label_lb, facecolor='none', lw=2)
    sns.barplot(x=ece+gl_lower_bound_debiased, y=net, ax=ax, edgecolor='tab:red', order=order, label=label_lbd, facecolor='none', lw=2)
    sns.barplot(x=ece, y=net, ax=ax, edgecolor='tab:blue', order=order, label=label_ece, facecolor='none', lw=2)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    return fig


def barplot_ece_gl(net, val, acc, ax=None,
                legend=True,
                bbox_to_anchor=(1, 1),
                loc='upper left',
                append_acc=True,
                table_acc=False,
                table_fontsize=12,
                ax_ratio=1,
                gray_bg=None,
                ncol=1,
                detailed=False,
                plot_xlabel=False,
                ):
    cl, glexp, glexp_bias, glind, clind = val
    if ax is None:
        fig = plt.figure()
        ax = plt.gca()
    acc = np.array(acc)
    idx_sort = np.argsort(-acc)
    if append_acc:
        net = np.array([f'{n} (Acc={100*a:.1f}%)' for n, a in zip(net, acc)])
    else:
        net = np.array(net)

    # label_glexp = r'$\hat{\mathcal{L}}_{GL}$'
    # label_cl = r'$\ell^2$-ECE'
    label_glexp = r'$\widehat{\mathrm{GL}}_{explained}(S_B)$'
    label_cl = r'$\widehat{\mathrm{CL}}(S_B)$'
    label_bias = 'bias'
    label_glind = r'$\widehat{\mathrm{GL}}_{induced}$'
    order = net[idx_sort]

    if detailed:
        sns.barplot(x=glexp+cl, y=net, ax=ax, edgecolor='tab:blue', order=order, label=label_cl, facecolor='tab:blue', lw=2)
        sns.barplot(x=glexp, y=net, ax=ax, edgecolor='tab:red', order=order, label=label_glexp, facecolor='tab:red', lw=2)
        sns.barplot(x=glexp_bias+glind, y=net, ax=ax, edgecolor='none', order=order, label=label_bias, facecolor='tab:orange', lw=2)
        sns.barplot(x=glind, y=net, ax=ax, edgecolor='none', order=order, label=label_glind, facecolor='tab:green', lw=2)
        xlabel = f'{label_glexp} + {label_cl}'

    else:
        glexp_corrected = glexp - glexp_bias - glind
        cl_corrected = cl - clind
        label_cl = r'$\widehat{\mathrm{CL}}$'
        label_glexp_corr = r'$\widehat{\mathrm{GL}}_{\mathrm{LB}}$'
        # label_glexp_corr = r'$\widehat{\mathrm{LB}}$'
        sns.barplot(x=glexp_corrected+cl_corrected, y=net, ax=ax, edgecolor='tab:blue', order=order, label=label_cl, facecolor='tab:blue', lw=2)
        sns.barplot(x=glexp_corrected, y=net, ax=ax, edgecolor='tab:red', order=order, label=label_glexp_corr, facecolor='tab:red', lw=2)
        xlabel = f'{label_glexp_corr} + {label_cl}'

    # print('bias')
    # print(glexp_bias)
    # print('glind')
    # print(glind)

    if legend:
        handles, labels = ax.get_legend_handles_labels()
        handles.reverse()
        labels.reverse()
        ax.legend(bbox_to_anchor=bbox_to_anchor, loc=loc, ncol=ncol,
                handles=handles, labels=labels)

    if not plot_xlabel:
        xlabel = None
    ax.set_xlabel(xlabel)
    ax.set_ylabel(None)

    cellText = np.transpose([list([f'{100*a:.1f}' for a in acc[idx_sort]])])
    n_nets = len(net)

    if gray_bg is not None:
        ylim = ax.get_ylim()
        cellColours = [['white']]*n_nets
        for i in range(0, n_nets, 2):
            cellColours[i] = [gray_bg]
        for k in range(0, n_nets, 2):
            ax.axhspan(k-0.5, k+0.5, color=gray_bg, zorder=0)
        ax.set_ylim(ylim)
    else:
        cellColours = None

    if table_acc:
        # xpos = 1.02 + ax_ratio*.08
        # ypos = 1.05
        # ypos = -0.05#1.05
        # ax.annotate('Accuracy$\\uparrow$\n(%)', xy=(xpos, ypos), xytext=(xpos, ypos),
        #             xycoords='axes fraction', ha='center', va='center',
        #             fontsize=plt.rcParams['axes.titlesize'],
        #             # arrowprops=dict(arrowstyle=f'->', lw=1, color='gray'),
        #             )
        table_width = .14*ax_ratio
        table_xpos = 1.02 + 0.025*ax_ratio
        xpos = table_xpos + table_width + 0.02
        ypos = 0.5
        ax.annotate(r'Accuracy$\uparrow$ (\%)', xy=(xpos, ypos), xytext=(xpos, ypos),
                    xycoords='axes fraction', ha='left', va='center',
                    fontsize=plt.rcParams['axes.labelsize'],
                    rotation=-90,
                    # arrowprops=dict(arrowstyle=f'->', lw=1, color='gray'),
                    )
        table = ax.table(cellText=cellText, loc='right',
                                rowLabels=None,
                                colLabels=None,#['Accuracy (%)'],
                            #    bbox=[1.32, -0.11, .19, .87],
                                bbox=[table_xpos, 0, table_width, 1],#(n_nets+1)/n_nets],
                                # bbox=[1.02, 0, .38, 1],#(n_nets+1)/n_nets],
                                #    bbox=[1.3, 0, .2, .735],
                                # colWidths=[0.14],
                                cellColours=cellColours,
                                cellLoc='center',
                                )
        table.auto_set_font_size(False)
        table.set_fontsize(table_fontsize)

    return ax


def barplots_ece_gl_cal(net, val1, val2, acc, plot_table=True, keep_scale=True,
                        figsize=(4, 3.5), loc='center right', bbox_to_anchor=(1, 0.5)):
    set_latex_font()

    # Font sizes
    # figsize = (4, 3.5)
    # figsize = (4, 3.5)
    # figsize = None
    gray_bg = '.96'
    table_fontsize = 11
    plt.rc('ytick', labelsize=14)
    plt.rc('axes', labelsize=16)
    plt.rc('axes', titlesize=18)
    plt.rc('legend', fontsize=14)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.1)
    plt.rc('legend', handlelength=1.6)
    plt.rc('legend', labelspacing=0.4)
    plt.rc('legend', handletextpad=0.4)

    ax_ratio = 4
    fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': [ax_ratio, 1]})
    # fig, axes = plt.subplots(1, 4, figsize=figsize, gridspec_kw={'width_ratios': [ax_ratio, 0.3, 1, 0.3]})
    plt.subplots_adjust(wspace=0.05)#0)#0.05)
    ax1 = axes[0]
    # ax_mid = axes[1]
    # ax_mid.axis('off')
    ax2 = axes[1]
    # ax_right = axes[3]
    # ax_right.axis('off')

    # label_lbd = 'Lower bound'
    # label_glexp = r'$\widehat{\mathrm{GL}}_{explained}(S_B)$'
    # label_ece = r'$\widehat{\mathrm{CL}}(S_B)$'

    barplot_ece_gl(net, val1, acc, ax1, bbox_to_anchor=bbox_to_anchor, loc=loc, #loc='lower right', bbox_to_anchor=(0, 1),#(1, 0.5), loc='center right'
    append_acc=False, ncol=1, gray_bg=gray_bg, plot_xlabel=True)

    if val2 is not None:
        barplot_ece_gl(net, val2, acc, ax2, legend=False, table_acc=plot_table, ax_ratio=ax_ratio,
        table_fontsize=table_fontsize, gray_bg=gray_bg, plot_xlabel=False)

    ax1.set_xlim(0, 0.22)
    xmin, xmax = ax1.get_xlim()
    if keep_scale:
        # if 0.05 < xmax/ax_ratio:
        #     xtick = 0.05
        # else:
        #     xtick = 0.01
        # ax2.set_xticks([0, 0.01])
        xtick = xmin + 0.23*(xmax - xmin)
        # xtick = int(100*xtick/ax_ratio)/100.
        xtick = float(f'{xtick:.1g}')
        print(xmax, ax_ratio, xmax/ax_ratio, xtick)
        # exit()
        ax2.set_xticks([0, xtick])
        ax2.set_xlim((xmin, xmax/ax_ratio))
    ax2.get_yaxis().set_visible(False)
    ax1.set_title('No recalibration')
    # ax1.set_title('No post-hoc\ncalibration')
    ax2.set_title('Isotonic')
    # ax2.set_title('Isotonic\ncalibration')
    # ax1.set_xlabel(f'{label_glexp} + {label_ece}$~~~$', loc='right')
    # ax2.set_xlabel('L2-ECE + Lower bound')

    # ylim = ax1.get_ylim()
    # for k in range(0, len(net), 2):
    #     ax_mid.axhspan(k-0.5, k+0.5, color=gray_bg, zorder=0)
    #     ax_right.axhspan(k-0.5, k+0.5, color=gray_bg, zorder=0)
    # ax_mid.set_ylim(ylim)
    # ax_right.set_ylim(ylim)

    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)

    # xticks = ax1.get_xticks()
    # ax1.set_xticks(xticks[:-2])

    # ax1.set_xlim(0, 0.22)

    return fig

def plot_lower_bound_vs_acc(net, gl_lower_bound, acc, style=None):
    fig = plt.figure()
    ax = sns.scatterplot(x=gl_lower_bound, y=acc, hue=net, style=style)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('Grouping loss lower bound (binarized)')
    return fig


def plot_lower_bound_vs_brier(net, gl_lower_bound, brier, style=None):
    fig = plt.figure()
    ax = sns.scatterplot(x=gl_lower_bound, y=brier, hue=net, style=style)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_ylabel('Brier')
    ax.set_xlabel('Grouping loss lower bound (binarized)')
    return fig


def plot_lower_bound_vs_ece(net, gl_lower_bound, ece, style=None):
    fig = plt.figure()
    ax = sns.scatterplot(x=gl_lower_bound, y=ece, hue=net, style=style)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_ylabel('ECE')
    ax.set_xlabel('Grouping loss lower bound (binarized)')
    return fig


def plot_brier_acc(net, brier, acc, style=None):
    fig = plt.figure()
    ax = sns.scatterplot(x=brier, y=acc, hue=net, style=style)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_ylabel('Acc')
    ax.set_xlabel('Brier')
    return fig


def plot_brier_ece(net, brier, ece, style=None):
    fig = plt.figure()
    ax = sns.scatterplot(x=brier, y=ece, hue=net, style=style)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_ylabel('ECE')
    ax.set_xlabel('Brier')
    return fig


def plot_fig_theorem(isoline_right=True, squared=True, legend_right=True):
    set_latex_font()
    # fig = plt.figure()
    fig = plt.figure(figsize=(3.7, 3.7))
    ax = plt.gca()
    colors = sns.color_palette("hls", 8).as_hex()
    fontsize = 14
    legend_fontsize = 10.65
    color_cluster1 = colors[5]#'tab:red'
    color_cluster2 = colors[0]#'tab:yellow'
    alpha_clusters = 0.5
    marker_pos = '+'
    marker_neg = 'o'
    xmin, xmax = 0.14, 0.93
    if squared:
        ymin, ymax = 0.15, 0.94  # For the squared version
    else:
        isoline_right = legend_right
        ymin, ymax = 0.15, 0.735
        # ymin, ymax = 0.15, 0.75
    x_under = 0.35
    x_above = 0.85
    x_mid = 0.54
    p_under = 0.6
    p_above = 0.9

    a = -1.6666666666666521e-001
    b = 5.6666666666666499e-001
    c = 0.1

    width = 0.2
    X = np.linspace(0, 1, 1000)

    def curve1(X):
        return a*X**2 + b*X + c

    def curve2(X):
        return curve1(X) + width

    def cut(X):
        return 1.6 - 2.3*X

    Y1 = curve1(X)
    Y2 = curve2(X)
    Y_cut = cut(X)

    ax.plot(X, Y1, color='black', label='Score isoline')#'Score level line')
    ax.plot(X, Y2, color='black')

    ax.fill_between(X, Y1, np.minimum(Y2, Y_cut), where=np.minimum(Y2, Y_cut) >= Y1, color=color_cluster1, label='Region 1', alpha=alpha_clusters, edgecolor='none')
    ax.fill_between(X, Y2, np.maximum(Y1, Y_cut), where=np.maximum(Y1, Y_cut) <= Y2, color=color_cluster2, label='Region 2', alpha=alpha_clusters, edgecolor='none')

    idx_cut_visible = (Y1 <= Y_cut) & (Y_cut <= Y2)
    ax.plot(X[idx_cut_visible], Y_cut[idx_cut_visible], color='black', lw=1, ls=':')


    x_isoline = xmax if isoline_right else xmin
    delta = 0.01 if isoline_right else -0.01
    deltay = 0 if squared else 0.01
    deltay = deltay if legend_right else -0.005
    ha = 'left' if isoline_right else 'right'
    xy_low = (x_isoline, curve1(x_isoline) + deltay)
    xy_up = (x_isoline, curve2(x_isoline) + deltay)
    xy_mid = (x_isoline, 0.5*(curve1(x_isoline) + curve2(x_isoline) + deltay))
    xytext_low = (x_isoline + delta, curve1(x_isoline) + deltay)
    xytext_up = (x_isoline + delta, curve2(x_isoline) + deltay)
    xytext_mid = (x_isoline + delta, 0.5*(curve1(x_isoline) + curve2(x_isoline) + deltay))
    va = 'center'
    ax.annotate('0.7', xy=xy_low, xytext=xytext_low, va=va, ha=ha, fontsize=fontsize)
    # ax.annotate(r'$\bar{S} = 0.75$', xy=xy_mid, xytext=xytext_mid, va=va, fontsize=fontsize)
    ax.annotate('0.8', xy=xy_up, xytext=xytext_up, va=va, ha=ha, fontsize=fontsize)
    ax.annotate(fr'$\bar{{Q}}_1 = {p_under}$', xy=(x_under, curve2(x_under)), xytext=(x_under, curve2(x_under)), va='bottom', ha='right', color=color_cluster1, fontsize=fontsize)
    ax.annotate(fr'$\bar{{Q}}_2 = {p_above}$', xy=(x_above, curve2(x_above)), xytext=(x_above, curve2(x_above)), va='bottom', ha='right', color=color_cluster2, fontsize=fontsize)
    ax.annotate(fr'$\bar{{C}} = {0.5*(p_above + p_under)}$', xy=(x_mid, curve1(x_mid)), xytext=(x_mid, curve1(x_mid) - 0.01), va='top', ha='left', color='black', fontsize=fontsize)

    def is_in_cluster1(X, margin=0):
        x, y = X
        return y >= curve1(x) + margin and y <= curve2(x) - margin and y <= cut(x) - margin

    def is_in_cluster2(X, margin=0):
        x, y = X
        return y >= curve1(x) + margin and y <= curve2(x) - margin and y > cut(x) + margin

    def is_in_cluster1_or_2(X, margin=0):
        x, y = X
        return is_in_cluster1(X, margin) or is_in_cluster2(X, margin)

    def is_out_frame(X, margin=0):
        x, y = X
        return x <= margin or x >= 1 - margin or y <= margin or y >= 1 - margin

    n = 12
    m = 10
    rs = np.random.RandomState(0)
    y_under = rs.binomial(n=1, p=p_under, size=n)
    y_above = rs.binomial(n=1, p=p_above, size=n)

    mean1 = np.array([0.48, 0.32])
    mean2 = np.array([0.45, 0.55])
    cov1 = 1/180*np.diag([5, 1])
    cov2 = 1/140*np.diag([7, 1])
    p1 = 0
    p2 = 1

    # Rotate
    def rotation(cov, theta):
        R = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])
        return R @ cov @ R.T

    theta1 = np.pi/12
    theta2 = np.pi/12
    cov1 = rotation(cov1, theta1)
    cov2 = rotation(cov2, theta2)

    # plot_covariance2D(mean1, cov1, ax, n_std=1, edgecolor='black', label=None)
    # plot_covariance2D(mean2, cov2, ax, n_std=1, edgecolor='black', label=None)

    def sample_x():
        # return rs.uniform(0, 1, size=2)
        if rs.binomial(n=1, p=0.5):
            return rs.multivariate_normal(mean1, cov1), rs.binomial(n=1, p=p1)
        return rs.multivariate_normal(mean2, cov2), rs.binomial(n=1, p=p2)

    L_X_dist1 = []  # samples in dist 1
    L_X_dist2 = []  # samples in dist 2
    L_X_cluster1 = []
    L_X_cluster2 = []
    L_y = []
    L_y_cluster1 = []
    L_y_cluster2 = []
    while len(L_X_dist1) < n or len(L_X_dist2) < n or len(L_X_cluster1) < m or len(L_X_cluster2) < m:
        x_prop, y_prop = sample_x()

        if is_out_frame(x_prop, margin=0.02):
            continue

        if is_in_cluster1(x_prop):
            if is_in_cluster1(x_prop, margin=0.02) and len(L_X_cluster1) < m:
                if np.sum(L_y_cluster1) >= int(p_under*m) and y_prop == 1:
                    continue # ignore sample with label 1 because too many already
                if len(L_y_cluster1) - np.sum(L_y_cluster1) >= m - int(p_under*m) and y_prop == 0:
                    continue # ignore sample with label 0 because too many already
                L_X_cluster1.append(x_prop)
                L_y_cluster1.append(y_prop)
            continue

        if is_in_cluster2(x_prop):
            if is_in_cluster2(x_prop, margin=0.02) and len(L_X_cluster2) < m:
                if np.sum(L_y_cluster2) >= int(p_above*m) and y_prop == 1:
                    continue # ignore sample with label 1 because too many already
                if len(L_y_cluster2) - np.sum(L_y_cluster2) >= m - int(p_above*m) and y_prop == 0:
                    continue # ignore sample with label 0 because too many already
                L_X_cluster2.append(x_prop)
                L_y_cluster2.append(y_prop)
            continue

        # Ignore points too close to clusters bondaries
        if is_in_cluster1_or_2(x_prop, margin=-0.02):
            continue

        # if cut(x_prop[0]) >= x_prop[1]:
        if len(L_X_dist1) < n:
            L_X_dist1.append(x_prop)
            L_y.append(y_prop)

        elif len(L_X_dist2) < n:
            L_X_dist2.append(x_prop)
            L_y.append(y_prop)

        # if cut(x_prop[0]) >= x_prop[1]:
        #     if len(L_X_dist1) < n:
        #         L_X_dist1.append(x_prop)
        #         L_y.append(y_prop)

        # elif len(L_X_dist2) < n:
        #     L_X_dist2.append(x_prop)
        #     L_y.append(y_prop)

    assert len(L_X_dist2) == n
    assert len(L_X_dist1) == n
    assert len(L_y) == 2*n

    Xs = np.array(L_X_dist1 + L_X_dist2)
    # y_labels = np.concatenate([y_under, y_above])
    # y_labels = np.concatenate([y_under, y_above])
    y_labels = np.array(L_y)
    Xs_clusters = np.array(L_X_cluster1 + L_X_cluster2)
    # y_labels_cluster1 = np.array([1]*int(m*p_under) + [0]*(m - int(m*p_under)))
    # y_labels_cluster2 = np.array([1]*int(m*p_above) + [0]*(m - int(m*p_above)))
    y_labels_clusters1 = np.array(L_y_cluster1)
    y_labels_clusters2 = np.array(L_y_cluster2)
    y_labels_clusters = np.concatenate([y_labels_clusters1, y_labels_clusters2])

    assert Xs.shape == (2*n, 2)
    assert y_labels.shape == (2*n,)
    assert y_labels_clusters.shape == (2*m,)

    idx_pos = y_labels == 1
    ax.scatter(Xs[idx_pos, 0], Xs[idx_pos, 1], color='black', marker=marker_pos)  # (0.8, 0.8, 0.8, 1)
    ax.scatter(Xs[~idx_pos, 0], Xs[~idx_pos, 1], color='black', marker=marker_neg)
    idx_pos = y_labels_clusters == 1
    ax.scatter(Xs_clusters[idx_pos, 0], Xs_clusters[idx_pos, 1], color='black', marker=marker_pos, label='Positive')
    ax.scatter(Xs_clusters[~idx_pos, 0], Xs_clusters[~idx_pos, 1], color='black', marker=marker_neg, label='Negative')

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if squared:
        legend = ax.legend(loc='upper center', ncol=2)#, fontsize=legend_fontsize)
    else:
        bbox_to_anchor = (1, -0.02) if legend_right else (0.0, 1.02)
        loc = 'lower left' if legend_right else 'upper right'
        legend = ax.legend(loc=loc, ncol=1, bbox_to_anchor=bbox_to_anchor)#, fontsize=legend_fontsize)
        # legend = ax.legend(loc='upper left', ncol=1, bbox_to_anchor=(1, 1.02))#, fontsize=legend_fontsize)
    legend.get_frame().set_alpha(None)
    # ax.set_axis_off()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_aspect('equal')
    # ax.set_title('Feature space $X$')

    # Plot Xcal feature space in corner
    ax.annotate('    ', xy=(0.02, 0.972), xycoords='axes fraction',
                bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize)
    ax.annotate(r'$\mathcal{X}~$', xy=(0.02, 0.972), xycoords='axes fraction',
                # bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize+3)
    return fig



def plot_fig_theorem_v2(isoline_right=True, squared=True, legend_right=True):
    set_latex_font(extra_preamble=[r'\usepackage[mathscr]{eucal}'])

    # plt.rc('legend', fontsize=10)
    # plt.rc('legend', title_fontsize=12)
    # plt.rc('legend', handletextpad=0.5)
    # plt.rc('legend', columnspacing=1)
    # plt.rc('legend', borderpad=0)
    plt.rc('legend', borderpad=0.1)
    plt.rc('legend', borderaxespad=0.1)
    # plt.rc('legend', borderaxespad=0.1)
    # plt.rc('legend', handlelength=1.6)
    plt.rc('legend', labelspacing=0.3)
    # plt.rc('xtick', labelsize=9)
    # plt.rc('ytick', labelsize=9)
    # plt.rc('axes', labelsize=12)
    # plt.rc('axes', titlesize=10)

    # fig = plt.figure()
    fig = plt.figure(figsize=(3.7, 3.7))
    ax = plt.gca()
    colors = sns.color_palette("hls", 8).as_hex()
    fontsize = 14
    legend_fontsize = 10.65
    color_cluster1 = colors[5]#'tab:red'
    color_cluster2 = colors[0]#'tab:yellow'
    alpha_clusters = 0.5
    marker_pos = '+'
    marker_neg = 'o'
    xmin, xmax = 0.14, 0.93
    if squared:
        ymin, ymax = 0.15, 0.94  # For the squared version
    else:
        isoline_right = legend_right
        ymin, ymax = 0.15, 0.735
        # ymin, ymax = 0.15, 0.75
    x_e = 0.48
    x_er1 = 0.43
    # x_er1 = 0.45
    # x_er1 = 0.4
    x_er2 = 0.8
    x_mid = 0.54
    x_r1 = 0.52
    x_r2 = 0.49
    p_under = 0.6
    p_above = 0.8
    p_mid = 0.5*(p_above + p_under)

    a = -1.6666666666666521e-001
    b = 5.6666666666666499e-001
    c = 0.1

    width = 0.2
    X = np.linspace(0, 1, 1000)

    def curve1(X):
        return a*X**2 + b*X + c

    def curve2(X):
        return curve1(X) + width

    def cut(X):
        return 1.6 - 2.3*X

    def is_in_cluster1(X, margin=0):
        x, y = X
        return y >= curve1(x) + margin and y <= curve2(x) - margin and y <= cut(x) - margin

    def is_in_cluster2(X, margin=0):
        x, y = X
        return y >= curve1(x) + margin and y <= curve2(x) - margin and y > cut(x) + margin

    def is_in_cluster1_or_2(X, margin=0):
        x, y = X
        return is_in_cluster1(X, margin) or is_in_cluster2(X, margin)

    def is_out_frame(X, margin=0):
        x, y = X
        return x <= margin or x >= 1 - margin or y <= margin or y >= 1 - margin


    def is_in_level_set(X):
        """X shape (n, 2)"""
        x, y = X[:, 0], X[:, 1]
        return (y >= curve1(x)) & (y <= curve2(x))

    Y1 = curve1(X)
    Y2 = curve2(X)
    Y2 = curve2(X)
    Y_cut = cut(X)

    line, = ax.plot(X, Y1, color='black', label='Level set $S = 0.7$')
    ax.plot(X, Y2, color='black')

    # ax.fill_between(X, Y1, np.minimum(Y2, Y_cut), where=np.minimum(Y2, Y_cut) >= Y1, color=color_cluster1, label='Region 1', alpha=alpha_clusters, edgecolor='none')
    # ax.fill_between(X, Y2, np.maximum(Y1, Y_cut), where=np.maximum(Y1, Y_cut) <= Y2, color=color_cluster2, label='Region 2', alpha=alpha_clusters, edgecolor='none')

    h = 100
    qmin = 0.5
    qmax = 0.9
    qmid = (qmin + qmax)/2
    n_levels = 500

    XX = np.linspace(xmin, xmax, h)
    WW = np.linspace(0, 0.2, h)

    XX0, WW0 = np.meshgrid(XX,
                           WW)
    YY0 = curve1(XX0) + WW0
    # XX0, YY0 = np.meshgrid(np.linspace(xmin, xmax, h),
    #                        np.linspace(ymin, ymax, h))

    # X = np.c_[XX0.ravel(), YY0.ravel()]

    # print(XX0.shape)
    # XX0 = XX0.ravel()[is_in_level_set(X)]
    # print(XX0.shape)
    # YY0 = YY0.ravel()[is_in_level_set(X)]

    def Q(X):
        x, y = X[:, 0], X[:, 1]
        v_perp = np.stack([np.ones_like(x), 2*a*x+b], axis=1)
        print(v_perp.shape)
        print(X.shape)
        print(np.inner(X, v_perp).shape)
        # t = np.divide(np.sum(X*v_perp, axis=1), np.linalg.norm(X, axis=1)*np.linalg.norm(v_perp, axis=1))
        t = np.sum(X*v_perp, axis=1)
        print(t.shape)
        tmax = np.max(t)
        tmin = np.min(t)
        # a = (qmax - qmin)/(xmax - xmin)
        # b = qmin
        # q =  np.linalg.norm(x)
        # q =  np.square(x)
        # q[~is_in_level_set(x)] = qmid
        # q = np.stack([x, a*x+b], axis=1)
        q = (qmax - qmin)/(tmax - tmin)*(t - tmin) + qmin
        # print(x)
        # print(q)
        # print(a)
        return q

    Z2 = Q(np.c_[XX0.ravel(), YY0.ravel()])
    Z2 = Z2.reshape(XX0.shape)
    # Z2 = Z2.reshape(XX0.shape)
    # Z2_max = np.max(Z2)
    # Z2_min = np.min(Z2)
    # Z2_lim = max(np.abs(Z2_max), np.abs(Z2_min))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cm = plt.cm.RdBu_r
    crf = ax.contourf(XX0, YY0, Z2, levels=np.linspace(qmin, qmax, n_levels), cmap=cm, alpha=1, vmin=qmin, vmax=qmax)
    cbar = fig.colorbar(crf, cax=cax)
    # cbar.ax.set_title(r'$Q$')
    # cbar.ax.set_yticks([-0.5, 0, 0.5])
    cbar.set_ticks([qmin, qmax])
    cbar.ax.set_yticklabels([f'{qmin:.2g}', f'{qmax:.2g}'])
    # cbar.ax.set_yticklabels([r'$-\frac{1}{2}$', '0', r'$\frac{1}{2}$'])

    for collection in crf.collections:
        # collection.set_facecolor("black")
        # collection.set_facecolor("black")
        collection.set_edgecolor("face")
        # collection.set_alpha(1)
        # collection.set_linewidth(0)
        collection.set_linewidth(0.02)

    idx_cut_visible = (Y1 <= Y_cut) & (Y_cut <= Y2)
    ax.plot(X[idx_cut_visible], Y_cut[idx_cut_visible], color='black', lw=1, ls=':')


    x_isoline = xmax if isoline_right else xmin
    delta = 0.01 if isoline_right else -0.01
    deltay = 0 if squared else 0.01
    deltay = deltay if legend_right else -0.005
    ha = 'left' if isoline_right else 'right'
    xy_low = (x_isoline, curve1(x_isoline) + deltay)
    xy_up = (x_isoline, curve2(x_isoline) + deltay)
    xy_mid = (x_isoline, 0.5*(curve1(x_isoline) + curve2(x_isoline) + deltay))
    xytext_low = (x_isoline + delta, curve1(x_isoline) + deltay)
    xytext_up = (x_isoline + delta, curve2(x_isoline) + deltay)
    xytext_mid = (x_isoline + delta, 0.5*(curve1(x_isoline) + curve2(x_isoline) + deltay))
    va = 'center'
    dd = 0.005
    # ax.annotate('0.7', xy=xy_low, xytext=xytext_low, va=va, ha=ha, fontsize=fontsize)
    # ax.annotate(r'$S = 0.7$', xy=xy_mid, xytext=xytext_mid, va=va, fontsize=fontsize)
    # ax.annotate('0.8', xy=xy_up, xytext=xytext_up, va=va, ha=ha, fontsize=fontsize)
    # ax.annotate(fr'$\bar{{Q}}_1 = {p_under}$', xy=(x_under, curve2(x_under)), xytext=(x_under, curve2(x_under)), va='bottom', ha='right', color=color_cluster1, fontsize=fontsize)
    # ax.annotate(fr'$\bar{{Q}}_2 = {p_above}$', xy=(x_above, curve2(x_above)), xytext=(x_above, curve2(x_above)), va='bottom', ha='right', color=color_cluster2, fontsize=fontsize)
    # ax.annotate(fr'$\bar{{C}} = {0.5*(p_above + p_under)}$', xy=(x_mid, curve1(x_mid)), xytext=(x_mid, curve1(x_mid) - 0.01), va='top', ha='left', color='black', fontsize=fontsize)
    ax.annotate(fr'$\mathbb{{E}}[Q|S] = {p_mid}$', xy=(x_e, curve2(x_e)), xytext=(x_e-dd, curve2(x_e)+dd), va='bottom', ha='right', color='black', fontsize=fontsize)
    ax.annotate(fr'$\mathbb{{E}}[Q|S,\mathscr{{R}}_1] = {p_under}$', xy=(x_er1, curve1(x_er1)), xytext=(x_er1+2*dd, curve1(x_er1)-2*dd), va='top', ha='left', color=color_cluster1, fontsize=fontsize)
    ax.annotate(fr'$\mathbb{{E}}[Q|S,\mathscr{{R}}_2] = {p_above}$', xy=(x_er2, curve2(x_er2)), xytext=(x_er2-dd, curve2(x_er2)), va='bottom', ha='right', color=color_cluster2, fontsize=fontsize)
    # ax.annotate(fr'$\bar{{Q}}_2 = {p_above}$', xy=(x_above, curve2(x_above)), xytext=(x_above, curve2(x_above)), va='bottom', ha='right', color=color_cluster2, fontsize=fontsize)
    ax.annotate(r'$\mathscr{{R}}_1$', xy=(x_r1, curve2(x_r1)), xytext=(x_r1-dd, curve1(x_r1)+dd), va='bottom', ha='right', color='black', fontsize=fontsize)
    # ax.annotate(r'$\mathscr{{R}}_1$', xy=(x_r1, curve2(x_r1)), xytext=(x_r1+dd, curve2(x_r1)-dd), va='top', ha='left', color='black', fontsize=fontsize)
    ax.annotate(r'$\mathscr{{R}}_2$', xy=(x_r2, curve2(x_r2)), xytext=(x_r2+dd, curve2(x_r2)-dd), va='top', ha='left', color='black', fontsize=fontsize)
    # ax.annotate(fr'$R_1$', xy=(x_mid, curve1(x_mid)), xytext=(x_mid, curve1(x_mid) - 0.01), va='top', ha='left', color='black', fontsize=fontsize)

    n = 12
    m = 10
    rs = np.random.RandomState(0)
    y_under = rs.binomial(n=1, p=p_under, size=n)
    y_above = rs.binomial(n=1, p=p_above, size=n)

    mean1 = np.array([0.48, 0.32])
    mean2 = np.array([0.45, 0.55])
    cov1 = 1/180*np.diag([5, 1])
    cov2 = 1/140*np.diag([7, 1])
    p1 = 0
    p2 = 1

    # Rotate
    def rotation(cov, theta):
        R = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])
        return R @ cov @ R.T

    theta1 = np.pi/12
    theta2 = np.pi/12
    cov1 = rotation(cov1, theta1)
    cov2 = rotation(cov2, theta2)

    # plot_covariance2D(mean1, cov1, ax, n_std=1, edgecolor='black', label=None)
    # plot_covariance2D(mean2, cov2, ax, n_std=1, edgecolor='black', label=None)

    def sample_x():
        # return rs.uniform(0, 1, size=2)
        if rs.binomial(n=1, p=0.5):
            return rs.multivariate_normal(mean1, cov1), rs.binomial(n=1, p=p1)
        return rs.multivariate_normal(mean2, cov2), rs.binomial(n=1, p=p2)

    L_X_dist1 = []  # samples in dist 1
    L_X_dist2 = []  # samples in dist 2
    L_X_cluster1 = []
    L_X_cluster2 = []
    L_y = []
    L_y_cluster1 = []
    L_y_cluster2 = []
    while len(L_X_dist1) < n or len(L_X_dist2) < n or len(L_X_cluster1) < m or len(L_X_cluster2) < m:
        x_prop, y_prop = sample_x()

        if is_out_frame(x_prop, margin=0.02):
            continue

        if is_in_cluster1(x_prop):
            if is_in_cluster1(x_prop, margin=0.02) and len(L_X_cluster1) < m:
                if np.sum(L_y_cluster1) >= int(p_under*m) and y_prop == 1:
                    continue # ignore sample with label 1 because too many already
                if len(L_y_cluster1) - np.sum(L_y_cluster1) >= m - int(p_under*m) and y_prop == 0:
                    continue # ignore sample with label 0 because too many already
                L_X_cluster1.append(x_prop)
                L_y_cluster1.append(y_prop)
            continue

        if is_in_cluster2(x_prop):
            if is_in_cluster2(x_prop, margin=0.02) and len(L_X_cluster2) < m:
                if np.sum(L_y_cluster2) >= int(p_above*m) and y_prop == 1:
                    continue # ignore sample with label 1 because too many already
                if len(L_y_cluster2) - np.sum(L_y_cluster2) >= m - int(p_above*m) and y_prop == 0:
                    continue # ignore sample with label 0 because too many already
                L_X_cluster2.append(x_prop)
                L_y_cluster2.append(y_prop)
            continue

        # Ignore points too close to clusters bondaries
        if is_in_cluster1_or_2(x_prop, margin=-0.02):
            continue

        # if cut(x_prop[0]) >= x_prop[1]:
        if len(L_X_dist1) < n:
            L_X_dist1.append(x_prop)
            L_y.append(y_prop)

        elif len(L_X_dist2) < n:
            L_X_dist2.append(x_prop)
            L_y.append(y_prop)

        # if cut(x_prop[0]) >= x_prop[1]:
        #     if len(L_X_dist1) < n:
        #         L_X_dist1.append(x_prop)
        #         L_y.append(y_prop)

        # elif len(L_X_dist2) < n:
        #     L_X_dist2.append(x_prop)
        #     L_y.append(y_prop)

    assert len(L_X_dist2) == n
    assert len(L_X_dist1) == n
    assert len(L_y) == 2*n

    Xs = np.array(L_X_dist1 + L_X_dist2)
    # y_labels = np.concatenate([y_under, y_above])
    # y_labels = np.concatenate([y_under, y_above])
    y_labels = np.array(L_y)
    Xs_clusters = np.array(L_X_cluster1 + L_X_cluster2)
    # y_labels_cluster1 = np.array([1]*int(m*p_under) + [0]*(m - int(m*p_under)))
    # y_labels_cluster2 = np.array([1]*int(m*p_above) + [0]*(m - int(m*p_above)))
    y_labels_clusters1 = np.array(L_y_cluster1)
    y_labels_clusters2 = np.array(L_y_cluster2)
    y_labels_clusters = np.concatenate([y_labels_clusters1, y_labels_clusters2])

    assert Xs.shape == (2*n, 2)
    assert y_labels.shape == (2*n,)
    assert y_labels_clusters.shape == (2*m,)

    idx_pos = y_labels == 1
    # ax.scatter(Xs[idx_pos, 0], Xs[idx_pos, 1], color='black', marker=marker_pos)  # (0.8, 0.8, 0.8, 1)
    # ax.scatter(Xs[~idx_pos, 0], Xs[~idx_pos, 1], color='black', marker=marker_neg)
    idx_pos = y_labels_clusters == 1
    # ax.scatter(Xs_clusters[idx_pos, 0], Xs_clusters[idx_pos, 1], color='black', marker=marker_pos, label='Positive')
    # ax.scatter(Xs_clusters[~idx_pos, 0], Xs_clusters[~idx_pos, 1], color='black', marker=marker_neg, label='Negative')

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if squared:
        legend = ax.legend(loc='upper center', ncol=2)#, fontsize=legend_fontsize)
    else:
        bbox_to_anchor = (1, -0.02) if legend_right else (0.0, 1.02)
        loc = 'lower left' if legend_right else 'upper right'
        # legend = ax.legend(loc=loc, ncol=1, bbox_to_anchor=bbox_to_anchor)#, fontsize=legend_fontsize)
        handles = [Patch(facecolor='none', edgecolor='black', linewidth=line.get_linewidth(), label=line.get_label())]
        legend = ax.legend(loc='lower right', ncol=1, bbox_to_anchor=(1, 0), fancybox=False, framealpha=0, handles=handles)#, fontsize=legend_fontsize)
        # legend = ax.legend(loc='lower right', ncol=1, bbox_to_anchor=(1, 0), fancybox=True, framealpha=1, handles=handles)#, fontsize=legend_fontsize)
    # legend.get_frame().set_alpha(None)
    # ax.set_axis_off()
    frame = legend.get_frame()
    frame.set_linewidth(0)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_aspect('equal')
    # ax.set_title('Feature space $X$')

    # Plot Xcal feature space in corner
    # ax.annotate(r'  ~$\mathrm{~}\;\;\;~$   ', xy=(0.02, 0.972), xycoords='axes fraction',
    #             bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
    #             ha='left', va='top', fontsize=fontsize)
    ax.annotate(r'$\mathcal{X}\hspace{1cm}$', xy=(0.02, 0.972), xycoords='axes fraction',
                # bbox=dict(boxstyle='square,pad=0.2', ec='black', fc='white', alpha=1, linewidth=0.7),
                # bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='top', fontsize=fontsize+3)
    cbar.ax.annotate(r'$Q$', xy=(1.6, 0.5), xycoords='axes fraction',
                # bbox=dict(boxstyle='square', ec='black', fc='white', alpha=1, linewidth=0.7),
                ha='left', va='center', fontsize=fontsize)

    ax.patches.extend([plt.Rectangle((0,0.86), 0.115, 0.14, linewidth=0.7,
                                    fill=False, color='black', alpha=1, zorder=1000,
                                    transform=ax.transAxes, figure=fig)])

    # ax.patches.extend([plt.Rectangle((0.29, 0), 0.71, 0.12, linewidth=0.7,
    #                                 fill=False, color='black', alpha=1, zorder=1000,
    #                                 transform=ax.transAxes, figure=fig)])

    return fig


def plot_fig_counter_example(with_arrow=False):
    set_latex_font()

    fontsize = 11#'large' #12
    fontsize_threshold = fontsize - 2.5
    lw = 2
    # plt.rc('legend', title_fontsize=10)
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.5)
    plt.rc('legend', columnspacing=1.3)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.5)
    plt.rc('legend', labelspacing=0.3)
    plt.rc('axes', labelsize=11)
    figsize = (1.8, 1.8)

    x_min = -1
    x_max = 1
    n = 100


    def f(x):
        x = np.atleast_1d(x)
        y = np.zeros_like(x)
        # y[x < -0.5] = 0.7
        y[x > 0] = 0.7
        y[x <= 0] = 0.2
        return y

    def f_star(x):
        x = np.atleast_1d(x)
        y = np.zeros_like(x)
        # y[x < -0.5] = 0.6

        y[x > 0] = 0.6
        y[x > 0.5] = 0.8
        y[x <= 0] = 0.3
        y[x < -0.5] = 0.1
        return y

    fig = plt.figure(figsize=figsize)
    ax = plt.gca()

    disc_gap = 0.01

    ax.axhline(0.5, color='black', lw=0.5)
    X = np.linspace(x_min, x_max, n)
    S = f(X)
    X, S = insert_nan_at_discontinuities(X, S, min_gap=disc_gap)
    ax.plot(X, S, label='$S(X)$', color='black', lw=lw)
    # ax.plot(X, S, label='Confidence score $S(X)$', color='black')
    # X = np.linspace(x_min, x_max, n)
    Q = f_star(X)
    X, Q = insert_nan_at_discontinuities(X, Q, min_gap=disc_gap)
    ax.plot(X, Q, label='$Q(X)$', color='tab:red', ls='-', lw=lw)
    # ax.plot(X, Q, label='True posterior probabilities $Q(X)$', color='tab:red', ls='--')
    ax.set_xlabel(r'$X \sim U([-1, 1])$')
    ax.set_ylabel('Output')
    ax.xaxis.set_label_coords(0.5, -0.05)
    ax.set_xticks([-1, 1])
    # ax.set_xticks([-1, 0, 1])
    ax.set_yticks([0, 0.5, 1])
    # ax.set_yticklabels(['0', '0.5', '1'])
    # ax.set_yticks([0, 0.1, 0.2, 0.3, 0.5, 1])
    # ax.set_yticklabels(['0', '', '', '', '0.5', '1'])
    # ax.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
    # h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax.get_legend_handles_labels()
    # ax.legend(handles=h2+h1, labels=l2+l1)
    # ax.spines['top'].set_visible(False)

    # ax.annotate('0.7', xy=(-0.5, 0.7), xytext=(-0.49, 0.7), color='black', va='center', ha='left', fontsize=fontsize)
    # ax.annotate('0.6', xy=(-0.5, 0.6), xytext=(-0.49, 0.6), color='tab:red', va='center', ha='left', fontsize=fontsize)

    # ax.annotate('0.7', xy=(0.5, 0.7), xytext=(0.49, 0.7), color='black', va='center', ha='right', fontsize=fontsize)
    # ax.annotate('0.8', xy=(0.5, 0.8), xytext=(0.49, 0.8), color='tab:red', va='center', ha='right', fontsize=fontsize)

    # ax.annotate('0.7', xy=(-1, 0.7), xytext=(-1.03, 0.7), color='black', va='center', ha='right', fontsize=fontsize)
    # ax.annotate('0.6', xy=(-1, 0.6), xytext=(-1.03, 0.6), color='tab:red', va='center', ha='right', fontsize=fontsize)

    # delta_left = 0.125
    delta_left = 0.05
    delta_right = 0.05

    ax.annotate('0.6', xy=(1, 0.6), xytext=(1 + delta_right, 0.6), color='tab:red', va='center', ha='left', fontsize=fontsize)
    ax.annotate('0.7', xy=(1, 0.7), xytext=(1 + delta_right, 0.7), color='black', va='center', ha='left', fontsize=fontsize)
    ax.annotate('0.8', xy=(1, 0.8), xytext=(1 + delta_right, 0.8), color='tab:red', va='center', ha='left', fontsize=fontsize)

    arrowprops = dict(arrowstyle=f'->, head_length=0, head_width=0', lw=1)
    ax.annotate('0.1', xy=(-1, 0.1), xytext=(-1 - delta_left, 0.1), color='tab:red', va='center', ha='right', fontsize=fontsize, arrowprops=arrowprops)
    ax.annotate('0.2', xy=(-1, 0.2), xytext=(-1 - delta_left, 0.2), color='black', va='center', ha='right', fontsize=fontsize, arrowprops=arrowprops)
    ax.annotate('0.3', xy=(-1, 0.3), xytext=(-1 - delta_left, 0.3), color='tab:red', va='center', ha='right', fontsize=fontsize, arrowprops=arrowprops)

    # x_text = -0.65
    # y_text = 0.43

    # delta_arrow = 0.4
    # y_arrow = 0.5
    # y_arrow = 0.49
    # y_text = 0.49
    # ax.annotate('decision treshold', xy=(0, y_text), xytext=(0, y_text), color='black', va='top', ha='center', fontsize=fontsize)
    if with_arrow:
        x_text = -0.55
        y_text = 0.43
        ax.annotate('decision threshold', xy=(x_text, y_text), xytext=(x_text, y_text), color='black', va='center', ha='left', fontsize=fontsize_threshold)
        style = "Simple, tail_width=0.01, head_width=2, head_length=3"
        kw = dict(arrowstyle=style, color="k")
        a = patches.FancyArrowPatch((x_text, y_text), (-.98, 0.495),
                                    connectionstyle="arc3,rad=-.22", **kw)
        ax.add_patch(a)

    else:
        ax.annotate('decision threshold', xy=(0, 0.5), xytext=(0, 0.49), color='black', va='top', ha='center', fontsize=fontsize_threshold)

    ax.legend(loc='lower center', bbox_to_anchor=(0.5, .81), ncol=2)
    # ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2)
    # ax.legend(loc='upper left')
    ax.set_xlim(-1, 1)
    # ax.set_ylim(0, 1)
    # ax.set_aspect('equal')

    return fig


def plot_fig_renditions(df, x, y, hue=None, z=None, figsize=None):
    set_latex_font()

    dfgb = df.groupby([y]).mean().reset_index()
    order = dfgb.sort_values(x, ascending=False)[y]

    dfgb2 = df.groupby([hue]).mean()

    print(dfgb2)
    print(dfgb2[z])

    dfgb2 = dfgb2.sort_values(z, ascending=False)
    hue_order = dfgb2.reset_index()[hue]

    print(hue_order)

    colors = {net: dfgb2.loc[net, z] for net in hue_order}
    vmin = np.min(list(colors.values()))
    vmax = np.max(list(colors.values()))
    import matplotlib.cm as cm
    from matplotlib.colors import Normalize
    norm = Normalize(vmin, vmax)

    cmap = 'jet'
    m = cm.ScalarMappable(norm=norm, cmap=cmap)

    colors = {net: m.to_rgba(v) for net, v in colors.items()}

    print(colors)

    # fig = plt.figure(figsize=figsize)
    # ax = plt.gca()

    g = sns.catplot(data=df, x=x, y=y, hue=hue, seed=0, order=order,
                    hue_order=hue_order)#, palette=colors)#, order=dfgb[y])
    fig = g.figure
    ax = fig.axes[0]



    # print(dfgb)

    # print(dfgb[y])

    sns.stripplot(data=dfgb, x=x, y=y, color='black', ax=ax, jitter=0, order=order)

    n_renditions = dfgb.shape[0]

    ax.set_ylim((n_renditions-0.5, -0.5))

    # Add gray layouts in the background every other rows
    for k in range(0, n_renditions, 2):
        ax.axhspan(k-0.5, k+0.5, color='.93', zorder=-1)

    ax.axvline(0, color='darkgray', lw=1, zorder=0)
    # ax.axvline(0)
    ax.set_ylabel('Rendition')
    ax.set_xlabel(r'$\widehat \mathcal{L}_{GL}^{rendition} - \widehat \mathcal{L}_{GL}^{all}$')

    return fig


def plot_cost_vs_gl(df, x='gl', y='error_bin', hue='estimator', style=None, bin=True):
    set_latex_font()
    plt.rc('legend', fontsize=10)
    plt.rc('legend', handletextpad=0.5)
    plt.rc('legend', columnspacing=1.3)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.5)
    plt.rc('legend', labelspacing=0.3)
    plt.rc('xtick', labelsize=7)
    plt.rc('ytick', labelsize=7)
    plt.rc('axes', labelsize=11)
    # palette = 'Spectral'
    # palette = 'hls'
    # palette = 'colorblind'
    # palette = 'Set2'
    palette = None
    # palette = 'muted'
    s = 40

    hue_order = df[hue].unique()
    if style is not None:
        style_order = df[style].unique()
    else:
        style_order = None

    print(hue_order)

    # df = df.query(f'{x} >= 0')
    df_neg = df.query(f'{x} < 0')
    df_pos = df.query(f'{x} >= 0')

    fig, _ = plt.subplots(1, 1, figsize=(2.5, 2))
    ax = fig.axes[0]
    sns.scatterplot(data=df, x=x, y=y, hue=hue, style=style, ax=ax, legend=None, palette=palette, hue_order=hue_order, style_order=style_order, alpha=0.4, s=s)
    sns.scatterplot(data=df_pos, x=x, y=y, hue=hue, style=style, ax=ax, legend='full', palette=palette, hue_order=hue_order, style_order=style_order, s=s)
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1.02))

    if bin:
        ax.set_xlabel(r'$\widehat \mathcal{L}_{GL,bin}$')
        ax.set_ylabel(r'$|C^{\star}_{bin} - \hat C_{bin}|$')
    else:
        ax.set_xlabel(r'$\widehat \mathcal{L}_{GL}$')
        ax.set_ylabel(r'$|C^{\star} - \hat C|$')


    z = sm.nonparametric.lowess(df[y], df[x], frac=0.75)
    ax.plot(z[:, 0], z[:, 1], lw=1.5, color='black')#lw=0.8)

    return fig


def plot_renditions_calibration(df, x='diff', y='rendition', hue='net'):
    set_latex_font()
    dfgb = df.groupby([y]).mean().reset_index()
    order = dfgb.sort_values(x, ascending=False)[y]

    plt.rc('legend', fontsize=11)
    plt.rc('legend', title_fontsize=12)
    plt.rc('legend', handletextpad=0.5)
    plt.rc('legend', columnspacing=1.3)
    plt.rc('legend', borderpad=0.3)
    plt.rc('legend', borderaxespad=0.2)
    plt.rc('legend', handlelength=1.5)
    plt.rc('legend', labelspacing=0.1)
    plt.rc('xtick', labelsize=7)
    plt.rc('ytick', labelsize=11)
    plt.rc('axes', labelsize=12)

    n_renditions = len(np.unique(df[y]))
    # plt.rc('figure', figsize=(2, 2))
    np.random.seed(0)
    g = sns.catplot(data=df, x=x, y=y, hue=hue, order=order, height=3.5)
    fig = g.figure
    ax = fig.axes[0]
    sns.stripplot(data=dfgb, x=x, y=y, color='black', ax=ax, jitter=0, order=order)

    ax.axvline(0, color='darkgray', lw=1, zorder=0)
    xmin, xmax = ax.get_xlim()
    xabs = max(abs(xmin), abs(xmax))
    # ax.set_xlim((-xabs, xabs))
    ax.set_ylim((n_renditions-0.5, 0.5))

    # Add gray layouts in the background every other rows
    for k in range(1, n_renditions, 2):
        ax.axhspan(k-0.5, k+0.5, color='.93', zorder=-1)

    ax.set_xlabel(r'$\bar{C}_{rendition} - \bar{C}_{all}$')
    ax.set_ylabel('Renditions')
    g.legend.set_title('Network')

    return fig


def plot_Q_vs_S_ex(ex, n=1000, ax=None):
    set_latex_font()

    X, y = ex.generate_X_y(n)
    S = ex.S(X)
    C = ex.C(X)
    Q = ex.Q(X)

    S = np.squeeze(S)
    C = np.squeeze(C)
    Q = np.squeeze(Q)

    if ax is None:
        fig, ax = plt.subplots(figsize=(3, 3))
    else:
        fig = ax.figure

    colors = [to_rgb('tab:orange') if v else to_rgb('tab:blue') for v in y]
    ax.scatter(S, Q, c=colors, edgecolor='white', linewidth=0.5)
    ax.scatter(S, C, color='black', linewidth=0.5, marker='.', s=5)

    lim_margin = 0.05
    ax.set_xlim((-lim_margin, 1 + lim_margin))
    ax.set_ylim((-lim_margin, 1 + lim_margin))

    divider = make_axes_locatable(ax)

    ax_top = divider.append_axes("top", size="10%", pad=0.0)
    ax_top.set_xlim(ax.get_xlim())
    ax_top.get_xaxis().set_visible(False)
    ax_top.get_yaxis().set_visible(False)
    ax_top.spines["right"].set_visible(False)
    ax_top.spines["top"].set_visible(False)
    ax_top.spines["left"].set_visible(False)

    ax_right = divider.append_axes("right", size="10%", pad=0.0)
    ax_right.set_ylim(ax.get_ylim())
    ax_right.get_xaxis().set_visible(False)
    ax_right.get_yaxis().set_visible(False)
    ax_right.spines["right"].set_visible(False)
    ax_right.spines["top"].set_visible(False)
    ax_right.spines["bottom"].set_visible(False)

    n_bins = 10
    bins = np.linspace(0, 1, n_bins + 1)
    ax_top.hist(
        S,
        bins=bins,
        density=False,
        histtype="bar",
        color='silver',
        edgecolor='black',
        linewidth=1,
    )
    ax_right.hist(
        Q,
        bins=bins,
        density=False,
        histtype="bar",
        color='silver',
        orientation='horizontal',
        edgecolor='black',
    )

    for x in bins:
        ax.axvline(x, lw=0.5, ls='--', color='grey', zorder=-1)

    ax.set_aspect('equal')
    ax.set(xlabel='Confidence score $S$', ylabel='True probability $Q$')

    return fig


def plot_QS_1D_ex(ex, disc_gap=0.1, figsize=(2, 2), ax=None, N=100,
                   bbox_to_anchor=(1, 0), loc='lower right', lw=1, frameon=True):
    set_latex_font()

    x_min = ex.x_min
    x_max = ex.x_max
    XX = np.linspace(x_min, x_max, N)
    pdf = ex.dist().pdf(XX)
    S = ex.S(XX)
    Q = ex.Q(XX)

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size="10%", pad=0.0)

    ax_top.plot(XX, pdf, label='$\\mathbb{P}_X$', lw=0.5)
    ax_top.fill_between(XX, pdf, alpha=0.2)
    ax_top.get_yaxis().set_visible(False)
    ax_top.spines['right'].set_visible(False)
    ax_top.spines['top'].set_visible(False)
    ax_top.spines['left'].set_visible(False)
    ax_top.set_xticklabels([])
    ylim = ax_top.get_ylim()
    ax_top.set_ylim((0, ylim[1]))

    ax.axhline(0.5, color='black', lw=0.5)
    _XX, _YY = insert_nan_at_discontinuities(XX, S, min_gap=disc_gap)
    ax.plot(_XX, _YY, label='$S(X)$', color='black', lw=lw)
    _XX, _YY = insert_nan_at_discontinuities(XX, Q, min_gap=disc_gap)
    ax.plot(_XX, _YY, label='$Q(X)$', color='tab:red', ls='-', lw=lw)
    ax.set_xlabel('$X$')
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '$\\frac{1}{2}$', '1'])
    h1, l1 = ax_top.get_legend_handles_labels()
    h2, l2 = ax.get_legend_handles_labels()
    ax.legend(handles=h2+h1, labels=l2+l1, bbox_to_anchor=bbox_to_anchor,
               loc=loc, frameon=frameon)
    ax.spines['top'].set_visible(False)

    return fig


def plot_GL_bounds(df, x):
    set_latex_font()

    print(df)

    _df = pd.melt(df, id_vars=[x], value_vars=[
        'GL',
        'UB_known',
        'UB_ER',
        'UB_acc_sq1',
        'UB_acc_sq2',
    ])

    print(df)

    fig, ax = plt.subplots(1, 1, figsize=(5, 4))
    # sns.lineplot(data=_df, x=x, y='value', hue='variable', ax=ax, marker=".")

    ax.plot(df[x], df['UB_known'], label='$\mathrm{UB}_{known}$', color='tab:blue', marker='.')
    # ax.plot(df[x], df['UB_acc_sq1'], label='$\mathrm{UB}_{acc1}$', color='tab:purple', marker='.')
    # ax.plot(df[x], df['UB_acc_sq2'], label='$\mathrm{UB}_{acc2}$', color='tab:pink', marker='.')
    ax.plot(df[x], df['UB_acc_sq'], label='$\mathrm{UB}_{acc}$', color='tab:red', marker='.')
    ax.plot(df[x], df['UB_ER'], label='$\mathrm{UB}_{ER}$', color='tab:green', marker='.')
    ax.plot(df[x], df['GL'], label='$\mathrm{GL}$', color='black', marker='.', zorder=0)

    ax.legend()
    ax.set(xlabel=f'$\{x}$', ylabel='Value')

    ax.axhline(1/4, lw=0.25, ls='--', color='grey', zorder=-1)
    ax.axhline(1/16, lw=0.25, ls='--', color='grey', zorder=-1)

    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size="30%", pad=0.05)

    ax_top.set_xlim(ax.get_xlim())
    ax_top.get_xaxis().set_visible(False)
    # ax_top.get_yaxis().set_visible(False)
    # ax_top.spines["right"].set_visible(False)
    # ax_top.spines["top"].set_visible(False)
    # ax_top.spines["left"].set_visible(False)

    ax_top.plot(df[x], df['acc_s'], label='$\mathrm{Acc}(S)$', color='tab:purple', marker='.')
    ax_top.plot(df[x], df['acc_bayes'], label='$\mathrm{Acc}(Q)$', color='tab:pink', marker='.')
    ax_top.plot(df[x], df['acc_s_wrt_q'], label='$\mathrm{Acc}(S\|\|Q)$', color='tab:brown', marker='.')
    ax_top.set(ylabel='Accuracy')
    ax_top.legend(ncol=3)

    return fig
