import os
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from scipy import stats

qs_dict = {
    'uniform': 'Uniform',
    'qbc': 'QBC', 'hintsvm': 'HintSVM', 'quire': 'QUIRE', 'albl': 'ALBL', 'dwus': 'DWUS', 'kcenter': 'Core-Set',  # libact
    'margin': 'US', 'graph': 'Graph', 'hier': 'Hier', 'mcm': 'MCM',  # google
    'lal': 'LAL', 'bmdr': 'BMDR',  # alipy
    'skactiveml_bald': 'BALD',  # scikit-activeml
}
ordered_qs = ['Uniform', 'US', 'QBC', 'BALD', 'Hier', 'Graph',
              'Core-Set', 'HintSVM', 'QUIRE',  'DWUS', 'MCM', 'BMDR', 'ALBL', 'LAL', ]
small_data_list = ["appendicitis", "sonar", "parkinsons", "ex8b", "heart", "haberman", "ionosphere", "clean1",
                   "breast", "wdbc", "australian", "diabetes", "mammographic", "ex8a", "tic", "german",
                   "splice", "gcloudb", "gcloudub", "checkerboard"]
large_data_list = ["spambase", "banana", "phoneme", "ringnorm", "twonorm", "phishing"]
real_data_list = ['covertype', 'bioresponse', 'pol', ]
data_list = small_data_list + large_data_list + real_data_list
data_dict = {d: d.capitalize() for d in data_list}

file_prefix = 'aubc/'
file_suffix = '-XGBoost-XGBoost-RS_noFix_scale-aubc.csv'

# collect valid res_expno for each data as index for dataframe in qs_data_aubc_dict by qs = 'uniform'
valid_res_expno_dict = {}
for data in data_dict:
    valid_res_expno_dict[data] = None
    # sort index from small to large
    uniform_data_aubc_curr = pd.read_csv(
        f'{file_prefix}{data}-uniform{file_suffix}', index_col=0)
    if data in small_data_list:
        valid_res_expno_dict[data] = uniform_data_aubc_curr.index.values[:100]
    else:
        valid_res_expno_dict[data] = uniform_data_aubc_curr.index.values[:10]

qs_data_aubc_dict = {}
for qs in qs_dict:
    for data in data_dict:
        qs_data_aubc_dict[(qs, data)] = None
        aubcs_curr = None
        try:
            aubcs_curr = pd.read_csv(
                f'{file_prefix}{data}-{qs}{file_suffix}', index_col=0)
            try:
                aubcs_curr = aubcs_curr.loc[valid_res_expno_dict[data],
                                            'res_tst_score']
                aubcs_curr = aubcs_curr[~aubcs_curr.index.duplicated(keep='last')]
                aubcs_curr = aubcs_curr.sort_index()
                qs_data_aubc_dict[(qs, data)] = aubcs_curr
            except:
                print(f'Exps of ({qs}, {data}) is not complete.')
                print('Index:', np.setdiff1d(
                    valid_res_expno_dict[data], aubcs_curr.index))
        except:
            print(f'No result {qs}, {data}.')

""" mean and std AUBCs for each dataset and each query strategy pair with 100 runs for small datasets and 10 runs for large datasets.
## Table. Summary Table

- small datasets $n < 2000$ : only use first 100 indicis $K_{S} = 100$.
- large datasets $n \geq 2000$ : only use first 10 indicis $K_{L} = 10$.

Calculate average (mean) and standard deviation of AUBCs by
$$
\overline{\mathrm{AUBC}}_{q, s} = \frac{\sum_{k=1}^{K_{\bullet}} \mathrm{AUBC}_{q, s, k}}{K_{\bullet}},
$$
where $K_{\bullet} \in {K_{S}, K_{L}}$.
"""

mean_aubc_qs = pd.DataFrame(index=data_dict.keys(),
                            columns=qs_dict.keys(), data=np.nan)
std_aubc_qs = pd.DataFrame(index=data_dict.keys(),
                           columns=qs_dict.keys(), data=np.nan)
