import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.lines as lines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import seaborn as sns
from calibration._linalg import create_orthonormal_vector
from calibration._plot import (plot_ffstar_1d, plot_score_vs_probas2,
                               set_latex_font)
from calibration._utils import compute_classif_metrics, save_fig
from calibration.CalibrationExample import CustomUniform, SigmoidExample, Steps
from calibration.xp_nn_calibration._utils import compute_calib_metrics
from calibration.xp_nn_calibration.main import cluster_evaluate
from matplotlib.colors import LogNorm, Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_metrics(df, var, bayes_opt=None, rev=False):
    if bayes_opt is not None and 'bayes_opt' in df.columns:
        df = df.query('bayes_opt == @bayes_opt')
    df['bayes_rate'] = df['acc'] - df['acc_bayes']
    df['ece_l2'] = np.square(df['rmsce'])
    print(df)

    # cmap = 'Reds'
    cmap = 'RdBu' if rev else 'RdBu_r'
    vmin = np.nanmin(df[var])
    vmax = np.nanmax(df[var])
    print(vmin, vmax)
    norm = Normalize(vmin, vmax)
    m = cm.ScalarMappable(norm=norm, cmap=cmap)

    extrema = np.inf if rev else -np.inf
    df[var] = df[var].fillna(extrema)

    fig, axes = plt.subplots(6, 4, figsize=(10, 13), gridspec_kw=dict(wspace=0.8, hspace=0.8))
    pairs = [
        (var, 'bayes_rate'),
        (var, 'acc'),
        (var, 'auroc'),
        None,
        (var, 'ece_l2'),
        (var, 'GL'),
        # (var, 'mse'),
        (var, 'brier_bayes'),
        (var, 'brier'),
        # (var, ''),
        # None,
        ('GL', 'mse'),
        ('lower_bound_debiased', 'GL'),
        None,
        None,
        ('brier', 'bayes_rate'),
        ('acc', 'bayes_rate'),
        ('auroc', 'bayes_rate'),
        ('GL', 'bayes_rate'),
        # None,
        # ('lower_bound_debiased', 'bayes_rate'),
        # ('mse', 'bayes_rate'),
        # None,
        ('brier', 'mse'),
        ('acc', 'mse'),
        ('auroc', 'mse'),
        ('GL', 'mse'),
        # ('GL', 'mse'),
        # ('lower_bound_debiased', 'mse'),
        # None,
        # None,
        # None,
        ('brier', 'cost_error'),
        ('acc', 'cost_error'),
        ('auroc', 'cost_error'),
        ('GL', 'cost_error'),
        # ('lower_bound_debiased', 'cost_error'),
        # ('mse', 'mse'),
        # ('brier', 'bayes_rate'),
        # ('brier', 'auroc'),
        # ('brier', 'ece_l2'),
        # ('brier', 'GL'),
        # ('brier', 'mse'),
        # ('auroc', 'bayes_rate'),
        # ('GL', 'mse'),
        # ('GL', 'auroc'),
        # (var, 'brier_bayes'),
        # ('auroc', 'mse'),
    ]

    rename_labels = {
        # var: r'$\alpha$',
        'bayes_rate': 'Bayes rate',
        'mse': 'MSE',
        'auroc': 'AUC',
        'cost_error': 'Cost error',
        'ece_l2': r'ECE $\ell^2$',
        'brier': 'Brier',
        'brier_bayes': 'Brier Bayes (IL)',
        'lower_bound_debiased': 'Lower bound debiased',
        'acc': 'Accuracy',
    }

    # for i, pair in enumerate(pairs):
    for i in range(len(fig.axes)):
        pair = None if i >= len(pairs) else pairs[i]
        if pair is None:
            fig.axes[i].axis('off')
            continue
        x, y = pair
        # if y == 'bayes_rate' and bayes_opt:
        #     continue  # skip
        ax = fig.axes[i]
        # sns.scatterplot(data=df, x=x, y=y, ax=ax, hue='bayes_opt', legend=False, color=cmap(norm(df['')
        # print(cmap(norm(np.array(df[var].astype(float)))))
        # print(m.to_rgba(np.array(df[var].astype(float))))
        ax.scatter(df[x], df[y], color=m.to_rgba(df[var]))#, edgecolor='black', lw=0.5)
        rx = rename_labels.get(x, x)
        ry = rename_labels.get(y, y)
        rx = rx.replace('_', ' ')
        ry = ry.replace('_', ' ')
        ax.set_xlabel(rx)
        ax.set_ylabel(ry)
        if 'ece' in y:
            ylim = ax.get_ylim()
            ax.set_ylim(0, 100*ylim[1])
        # ax.set_aspect('equal')

    cb_ax_id = 3
    # cb_ax_id = -5
    fig.axes[cb_ax_id].axis('on')
    # fig.axes[cb_ax_id].set_aspect(5)
    cb = plt.colorbar(m, cax=fig.axes[cb_ax_id])
    cb.ax.set_title(var)

    # dh = 1./6
    # for i in range(0,7):
    # for i in [0.2, 1.05, 2.02, 3.96]:
    #     y = i
    # line = mpl.lines.Line2D([0.2, 0.8], [0.2, 0.8], lw=10, ls='-', color='black',
    #     alpha=1, transform=fig.transFigure)

    for y in [0.62]:
        fig.add_artist(lines.Line2D([0.05, .9], [y, y], color='silver', lw=0.5))
    # fig.add_artist(lines.Line2D([0, 1], [1, 0]))

    return fig


