import os
import csv
import json
import pandas as pd
import numpy as np

qs_dict = {
    'uniform': 'Uniform',
    'qbc': 'QBC', 'hintsvm': 'HintSVM', 'quire': 'QUIRE', 'albl': 'ALBL', 'dwus': 'DWUS', 'kcenter': 'Core-Set',  # libact
    'margin': 'US', 'graph': 'Graph', 'hier': 'Hier', 'mcm': 'MCM',  # google
    'lal': 'LAL',  # alipy
    'skactiveml_bald': 'BALD',  # scikit-activeml
}
ordered_qs = ['Uniform', 'US', 'QBC', 'BALD', 'Hier', 'Graph',
              'Core-Set', 'HintSVM', 'QUIRE',  'DWUS', 'MCM', 'ALBL', 'LAL', ]
small_data_list = ["appendicitis", "sonar", "parkinsons", "ex8b", "heart", "haberman", "ionosphere", "clean1",
                   "breast", "wdbc", "australian", "diabetes", "mammographic", "ex8a", "tic", "german",
                   "splice", "gcloudb", "gcloudub", "checkerboard"]
large_data_list = ["spambase", "banana",
                   "phoneme", "ringnorm", "twonorm", "phishing"]
data_list = small_data_list + large_data_list
data_dict = {d: d.capitalize() for d in data_list}

# 設定檔案目錄和要統整的dataset列表
file_prefix = 'random-forest/aubc/'
file_suffix = '-zhan-RandomForest-RandomForest-RS_noFix_scale-aubc.csv'
file_suffix_skal = '-RandomForest-RandomForest-RS_noFix_scale-aubc.csv'

qs_name_wo_zhan = ['skactiveml_bald', 'eps_greedyeps=0.1',
                   'eps_greedyeps=0.2', 'eps_greedyeps=0.3']
undone_qs_data = []
infeasible_qs_data = [('quire', 'spambase'),
                      ('quire', 'twonorm'), ('quire', 'phishing'),]
skip_qs_data = undone_qs_data + infeasible_qs_data

# collect valid res_expno for each data as index for dataframe in qs_data_aubc_dict by qs = 'uniform'
valid_res_expno_dict = {}
for data in data_dict:
    valid_res_expno_dict[data] = None
    # sort index from small to large
    uniform_data_aubc_curr = pd.read_csv(
        f'{file_prefix}{data}-uniform{file_suffix}', index_col=0)
    if data in small_data_list:
        valid_res_expno_dict[data] = uniform_data_aubc_curr.index.values[:100]
    else:
        valid_res_expno_dict[data] = uniform_data_aubc_curr.index.values[:10]
json_data = []

qs_data_aubc_list = []
for qs in qs_dict:
    for data in data_dict:
        if (qs, data) in skip_qs_data:
            continue
        aubcs_curr = None
        try:
            if qs in qs_name_wo_zhan:
                aubcs_curr = pd.read_csv(
                    f'{file_prefix}{data}-{qs}{file_suffix_skal}', index_col=0)
            else:
                aubcs_curr = pd.read_csv(
                    f'{file_prefix}{data}-{qs}{file_suffix}', index_col=0)
            try:
                aubcs_curr = aubcs_curr.loc[valid_res_expno_dict[data],
                                            'res_tst_score']
                aubcs_curr = aubcs_curr[~aubcs_curr.index.duplicated(keep='last')]
                aubcs_curr = aubcs_curr.sort_index()
                aubcs_curr = aubcs_curr.reset_index()
                aubcs_curr['data'] = data
                aubcs_curr['qs'] = qs
                qs_data_aubc_list.append(aubcs_curr)
            except:
                print(f'Exps of ({qs}, {data}) is not complete.')
                print('Index:', np.setdiff1d(
                    valid_res_expno_dict[data], aubcs_curr.index))
        except:
            print(f'No result {qs}, {data}.')

# concat all qs_data_aubc to a dataframe
qs_data_aubc = pd.concat(qs_data_aubc_list)
qs_data_aubc.columns = ['seed', 'aubc', 'data', 'qs']
qs_data_aubc = qs_data_aubc[['data', 'seed', 'qs', 'aubc']]
data_seed_idx = qs_data_aubc.groupby(['data', 'seed']).size().index
from itertools import product
qs_pairs = [(qs_a, qs_b) for qs_a, qs_b in list(product(qs_dict, qs_dict)) if qs_a != qs_b]
qs_data_battle = pd.DataFrame(index=data_seed_idx, columns=['qs_a', 'qs_b', 'aubc_a', 'aubc_b'], data=np.nan)
breakpoint()
