from pathlib import Path
import pickle
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.lines as mlines

import cfg
import pandas as pd
import numpy as np
from src.figures import (merge_results as merge_results,
                         figures as figures,
                         tables as tables,
                         result_table as result_table,
                         result_merge as result_merge)

# create a df with all available result files
if False:
    result_table.create()
    result_table.approximate_learning_curve()
    result_merge.create()


# figure: AL learning curve schematic
if False:
    fig, ax = plt.subplots(1, 1, figsize=(4.3, 2.2))
    figures.fig_al_schematic(ax)
    fig.tight_layout()
    fig.subplots_adjust(left=0.13, right=0.995, top=1, bottom=0.2)
    plt.savefig(Path(cfg.path_fig, 'sf_schematic.pdf'))
    plt.show()

"""
MSCOCO: SPECIFIC FIGURES
"""
# learning curve
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = ('mscoco', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL')
    sampling_methods = ['random', 'ratio_max']
    linestyle_dict = {2: ':',
                      20: '-',
                      200: '--',
                      'augmented': ':',
                      'sl': '-',
                      'semi-sl': '--',
                      'random': ':',
                      'tl': '-',
                      'self-sl': '--',
                      'frozen': '-',
                      'finetune': ':',
                      'finetune-last-2': '-.',
                      'finetune-last-5': '--'}

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            if row == 0 and col == 0:
                for budget in [2, 20, 200]:
                    current_dataset = datasets[:2] + (budget, budget,) + datasets[4:]
                    figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                  sampling_methods, linestyle=linestyle_dict[budget])
                line_2 = mlines.Line2D([], [], color='black', linestyle=':', label='2')
                line_20 = mlines.Line2D([], [], color='black', linestyle='-', label='20')
                line_200 = mlines.Line2D([], [], color='black', linestyle='--', label='200')
                ax[row, col].legend(handles=[line_2, line_20, line_200], title='Initial Budget / Query Size',
                                    frameon=False, loc='lower right')
            elif row == 0 and col == 1:
                for train_paradigm in cfg.all_train_paradigms:
                    current_dataset = datasets[:2] + (200, 200) + datasets[4:5] + (train_paradigm,) + datasets[6:]
                    figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                  sampling_methods,
                                                                  linestyle=linestyle_dict[train_paradigm])
                line_aug = mlines.Line2D([], [], color='black', linestyle=':', label='SL (incl. augmentations)')
                line_sl = mlines.Line2D([], [], color='black', linestyle='-', label='SL')
                line_semisl = mlines.Line2D([], [], color='black', linestyle='--', label='Semi-SL')
                ax[row, col].legend(handles=[line_aug, line_sl, line_semisl], title='Training Paradigm',
                                    frameon=False, loc='lower right')
            elif row == 1 and col == 0:
                for weight_init in cfg.all_weight_inits:
                    current_dataset = datasets[:6] + (weight_init, 'finetune') + datasets[8:]
                    figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                  sampling_methods,
                                                                  linestyle=linestyle_dict[weight_init])
                    line_rand = mlines.Line2D([], [], color='black', linestyle=':', label='Random')
                    line_tl = mlines.Line2D([], [], color='black', linestyle='-', label='Transfer Learning')
                    line_selfsl = mlines.Line2D([], [], color='black', linestyle='--', label='Self-SL')
                    ax[row, col].legend(handles=[line_rand, line_tl, line_selfsl], title='Weight Initialization',
                                        frameon=False, loc='lower right')
            elif row == 1 and col == 1:
                for training in cfg.all_training:
                    current_dataset = datasets[:7] + (training,) + datasets[8:]
                    figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                  sampling_methods, linestyle=linestyle_dict[training])
                    line_finetune = mlines.Line2D([], [], color='black', linestyle=':', label='Fine-Tune')
                    line_frozen = mlines.Line2D([], [], color='black', linestyle='-', label='Frozen')
                    ax[row, col].legend(handles=[line_finetune, line_frozen], title='Training Strategy',
                                        frameon=False, loc='lower right')

            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            #ax[row, col].set_title(cfg.figures[datasets[row+col][0]]['text'])
            ax[row, col].set_xlim([0, 2000])
    plt.setp(ax[-1, :], xlabel='# training samples')
    plt.setp(ax[:, 0], ylabel='Macro F1 Score')
    fig.tight_layout()
    fig.show()

# metric change
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = ('mscoco', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 80, 'ALL')
    sampling_methods = ['random', 'ratio_max']

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            if row == 0 and col == 0:
                for budget in [2, 20, 200]:
                    current_dataset = datasets[:2] + (budget, budget,) + datasets[4:]
                    figures.fig_performance_metrics_over_training_samples(ax[row, col], current_dataset, results,
                                                                          sampling_methods)

            if row == 0 and col == 1:
                for train_paradigm in ['sl', 'semi-sl']:
                    current_dataset = datasets[:5] + (train_paradigm,) + datasets[6:]
                    figures.fig_performance_metrics_over_training_samples(ax[row, col], current_dataset, results,
                                                                          sampling_methods)

            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            #ax[row, col].set_title(cfg.figures[datasets[row+col][0]]['text'])
            ax[row, col].set_xlim([700, 1400])
    plt.setp(ax[-1, :], xlabel='# training samples')
    plt.setp(ax[:, 0], ylabel='Macro F1 Score')
    fig.tight_layout()
    fig.show()

"""
2K FIGURES
"""
# figure: 2k datasets, all datasets, performance over iterations
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 37, 'ALL'),
                ('mscoco', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL'),
                ('reuters', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 31, 'ALL'),
                ('scene', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 6, 'ALL'),]
    sampling_methods = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_performance_over_training_samples(ax[row, col], datasets[2*row+col], results, sampling_methods)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[datasets[row+col][0]]['text'])
            ax[row, col].set_xlim([0, 2000])
            if row == 0 and col == 1:
                ax[row, col].legend(frameon=False, labelspacing=0.2)
    plt.setp(ax[-1, :], xlabel='# training samples')
    plt.setp(ax[:, 0], ylabel='Macro F1 Score')
    fig.tight_layout()
    fig.show()

# figure: 2k single_label, all datasets, frac(AL_samples / rand_samples) over performance
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 37, 'ALL'),
                ('mscoco', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL'),
                ('reuters', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 31, 'ALL'),
                ('scene', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 6, 'ALL'),]
    sampling_methods = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_frac_al_rand_over_performance(ax[row, col], datasets[2*row+col], results, sampling_methods)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            dataset = datasets[2*row+col][0]
            ax[row, col].set_title(cfg.figures[dataset]['text'])
            if dataset == 'carina':
                ax[row, col].set_xlim([0.17, 0.35])
                ax[row, col].set_ylim([0.6, 1.5])
            elif dataset == 'mscoco':
                ax[row, col].set_xlim([0.05, 0.45])
                ax[row, col].set_ylim([0.7, 1.35])
            elif dataset == 'reuters':
                ax[row, col].set_xlim([0.1, 0.35])
                ax[row, col].set_ylim([0.87, 1.2])
            elif dataset == 'scene':
                ax[row, col].set_xlim([0.09, 0.7])
                ax[row, col].set_ylim([0.4, 1.2])

            if row == 0 and col == 1:
                ax[row, col].legend(frameon=False, labelspacing=0.2)
    plt.setp(ax[-1, :], xlabel='Performance [Macro F1 Score]')
    fig.supylabel('# training samples AL / # training samples RANDOM')
    fig.tight_layout()
    fig.show()

# figure: 2k performance over number training samples with approximation
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 37, 'ALL'),
                ('mscoco', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL'),
                ('reuters', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 31, 'ALL'),
                ('scene', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 6, 'ALL')]
    sampling_methods = [['random', 'multilabel_simple_crw'],
                        ['random', 'ratio_max'],
                        ['random', 'kmeans'],
                        ['random', 'bald']]

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_learning_curve_approximation(ax[row, col], datasets[2*row+col], results, sampling_methods[2*row+col])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[datasets[row+col][0]]['text'])

    plt.setp(ax[:, 0], ylabel='Performance')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()

