# analysis qss_datas_details_dict.pickle
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from collections import OrderedDict

data_attributes = pd.read_csv('dataset.csv', index_col=0)
data_attributes.index = data_attributes.index.str.lower()

qs_dict = {
    'uniform': 'Uniform',
    'qbc': 'QBC', 'hintsvm': 'HintSVM', 'quire': 'QUIRE', 'albl': 'ALBL', 'dwus': 'DWUS', 'kcenter': 'Core-Set',  # libact
    'margin': 'US', 'graph': 'Graph', 'hier': 'Hier', 'mcm': 'MCM',  # google
    'lal': 'LAL', 'bmdr': 'BMDR',  # alipy
    'skactiveml_bald': 'BALD',  # scikit-activeml
}
ordered_qs = ['Uniform', 'US', 'QBC', 'BALD', 'Hier', 'Graph',
              'Core-Set', 'HintSVM', 'QUIRE',  'DWUS', 'MCM', 'BMDR', 'ALBL', 'LAL', ]
small_data_list = ["appendicitis", "sonar", "parkinsons", "ex8b", "heart", "haberman", "ionosphere", "clean1",
                   "breast", "wdbc", "australian", "diabetes", "mammographic", "ex8a", "tic", "german",
                   "splice", "gcloudb", "gcloudub", "checkerboard"]
large_data_list = ["spambase", "banana",
                   "phoneme", "ringnorm", "twonorm", "phishing"]
real_data_list = ['covertype', 'bioresponse', 'pol', ]
data_list = small_data_list + large_data_list + real_data_list
data_dict = {d: d.capitalize() for d in data_list}
# visualization qss
qs_vis = {
    'uniform': 'Uniform',
    'margin': 'US',
    'skactiveml_bald': 'BALD',
    'qbc': 'QBC',
    'kcenter': 'Core-Set',
    'graph': 'Graph',
    'hier': 'Hier',
    'hintsvm': 'HintSVM',
    'quire': 'QUIRE',
    'dwus': 'DWUS',
    'mcm': 'MCM',
    'bmdrt': 'BMDR',
    'albl': 'ALBL',
    'lal': 'LAL',
}
qs_vis_focus = {
    'uniform': 'Uniform',
    'margin': 'US',
    'skactiveml_bald': 'BALD',
    'kcenter': 'Core-Set',
    'lal': 'LAL',
}
qs_vis_color = {
    'uniform': 'gray',
    'margin': 'blue',
    'skactiveml_bald': 'purple',
    'qbc': 'orange',
    'kcenter': 'green',
    'graph': 'lightgreen',
    'hier': 'darkgreen',
    'hintsvm': 'cyan',
    'quire': 'lightblue',
    'dwus': 'darkblue',
    'mcm': 'brown',
    'bmdr': 'blueviolet',
    'albl': 'lightcoral',
    'lal': 'red',
}

# load data
with open('qss_datas_details_dict.pickle', 'rb') as f:
    qss_datas_details_dict = pickle.load(f)


def AUBC(budgets, metrics):
    total_budget = budgets.shape[0]
    # use np.traozid to calculate AUBC
    ressum = np.trapz(metrics, x=budgets)/total_budget
    return np.round(ressum, 5)


def scale_float(target):
    values = [0.1, 0.01, 0.001]
    closest_value = min(values, key=lambda x: abs(x - target))
    return closest_value


def to_percentage(x, precision=2):
    return f"{x * 100:.{precision}f}%"


# every dataframe in qss_datas_details_dict has the same format
# with columns ['expno', 'round', 'res_tst_score']
# set_index(['expno', 'round']) to get a multiindex dataframe
for qs_data, test_scores in tqdm(qss_datas_details_dict.items()):
    if test_scores is None:
        print(qs_data)
        continue
    test_scores.set_index(['expno', 'round'], inplace=True)
    # remove duplicates in index
    test_scores = test_scores[~test_scores.index.duplicated()]
    # fill the missing values in test_scores
    expno_curr = test_scores.index.get_level_values('expno').unique()
    round_curr = test_scores.index.get_level_values('round').unique()
    new_index = pd.MultiIndex.from_product(
        (expno_curr, round_curr), names=test_scores.index.names)
    test_scores = test_scores.reindex(new_index)
    test_scores = test_scores.unstack().ffill().stack(future_stack=True)
    qss_datas_details_dict[qs_data] = test_scores
    if len(test_scores) < 100:
        print('error exps')
        print(qs_data)
        continue

