"""Code that generates the figures of the paper."""
import os
from os.path import join
import pandas as pd
import numpy as np
import pytest
import torch
import matplotlib.pyplot as plt
import itertools
import re
from joblib import Memory
import shutil
from pathlib import Path
from sklearn.utils import check_random_state

from .._plot import plot_frac_pos_vs_scores, barplot_ece_gl_brier, plot_lower_bound_vs_acc, plot_lower_bound_vs_brier, plot_lower_bound_vs_ece, plot_brier_acc, plot_brier_ece, plot_fig_theorem, plot_fig_counter_example, barplot_ece_gl, barplots_ece_gl_cal, plot_fig_renditions, plot_cost_vs_gl, plot_renditions_calibration
from .._utils import save_fig, compute_classif_metrics
from ._utils import binarize_multiclass_max, compute_calib_metrics, compute_multi_classif_metrics, calibrate_scores, bin_train_test_split, binarize_multiclass_marginal, calibration_curve
from .main import cluster_evaluate_marginals, cluster_evaluate_max
from .networks import ZeroShotBartYahoo, IMAGENET_ConvNeXt, IMAGENET_AlexNet, IMAGENET_VGG, IMAGENET_ResNet, IMAGENET_DenseNet, IMAGENET_Inception, IMAGENET_GoogLeNet, IMAGENET_ShuffleNet, IMAGENET_MobileNet, IMAGENET_ResNext, IMAGENET_WideResNet, IMAGENET_MNASNet, IMAGENET_EfficientNet, IMAGENET_RegNet, IMAGENET_VisionTransformer, IMAGENET_MLPMixer, Amazon_BertBase, Emotion_DistilBert, ALL_IMAGENET_NETWORKS
from .results import _get_out_kwargs
from .test_results import read_los_prediction_results2
from ..CalibrationExample import SigmoidExample
from .._linalg import create_orthonormal_vector
from .._plot import plot_ffstar_2d_v2
from ..CalibrationExample import CustomUniform, CustomUnconstrained
from .._plot import plot_ffstar_1d, plot_score_vs_probas2

memory = Memory('joblib_cache')


best_versions = {
    IMAGENET_VGG: '19_bn',
    IMAGENET_ResNet: '152',
    IMAGENET_DenseNet: '161',
    IMAGENET_ShuffleNet: '1_0',
    IMAGENET_MobileNet: 'v3L',
    IMAGENET_ResNext: '101',
    IMAGENET_WideResNet: '101',
    IMAGENET_MNASNet: '1_0',
    IMAGENET_EfficientNet: 'b7',
    IMAGENET_RegNet: 'y_32gf',
    IMAGENET_VisionTransformer: 'l_16',
    IMAGENET_ConvNeXt: 'large',
}

rename_versions = {
    '19_bn': '-19 BN',
    '50': '-50',
    '11': '-11',
    '18': '-18',
    '152': '-152',
    '161': '-161',
    '121': '-121',
    '1_0': ' 1.0',
    '0_5': ' 0.5',
    'v3L': ' V3L',
    # 'v3L': ' V3 Large',
    'v2': ' V2',
    '101': '-101',
    'b7': '-B7',
    'b0': '-B0',
    'l_16': ' L-16',
    'b_16': ' B-16',
    'large': ' Large',
    'y_400mf': ' y_400mf',
    'y_32gf': ' y_32gf',
    '': '',
}


