# analysis qss_datas_details_dict.pickle
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from collections import OrderedDict

data_attributes = pd.read_csv('dataset.csv', index_col=0)
data_attributes.index = data_attributes.index.str.lower()

qs_dict = {
    'uniform': 'Uniform',
    'qbc': 'QBC', 'hintsvm': 'HintSVM', 'quire': 'QUIRE', 'albl': 'ALBL', 'dwus': 'DWUS', 'kcenter': 'Core-Set',  # libact
    'margin': 'US', 'graph': 'Graph', 'hier': 'Hier', 'mcm': 'MCM',  # google
    'lal': 'LAL', 'bmdr': 'BMDR',  # alipy
    'skactiveml_bald': 'BALD',  # scikit-activeml
}
ordered_qs = ['Uniform', 'US', 'QBC', 'BALD', 'Hier', 'Graph',
              'Core-Set', 'HintSVM', 'QUIRE',  'DWUS', 'MCM', 'BMDR', 'ALBL', 'LAL', ]
small_data_list = ["appendicitis", "sonar", "parkinsons", "ex8b", "heart", "haberman", "ionosphere", "clean1",
                   "breast", "wdbc", "australian", "diabetes", "mammographic", "ex8a", "tic", "german",
                   "splice", "gcloudb", "gcloudub", "checkerboard"]
large_data_list = ["spambase", "banana",
                   "phoneme", "ringnorm", "twonorm", "phishing"]
real_data_list = ['covertype', 'bioresponse', 'pol', ]
data_list = small_data_list + large_data_list + real_data_list
data_dict = {d: d.capitalize() for d in data_list}
# visualization qss
qs_vis = {
    'uniform': 'Uniform',
    'margin': 'US',
    'skactiveml_bald': 'BALD',
    'qbc': 'QBC',
    'kcenter': 'Core-Set',
    'graph': 'Graph',
    'hier': 'Hier',
    'hintsvm': 'HintSVM',
    'quire': 'QUIRE',
    'dwus': 'DWUS',
    'mcm': 'MCM',
    'bmdrt': 'BMDR',
    'albl': 'ALBL',
    'lal': 'LAL',
}
qs_vis_focus = {
    'uniform': 'Uniform',
    'margin': 'US',
    'skactiveml_bald': 'BALD',
    'kcenter': 'Core-Set',
    'lal': 'LAL',
}
qs_vis_color = {
    'uniform': 'gray',
    'margin': 'blue',
    'skactiveml_bald': 'purple',
    'qbc': 'orange',
    'kcenter': 'green',
    'graph': 'lightgreen',
    'hier': 'darkgreen',
    'hintsvm': 'cyan',
    'quire': 'lightblue',
    'dwus': 'darkblue',
    'mcm': 'brown',
    'bmdr': 'blueviolet',
    'albl': 'lightcoral',
    'lal': 'red',
}

# load data
with open('qss_datas_details_dict.pickle', 'rb') as f:
    qss_datas_details_dict = pickle.load(f)


def scale_float(target):
    values = [0.1, 0.01, 0.001]
    closest_value = min(values, key=lambda x: abs(x - target))
    return closest_value


def to_percentage(x, precision=2):
    return f"{x * 100:.{precision}f}%"


# every dataframe in qss_datas_details_dict has the same format
# with columns ['expno', 'round', 'res_tst_score']
# set_index(['expno', 'round']) to get a multiindex dataframe
rs_datas_details_dict = {}
for qs_data, test_scores in tqdm(qss_datas_details_dict.items()):
    if test_scores is None:
        print(qs_data)
        continue
    test_scores.set_index(['expno', 'round'], inplace=True)
    # remove duplicates in index
    test_scores = test_scores[~test_scores.index.duplicated()]
    # fill the missing values in test_scores
    expno_curr = test_scores.index.get_level_values('expno').unique()
    round_curr = test_scores.index.get_level_values('round').unique()
    new_index = pd.MultiIndex.from_product(
        (expno_curr, round_curr), names=test_scores.index.names)
    test_scores = test_scores.reindex(new_index)
    test_scores = test_scores.unstack().ffill().stack(future_stack=True)
    qss_datas_details_dict[qs_data] = test_scores
    if len(test_scores) < 100:
        print('error exps')
        print(qs_data)
        continue
    if 'uniform' in qs_data:
        rs_datas_details_dict[qs_data] = test_scores

