# analysis qss_datas_details_dict.pickle
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


data_attributes = pd.read_csv('dataset.csv', index_col=0)
data_attributes.index = data_attributes.index.str.lower()

qs_dict = {
    'uniform': 'Uniform',
    'qbc': 'QBC', 'hintsvm': 'HintSVM', 'quire': 'QUIRE', 'albl': 'ALBL', 'dwus': 'DWUS', 'kcenter': 'Core-Set',  # libact
    'margin': 'US', 'graph': 'Graph', 'hier': 'Hier', 'mcm': 'MCM',  # google
    'lal': 'LAL',  'bmdr': 'BMDR', # alipy
    'skactiveml_bald': 'BALD',  # scikit-activeml
}
ordered_qs = ['Uniform', 'US', 'QBC', 'BALD', 'Hier', 'Graph',
              'Core-Set', 'HintSVM', 'QUIRE',  'DWUS', 'MCM', 'BMDR', 'ALBL', 'LAL', ]
small_data_list = ["appendicitis", "sonar", "parkinsons", "ex8b", "heart", "haberman", "ionosphere", "clean1",
                   "breast", "wdbc", "australian", "diabetes", "mammographic", "ex8a", "tic", "german",
                   "splice", "gcloudb", "gcloudub", "checkerboard"]
large_data_list = ["spambase", "banana",
                   "phoneme", "ringnorm", "twonorm", "phishing"]
real_data_list = ['covertype', 'bioresponse', 'pol', ]
data_list = small_data_list + large_data_list + real_data_list
data_dict = {d: d.capitalize() for d in data_list}
# visualization qss
qs_vis = {
    'uniform': 'Uniform',
    'margin': 'US', 
    'skactiveml_bald': 'BALD', 
    'qbc': 'QBC',
    'kcenter': 'Core-Set',
    'graph': 'Graph',
    'hier': 'Hier', 
    'hintsvm': 'HintSVM',
    'quire': 'QUIRE',
    'dwus': 'DWUS',
    'mcm': 'MCM',
    'albl': 'ALBL',
    'lal': 'LAL',
}
qs_vis_focus = {
    'uniform': 'Uniform',
    'margin': 'US', 
    'skactiveml_bald': 'BALD', 
    'kcenter': 'Core-Set',
    'lal': 'LAL',
}
qs_vis_color = {
    'uniform': 'gray',
    'margin': 'blue', 
    'skactiveml_bald': 'purple', 
    'qbc': 'orange',
    'kcenter': 'green',
    'graph': 'lightgreen',
    'hier': 'darkgreen',
    'hintsvm': 'cyan',
    'quire': 'lightblue',
    'dwus': 'darkblue',
    'mcm': 'brown',
    'albl': 'lightcoral',
    'lal': 'red',
}

# load data
with open('qss_datas_details_dict.pickle', 'rb') as f:
    qss_datas_details_dict = pickle.load(f)

def AUBC(budgets, metrics):
    total_budget = budgets.shape[0]
    # use np.traozid to calculate AUBC 
    ressum = np.trapz(metrics, x=budgets)/total_budget
    return np.round(ressum, 5)

# every dataframe in qss_datas_details_dict has the same format
# with columns ['expno', 'round', 'res_tst_score']
# set_index(['expno', 'round']) to get a multiindex dataframe
for qs_data, test_scores in qss_datas_details_dict.items():
    if test_scores is None:
        print(qs_data)
        continue
    test_scores.set_index(['expno', 'round'], inplace=True)
    # remove duplicates in index
    test_scores = test_scores[~test_scores.index.duplicated()]
    # fill the missing values in test_scores
    expno_curr = test_scores.index.get_level_values('expno').unique()
    round_curr = test_scores.index.get_level_values('round').unique()
    new_index = pd.MultiIndex.from_product((expno_curr, round_curr), names=test_scores.index.names)
    test_scores = test_scores.reindex(new_index)
    test_scores = test_scores.unstack().ffill().stack(future_stack=True)
    qss_datas_details_dict[qs_data] = test_scores

# calculate tau (difference of qs and uniform) for each dataframe in qss_datas_details_dict
qss_datas_taus_dict = {}
for data in data_dict:
    test_scores_uniform = qss_datas_details_dict.get(('uniform', data))
    for qs in qs_dict:
        if qs == 'uniform':
            continue
        test_scores_qs = qss_datas_details_dict.get((qs, data))
        if test_scores_qs is None:
            continue
        tau = test_scores_qs - test_scores_uniform
        qss_datas_taus_dict[(qs, data)] = tau