# figure: 2k Comparison different evaluation metrics
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 37, 'ALL'),
                ('mscoco', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL'),
                ('reuters', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 31, 'ALL'),
                ('scene', '2k', 40, 20, 2000, 'sl', 'tl', 'frozen', 6, 'ALL')]
    sampling_methods = [['random', 'multilabel_simple_crw'],
                        ['random', 'ratio_max'],
                        ['random', 'kmeans'],
                        ['random', 'bald']]

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_performance_metrics_over_training_samples(ax[row, col], datasets[2*row+col], results,
                                                                  sampling_methods[2*row+col])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[datasets[2*row + col][0]]['text'])

    plt.setp(ax[:, 0], ylabel='Performance Metric Change [%]')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()



"""
COMPLETE FIGURES
"""
# figure: complete datasets, all datasets, performance over iterations
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 37, 'ALL'),
                ('mscoco', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 80, 'ALL'),
                ('reuters', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 31, 'ALL'),
                ('scene', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 6, 'ALL')]
    sampling_methods = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_performance_over_training_samples(ax[row, col], datasets[2*row+col], results, sampling_methods)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[datasets[2*row+col][0]]['text'])
            ax[row, col].set_xlim([0, 1400])
            if row == 0 and col == 1:
                ax[row, col].legend(frameon=False, labelspacing=0.2)
    plt.setp(ax[-1, :], xlabel='# training samples')
    plt.setp(ax[:, 0], ylabel='Macro F1 Score')
    fig.tight_layout()
    fig.show()

# figure: complete performance over number training samples with approximation
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 's'),
                ('mscoco', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'orange'),
                ('reuters', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'earn'),
                ('scene', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'Sunset')]
    sampling_methods = [['random', 'multilabel_simple_crw'],
                        ['random', 'ratio_max'],
                        ['random', 'kmeans'],
                        ['random', 'bald']]

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_learning_curve_approximation(ax[row, col], datasets[2*row+col], results, sampling_methods[2*row+col])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[datasets[row+col][0]]['text'])
            fig.show()

    plt.setp(ax[:, 0], ylabel='Performance')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()

# figure: complete Comparison different evaluation metrics
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    nrows = 2
    ncols = 2
    #(dataset, dataset_type, init_train_sample, add_train_sample, max_train_samples, train_paradigm, weight_init,
    # training, nr_classes, classes)
    datasets = [('carina', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 's'),
                ('mscoco', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'orange'),
                ('reuters', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'earn'),
                ('scene', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'Sunset')]
    sampling_methods = [['random', 'multilabel_simple_crw'],
                        ['random', 'ratio_max'],
                        ['random', 'kmeans'],
                        ['random', 'bald']]

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5), sharex=True)
    for row in range(nrows):
        for col in range(ncols):
            figures.fig_performance_metrics_over_training_samples(ax[row, col], datasets[2*row+col], results,
                                                                  sampling_methods[2*row+col])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[datasets[2*row + col][0]]['text'])

    plt.setp(ax[:, 0], ylabel='Performance Metric Change [%]')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()




"""
PROPERTY FIGURES
"""
# figure: ALL samples, different evaluation metrics
if False:
    datasets_ml_rare_freq = [
        ('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw'], 98663, 'g'),
        ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max'], 118287, 'car'),
        ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans'], 53571, 'interest'),
        ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'], 2407, 'Sunset')]
    #dict_performance_metric = tables.table_al_metric_comparison(datasets_ml_rare_freq, 'multi_label')
    with open(Path(cfg.path_fig, 'dict_performance_metric_change_ml.pkl'), 'rb') as f:
        dict_performance_metric = pickle.load(f)
    nrows = 2
    ncols = 2
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, _, sampling_methods, _, _ = datasets_ml_rare_freq[2 * row + col]
            figures.fig_performance_metrics_over_training_samples(ax[row, col], dataset, 'multi_label',
                                                                  dict_performance_metric, sampling_methods[1])

            ax[row, col].set_title(cfg.figures[dataset]['text'])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)

    plt.setp(ax[:, 0], ylabel='Performance Metric Change [%]')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()