# Not used
def test_fig_showcase(out, n_jobs):
    n_bins = 10
    clustering = 'decision_tree'
    test_size = 0.9
    max_clusters_bin = 2
    hist = False
    ci = None #'clopper'
    min_cluster_size = 10
    vary_cluster_size = False
    capsize = 4
    cluster_size = 80
    absolute_size_scale = (10, 5000)
    plot_cal_hist = True
    figsize = (2.5, 2.5)

    net = ZeroShotBartYahoo(split='test_unseen')
    dirpath = net.get_default_dirpath()
    Xt = torch.load(join(dirpath, 'Xt.pt'))
    y_scores = torch.load(join(dirpath, 'y_scores.pt'))
    y_labels = torch.load(join(dirpath, 'y_labels.pt'))

    k = 1  # Plot the positive class
    class_name = net.get_class_names()[k]

    (frac_pos,
     counts,
     mean_scores,
     *_
     ) = cluster_evaluate_marginals(Xt, y_labels, y_scores,
                                    positive_class=k,
                                    bins=n_bins,
                                    clustering=clustering,
                                    test_size=test_size,
                                    min_samples_leaf=None,
                                    max_clusters_bin=max_clusters_bin,
                                    verbose=2,
                                    n_jobs=n_jobs)

    # y_pred_scores, y_well_guess = binarize_multiclass_marginal(np.array(y_scores), np.array(y_labels), k)

    fig = plot_frac_pos_vs_scores(frac_pos,
                                    counts,
                                    mean_scores,
                                    y_scores=None,#y_pred_scores,
                                    y_labels=None,#y_well_guess,
                                    ncol=1,
                                    legend_loc='upper left',
                                    title=None,#class_name.capitalize(),
                                    xlim_margin=0.05,
                                    ylim_margin=0.05,
                                    min_cluster_size=min_cluster_size,
                                    hist=hist,
                                    ci=ci,
                                    legend_cluster_sizes=False,
                                    vary_cluster_size=vary_cluster_size,
                                    capsize=capsize,
                                    xlabel='Confidence score of positive class',
                                    ylabel='Fraction of positives',
                                    cluster_size=cluster_size,
                                    absolute_size_scale=absolute_size_scale,
                                    plot_cal_hist=plot_cal_hist,
                                    figsize=figsize,
                                    )

    out_kwargs = {
        'clustering': clustering,
        'n_bins': n_bins,
        'ci': ci,
        'net': str(net),
        'hist': hist,
        'test_size': test_size,
        'positive_class': class_name,
        'vary_cluster_size': vary_cluster_size,
        'capsize': capsize,
    }

    order = ['net', 'clustering', 'test_size']

    out_kwargs['max_clusters_bin'] = max_clusters_bin
    order.append('max_clusters_bin')

    save_fig(fig, out, order=order, **out_kwargs)