# plot tau of different query strategies on the same figure
for data in data_dict:
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    for qs in qs_vis_focus:
        tau = qss_datas_taus_dict.get((qs, data))
        if tau is None:
            continue
        tau = tau.unstack()
        x = tau.columns.levels[1][1:]
        tau_mean = tau.loc[:, (slice(None), x)].mean(axis=0).values
        tau_std = tau.loc[:, (slice(None), x)].std(axis=0).values
        tau_std = tau_std / np.sqrt(tau.shape[0])
        # plot learning curves with color
        ax.plot(x.values, tau_mean, label=qs_dict[qs], color=qs_vis_color[qs])
        # plot learning curves with 95% confidence interval
        if qs in qs_vis_focus:
            ax.fill_between(x, tau_mean - tau_std, tau_mean + tau_std, color=qs_vis_color[qs], alpha=0.2)

    # add vertical line at x=0.1*total_budget
    total_budget = data_attributes.loc[data, 'n']
    x_vline = 0.1 * total_budget
    if x_vline <= 3000:
        ax.axvline(x_vline, color='gold', linestyle='--', label='10% of total budget')
    else:
        percentage = 3000 / total_budget * 100
        percentage = round(percentage, 2)
        ax.axvline(3000, color='yellow', linestyle='--', label=f'{percentage}% of total budget')

    # limit x-axis than 5000
    # if data == 'covertype':
    #     ax.set_xlim(left=0, right=5000)
    # add horizontal line at y=0
    ax.axhline(0, color='black', linestyle='--')
    ax.set_xlabel('No. of labeled examples')
    ax.set_ylabel('Performance improvement (τ)')
    ax.set_title(f'Difference of accuracy on {data_dict[data]}')
    ax.legend()
    fig.savefig(f'images/tau-{data}.pdf')
    plt.close(fig)

# calculate AUBC for each dataframe
aubc_dict = {}
for qs_data, test_scores in qss_datas_details_dict.items():
    if test_scores is None:
        print(qs_data)
        continue
    budgets = test_scores.index.get_level_values('round').unique()
    aubc = test_scores.groupby('expno').apply(lambda x: AUBC(budgets, x['res_tst_score']))
    qs_data = (qs_dict[qs_data[0]], data_dict[qs_data[1]])
    aubc_dict[qs_data] = aubc

# plot learning curves of test_scores
# plot multiple query strategies on the same figure
for data in data_dict:
    fig1, ax1 = plt.subplots(1, 1, figsize=(8, 6))
    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 6))
    fig3, ax3 = plt.subplots(1, 1, figsize=(8, 6))
    for qs in qs_vis:
        test_scores = qss_datas_details_dict.get((qs, data))
        if test_scores is None:
            continue
        test_scores = test_scores.unstack()
        x = test_scores.columns.levels[1]
        test_scores_mean = test_scores.mean(axis=0).values
        test_scores_std = test_scores.std(axis=0).values
        test_scores_std = test_scores_std / np.sqrt(test_scores.shape[0])
        # plot learning curves with color
        ax1.plot(x.values, test_scores_mean, label=qs_dict[qs], color=qs_vis_color[qs])
        # plot learning curves with 95% confidence interval
        if qs in qs_vis_focus:
            ax2.plot(x.values, test_scores_mean, label=qs_dict[qs], color=qs_vis_color[qs])
            ax2.fill_between(x, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, color=qs_vis_color[qs], alpha=0.2)
            ax3.plot(x.values, test_scores_mean, label=qs_dict[qs], color=qs_vis_color[qs])

    ax1.set_xlabel('No. of labeled examples')
    ax1.set_ylabel('Accuracy')
    ax1.set_title(f'Learning curves on {data_dict[data]}')
    ax1.legend()
    fig1.savefig(f'images/learning_curves-mean-{data}.pdf')
    ax2.set_xlabel('No. of labeled examples')
    ax2.set_ylabel('Accuracy')
    ax2.set_title(f'Learning curves with 95% C.I. on {data_dict[data]}')
    ax2.legend()
    fig2.savefig(f'images/learning_curves-95ci-{data}.pdf')
    ax3.set_xlabel('No. of labeled examples')
    ax3.set_ylabel('Accuracy')
    ax3.set_title(f'Learning curves on {data_dict[data]}')
    ax3.legend()
    fig3.savefig(f'images/learning_curves-mean_focus-{data}.pdf')
    plt.close(fig1)
    plt.close(fig2)
    plt.close(fig3)