# figure: 2k samples, different evaluation metrics
if False:
    datasets_sl_all = [('carina', ['1b81389580c2de87bb6f790dfdd6bd13'], ['random', 'multilabel_simple_crw']),
                       ('mscoco', ['a8592bba830413ec16f00f2e6ed9dec4'], ['random', 'ratio_max']),
                       ('reuters', ['1af4bfdc5ffcca8d3d5df20ce543840a'], ['random', 'kmeans']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'], ['random', 'bald'])]
    #dict_performance_metric = tables.table_al_metric_comparison(datasets_sl_all, 'single_label')
    with open(Path(cfg.path_fig, 'dict_performance_metric_change_sl.pkl'), 'rb') as f:
        dict_performance_metric = pickle.load(f)
    nrows = 2
    ncols = 2
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, hashs, sampling_methods = datasets_sl_all[2 * row + col]
            figures.fig_performance_metrics_over_training_samples(ax[row, col], dataset, 'single_label',
                                                                  dict_performance_metric, sampling_methods[1])

            ax[row, col].set_title(cfg.figures[dataset]['text'])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)

    plt.setp(ax[:, 0], ylabel='Performance Metric Change [%]')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()

# figure: Time_qm/time_rand over nr training samples 2k
if False:
    nrows = 2
    ncols = 2
    datasets_sl_all = [('carina', ['1b81389580c2de87bb6f790dfdd6bd13'], ['random', 'multilabel_simple_crw']),
                       ('mscoco', ['a8592bba830413ec16f00f2e6ed9dec4'], ['random', 'ratio_max']),
                       ('reuters', ['1af4bfdc5ffcca8d3d5df20ce543840a'], ['random', 'kmeans']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'], ['random', 'bald'])]
    sampling_methods_all = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, hashs, _ = datasets_sl_all[2 * row + col]
            print(dataset)

            figures.fig_processing_time_over_training_samples(ax[row, col], 'single_label', dataset, hashs,
                                                              sampling_methods_all, tag='evaluation')
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'])
            ax[row, col].set_xlabel('# added training samples')
            ax[row, col].set_ylabel('$t_{qm}/t_{rand}$')
            ax[row, col].set_yscale('log')
            x = 1

    fig.tight_layout()
    fig.show()

# figure: Time_qm/time_rand over nr training samples ALL DATA
if False:
    nrows = 2
    ncols = 2
    datasets_ml_all = [('carina', ['0faf6596d6158f1f9713e1360f22d328']),
                       ('mscoco', ['e4eba74b67f7755aea568080a6a0bf3e']),
                       ('reuters', ['ceef4756f7256c06275c7b1e1c3265c5']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'])]
    sampling_methods_all = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, hashs, = datasets_ml_all[2 * row + col]
            print(dataset)

            figures.fig_processing_time_over_training_samples(ax[row, col], 'multi_label', dataset, hashs,
                                                              sampling_methods_all, tag='evaluation')
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'])
            ax[row, col].set_xlabel('# added training samples')
            ax[row, col].set_ylabel('$t_{qm}/t_{rand}$')
            ax[row, col].set_yscale('log')
            x = 1

    fig.tight_layout()
    fig.show()



""" FIGURES: HELPER"""
# print all single classes whole dataset one sampling method
if False:
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    datasets = [('carina', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'ALL'),
                ('mscoco', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'ALL'),
                ('reuters', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'ALL'),
                ('scene', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'ALL')]
    sampling_methods = [['random', 'multilabel_simple_crw'],
                        ['random', 'ratio_max'],
                        ['random', 'kmeans'],
                        ['random', 'bald']]

    index0 = 0
    for tuple_ in datasets:
        qms = sampling_methods[index0]
        index0 = index0 + 1

        # get unique classes
        df_filter = results[(results['dataset'] == tuple_[0]) & (results['nr_classes'] == 1)]
        all_classes = set()
        for class_list in df_filter['classes']:
            all_classes.update(class_list)
        all_classes = list(all_classes)
        all_classes_no_prefix = [cls[len(cfg.label_prefix):] if cls.startswith(cfg.label_prefix) else cls for cls in all_classes]


        fig, ax = plt.subplots(1, len(all_classes_no_prefix), figsize=(len(all_classes_no_prefix) * 5, 5))
        index = 0
        for label in all_classes_no_prefix:
            print(tuple_[0], ' ', label)
            tuple_class = tuple_[:-1] + (label,)
            figures.fig_learning_curve_approximation(ax[index], tuple_class, results, qms)


            # figures.fig_learning_curve_approximation(df, ax[row_idx], [hash_], sampling_method)
            ax[index].spines['right'].set_visible(False)
            ax[index].spines['top'].set_visible(False)
            ax[index].set_title(label)

            index = index + 1

        fig.tight_layout()
        plt.savefig(Path(cfg.path_fig, tuple_[0] + '.pdf'))
        fig.show()






""" FIGURES ABOUT AL PROPERTY INVESTIGATION."""
# figure: speedup factor over number events
if False:
    nrows = 2
    ncols = 2
    qm_sl_single = [('carina', 'multilabel_simple_crw'),
                    ('mscoco', 'ratio_max'),
                    ('reuters', 'kmeans'),
                    ('scene', 'bald')]
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, sampling_method = qm_sl_single[2 * row + col]

            figures.fig_sf_over_events(ax[row, col], 'single_label', dataset, sampling_method)
            ax[row, col].set_xscale('log')
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'] + sampling_method)
            ax[row, col].set_xlabel('# Events')
            ax[row, col].set_ylabel('Speedup factor')

    fig.tight_layout()
    fig.show()

# figure: performance over # added training samples, freq and rare species with fit
if False:
    nrows = 2
    ncols = 2
    datasets = [('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw']),
                ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max']),
                ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans']),
                ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'])]
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, hashs, sampling_methods = datasets[2 * row + col]
            figures.fig_learning_curve_approximation(ax[row, col], 'multi_label', dataset, hashs, sampling_methods)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'])

    plt.setp(ax[:, 0], ylabel='Performance')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()

# figure: ALL samples, % events over nr training samples
if False:
    nrows = 2
    ncols = 2
    datasets = [('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw']),
                ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max']),
                ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans']),
                ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'])]
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, hashs, sampling_methods = datasets[2 * row + col]
            figures.fig_events_over_training_samples(ax[row, col], 'multi_label', dataset, hashs, sampling_methods)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'])

    plt.setp(ax[:, 0], ylabel='positive samples [%]')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()

# figure: ALL samples, speedup factor over # classes
if False:
    nrows = 2
    ncols = 2
    datasets = [('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw'], 98663),
                ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max'], 118287),
                ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans'], 53571),
                ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'], 2407)]
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, _, sampling_methods, dataset_size = datasets[2 * row + col]
            print(dataset)
            figures.fig_sf_over_classes(ax[row, col], 'multi_label', dataset, sampling_methods[1], dataset_size)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'])

    plt.setp(ax[:, 0], ylabel='Speedup Factor')
    plt.setp(ax[-1, :], xlabel='# classes')
    fig.tight_layout()
    fig.show()

# figure: ALL samples, % events over nr training samples, 1 class, DS with 1-4 classes
if False:
    nrows = 2
    ncols = 2
    datasets = [('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw'], 98663, 'g'),
                ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max'], 118287, 'car'),
                ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans'], 53571, 'interest'),
                ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'], 2407, 'Sunset')]

    fig, ax = plt.subplots(nrows, ncols, figsize=(10, 5))
    for row in range(nrows):
        for col in range(ncols):
            dataset, _, sampling_methods, dataset_size, label = datasets[2 * row + col]
            print(dataset)
            figures.fig_events_over_training_samples_one_class(ax[row, col], 'multi_label', dataset,
                                                               sampling_methods[1], dataset_size, label)
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].set_title(cfg.figures[dataset]['text'] + ' | ' + label)

    plt.setp(ax[:, 0], ylabel='positive samples [%]')
    plt.setp(ax[-1, :], xlabel='# added training samples')
    fig.tight_layout()
    fig.show()





""" FINAL FIGURE PAPER"""
# figure: all figures for the paper OLD
if False:
    nrows = 10
    ncols = 4
    datasets_sl_all = [('carina', ['1b81389580c2de87bb6f790dfdd6bd13'], ['random', 'multilabel_simple_crw']),
                       ('mscoco', ['a8592bba830413ec16f00f2e6ed9dec4'], ['random', 'ratio_max']),
                       ('reuters', ['1af4bfdc5ffcca8d3d5df20ce543840a'], ['random', 'kmeans']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'], ['random', 'bald'])]
    sampling_methods_all = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']

    datasets_ml_all = [('carina', ['0faf6596d6158f1f9713e1360f22d328']),
                       ('mscoco', ['e4eba74b67f7755aea568080a6a0bf3e']),
                       ('reuters', ['ceef4756f7256c06275c7b1e1c3265c5']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'])]

    datasets_ml_rare_freq = [
        ('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw'], 98663, 'g'),
        ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max'], 118287, 'car'),
        ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans'], 53571, 'interest'),
        ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'], 2407, 'Sunset')]

    fig, ax = plt.subplots(nrows, ncols,
                           figsize=(16, 19),
                           gridspec_kw={'hspace': .5, 'height_ratios': [1, 1, 1, 1, -.1, 1, 1, 1, 1, 1], 'wspace': .15}
                           )

    for row in range(nrows):
        for col in range(ncols):
            print(f'Row {row + 1}/{nrows} | Col {col + 1}/{ncols}')

            # row 0: Performance over # train samples, ALL classes, ALL samples are 2000
            if row == 0:
                dataset, hashs, _ = datasets_sl_all[col]
                figures.fig_performance_over_training_samples(ax[row, col], 'single_label', dataset, hashs,
                                                              sampling_methods_all)
                ax[row, col].set_title(cfg.figures[dataset]['text'] + '$_{2k}$', fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 2000])
            # row 1: fraction train samples AL/fraction train samples rand over performance
            elif row == 1:
                dataset, hashs, _ = datasets_sl_all[col]
                figures.fig_frac_al_rand_over_performance(ax[row, col], 'single_label', dataset, hashs,
                                                          sampling_methods_all)
                ax[row, col].set_xlabel('Performance')
                ax[row, 0].set_ylabel('$x_{qm}/x_{rand}$')
            # row 2: learning curve approximation fit, selected QMs
            elif row == 2:
                dataset, hashs, sampling_methods = datasets_sl_all[col]
                figures.fig_learning_curve_approximation(ax[row, col], 'single_label', dataset, hashs, sampling_methods)
                ax[row, col].set_xlabel('# added training samples')
                ax[row, 0].set_ylabel('Performance')
                ax[row, col].set_xlim([0, 2000])
            # row 3: speedup factor over # events
            elif row == 3:
                dataset, _, sampling_methods = datasets_sl_all[col]
                sampling_method = sampling_methods[1]
                figures.fig_sf_over_events(ax[row, col], 'single_label', dataset, sampling_method)
                ax[row, col].set_xscale('log')
                ax[row, col].set_xlabel('# positive samples')
                ax[row, 0].set_ylabel('Speedup Factor')
            # row 4: Performance over # train samples, ALL classes, ALL samples
            elif row == 4:
                ax[row, col].set_visible(False)
            elif row == 5:
                dataset, hashs = datasets_ml_all[col]
                figures.fig_performance_over_training_samples(ax[row, col], 'multi_label', dataset, hashs,
                                                              sampling_methods_all)
                ax[row, col].set_title(cfg.figures[dataset]['text'], fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')

                ax[row, col].set_xlim([0, 1400])
                ax[row, 3].legend(frameon=False, labelspacing=0.2)
            # row 5: Performance over # train samples, rare and freq DS, ALL samples
            elif row == 6:
                dataset, hashs, sampling_methods, _, _ = datasets_ml_rare_freq[col]
                figures.fig_learning_curve_approximation(ax[row, col], 'multi_label', dataset, hashs, sampling_methods)
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])
            elif row == 7:
                dataset, hashs, sampling_methods, _, _ = datasets_ml_rare_freq[col]
                figures.fig_events_over_training_samples(ax[row, col], 'multi_label', dataset, hashs, sampling_methods)
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='positive samples [%]')
                ax[row, col].set_xlim([0, 1400])
            elif row == 8:
                dataset, _, sampling_methods, dataset_size, _ = datasets_ml_rare_freq[col]
                figures.fig_sf_over_classes(ax[row, col], 'multi_label', dataset, sampling_methods[1], dataset_size)
                plt.setp(ax[row, col], xlabel='# classes')
                plt.setp(ax[row, 0], ylabel='Speedup Factor')
            elif row == 9:
                dataset, _, sampling_methods, dataset_size, label = datasets_ml_rare_freq[col]
                figures.fig_events_over_training_samples_one_class(ax[row, col], 'multi_label', dataset,
                                                                   sampling_methods[1], dataset_size, label)
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='positive samples [%]')
                ax[row, col].set_xlim([0, 1400])

            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            ax[row, col].yaxis.set_major_formatter(
                ticker.FuncFormatter(lambda x, pos: f'{int(x)}' if x.is_integer() else f'{x:.1f}'.lstrip('0')))

    fig.subplots_adjust(left=0.0325, right=0.989, top=0.988, bottom=0.025)
    # change legends
    for row in range(nrows):
        for col in range(ncols):
            if row == 0 and col == 3:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', bbox_to_anchor=(1.085, 0))
            elif row == 1 and col == 3:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', bbox_to_anchor=(1.085, 0))
                ax[row, col].set_xlim([None, 0.9])
            elif row == 2:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', bbox_to_anchor=(1.085, 0))
            elif row == 3:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', bbox_to_anchor=(1.085, 0))
            elif row == 5 and col == 3:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', bbox_to_anchor=(1.085, 0))
            elif row == 6:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', bbox_to_anchor=(1.085, 0))

    plt.savefig(Path(cfg.path_fig, 'experiments.pdf'))
    fig.show()

# figure: all figures for paper IJCAI
if False:
    nrows = 10
    ncols = 4
    datasets_sl_all = [('carina', ['1b81389580c2de87bb6f790dfdd6bd13'], ['random', 'multilabel_simple_crw']),
                       ('mscoco', ['a8592bba830413ec16f00f2e6ed9dec4'], ['random', 'ratio_max']),
                       ('reuters', ['1af4bfdc5ffcca8d3d5df20ce543840a'], ['random', 'kmeans']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'], ['random', 'bald'])]
    sampling_methods_all = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']

    datasets_ml_all = [('carina', ['0faf6596d6158f1f9713e1360f22d328']),
                       ('mscoco', ['e4eba74b67f7755aea568080a6a0bf3e']),
                       ('reuters', ['ceef4756f7256c06275c7b1e1c3265c5']),
                       ('scene', ['a268ec081b548d36cc1ac961023431e9'])]

    datasets_ml_rare_freq = [
        ('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw'], 98663, 'g'),
        ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max'], 118287, 'car'),
        ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans'], 53571, 'interest'),
        ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'], 2407, 'Sunset')]

    with open(Path(cfg.path_fig, 'dict_performance_metric_change_ml.pkl'), 'rb') as f:
        dict_performance_metric_ml = pickle.load(f)

    with open(Path(cfg.path_fig, 'dict_performance_metric_change_sl.pkl'), 'rb') as f:
        dict_performance_metric_sl = pickle.load(f)

    if False:
        tables.table_al_metric_comparison(datasets_ml_rare_freq)

    fig, ax = plt.subplots(nrows, ncols,
                           figsize=(16, 19),
                           gridspec_kw={'hspace': .5, 'height_ratios': [1, 1, 1, 1, 1, -.1, 1, 1, 1, 1], 'wspace': .15}
                           )

    for row in range(nrows):
        for col in range(ncols):
            print(f'Row {row + 1}/{nrows} | Col {col + 1}/{ncols}')

            # row 0: Performance over # train samples, ALL classes, ALL samples are 2000
            if row == 0:
                dataset, hashs, _ = datasets_sl_all[col]
                figures.fig_performance_over_training_samples(ax[row, col], 'single_label', dataset, hashs,
                                                              sampling_methods_all)
                ax[row, col].set_title(cfg.figures[dataset]['text'] + '$_{2k}$', fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 2000])
            # row 1: fraction train samples AL/fraction train samples rand over performance
            elif row == 1:
                dataset, hashs, _ = datasets_sl_all[col]
                figures.fig_frac_al_rand_over_performance(ax[row, col], 'single_label', dataset, hashs,
                                                          sampling_methods_all)
                ax[row, col].set_xlabel('Performance')
                ax[row, 0].set_ylabel('$x_{qm}/x_{rand}$')
                # set limits for better visualisation
                if col == 0:
                    xlim = [0.17, None]
                    ylim = [0.6, 1.5]
                elif col == 1:
                    xlim = [0.05, None]
                    ylim = [0.7, 1.35]
                elif col == 2:
                    xlim = [0.06, None]
                    ylim = [0.87, 1.2]
                elif col == 3:
                    xlim = [0.09, None]
                    ylim = [0.4, 1.2]
                ax[row, col].set_xlim(xlim)
                ax[row, col].set_ylim(ylim)
            # row 2: learning curve approximation fit, selected QMs
            elif row == 2:
                dataset, hashs, sampling_methods = datasets_sl_all[col]
                S = figures.fig_learning_curve_approximation(ax[row, col], 'single_label', dataset, hashs, sampling_methods)
                ax[row, col].set_xlabel('# added training samples')
                ax[row, 0].set_ylabel('Performance')
                ax[row, col].set_xlim([0, 2000])

                # plot speedup factor hline in plot above
                color = cfg.figures[sampling_methods[1]]['color']
                label = 'S ' + cfg.figures[sampling_methods[1]]['label']
                ax[row-1, col].axhline(y=S, linestyle='--', color=color, xmin=0, xmax=1, label=label)
                # Only keep the last handle and label
                handles, labels = ax[row - 1, col].get_legend_handles_labels()
                if col == 1:
                    ax[row - 1, col].legend([handles[-1]], [labels[-1]], frameon=False, loc='lower center', bbox_to_anchor=(0.4, -0.01))
                else:
                    ax[row - 1, col].legend([handles[-1]], [labels[-1]], frameon=False)
            # Metric performance change over nr training samples
            elif row == 3:
                dataset, hashs, sampling_methods = datasets_sl_all[col]
                figures.fig_performance_metrics_over_training_samples(ax[row, col], dataset, 'single_label',
                                                                      dict_performance_metric_sl, sampling_methods[1])

                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='AL Perf. Metric Change [%]')
                ax[row, col].set_xlim([700, 2000])

            # row 4: processing time over nr training samples
            elif row == 4:
                dataset, hashs, _ = datasets_sl_all[col]
                figures.fig_processing_time_over_training_samples(ax[row, col], 'single_label', dataset, hashs,
                                                                  sampling_methods_all)
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='$t_{qm}/t_{rand}$')
                ax[row, col].set_yscale('log')
                ax[row, col].set_xlim([0, 2000])
            elif row == 5:
                ax[row, col].set_visible(False)
            # row 6: Performance over # train samples, ALL classes, ALL samples
            elif row == 6:
                dataset, hashs = datasets_ml_all[col]
                figures.fig_performance_over_training_samples(ax[row, col], 'multi_label', dataset, hashs,
                                                              sampling_methods_all)
                ax[row, col].set_title(cfg.figures[dataset]['text'], fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])
            # row 7: Performance over # train samples, rare and freq DS, ALL samples
            elif row == 7:
                dataset, hashs, sampling_methods, _, _ = datasets_ml_rare_freq[col]
                figures.fig_learning_curve_approximation(ax[row, col], 'multi_label', dataset, hashs, sampling_methods)
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])
            # Metric performance change over nr training samples
            elif row == 8:
                dataset, _, sampling_methods, _, _ = datasets_ml_rare_freq[col]
                figures.fig_performance_metrics_over_training_samples(ax[row, col], dataset, 'multi_label',
                                                                      dict_performance_metric_ml, sampling_methods[1])
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='AL Perf. Metric Change [%]')
                ax[row, col].set_xlim([700, 1400])
            # row 3: processing time over nr training samples
            elif row == 9:
                dataset, hashs, = datasets_ml_all[col]
                figures.fig_processing_time_over_training_samples(ax[row, col], 'multi_label', dataset, hashs,
                                                                  sampling_methods_all, tag='evaluation')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='$t_{qm}/t_{rand}$')
                ax[row, col].set_yscale('log')
                ax[row, col].set_xlim([0, 1400])
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)
            if ax[row, col].get_yscale() != 'log':
                if row == 2 and col == 0:
                    ax[row, col].yaxis.set_major_formatter(
                        ticker.FuncFormatter(lambda x, pos: f'{int(x)}' if x.is_integer() else f'{x:.2f}'.lstrip('0')))
                else:
                    ax[row, col].yaxis.set_major_formatter(
                        ticker.FuncFormatter(lambda x, pos: f'{int(x)}' if x.is_integer() else f'{x:.1f}'.lstrip('0')))

    # legend below plot
    handles, labels = ax[0, 0].get_legend_handles_labels()
    ph = [plt.plot([], marker='', ls='')[0]]  # dummy
    handles = ph + handles
    labels = ['Legend:'] + labels
    fig.legend(handles, labels, loc='lower center', ncol=8, bbox_to_anchor=(0.5, -0.002))

    # crop out figure
    fig.subplots_adjust(left=0.06, right=0.989, top=0.988, bottom=0.04)

    # change legends
    row_labels = ['A)', 'B)', 'C)', 'D)', 'E)', 'F)', 'F)', 'G)', 'H)', 'I)']
    for row in range(nrows):
        for col in range(ncols):
            if col == 0:
                ax[row, col].text(-0.25, 0.5, row_labels[row], transform=ax[row, col].transAxes, fontsize=14,
                                  fontweight='bold', va='center', ha='center')
            if row == 2:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right')
            elif row == 3:
                _, _, sampling_methods = datasets_sl_all[col]
                legend_title = cfg.figures[sampling_methods[1]]['label']
                if col in [1, 3]:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, loc='upper right')
                else:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.7, 0.45))
            elif row == 7:
                if col == 1:
                    ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', ncol=2, columnspacing=0,
                                        bbox_to_anchor=(1.085, -0.08))
                else:
                    ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right')
            elif row == 8:
                _, _, sampling_methods, _, _ = datasets_ml_rare_freq[col]
                legend_title = cfg.figures[sampling_methods[1]]['label']
                if col == 0:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.3, 0.25))
                else:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, loc='upper right')

    plt.savefig(Path(cfg.path_fig, 'experiments_IJCAI.pdf'))
    fig.show()