# Not used
@pytest.mark.parametrize('calibrate', [None, 'isotonic'])
def test_fig_nlp_comparison(calibrate, out, n_jobs):
    n_bins = 15
    clustering = 'decision_tree'
    test_size = 0.5
    max_clusters_bin = 2
    hist = True
    ci = 'clopper'
    min_cluster_size = 10
    capsize = 4
    breakout = False
    cluster_size = 100
    vary_cluster_size = False
    absolute_size_scale = (10, 1000)
    plot_cal_hist = True
    min_samples_leaf = None
    n_clusters = None

    bins = np.linspace(0, 1, n_bins+1)

    def compute_one():
        # split = 'test_r'
        nets = [
            ZeroShotBartYahoo(split='test_unseen'),
            ZeroShotBartYahoo(split='test_seen'),
            Amazon_BertBase(split='test'),
            Emotion_DistilBert(split='test'),
        ]

        dfs = []
        for net, calibrate in itertools.product(nets, [None, 'isotonic']):
            print(net, calibrate)
            dirpath = net.get_default_dirpath()
            Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
            y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
            y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

            y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)
            splitter = bin_train_test_split(y_pred_scores, test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
            train_idx, test_idx = next(splitter)

            _y_scores, _y_labels = y_scores, y_labels
            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=True)
                y_labels = y_well_guess

            (frac_pos,
            counts,
            mean_scores,
            *_
            ) = cluster_evaluate_max(Xt, y_labels, y_scores, breakout=breakout,
                                    bins=n_bins, verbose=2,
                                    n_jobs=n_jobs,
                                    min_samples_leaf=min_samples_leaf,
                                    max_clusters_bin=max_clusters_bin,
                                    clustering=clustering,
                                    n_clusters=n_clusters,
                                    test_size=test_idx,
                                    binary_scores=calibrate is not None,
                                    )

            y_pred_scores, y_well_guess = binarize_multiclass_max(_y_scores, _y_labels)

            extra_out_kwargs = {
                'split': net.split,
                'dataset': net.get_dataset_name(),
                'network': net.get_class_name(False),
                'network+version': net.get_class_name(True),
                'calibrate': str(calibrate),
                'new': True,
            }
            name = str(net)
            order = []
            out_kwargs, order = _get_out_kwargs(
                clustering, n_bins, ci, name, hist, test_size,
                calibrate, max_clusters_bin, min_samples_leaf, n_clusters,
                min_cluster_size, extra_out_kwargs, order)
            out_kwargs['breakout'] = breakout

            metrics = {}
            metrics.update(out_kwargs)
            metrics.update(compute_multi_classif_metrics(_y_scores, _y_labels))
            metrics_binarized = compute_classif_metrics(y_pred_scores, y_well_guess)
            metrics_binarized.update(compute_calib_metrics(frac_pos, counts, y_pred_scores, y_well_guess, bins))
            metrics_binarized = {f'binarized_{k}': v for k, v in metrics_binarized.items()}
            metrics.update(metrics_binarized)

            df = pd.DataFrame([metrics])

            dfs.append(df)

        df = pd.concat(dfs, axis=0, ignore_index=True)
        return df

    os.makedirs(out, exist_ok=True)
    path = join(out, f'metrics_nlp.csv')
    if not os.path.exists(path):
        df = compute_one()
        df.to_csv(path)

    df = pd.read_csv(path, index_col=0)

    print(df)

    df = df.query(f'calibrate == "{calibrate}"')

    # dfs = [pd.read_csv(path, index_col=0) for path in df_paths.values()]
    # df = pd.concat(dfs, axis=0, ignore_index=True)
    print(df)

    ece = df['binarized_msce']
    gl_lower_bound_debiased = np.maximum(df['binarized_lower_bound_debiased'], 0)
    gl_lower_bound = df['binarized_lower_bound']
    brier = df['binarized_brier']
    acc = df['acc']
    names = df['name']
    net_names = df['network+version']
    # splits = df['split']
    splits = None

    kwargs = {'cal': calibrate}

    fig = barplot_ece_gl_brier(net_names, ece, gl_lower_bound, gl_lower_bound_debiased, brier, acc)
    save_fig(fig, out, **kwargs)

    fig = plot_lower_bound_vs_acc(net_names, gl_lower_bound_debiased, acc, style=splits)
    save_fig(fig, out, which='lower_bound', **kwargs)


    fig = plot_lower_bound_vs_brier(net_names, gl_lower_bound_debiased, brier, style=splits)
    save_fig(fig, out, which='brier', **kwargs)


    fig = plot_lower_bound_vs_ece(net_names, gl_lower_bound_debiased, ece, style=splits)
    save_fig(fig, out, which='ece', **kwargs)


    fig = plot_brier_acc(net_names, brier, acc, style=splits)
    save_fig(fig, out, which='brieracc', **kwargs)


    fig = plot_brier_ece(net_names, brier, ece, style=splits)
    save_fig(fig, out, which='brierece', **kwargs)