for qs_data in qs_data_aubc_dict:
    qs, data = qs_data
    if qs_data_aubc_dict[qs_data] is None:
        continue
    else:
        qs_data_aubc_curr = qs_data_aubc_dict[qs_data]

    mean_aubc_qs.loc[data, qs] = qs_data_aubc_curr.mean()
    std_aubc_qs.loc[data, qs] = qs_data_aubc_curr.std()

# nlargest 3 qs for each data without epsus
aubc_nlargest_qs = mean_aubc_qs.apply(
    lambda x: x.nlargest(3).index, axis=1)
# split values in aubc_nlargest_qs to 3 columns
aubc_nlargest_qs = pd.DataFrame(aubc_nlargest_qs.values.tolist(
), index=aubc_nlargest_qs.index, columns=['¹', '²', '³'])
# export mean and std abucs with style with format "{:.2%}"
mean_std_aubc = mean_aubc_qs.map(
    '{:.2%}'.format) + '±' + std_aubc_qs.map('{:.2%}'.format)
# highlight nlargest 3 qs for each data with superscript 1, 2, 3
for data in mean_std_aubc.index:
    mean_std_aubc.loc[data, aubc_nlargest_qs.loc[data]
                                  ] = mean_std_aubc.loc[data, aubc_nlargest_qs.loc[data]] + aubc_nlargest_qs.columns

# reorder columns and replace column names with full names by qs_dict
mean_std_aubc.columns = mean_std_aubc.columns.map(
    qs_dict)
mean_std_aubc.index = mean_std_aubc.index.map(
    data_dict)
# reorder columns in mean_std_aubc_qs by ordered_qs
mean_std_aubc = mean_std_aubc[ordered_qs]
mean_std_aubc.to_csv('github/mean_std_aubc.csv')
# export best and worst qs for each data
# format: data, uniform, best qs, best aubc, worst qs, worst aubc
best_worst_qs = pd.DataFrame(index=data_dict.keys(), columns=[
                             'Uniform', 'BEST_QS', 'BEST', 'WORST_QS', 'WORST'], data=np.nan)
best_worst_qs['Uniform'] = mean_aubc_qs['uniform']
best_worst_qs['BEST_QS'] = mean_aubc_qs.idxmax(axis=1).map(qs_dict)
best_worst_qs['BEST'] = mean_aubc_qs.max(axis=1)
best_worst_qs['WORST_QS'] = mean_aubc_qs.idxmin(axis=1).map(qs_dict)
best_worst_qs['WORST'] = mean_aubc_qs.min(axis=1)
# export best_worst_qs to latex table
best_worst_qs.index = best_worst_qs.index.map(data_dict)
best_worst_qs[['Uniform', 'BEST', 'WORST']] = best_worst_qs[[
    'Uniform', 'BEST', 'WORST']].map('{:.2%}'.format)
best_worst_qs.to_csv('github/bestworst.csv')
best_worst_qs.to_latex('tables/bestworst.tex', escape=True)

""" mean and std improvement from uniform to each query strategy for each dataset.
## Table. Improvement Table

$$
\tau_{q, s, k} = \mathrm{AUBC}_{q, s, k} - \mathrm{AUBC}_{\text{uniform}, s, k}
$$
"""
data_uniform_key = [
    data_qs for data_qs in qs_data_aubc_dict.keys() if 'uniform' in data_qs]
qs_data_tau_dict = {}
for data in data_dict:
    uniform_data_aubc = qs_data_aubc_dict[('uniform', data)]
    for qs in qs_dict:
        if qs == 'uniform':
            continue
        qs_data_tau_dict[(qs, data)] = None
        aubcs_curr = None
        aubcs_curr = qs_data_aubc_dict[(qs, data)]
        if aubcs_curr is not None:
            qs_data_tau_dict[(qs, data)] = aubcs_curr - uniform_data_aubc
        else:
            print(f'No tau of ({qs}, {data})')

mean_tau_qs = pd.DataFrame(index=data_dict.keys(),
                           columns=qs_dict.keys(), data=np.nan)