# figure: AAAI submitted / MULTILABEL
if False:
    # load result file
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    # define properties to plot for 2k datasets
    datasets_2k = [('carina', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 37, 'ALL'),
                   ('mscoco', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL'),
                   ('reuters', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 31, 'ALL'),
                   ('scene', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 6, 'ALL')]
    # define properties to plot for complete datasets
    datasets_complete = [('carina', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 37, 'ALL'),
                         ('mscoco', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 80, 'ALL'),
                         ('reuters', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 31, 'ALL'),
                         ('scene', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 6, 'ALL')]
    # define properties to plot for complete datasets with 1 class
    datasets_complete_1class = [('carina', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 's'),
                                ('mscoco', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'orange'),
                                ('reuters', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'earn'),
                                ('scene', 'complete', 40, 20, 1400, 'sl', 'tl', 'frozen', 1, 'Sunset')]
    # datasets for first 2 rows
    first2rows_datasets = ('mscoco', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL')

    # all sampling methods selected
    sampling_methods = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']
    # sampling methods for the specific datasets
    sampling_methods_specific = [['random', 'multilabel_simple_crw'],
                                 ['random', 'ratio_max'],
                                 ['random', 'kmeans'],
                                 ['random', 'badge']]
    # sampling methods first 2 rows
    first2rows_sampling_methods = ['random', 'ratio_max']

    linestyle_dict_first2rows = {2: ':',
                                 20: '-',
                                 200: '--',
                                 'augmented': ':',
                                 'sl': '-',
                                 'semi-sl': '--',
                                 'random': ':',
                                 'tl': '-',
                                 'self-sl': '--',
                                 'frozen': '-',
                                 'finetune': ':',
                                 'finetune-last-2': '-.',
                                 'finetune-last-5': '--'}

    # create figure
    nrows = 11
    ncols = 4
    fig, ax = plt.subplots(nrows, ncols, figsize=(16, 19),
                           gridspec_kw={'hspace': .5,
                                        'height_ratios': [1, 1, -.1, 1, 1, 1, 1, -.1, 1, 1, 1],
                                        'wspace': .15}
                           )

    # iterate over all subfigures and plot figures
    for row in range(nrows):
        for col in range(ncols):
            print(f'Row {row}/{nrows-1} | Col {col}/{ncols-1}')

            # row 0: Dataset: MSCOCO: Performance over # train samples,
            if row == 0:
                if col == 0:
                    for budget in cfg.all_budgets:
                        current_dataset = first2rows_datasets[:2] + (budget, budget,) + first2rows_datasets[4:]
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[budget])
                elif col == 1:
                    for train_paradigm in cfg.all_train_paradigms:
                        current_dataset = (first2rows_datasets[:2] + (200, 200,) + first2rows_datasets[4:5] +
                                           (train_paradigm,) + first2rows_datasets[6:])
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[train_paradigm])
                elif col == 2:
                    for weight_init in cfg.all_weight_inits:
                        current_dataset = first2rows_datasets[:6] + (weight_init, 'finetune') + first2rows_datasets[8:]
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[weight_init],
                                                                      moving_avg=True)
                    ax[row, col].set_ylim([0, 0.1])
                elif col == 3:
                    for training in cfg.all_training:
                        current_dataset = first2rows_datasets[:7] + (training,) + first2rows_datasets[8:]
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[training])
                    ax[row, col].set_ylim([-.09, None])

                ax[row, col].set_xlabel('# added training samples')
                ax[row, 0].set_ylabel('Performance')
                ax[row, col].set_xlim([0, 2000])
                ax[row, col].set_title('MS COCO$_{2k}$', fontweight='bold')

            # row 1: Dataset: MSCOCO: Metric performance change over nr training samples
            elif row == 1:
                if col == 0:
                    for budget in cfg.all_budgets:
                        current_dataset = first2rows_datasets[:2] + (budget, budget,) + first2rows_datasets[4:]
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[budget])
                    ax[row, col].set_xlim([0.12, None])
                    ax[row, col].set_ylim([0.9, 1.3])

                elif col == 1:
                    for train_paradigm in cfg.all_train_paradigms:
                        current_dataset = (first2rows_datasets[:2] + (200, 200,) + first2rows_datasets[4:5] +
                                           (train_paradigm,) + first2rows_datasets[6:])
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[train_paradigm])
                    ax[row, col].set_xlim([0.16, None])
                    ax[row, col].set_ylim([0.84, 1.2])

                elif col == 2:
                    for weight_init in cfg.all_weight_inits:
                        current_dataset = first2rows_datasets[:6] + (weight_init, 'finetune') + first2rows_datasets[8:]
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[weight_init])
                    ax[row, col].set_xlim([0.03, None])
                    ax[row, col].set_ylim([-4, None])

                elif col == 3:
                    for training in cfg.all_training:
                        current_dataset = first2rows_datasets[:7] + (training,) + first2rows_datasets[8:]
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[training])
                    ax[row, col].set_xlim([0.15, None])
                    ax[row, col].set_ylim([0.5, 1.5])

                ax[row, col].set_xlabel('Performance')
                ax[row, 0].set_ylabel('$x_{qm}/x_{rand}$')

            # row 2: empty row for vertical space
            elif row == 2:
                ax[row, col].set_visible(False)

            # row 3: 2k datasets: Performance all QMs
            elif row == 3:
                figures.fig_performance_over_training_samples(ax[row, col], datasets_2k[col], results, sampling_methods)
                ax[row, col].set_title(cfg.figures[datasets_2k[col][0]]['text'] + '$_{2k}$', fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 2000])

            # row 4: 2k datasets: fraction train samples AL/fraction train samples rand over performance
            elif row == 4:
                figures.fig_frac_al_rand_over_performance(ax[row, col], datasets_2k[col], results, sampling_methods)
                ax[row, col].set_xlabel('Performance')
                ax[row, 0].set_ylabel('$x_{qm}/x_{rand}$')
                # set limits for better visualisation
                if col == 0:
                    xlim = [0.225, None]
                    ylim = [0.86, 1.4]
                elif col == 1:
                    xlim = [0.15, None]
                    ylim = [0.3, 2]
                elif col == 2:
                    xlim = [0.1, None]
                    ylim = [0.7, 1.15]
                elif col == 3:
                    xlim = [0.25, None]
                    ylim = [0.2, 1.35]
                ax[row, col].set_xlim(xlim)
                ax[row, col].set_ylim(ylim)

            # row 5: 2k datasets: Learning curve with SF approximation, selected QM
            elif row == 5:
                S = figures.fig_learning_curve_approximation(ax[row, col], datasets_2k[col], results,
                                                             sampling_methods_specific[col])

                ax[row, col].set_xlabel('# added training samples')
                ax[row, 0].set_ylabel('Performance')
                ax[row, col].set_xlim([0, 2000])

                # plot speedup factor hline in plot above
                qm = sampling_methods_specific[col][1]
                color = cfg.figures[qm]['color']
                label = 'S ' + cfg.figures[qm]['label']
                ax[row - 1, col].axhline(y=S, linestyle='--', color=color, xmin=0, xmax=1, label=label)
                # Only keep the last handle and label
                handles, labels = ax[row - 1, col].get_legend_handles_labels()
                ax[row - 1, col].legend([handles[-1]], [labels[-1]], frameon=False, loc='lower right', bbox_to_anchor=(1,-.09))

            # row 6: 2k datasets: Metric stability
            elif row == 6:
                figures.fig_performance_metrics_over_training_samples(ax[row, col], datasets_2k[col], results,
                                                                      sampling_methods_specific[col])

                plt.setp(ax[row, col], xlabel='Stop Budget')
                plt.setp(ax[row, 0], ylabel='AL Perf. Metric Change [%]')
                ax[row, col].set_xlim([700, 2000])

            # row 7: empty row for vertical space
            elif row == 7:
                ax[row, col].set_visible(False)

            # row 8: complete datasets: Performance all QMs
            elif row == 8:
                figures.fig_performance_over_training_samples(ax[row, col], datasets_complete[col], results,
                                                              sampling_methods)

                ax[row, col].set_title(cfg.figures[datasets_complete[col][0]]['text'], fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])

            # row 9: complete datasets: Performance over # train samples, single label, different performance
            elif row == 9:
                figures.fig_learning_curve_approximation(ax[row, col], datasets_complete_1class[col], results,
                                                         sampling_methods_specific[col])
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])

            # row 10: complete datasets: Metric stability
            elif row == 10:
                figures.fig_performance_metrics_over_training_samples(ax[row, col], datasets_complete_1class[col],
                                                                      results, sampling_methods_specific[col])
                plt.setp(ax[row, col], xlabel='Stop Budget')
                plt.setp(ax[row, 0], ylabel='AL Perf. Metric Change [%]')
                ax[row, col].set_xlim([700, 1400])

            # set upper and right bound to invisible
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)

            # delete leading 0 for yaxis
            ax[row, col].yaxis.set_major_formatter(
                ticker.FuncFormatter(lambda x, pos: f'{int(x)}' if x.is_integer() else f'{x:.2f}'.rstrip('0').lstrip('0')))

    # legend below plot
    handles, labels = ax[3, 0].get_legend_handles_labels()
    ph = [plt.plot([], marker='', ls='')[0]]  # dummy
    handles = ph + handles
    labels = ['Legend:'] + labels
    fig.legend(handles, labels, loc='lower center', ncol=8, bbox_to_anchor=(0.5, 0))

    # crop out figure
    fig.subplots_adjust(left=0.06, right=0.989, top=0.988, bottom=0.04)


    # change legends
    row_labels = ['A)', 'B)', 'C)', 'C)', 'D)', 'E)', 'F)', 'F)', 'G)', 'H)', 'I)']
    for row in range(nrows):
        for col in range(ncols):
            # Set row identifier (letters)
            if col == 0:
                ax[row, col].text(-0.25, 0.5, row_labels[row], transform=ax[row, col].transAxes, fontsize=14,
                                  fontweight='bold', va='center', ha='center')

            # adjust legends for specific plots
            if row in [0, 1]:
                label_dot = '_nolegend_'
                label_solid = '_nolegend_'
                label_dash = '_nolegend_'
                ncol=1
                if col == 0:
                    title = 'Initial Budget / Query Size'
                    label_dot = '2'
                    label_solid = '20'
                    label_dash = '200'
                    legend_bbox_anchor = (1, -.09)
                    ncol = 3
                elif col == 1:
                    title = 'Training Paradigm'
                    label_dot = 'SL (incl. augmentations)'
                    label_solid = 'SL'
                    label_dash = 'Semi-SL'
                    legend_bbox_anchor = (1, -.09)
                elif col == 2:
                    title = 'Weight Initialization'
                    label_dot = 'Random'
                    label_solid = 'Transfer Learning'
                    label_dash = 'Self-SL'
                    legend_bbox_anchor = (1, -.09)
                elif col == 3:
                    title = 'Fine-Tuned Layers (last N)'
                    label_dot = 'All'
                    label_solid = '1'
                    label_dash = '5'
                    label_dashdot = '2'
                    legend_bbox_anchor = (1, -.09)

                # create legend
                line_dot = mlines.Line2D([], [], color='black', linestyle=':', label=label_dot)
                line_solid = mlines.Line2D([], [], color='black', linestyle='-', label=label_solid)
                line_dash = mlines.Line2D([], [], color='black', linestyle='--', label=label_dash)
                if col != 3:
                    ax[row, col].legend(handles=[line_dot, line_solid, line_dash], title=title, labelspacing=0,
                                        frameon=False, loc='lower right', ncol=ncol, bbox_to_anchor=legend_bbox_anchor)
                else:
                    line_dashdot = mlines.Line2D([], [], color='black', linestyle='-.', label=label_dashdot)
                    ax[row, col].legend(handles=[line_solid, line_dashdot, line_dash, line_dot], title=title,
                                        labelspacing=0, frameon=False, loc='lower right', ncol=4, columnspacing=1.5,
                                        bbox_to_anchor=legend_bbox_anchor)

            if row == 5:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right')
            elif row == 6:
                qm = sampling_methods_specific[col][1]
                legend_title = cfg.figures[qm]['label']
                if col in [1, 3]:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, loc='upper right')
                elif col == 0:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.7, 0.45))
                elif col == 2:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.6, 0.35))
            elif row == 9:
                if col == 1:
                    ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', ncol=2, columnspacing=0,
                                        bbox_to_anchor=(1.085, -0.08))
                else:
                    ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right')
            elif row == 10:
                qm = sampling_methods_specific[col][1]
                legend_title = cfg.figures[qm]['label']
                if col == 0:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.3, 0.25))
                elif col == 3:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.6, 0.25))
                else:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, loc='upper right')

    # save and show plot
    plt.savefig(Path(cfg.path_fig, 'experiments_AAAI.pdf'))
    fig.show()