def plot_simu(df, x='n_samples_per_cluster_per_bin', legend=True, only_strat='uniform', ax=None):
    set_latex_font()
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(2.5, 2))
    else:
        fig = ax.figure
    df['LB_est'] = df['LB_biased'] - df['bias'] - df['GL_ind']

    idx_means = {}
    for strategy, ls in [('uniform', '-'), ('quantile', '--')]:
        subdf = df.query('n_size_one_clusters > 0 and strategy == @strategy')
        print(subdf)
        idx = subdf.groupby('trial').aggregate({x: max})
        print(idx)
        print(idx.mean())
        idx_mean = idx.mean().item()
        idx_means[strategy] = idx_mean

        # ax.axvline(idx_mean, color='lightgray', ls=ls, lw=1)

    idx_mean_max = np.nanmax(list(idx_means.values()))

    _df = df.melt(id_vars=[x, 'strategy', 'trial', 'n_size_one_clusters'], value_vars=['LB_biased', 'bias', 'GL_ind', 'LB_est', 'GL'])

    # Discard points of GL_LB that are mainly negative
    for strategy in ['uniform', 'quantile']:
        __df = _df.query('variable == "LB_est" and value < 0 and strategy == @strategy')
        x_to_discard, counts = np.unique(__df[x], return_counts=True)
        x_to_discard = x_to_discard[counts >= 0.5*__df.shape[0]]
        print(x_to_discard)
        print(counts)
        print(_df.shape)
        _df = _df.query(f'{x} not in @x_to_discard or strategy != @strategy or variable != "LB_est"')
        print(_df.shape)

    # Filter to eliminate parts of curves that are invalid
    idx_uniform = int(np.nan_to_num(idx_means["uniform"]))
    idx_quantile = int(np.nan_to_num(idx_means["quantile"]))
    print(idx_quantile)
    print(idx_uniform)
    _df = _df.query(
        f'strategy == "uniform" and {x} >= {idx_uniform}'
        f' or '
        f'strategy == "quantile" and {x} >= {idx_quantile}'
        f' or '
        f'variable in ["LB_biased", "GL_ind", "GL"]'
        )

    if only_strat is not None:
        _df = _df.query(f'strategy == "{only_strat}"')

    _df['strategy'] = _df['strategy'].replace({
        'uniform': 'Equal-width',
        'quantile': 'Equal-mass',
    })
    var_replace = {
        'LB_biased': r'$\widehat{\mathrm{GL}}_{\textit{\scriptsize plugin}}$',
        'bias': r'$\widehat{\mathrm{GL}}_{\textit{\scriptsize bias}}$',
        # 'bias': r'$\widehat{\mathrm{bias}}$',
        'GL_ind': r'$\widehat{\mathrm{GL}}_{\textit{\scriptsize induced}}$',
        'LB_est': r'$\widehat{\mathrm{GL}}_{\mathrm{LB}}$',
        'GL': r'True $\mathrm{GL}$',
    }
    _df['variable'] = _df['variable'].replace(var_replace)
    _df.rename({'strategy': 'Binning'}, axis=1, inplace=True)

    # _df = _df.query('Binning == "Equal-width"')
    style = None if only_strat is not None else 'Binning'
    style_order = None if only_strat is not None else ['Equal-width', 'Equal-mass'],

    hue_order = [var_replace[v] for v in [
        'GL',
        'LB_biased',
        'bias',
        'GL_ind',
        'LB_est',
    ]]

    palette = ['black', 'tab:blue', 'tab:orange', 'tab:green', 'tab:red']

    sns.lineplot(data=_df, x=x, y='value', hue='variable',
                 #  style='Binning',
                 #  style_order=['Equal-width', 'Equal-mass'],
                 style=style,
                 style_order=style_order,
                 ax=ax,
                 errorbar=("sd", 1),
                 legend='auto' if legend else False,
                 palette=palette,
                 hue_order=hue_order,
                 err_kws=dict(edgecolor='none'),
                 )

    # print(df)

    # print(df[x])

    if not np.isnan(idx_mean_max):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.fill_betweenx([-1, 1], -idx_mean_max, idx_mean_max, color='lightgray', edgecolor='none', alpha=0.5, zorder=0)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

    # print(df['n_size_one_clusters'])
    # idx = (df['n_size_one_clusters'] > 0)
    # print(idx)
    # y = df[idx, [x]]
    # print(y)

    # divider = make_axes_locatable(ax)
    # ax_top = divider.append_axes("top", size="30%", pad=0.05)
    # twinx = ax_top.twinx()

    # df.rename({'strategy': 'Strategy'}, axis=1, inplace=True)
    # df['prop_size_one_clusters'] = 100*df['n_size_one_clusters']/df['n_nonzero_clusters']
    # sns.lineplot(data=df, x=x, y='n_samples_per_cluster', style='Strategy', ax=ax_top, color='olive', legend=False, style_order=['Equal-width', 'Equal-mass'], errorbar=("sd", 1))
    # sns.lineplot(data=df, x=x, y='prop_size_one_clusters', style='Strategy', ax=twinx, color='brown', legend=False, style_order=['Equal-width', 'Equal-mass'], errorbar=("sd", 1))

    # ax_top.get_xaxis().set_visible(False)
    # ax_top.set_yscale('log')
    # ax_top.set(ylabel='Samples per cluster')
    # twinx.set(ylabel='Proportion of\nsize-1 clusters\n(%)')
    # ax_top.tick_params(axis='y', colors='olive')
    # ax_top.spines['left'].set_color('olive')
    # twinx.tick_params(axis='y', colors='brown')
    # twinx.spines['right'].set_color('brown')

    if legend:
        handles, labels = ax.get_legend_handles_labels()
        print(plt.rcParams['legend.borderaxespad'])
        plt.rc('legend', borderaxespad=0.1)

        if style_order is None:
            handles = handles[1:]
            labels = labels[1:]

        _legend = ax.legend(handles=handles, labels=labels, ncol=1, bbox_to_anchor=(1, 1), loc='upper left', fancybox=False, framealpha=1)
        frame = _legend.get_frame()
        frame.set_linewidth(0)
    ax.set(ylabel=None)
    # ax.set(ylabel='GL')

    return fig