std_tau_qs = pd.DataFrame(index=data_dict.keys(),
                          columns=qs_dict.keys(), data=np.nan)
for qs_data in qs_data_tau_dict:
    qs, data = qs_data
    qs_data_tau_curr = None
    qs_data_tau_curr = qs_data_tau_dict[qs_data]
    if qs_data_tau_curr is None:
        continue

    mean_tau_qs.loc[data, qs] = qs_data_tau_curr.mean()
    std_tau_qs.loc[data, qs] = qs_data_tau_curr.std()

# pairwise t-test for qs and uniform pairs
ttest_pval_qs = pd.DataFrame(
    index=data_dict.keys(), columns=qs_dict.keys(), data=np.nan)
for data in data_dict:
    uniform_data_aubc = qs_data_aubc_dict[('uniform', data)]
    for qs in qs_dict:
        if qs == 'uniform':
            continue
        qs_data_tau_dict[(qs, data)] = None
        aubcs_curr = None
        aubcs_curr = qs_data_aubc_dict[(qs, data)]
        if aubcs_curr is not None:
            _, pval = stats.ttest_rel(
                uniform_data_aubc, aubcs_curr, alternative='less')
        else:
            print(f'No tau of ({qs}, {data})')
            continue
        ttest_pval_qs.loc[data, qs] = pval

alpha_95 = 0.05
# highlight pval < alpha_95
test_pval_qs = ttest_pval_qs.applymap(lambda x: 1 if x < alpha_95 else 0)
# remove uniform, epsus
test_pval_qs = test_pval_qs.drop(['uniform'], axis=1)
test_pval_qs.columns = test_pval_qs.columns.map(qs_dict)
tau_data_qs_cnt_qsview = test_pval_qs.sum(axis=0)
tau_data_qs_cnt_dataview = test_pval_qs.sum(axis=1)
# plot horizontal bar chart of pval < alpha_95 by qs axis in test_pval_qs
fig, (ax_qs, ax_data) = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
tau_data_qs_cnt_qsview.plot(kind='barh', ax=ax_qs)
tau_data_qs_cnt_dataview.plot(kind='barh', ax=ax_data)

ax_qs.set_title('Query strategy aspect', fontsize=24)
ax_qs.set_ylabel('')
datalim = len(data_dict)
ax_qs.set_xlim(0, datalim)
ax_qs.invert_yaxis()
ax_qs.tick_params(axis='both', which='major', labelsize=20)

# Annotate each bar with its value
for index, value in enumerate(tau_data_qs_cnt_qsview):
    ax_qs.text(value + 0.5, index, str(value), va='center', fontsize=18)

ax_data.set_title('Dataset aspect', fontsize=24)
ax_data.set_ylabel('')
ax_data.invert_yaxis()
qslim = len(qs_dict)
ax_data.set_xlim(0, qslim)
ax_data.tick_params(axis='both', which='major', labelsize=20)

# Annotate each bar with its value
for index, value in enumerate(tau_data_qs_cnt_dataview):
    ax_data.text(value + 0.5, index, str(value), va='center', fontsize=18)

plt.subplots_adjust(left=None, bottom=None, right=None,
                    top=None, wspace=0.7, hspace=None)

fig.savefig('images/n_QSgtRS.pdf', bbox_inches='tight')
plt.clf()

# nlargest 3 qs for each data
tau_nlargest_qs = mean_tau_qs.apply(
    lambda x: x.nlargest(3).index, axis=1)
# split values in tau_nlargest_qs to 3 columns
tau_nlargest_qs = pd.DataFrame(tau_nlargest_qs.values.tolist(
), index=tau_nlargest_qs.index, columns=['¹', '²', '³'])
# export mean and std abucs with style with format "{:.2%}"
mean_std_tau_qs_wo_epsus = mean_tau_qs.map(
    '{:.2%}'.format) + '±' + std_tau_qs.map('{:.2%}'.format)