# figure: AAAI submitted / MULTICLASS
if False:
    # load result file
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    # define properties to plot for 2k datasets
    datasets_2k = [('urbansound8k', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 10, 'ALL'),
                   ('cifar10', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 10, 'ALL'),
                   ('agnews', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 4, 'ALL'),
                   ('letter', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 26, 'ALL')]
    # define properties to plot for complete datasets
    datasets_complete = [('urbansound8k', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 10, 'ALL'),
                   ('cifar10', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 10, 'ALL'),
                   ('agnews', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 4, 'ALL'),
                   ('letter', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 26, 'ALL')]
    # datasets for first 2 rows
    first2rows_datasets = ('cifar10', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 10, 'ALL')

    # all sampling methods selected
    sampling_methods = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']
    # sampling methods for the specific datasets
    sampling_methods_specific = [['random', 'bald'],
                                 ['random', 'kmeans'],
                                 ['random', 'badge'],
                                 ['random', 'ratio_max']]
    # sampling methods first 2 rows
    first2rows_sampling_methods = ['random', 'ratio_max']

    linestyle_dict_first2rows = {2: ':',
                                 20: '-',
                                 200: '--',
                                 'augmented': ':',
                                 'sl': '-',
                                 'semi-sl': '--',
                                 'random': ':',
                                 'tl': '-',
                                 'self-sl': '--',
                                 'frozen': '-',
                                 'finetune': ':',
                                 'finetune-last-2': '-.',
                                 'finetune-last-5': '--'}

    # create figure
    nrows = 11
    ncols = 4
    fig, ax = plt.subplots(nrows, ncols, figsize=(16, 19),
                           gridspec_kw={'hspace': .5,
                                        'height_ratios': [1, 1, -.1, 1, 1, 1, 1, -.1, 1, 1, 1],
                                        'wspace': .15}
                           )

    # iterate over all subfigures and plot figures
    for row in range(nrows):
        for col in range(ncols):
            print(f'Row {row}/{nrows-1} | Col {col}/{ncols-1}')

            # row 0: Dataset: CIFAR10: Performance over # train samples,
            if row == 0:
                if col == 0:
                    for budget in cfg.all_budgets:
                        current_dataset = first2rows_datasets[:2] + (budget, budget,) + first2rows_datasets[4:]
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[budget])
                elif col == 1:
                    for train_paradigm in cfg.all_train_paradigms:
                        current_dataset = (first2rows_datasets[:2] + (200, 200,) + first2rows_datasets[4:5] +
                                           (train_paradigm,) + first2rows_datasets[6:])
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[train_paradigm])
                elif col == 2:
                    for weight_init in cfg.all_weight_inits:
                        current_dataset = first2rows_datasets[:6] + (weight_init, 'finetune') + first2rows_datasets[8:]
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[weight_init],
                                                                      moving_avg=True)
                    ax[row, col].set_ylim([0, 0.6])
                elif col == 3:
                    for training in cfg.all_training:
                        moving_avg = False
                        if training == 'finetune':
                            moving_avg = True
                        current_dataset = first2rows_datasets[:7] + (training,) + first2rows_datasets[8:]
                        figures.fig_performance_over_training_samples(ax[row, col], current_dataset, results,
                                                                      first2rows_sampling_methods,
                                                                      linestyle=linestyle_dict_first2rows[training],
                                                                      moving_avg=moving_avg)
                    #ax[row, col].set_ylim([-.09, None])

                ax[row, col].set_xlabel('# added training samples')
                ax[row, 0].set_ylabel('Performance')
                ax[row, col].set_xlim([0, 2000])
                ax[row, col].set_title('CIFAR-10$_{2k}$', fontweight='bold')

            # row 1: Dataset: MSCOCO: Metric performance change over nr training samples
            elif row == 1:
                if col == 0:
                    for budget in cfg.all_budgets:
                        current_dataset = first2rows_datasets[:2] + (budget, budget,) + first2rows_datasets[4:]
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[budget])
                    ax[row, col].set_xlim([0.4, 0.8])
                    ax[row, col].set_ylim([0.6, 1.2])

                elif col == 1:
                    for train_paradigm in cfg.all_train_paradigms:
                        current_dataset = (first2rows_datasets[:2] + (20, 20,) + first2rows_datasets[4:5] +
                                           (train_paradigm,) + first2rows_datasets[6:])
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[train_paradigm])
                    ax[row, col].set_xlim([0.4, 0.8])
                    #ax[row, col].set_ylim([0.84, 1.2])

                elif col == 2:
                    for weight_init in cfg.all_weight_inits:
                        current_dataset = first2rows_datasets[:6] + (weight_init, 'finetune') + first2rows_datasets[8:]
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[weight_init],
                                                                  moving_avg=True)
                    ax[row, col].set_xlim([0.25, 0.55])
                    ax[row, col].set_ylim([0, None])

                elif col == 3:
                    for training in ['frozen', 'finetune-last-2', 'finetune-last-5']:
                        moving_avg = False
                        if training == 'finetune':
                            moving_avg = True
                        current_dataset = first2rows_datasets[:7] + (training,) + first2rows_datasets[8:]
                        figures.fig_frac_al_rand_over_performance(ax[row, col], current_dataset, results,
                                                                  first2rows_sampling_methods,
                                                                  linestyle=linestyle_dict_first2rows[training],
                                                                  moving_avg=moving_avg)
                    ax[row, col].set_xlim([None, 0.8])
                    ax[row, col].set_ylim([0.6, None])

                ax[row, col].set_xlabel('Performance')
                ax[row, 0].set_ylabel('$x_{qm}/x_{rand}$')

            # row 2: empty row for vertical space
            elif row == 2:
                ax[row, col].set_visible(False)

            # row 3: 2k datasets: Performance all QMs
            elif row == 3:
                figures.fig_performance_over_training_samples(ax[row, col], datasets_2k[col], results, sampling_methods)
                ax[row, col].set_title(cfg.figures[datasets_2k[col][0]]['text'] + '$_{2k}$', fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 2000])

            # row 4: 2k datasets: fraction train samples AL/fraction train samples rand over performance
            elif row == 4:
                figures.fig_frac_al_rand_over_performance(ax[row, col], datasets_2k[col], results, sampling_methods)
                ax[row, col].set_xlabel('Performance')
                ax[row, 0].set_ylabel('$x_{qm}/x_{rand}$')
                # set limits for better visualisation
                if col == 0:
                    xlim = [0.03, None]
                    ylim = [None, None]
                elif col == 1:
                    xlim = [0.4, 0.8]
                    ylim = [0.7, None]
                elif col == 2:
                    xlim = [0.5, None]
                    ylim = [0, None]
                elif col == 3:
                    xlim = [0.1, None]
                    ylim = [None, 2.2]
                ax[row, col].set_xlim(xlim)
                ax[row, col].set_ylim(ylim)

            # row 5: 2k datasets: Learning curve with SF approximation, selected QM
            elif row == 5:
                S = figures.fig_learning_curve_approximation(ax[row, col], datasets_2k[col], results,
                                                             sampling_methods_specific[col])

                ax[row, col].set_xlabel('# added training samples')
                ax[row, 0].set_ylabel('Performance')
                ax[row, col].set_xlim([0, 2000])

                # plot speedup factor hline in plot above
                qm = sampling_methods_specific[col][1]
                color = cfg.figures[qm]['color']
                label = 'S ' + cfg.figures[qm]['label']
                ax[row - 1, col].axhline(y=S, linestyle='--', color=color, xmin=0, xmax=1, label=label)
                # Only keep the last handle and label
                handles, labels = ax[row - 1, col].get_legend_handles_labels()
                ax[row - 1, col].legend([handles[-1]], [labels[-1]], frameon=False, loc='lower right', bbox_to_anchor=(1,-.09))

            # row 6: 2k datasets: Metric stability
            elif row == 6:
                figures.fig_performance_metrics_over_training_samples(ax[row, col], datasets_2k[col], results,
                                                                      sampling_methods_specific[col])

                plt.setp(ax[row, col], xlabel='Stop Budget')
                plt.setp(ax[row, 0], ylabel='AL Perf. Metric Change [%]')
                ax[row, col].set_xlim([700, 2000])

            # row 7: empty row for vertical space
            elif row == 7:
                ax[row, col].set_visible(False)

            # row 8: complete datasets: Performance all QMs
            elif row == 8:
                figures.fig_performance_over_training_samples(ax[row, col], datasets_complete[col], results,
                                                              sampling_methods)

                ax[row, col].set_title(cfg.figures[datasets_complete[col][0]]['text'], fontweight='bold')
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])

            # row 9: complete datasets: Performance over # train samples
            elif row == 9:
                figures.fig_learning_curve_approximation(ax[row, col], datasets_complete[col], results,
                                                         sampling_methods_specific[col])
                plt.setp(ax[row, col], xlabel='# added training samples')
                plt.setp(ax[row, 0], ylabel='Performance')
                ax[row, col].set_xlim([0, 1400])

            # row 10: complete datasets: Metric stability
            elif row == 10:
                figures.fig_performance_metrics_over_training_samples(ax[row, col], datasets_complete[col],
                                                                      results, sampling_methods_specific[col])
                plt.setp(ax[row, col], xlabel='Stop Budget')
                plt.setp(ax[row, 0], ylabel='AL Perf. Metric Change [%]')
                ax[row, col].set_xlim([700, 1400])

            # set upper and right bound to invisible
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)

            # delete leading 0 for yaxis
            ax[row, col].yaxis.set_major_formatter(
                ticker.FuncFormatter(lambda x, pos: f'{int(x)}' if x.is_integer() else f'{x:.2f}'.rstrip('0').lstrip('0')))

    # legend below plot
    handles, labels = ax[3, 0].get_legend_handles_labels()
    ph = [plt.plot([], marker='', ls='')[0]]  # dummy
    handles = ph + handles
    labels = ['Legend:'] + labels
    fig.legend(handles, labels, loc='lower center', ncol=8, bbox_to_anchor=(0.5, 0))

    # crop out figure
    fig.subplots_adjust(left=0.06, right=0.989, top=0.988, bottom=0.04)


    # change legends
    row_labels = ['A)', 'B)', 'C)', 'C)', 'D)', 'E)', 'F)', 'F)', 'G)', 'H)', 'I)']
    for row in range(nrows):
        for col in range(ncols):
            # Set row identifier (letters)
            if col == 0:
                ax[row, col].text(-0.25, 0.5, row_labels[row], transform=ax[row, col].transAxes, fontsize=14,
                                  fontweight='bold', va='center', ha='center')

            # adjust legends for specific plots
            if row in [0, 1]:
                label_dot = '_nolegend_'
                label_solid = '_nolegend_'
                label_dash = '_nolegend_'
                ncol=1
                if col == 0:
                    title = 'Initial Budget / Query Size'
                    label_dot = '2'
                    label_solid = '20'
                    label_dash = '200'
                    legend_bbox_anchor = (1, -.09)
                    ncol = 3
                elif col == 1:
                    title = 'Training Paradigm'
                    label_dot = 'SL (incl. augmentations)'
                    label_solid = 'SL'
                    label_dash = 'Semi-SL'
                    legend_bbox_anchor = (1, -.09)
                elif col == 2:
                    title = 'Weight Initialization'
                    label_dot = 'Random'
                    label_solid = 'Transfer Learning'
                    label_dash = 'Self-SL'
                    legend_bbox_anchor = (1, -.09)
                elif col == 3:
                    title = 'Fine-Tuned Layers (last N)'
                    label_dot = 'All'
                    label_solid = '1'
                    label_dash = '5'
                    label_dashdot = '2'
                    legend_bbox_anchor = (1, -.09)

                # create legend
                line_dot = mlines.Line2D([], [], color='black', linestyle=':', label=label_dot)
                line_solid = mlines.Line2D([], [], color='black', linestyle='-', label=label_solid)
                line_dash = mlines.Line2D([], [], color='black', linestyle='--', label=label_dash)
                if col != 3:
                    ax[row, col].legend(handles=[line_dot, line_solid, line_dash], title=title, labelspacing=0,
                                        frameon=False, loc='lower right', ncol=ncol, bbox_to_anchor=legend_bbox_anchor)
                else:
                    line_dashdot = mlines.Line2D([], [], color='black', linestyle='-.', label=label_dashdot)
                    ax[row, col].legend(handles=[line_solid, line_dashdot, line_dash, line_dot], title=title,
                                        labelspacing=0, frameon=False, loc='lower right', ncol=4, columnspacing=1.5,
                                        bbox_to_anchor=legend_bbox_anchor)

            if row == 5:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right')
            elif row == 6:
                qm = sampling_methods_specific[col][1]
                legend_title = cfg.figures[qm]['label']
                if col in [0, 1]:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.3, 0.35))
                else:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, loc='upper right')
            elif row == 9:
                ax[row, col].legend(frameon=False, labelspacing=0, loc='lower right', ncol=2, columnspacing=0,
                                    bbox_to_anchor=(1.085, -0.08))
            elif row == 10:
                qm = sampling_methods_specific[col][1]
                legend_title = cfg.figures[qm]['label']
                if col in [0, 3]:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, ncol=2, columnspacing=0,
                                        loc='lower center', bbox_to_anchor=(0.7, 0.4))
                else:
                    ax[row, col].legend(title=legend_title, frameon=False, labelspacing=0, loc='upper right')

    # save and show plot
    plt.savefig(Path(cfg.path_fig, 'AAAI_multiclass.pdf'))
    fig.show()