data_target_acc = {}
data_rs_osts = {}
for qs_data in tqdm(rs_datas_details_dict):
    rs_qs_data = rs_datas_details_dict[qs_data]
    total_budget = rs_qs_data.groupby('expno').apply(
        lambda x: x.index.get_level_values('round').max()).max()
    rs_qs_data_finalAcc = rs_qs_data.xs(total_budget, level='round')
    target_acc = rs_qs_data_finalAcc - 0.01
    # get the first round that the accuracy is higher than target_acc for each expno
    target_acc_aligned = rs_qs_data.index.get_level_values(
        'expno').map(target_acc['res_tst_score'])
    aligned_target_acc_df = pd.DataFrame(
        {'target_res_tst_score': target_acc_aligned}, index=rs_qs_data.index)
    comparison_result = aligned_target_acc_df['target_res_tst_score'] <= rs_qs_data['res_tst_score']
    first_round = comparison_result.groupby('expno').idxmax()
    first_round = first_round.to_frame()
    first_round[0] = [x[1] for x in first_round[0]]
    first_round.columns = ['osp']
    data_rs_osts[qs_data[1]] = first_round
    data_target_acc[qs_data[1]] = target_acc

data_qs_osts = {x: {} for x in data_list}
for qs_data in tqdm(qss_datas_details_dict):
    test_scores = qss_datas_details_dict[qs_data]
    if test_scores is None:
        print(qs_data)
        continue
    if len(test_scores) < 100:
        print(qs_data)
        continue
    if 'uniform' in qs_data:
        continue

    qs_total_rounds = test_scores.index.get_level_values('round').unique().max()
    rs_total_rounds = rs_datas_details_dict[('uniform', qs_data[1])].index.get_level_values('round').unique().max()
    if qs_total_rounds < rs_total_rounds:
        print(qs_data)
        continue

    target_acc = data_target_acc[qs_data[1]]
    target_acc_aligned = test_scores.index.get_level_values(
        'expno').map(target_acc['res_tst_score'])
    aligned_target_acc_df = pd.DataFrame(
        {'target_res_tst_score': target_acc_aligned}, index=test_scores.index)
    comparison_result = aligned_target_acc_df['target_res_tst_score'] <= test_scores['res_tst_score']
    first_round = comparison_result.groupby('expno').idxmax()
    first_round = first_round.to_frame()
    first_round[0] = [x[1] for x in first_round[0]]
    first_round.columns = ['osp']
    data_qs_osts[qs_data[1]][qs_data[0]] = first_round

# calculate the data utilization rate
# data utilization rate = qs_osts / rs_osts for each data
data_utilization_rate = {x: {} for x in data_list}
for data in data_list:
    rs_osts = data_rs_osts[data]
    data_utilization_rate[data] = {}
    for qs_name in data_qs_osts[data]:
        qs_osts = data_qs_osts[data][qs_name]
        utilization_rate = qs_osts['osp'] / rs_osts['osp']
        data_utilization_rate[data][qs_name] = utilization_rate

# filter out index where data_rs_osts['osp'] < 20+10
valid_idx = {x: None for x in data_list}
for data in data_list:
    valid_idx[data] = data_rs_osts[data][data_rs_osts[data]['osp'] >= 20+10].index.get_level_values('expno').unique()

# calculate the average utilization rate for each data
mean_dur = {}
for data in data_list:
    dur = pd.DataFrame(data_utilization_rate[data])
    dur = dur.loc[valid_idx[data], :]
    mean_dur[data] = dur.mean()

mean_dur = pd.DataFrame(mean_dur)
mean_dur.index = mean_dur.index.map(qs_dict)
mean_dur = mean_dur.loc[ordered_qs[1:], :]
mean_dur = mean_dur.T
mean_dur = mean_dur.loc[data_list, :]
mean_dur.index = mean_dur.index.map(data_dict)

# get rank 1st, 2nd, 3rd of mean_dur for each row
# ranked_df = mean_dur.rank(axis=1, method='min')
# top_3_methods = ranked_df.apply(lambda x: x.nsmallest(3).index.tolist(), axis=1)
# top_3_methods_df = pd.DataFrame(top_3_methods.tolist(), index=mean_dur.index, columns=['¹', '²', '³'])

def format_as_percentage(df):
    return df.applymap(lambda x: f"{x * 100:.2f}%")

mean_dur_percentage = format_as_percentage(mean_dur)

def annotate(df, top3):
    annotated_df = df.copy()
    for row in top3.index:
        for rank, method in enumerate(top3.loc[row], start=1):
            annotation = f"¹" if rank == 1 else f"²" if rank == 2 else f"³"
            annotated_df.at[row, method] = f"{annotated_df.at[row, method]}{annotation}"
    return annotated_df

# Function to annotate the DataFrame with values >= 1.00
def annotate_greater_than_one(df, original_df):
    annotated_df = df.copy()
    for row in df.index:
        for col in df.columns:
            if original_df.at[row, col] >= 1.00:
                annotated_df.at[row, col] = f"\\textit{{{df.at[row, col]}}}"
    return annotated_df

# export to csv
mean_dur_percentage.to_csv('github/analysis_dur.csv')
# annotated_mean_dur = annotate(mean_dur_percentage, top_3_methods_df)
annotated_mean_dur = annotate_greater_than_one(mean_dur_percentage, mean_dur)
# export to latex
annotated_mean_dur.to_latex('tables/analysis_dur.tex', escape=False)