# highlight nlargest 3 qs for each data with superscript 1, 2, 3
for data in mean_std_tau_qs_wo_epsus.index:
    mean_std_tau_qs_wo_epsus.loc[data, tau_nlargest_qs.loc[data]
                                 ] = mean_std_tau_qs_wo_epsus.loc[data, tau_nlargest_qs.loc[data]] + tau_nlargest_qs.columns

# reorder columns and replace column names with full names by qs_dict
mean_std_tau_qs_wo_epsus.columns = mean_std_tau_qs_wo_epsus.columns.map(
    qs_dict)
mean_std_tau_qs_wo_epsus.index = mean_std_tau_qs_wo_epsus.index.map(data_dict)
mean_std_tau_qs_wo_epsus = mean_std_tau_qs_wo_epsus[ordered_qs]
mean_std_tau_qs_wo_epsus = mean_std_tau_qs_wo_epsus.drop('Uniform', axis=1)
mean_std_tau_qs_wo_epsus.to_csv('github/mean_std_tau.csv')

# compare win/tie/loss for each qs pair except uniform
data_qss_aubc_dict = {}
data_qss_aubc_rank_dict = {}
for data in data_dict:
    data_keys = [d for d in list(qs_data_aubc_dict.keys()) if d[1] == data]
    data_qss_aubc_list = []
    curr_col_qs = []
    for d in data_keys:
        if qs_data_aubc_dict[d] is None:
            print(d)
        else:
            curr_col_qs.append(d[0])
            qs_data_aubc_curr = qs_data_aubc_dict[d]
            qs_data_aubc_curr = qs_data_aubc_curr[~qs_data_aubc_curr.index.duplicated(
                keep='first')]
            data_qss_aubc_list.append(qs_data_aubc_curr)
    # data_qss_aubc_list = [qs_data_aubc_dict[d]['res_tst_score'] for d in data_keys]
    # data_qss_aubc_list = [df[~df.index.duplicated(keep='first')] for df in data_qss_aubc_list]
    data_qss_aubc_dict[data] = pd.concat(data_qss_aubc_list, axis=1)
    data_qss_aubc_dict[data].columns = curr_col_qs
    data_qss_aubc_dict[data] = data_qss_aubc_dict[data].loc[valid_res_expno_dict[data]]
    data_qss_aubc_dict[data] = data_qss_aubc_dict[data].drop('uniform', axis=1)
    # rank
    data_qss_aubc_rank_dict[data] = data_qss_aubc_dict[data].rank(
        axis=1, ascending=False, method='first')

# mean rank for each qs pair except uniform
mean_rank_qs = pd.DataFrame(index=data_dict.keys(),
                            columns=qs_dict.keys(), data=np.nan)
for data in data_dict:
    mean_rank_qs.loc[data] = data_qss_aubc_rank_dict[data].mean()

mean_rank_qs.columns = mean_rank_qs.columns.map(qs_dict)
mean_rank_qs.index = mean_rank_qs.index.map(data_dict)
# reorder columns in mean_rank_qs by ordered_qs
mean_rank_qs = mean_rank_qs[ordered_qs]
# nsmallest 3 qs for each data on mean_rank_qs
rank_nsmallest_qs = mean_rank_qs.apply(lambda x: x.nsmallest(3).index, axis=1)
# split values in rank_nsmallest_qs to 3 columns
rank_nsmallest_qs = pd.DataFrame(rank_nsmallest_qs.values.tolist(
), index=rank_nsmallest_qs.index, columns=['¹', '²', '³'])
# export mean and std abucs with style with format "{:.2%}"
mean_rank_qs = mean_rank_qs.map('{:.2f}'.format)
# highlight nsmallest 3 qs for each data with superscript 1, 2, 3
for data in mean_rank_qs.index:
    mean_rank_qs.loc[data, rank_nsmallest_qs.loc[data]] = mean_rank_qs.loc[data,
                                                                           rank_nsmallest_qs.loc[data]] + rank_nsmallest_qs.columns

# drop uniform
mean_rank_qs = mean_rank_qs.drop('Uniform', axis=1)
mean_rank_qs.to_csv('github/mean_rank.csv')
mean_rank_qs.to_latex('tables/mean_rank.tex', escape=True)