def plot_fig_binning(N=1000, n_bins=2):
    set_latex_font()
    # plt.rc('legend', fontsize=14)
    plt.rc('legend', borderpad=0.4)
    plt.rc('legend', borderaxespad=0.1)
    # plt.rc('legend', columnspacing=0.7)
    plt.rc('legend', columnspacing=1.2)
    plt.rc('legend', handletextpad=0.5)
    # plt.rc('legend', labelspacing=0.1)
    # plt.rc('legend', handlelength=1.6)
    # plt.rc('legend', labelspacing=0.2)
    # plt.rc('legend', handletextpad=0.4)

    fig, ax = plt.subplots(1, 1, figsize=(1.8, 1.8))
    # fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.25))
    plot_first_last_bins = False

    # Plot bins
    bins = np.linspace(0, 1, n_bins+1)
    _bins = bins if plot_first_last_bins else bins[1:-1]
    for i, x in enumerate(_bins):
        # label = 'Bin' if i == 0 else None
        label = 'Bin edge' if i == 0 else None
        # label = None
        ax.axvline(x, lw=0.5, ls='--', color='grey', zorder=-1, label=label)

    # Plot calibration curve
    S = np.linspace(0, 1, N)
    def c(s):
        return np.square(s)
    C = c(S)
    ax.plot(S, C, color='black', label='$C$')

    # Plot binned calibration curve
    for i in range(n_bins):
        a = bins[i]
        b = bins[i+1]
        CB = 1/(3*(b-a))*(b**3 - a**3)
        label = r'$C_B$' if i == 0 else None
        ax.plot([a, b], [CB, CB], color='black', label=label, ls='--')

        label = '$S_B$' if i == 0 else None
        line = ax.scatter((a+b)/2, 0, color='black', label=label)
        line.set_clip_on(False)

        Sab = np.linspace(a, b, N//n_bins)
        M = len(Sab)

        label = r'$\mathrm{GL}_{induced}$' if i == 0 else None
        # label = r'$\mathrm{GL}_{induced}$' if i == 0 else None
        ax.fill_between(Sab, c(Sab), [CB]*M, color='tab:red', label=label, edgecolor='none', zorder=-2)

    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    # ax.set_xlabel('S')
    d = 0.08
    ax.annotate('$S$', xy=(0.5, -d), xytext=(0.5, -d),
                xycoords='axes fraction', ha='center', va='top',
                fontsize=plt.rcParams['legend.fontsize'],
                )

    handles, labels = ax.get_legend_handles_labels()
    handles = list(np.roll(handles, 1))
    labels = list(np.roll(labels, 1))
    handles[0], handles[1] = handles[1], handles[0]
    labels[0], labels[1] = labels[1], labels[0]
    _legend = ax.legend(ncol=1, framealpha=0, loc='upper left', bbox_to_anchor=(0, 1), handles=handles, labels=labels)
    # ax.legend(ncol=1, loc='upper left', bbox_to_anchor=(0, 1))
    # ax.legend(fancybox=True, framealpha=1, ncol=1, loc='upper left', bbox_to_anchor=(0, 1))
    # ax.legend(ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1))
    # ax.legend(ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1))
    # ax.legend(fancybox=True, framealpha=1, ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1))

    frame = _legend.get_frame()
    frame.set_linewidth(0)
    return fig