# calculate AUBC for each dataframe
acc_ckpts = OrderedDict({"5%": {'mean': [], 'std': []}, "10%": {'mean': [], 'std': []}, "15%": {'mean': [], 'std': []},
                         "20%": {'mean': [], 'std': []}, "30%": {'mean': [], 'std': []}, "40%": {'mean': [], 'std': []},
                         "50%": {'mean': [], 'std': []}})
durs = {}
rs_datas_details_dict = {}
for qs_data, test_scores in tqdm(qss_datas_details_dict.items()):
    if test_scores is None:
        print(qs_data)
        continue
    if len(test_scores) < 100:
        print(qs_data)
        continue

    if qs_data[0] == 'uniform':
        rs_datas_details_dict[qs_data] = test_scores

    budgets = test_scores.index.get_level_values('round').unique()
    total_budget = budgets.max()
    # get accuracy of each checkpoint, 5%, 10%, 15%, 20%, 30%, 40%, 50%
    checkpoints_budget = [int(round(total_budget * i))
                          for i in [0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5]]

    for bg, bg_name in zip(checkpoints_budget, acc_ckpts.keys()):
        if bg < 20:
            continue
        accs_budget = test_scores.xs(bg, level='round')
        mean_acc_budget = accs_budget.mean().values[0]
        std_acc_budget = accs_budget.std().values[0]
        acc_ckpts[f"{bg_name}"]['mean'].append(list(qs_data)+[mean_acc_budget])
        acc_ckpts[f"{bg_name}"]['std'].append(list(qs_data)+[std_acc_budget])

for accs in tqdm(acc_ckpts):
    if accs == '5%':
        continue

    mean_accs = acc_ckpts[accs]['mean']
    std_accs = acc_ckpts[accs]['std']
    mean_accs = pd.DataFrame(mean_accs, columns=['qs', 'data', 'mean'])
    std_accs = pd.DataFrame(std_accs, columns=['qs', 'data', 'std'])
    mean_accs = pd.pivot_table(
        mean_accs, values='mean', index='data', columns='qs')
    std_accs = pd.pivot_table(std_accs, values='std',
                              index='data', columns='qs')
    # show qs with max accs for each dataset
    max_qs = mean_accs.idxmax(axis=1)
    # convert mean_accs to percentage with 2 decimal places and ± std with 3 decimal places
    mean_accs = mean_accs.applymap(to_percentage)
    std_accs = std_accs.applymap(lambda x: to_percentage(x, 3))
    # concatenate mean_accs and std_accs with ± as string to form a new dataframe
    # accs_final = mean_accs.astype(str) + '±' + std_accs.astype(str)
    accs_final = mean_accs.astype(str)
    # sort index by data_list
    # filter missing data and keep the order of data_list
    data_list_keep = [d for d in data_list if d in accs_final.index]
    accs_final = accs_final.loc[data_list_keep]
    # map index by data_dict
    accs_final.index = accs_final.index.map(data_dict)
    # map columns by qs_dict
    accs_final.columns = accs_final.columns.map(qs_dict)
    # sort columns by ordered_qs
    # filter missing qs and keep the order of ordered_qs
    qs_list_keep = [qs for qs in ordered_qs if qs in accs_final.columns]
    accs_final = accs_final[qs_list_keep]
    # save to csv
    export_name = f'analysis_accs_{accs}'.replace('%', 'percent')
    accs_final.to_csv(f'github/{export_name}.csv')
    # bold max qs for each dataset
    accs_final_max = accs_final.copy()
    max_qs = max_qs.map(qs_dict)
    max_qs.index = max_qs.index.map(data_dict)
    for data_qs in max_qs.items():
        accs_final_max.loc[data_qs] = f"\\textbf{{{accs_final_max.loc[data_qs]}}}"

    accs_final_max.to_latex(f'tables/{export_name}.tex', escape=True)