# figure: AAAI submitted / computational time
if False:
    # load result file
    results = pd.read_pickle(Path(cfg.path_exp, 'results.pkl'))
    # define properties to plot for 2k datasets
    datasets_2k = [('carina', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 37, 'ALL'),
                   ('mscoco', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 80, 'ALL'),
                   ('reuters', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 31, 'ALL'),
                   ('scene', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 6, 'ALL')]
    # define properties to plot for complete datasets
    datasets_complete = [('carina', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 37, 'ALL'),
                         ('mscoco', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 80, 'ALL'),
                         ('reuters', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 31, 'ALL'),
                         ('scene', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 6, 'ALL')]
    # define properties to plot for 2k datasets multiclass
    datasets_2k_multiclass = [('urbansound8k', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 10, 'ALL'),
                              ('cifar10', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 10, 'ALL'),
                              ('agnews', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 4, 'ALL'),
                              ('letter', '2k', 20, 20, 2000, 'sl', 'tl', 'frozen', 26, 'ALL')]
    # define properties to plot for complete datasets
    datasets_complete_multiclass = [('urbansound8k', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 10, 'ALL'),
                                    ('cifar10', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 10, 'ALL'),
                                    ('agnews', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 4, 'ALL'),
                                    ('letter', 'complete', 20, 20, 1400, 'sl', 'tl', 'frozen', 26, 'ALL')]

    # all sampling methods selected
    sampling_methods = ['random', 'ratio_max', 'kmeans', 'multilabel_simple_crw', 'badge', 'bald', 'beal']


    # create figure
    nrows = 4
    ncols = 4
    fig, ax = plt.subplots(nrows, ncols, figsize=(16, 8),
                           gridspec_kw={'hspace': .5,
                                        'height_ratios': [1, 1, 1, 1],
                                        'wspace': .15}
                           )

    # iterate over all subfigures and plot figures
    for row in range(nrows):
        for col in range(ncols):
            print(f'Row {row}/{nrows-1} | Col {col}/{ncols-1}')

            if row == 0:
                figures.fig_processing_time_over_training_samples(ax[row, col], datasets_2k[col], results, sampling_methods)
                ax[row, col].set_title(cfg.figures[datasets_2k[col][0]]['text'] + '$_{2k}$', fontweight='bold')
                ax[row, col].set_xlim([0, 1960])
            elif row == 1:
                figures.fig_processing_time_over_training_samples(ax[row, col], datasets_complete[col], results,
                                                                  sampling_methods)
                ax[row, col].set_title(cfg.figures[datasets_complete[col][0]]['text'], fontweight='bold')
                ax[row, col].set_xlim([0, 1400])
            elif row == 2:
                figures.fig_processing_time_over_training_samples(ax[row, col], datasets_2k_multiclass[col], results,
                                                                  sampling_methods)
                ax[row, col].set_title(cfg.figures[datasets_2k_multiclass[col][0]]['text'] + '$_{2k}$', fontweight='bold')
                ax[row, col].set_xlim([0, 1960])
            elif row == 3:
                figures.fig_processing_time_over_training_samples(ax[row, col], datasets_complete_multiclass[col],
                                                                  results, sampling_methods)
                ax[row, col].set_title(cfg.figures[datasets_complete_multiclass[col][0]]['text'], fontweight='bold')
                ax[row, col].set_xlim([0, 1400])
                ax[row, col].set_xlabel('# added training samples')


            # set upper and right bound to invisible
            plt.setp(ax[row, 0], ylabel='$t_{qm}/t_{rand}$')
            ax[row, col].set_yscale('log')
            ax[row, col].spines['right'].set_visible(False)
            ax[row, col].spines['top'].set_visible(False)


    # legend below plot
    handles, labels = ax[0, 0].get_legend_handles_labels()
    ph = [plt.plot([], marker='', ls='')[0]]  # dummy
    handles = ph + handles
    labels = ['Legend:'] + labels
    fig.legend(handles, labels, loc='lower center', ncol=8, bbox_to_anchor=(0.5, 0))

    # crop out figure
    fig.subplots_adjust(left=0.06, right=0.989, top=0.97, bottom=0.1)


    # change legends
    row_labels = ['A)', 'B)', 'C)', 'D)']
    for row in range(nrows):
        for col in range(ncols):
            # Set row identifier (letters)
            if col == 0:
                ax[row, col].text(-0.25, 0.5, row_labels[row], transform=ax[row, col].transAxes, fontsize=14,
                                  fontweight='bold', va='center', ha='center')


    # save and show plot
    plt.savefig(Path(cfg.path_fig, 'AAAI_exp_time.pdf'))
    fig.show()