# Not used
@pytest.mark.parametrize('calibrate', [None, 'isotonic'])
def test_fig_nlp_comparison_binary(calibrate, out, n_jobs):
    n_bins = 15
    clustering = 'decision_tree'
    test_size = 0.5
    max_clusters_bin = 2
    hist = True
    ci = 'clopper'
    min_cluster_size = 10
    capsize = 4
    breakout = False
    cluster_size = 100
    vary_cluster_size = False
    absolute_size_scale = (10, 1000)
    plot_cal_hist = True
    min_samples_leaf = None
    n_clusters = None

    bins = np.linspace(0, 1, n_bins+1)

    def compute_one():
        # split = 'test_r'
        nets = [
            ZeroShotBartYahoo(split='test_unseen'),
            ZeroShotBartYahoo(split='test_seen'),
        ]

        dfs = []
        for net, calibrate in itertools.product(nets, [None, 'isotonic']):
            print(net, calibrate)
            dirpath = net.get_default_dirpath()
            Xt = torch.load(join(dirpath, 'Xt.pt')).numpy()
            y_scores = torch.load(join(dirpath, 'y_scores.pt')).numpy()
            y_labels = torch.load(join(dirpath, 'y_labels.pt')).numpy()

            k = 1  # Plot the positive class

            # y_pred_scores, y_well_guess = binarize_multiclass_marginal(y_scores, y_labels, k)
            splitter = bin_train_test_split(y_scores[:, k], test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
            train_idx, test_idx = next(splitter)

            if calibrate is not None:
                y_scores, _ = calibrate_scores(y_scores, y_labels, method=calibrate,
                                            test_size=test_idx, max_calibration=False)
                # y_labels = y_well_guess

            (frac_pos,
            counts,
            mean_scores,
            *_
            ) = cluster_evaluate_marginals(Xt, y_labels, y_scores,
                                            positive_class=k,
                                            bins=n_bins,
                                            clustering=clustering,
                                            test_size=test_idx,
                                            min_samples_leaf=None,
                                            max_clusters_bin=max_clusters_bin,
                                            verbose=2,
                                            n_jobs=n_jobs)

            y_scores = y_scores[:, k]

            extra_out_kwargs = {
                'split': net.split,
                'dataset': net.get_dataset_name(),
                'network': net.get_class_name(False),
                'network+version': net.get_class_name(True),
                'calibrate': str(calibrate),
                'new': True,
            }
            name = str(net)
            order = []
            out_kwargs, order = _get_out_kwargs(
                clustering, n_bins, ci, name, hist, test_size,
                calibrate, max_clusters_bin, min_samples_leaf, n_clusters,
                min_cluster_size, extra_out_kwargs, order)
            out_kwargs['breakout'] = breakout

            metrics = {}
            metrics.update(out_kwargs)
            metrics.update(compute_classif_metrics(y_scores, y_labels))
            metrics.update(compute_calib_metrics(frac_pos, counts, y_scores, y_labels, bins))

            df = pd.DataFrame([metrics])

            dfs.append(df)

        df = pd.concat(dfs, axis=0, ignore_index=True)
        return df

    os.makedirs(out, exist_ok=True)
    path = join(out, f'metrics_nlp_yahoo.csv')
    if not os.path.exists(path):
        df = compute_one()
        df.to_csv(path)

    df = pd.read_csv(path, index_col=0)

    print(df)

    df = df.query(f'calibrate == "{calibrate}"')

    # dfs = [pd.read_csv(path, index_col=0) for path in df_paths.values()]
    # df = pd.concat(dfs, axis=0, ignore_index=True)
    print(df)

    ece = df['msce']
    gl_lower_bound_debiased = df['lower_bound_debiased']
    # gl_lower_bound_debiased = np.maximum(df['lower_bound_debiased'], 0)
    gl_lower_bound = df['lower_bound']
    brier = df['brier']
    acc = df['acc']
    names = df['name']
    net_names = df['network+version']
    # splits = df['split']
    splits = None

    kwargs = {'cal': calibrate}

    fig = barplot_ece_gl_brier(net_names, ece, gl_lower_bound, gl_lower_bound_debiased, brier, acc,
    label_brier='Brier', label_lb='Lower bound', label_lbd='Lower bound debiased', label_ece='L2-ECE')
    save_fig(fig, out, **kwargs)

    fig = plot_lower_bound_vs_acc(net_names, gl_lower_bound_debiased, acc, style=splits)
    save_fig(fig, out, which='lower_bound', **kwargs)


    fig = plot_lower_bound_vs_brier(net_names, gl_lower_bound_debiased, brier, style=splits)
    save_fig(fig, out, which='brier', **kwargs)


    fig = plot_lower_bound_vs_ece(net_names, gl_lower_bound_debiased, ece, style=splits)
    save_fig(fig, out, which='ece', **kwargs)


    fig = plot_brier_acc(net_names, brier, acc, style=splits)
    save_fig(fig, out, which='brieracc', **kwargs)


    fig = plot_brier_ece(net_names, brier, ece, style=splits)
    save_fig(fig, out, which='brierece', **kwargs)


# Not used
def test_fig_mimic_cost_evaluation(out, n_jobs):

    dirpath = 'mimic_dist'
    clustering = 'na_2'
    calibrate = 'isotonic'
    test_size = 0.5
    n_bins = 15
    max_clusters_bin = 2
    debiased = True

    has_splits = {
            '_los_prediction_horizon7_discharged_shift_is_medicare_or_medicaid': False,
            '_los_prediction_horizon7_discharged_shift_is_first_unit_surgical': False,
            '_los_prediction_horizon7_discharged': True,
            '_los_prediction_horizon3_discharged': True,
    }
    folders = list(has_splits.keys())

    estimators = [
        'hist_gradient_boosting_classifier',
        'random_forest_classifier',
        'lstm',
        'logistic_regression',
    ]

    from .test_results import read_los_prediction_results2
    from pathlib import Path
    from ._utils import grouping_loss_lower_bound

    def cost_binary(y_scores_pos, y_labels, is_na_pos,
                    cost_pos_na_pos=1,
                    cost_pos_na_neg=1,
                    cost_neg_na_pos=1,
                    cost_neg_na_neg=1,
                    pos_label=1):
        assert y_scores_pos.ndim == 1
        assert is_na_pos.ndim == 1
        y_labels = np.array(y_labels == pos_label, dtype=int)

        predicted_costs = y_scores_pos*(cost_pos_na_pos*is_na_pos + cost_pos_na_neg*(1 - is_na_pos)) \
            + (1 - y_scores_pos)*(cost_neg_na_pos*is_na_pos + cost_neg_na_neg*(1 - is_na_pos))

        oracle_costs = y_labels*(cost_pos_na_pos*is_na_pos + cost_pos_na_neg*(1 - is_na_pos)) \
            + (1 - y_labels)*(cost_neg_na_pos*is_na_pos + cost_neg_na_neg*(1 - is_na_pos))

        return predicted_costs, oracle_costs

    @memory.cache()
    def compute_one_(folder, estimator, rs, min_cluster_size=5, debiased=False,
                cost_pos_na_pos=100,
                cost_pos_na_neg=50,
                cost_neg_na_pos=1,
                cost_neg_na_neg=10,
                test_size=None,
                ):
        xp_path = join(dirpath, folder)
        xps = read_los_prediction_results2(Path(xp_path))

        estimator_filename = f'estimator={estimator}__split_rs={rs}'

        xp = xps[estimator_filename]
        X = xp['X_imputed']
        y_labels = xp['y']
        y_scores = xp['y_pred']
        strat = xp['stratification_features']

        r = re.match(r'na_([A-z,0-9])*', clustering)
        n = int(r.group(1))
        _clustering = strat['n_codes_missing']
        age_min = np.min(_clustering)
        age_max = np.max(_clustering)
        qs = np.linspace(0, 1, n + 1)[1:-1]
        bins = [age_min - 1] + list(np.quantile(_clustering, qs)) + [age_max + 1]
        _clustering = np.digitize(_clustering, bins=bins)

        if test_size is not None:
            splitter = bin_train_test_split(y_scores[:, 1], test_size=test_size, n_splits=1, bins=n_bins, random_state=0)
            train_idx, test_idx = next(splitter)
        else:
            test_idx = None

        # Post hoc calibration
        if calibrate is not None:
            y_scores, _ = calibrate_scores(y_scores, y_labels, method='isotonic',
                                        test_size=test_idx, max_calibration=False)

        (frac_pos,
        counts,
        mean_scores,
        *_
        ) = cluster_evaluate_marginals(X, y_labels, y_scores,
                                        positive_class=1,
                                        bins=n_bins,
                                        clustering=_clustering,
                                        test_size=test_idx,
                                        min_samples_leaf=None,
                                        max_clusters_bin=max_clusters_bin,
                                        verbose=2,
                                        n_jobs=n_jobs)

        gl = grouping_loss_lower_bound(
            frac_pos, counts, reduce_bin=False, debiased=debiased, return_bias=False
        )
        # Unormalize GL
        gl = gl*np.divide(np.sum(counts), np.sum(counts, axis=1))
        y_scores_pos = y_scores[:, 1]  # score of positive clas
        is_na_pos = _clustering - 1  # na clustering assignment in {0, 1}

        # Cut per bin
        bins = np.linspace(0, 1, n_bins + 1)
        y_bins = np.digitize(y_scores_pos, bins=bins) - 1
        y_bins = np.clip(y_bins, a_min=None, a_max=n_bins-1)

        n_samples = len(y_scores_pos)

        predicted_costs_bin = np.full(n_bins, np.nan)
        oracle_costs_bin = np.full(n_bins, np.nan)
        mse_bin = np.full(n_bins, np.nan)

        for i in range(n_bins):
            y_labels_bin = y_labels[y_bins == i]
            y_scores_pos_bin = y_scores_pos[y_bins == i]
            is_na_pos_bin = is_na_pos[y_bins == i]

            predicted_costs, oracle_costs = cost_binary(y_scores_pos_bin, y_labels_bin, is_na_pos_bin,
                cost_pos_na_pos=cost_pos_na_pos,
                cost_pos_na_neg=cost_pos_na_neg,
                cost_neg_na_pos=cost_neg_na_pos,
                cost_neg_na_neg=cost_neg_na_neg,
            )

            predicted_costs_bin[i] = np.mean(predicted_costs)
            oracle_costs_bin[i] = np.mean(oracle_costs)
            mse_bin[i] = np.mean(np.square(oracle_costs - predicted_costs))

        big_enough = np.any(counts >= min_cluster_size, axis=1)
        idx_pos = gl >= 0
        idx = np.logical_and(idx_pos, big_enough)

        error_bin = np.absolute(predicted_costs_bin - oracle_costs_bin)

        # n_samples = np.sum(idx)
        n_samples = n_bins
        return pd.DataFrame({
            'bins': np.arange(n_bins),
            'error_bin': error_bin,
            'gl': gl,
            'idx': idx,
            'idx_pos': idx_pos,
            'big_enough': big_enough,
            'estimator_filename': np.full(n_samples, estimator_filename),
            'estimator': np.full(n_samples, estimator),
            'folder': np.full(n_samples, folder),
            'rs': np.full(n_samples, rs)
        })
        # return pd.DataFrame({
        #     'bins': np.arange(n_bins)[idx],
        #     'error_bin': error_bin[idx],
        #     'gl': gl[idx],
        #     'estimator_filename': np.full(n_samples, estimator_filename),
        #     'estimator': np.full(n_samples, estimator),
        #     'folder': np.full(n_samples, folder),
        #     'rs': np.full(n_samples, rs)
        # })

    dfs = []
    for folder in folders:
        if has_splits[folder]:
            splits = np.arange(5)
        else:
            splits = [0]

        for split in splits:
            for estimator in estimators:
                df = compute_one_(folder, estimator, split, debiased=debiased, test_size=test_size)
                dfs.append(df)

    df = pd.concat(dfs, axis=0)

    print(df)

    rename_estimator = {
        'lstm': 'LSTM',
        'hist_gradient_boosting_classifier': 'Gradient boosting',
        'random_forest_classifier': 'Random Forest',
        'logistic_regression': 'Logistic',
    }

    rename_folder = {
            '_los_prediction_horizon7_discharged_shift_is_medicare_or_medicaid': r'$\pm7j$ shift insurance',
            '_los_prediction_horizon7_discharged_shift_is_first_unit_surgical': r'$\pm7j$ shift unit',
            '_los_prediction_horizon7_discharged': r'$\pm7j$',
            '_los_prediction_horizon3_discharged': r'$\pm3j$',
    }

    df = df.query('big_enough == True')
    # df = df.query('gl >= 0.')
    df = df.query('gl >= 0.001')
    # df = df.query('gl >= -0.003')
    # df = df.query('gl <= 0.01')
    df = df.query('gl <= 0.015')

    df['folder'] = df['folder'].replace(rename_folder)
    df['task'] = df['folder']
    _df = df.copy()
    df['estimator'] = df['estimator'].replace(rename_estimator)
    df['fold_est'] = df['folder'] + ' ' + df['estimator']


    fig = plot_cost_vs_gl(df, x='gl', y='error_bin', hue='estimator', style='task')
    save_fig(fig, out, debiased=debiased, test_size=test_size)

    for estimator in estimators:

        subdf = _df.query('estimator == @estimator')
        fig = plot_cost_vs_gl(subdf, x='gl', y='error_bin', hue='task')
        save_fig(fig, out, est=estimator, debiased=debiased, test_size=test_size)
