
import os
import math
import json
import warnings
import argparse
import pandas as pd
# from scipy import stats
# import scikit_posthocs as sp
import matplotlib.pyplot as plt
# import seaborn as sns
import numpy as np
from scipy.stats import rankdata

# from cliffs_delta import cliffs_delta

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--task', type=str, default='acc_loss', help='Description of task')
parser.add_argument('--dataset', type=str, default='ImageNet', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=3, help='Fraction of saving in hyperband')
args = parser.parse_args()

ult_objs = args.ult_objs
dataset = args.dataset
dataset_path = os.path.join('../../Records/Hyperband/', dataset)
max_iters = args.max_iters
if not dataset in ('ImageNet', 'Cifar10', 'Cifar100'):
    max_iters = [3, 6, 10, 15, 30]

file_name = args.task
eta = args.eta
rounds = 1000

import matplotlib.patches as mpatches



label_dic = {'train_accuracy': 'Training accuracy',
             'train_losses': 'Training loss',
             'valid_accuracy': 'Validation accuracy',
             'valid_losses': 'Validation loss'}

title_dic = {'Cifar10': 'CIFAR-10',
             'Cifar100': 'CIFAR-100',
             'ImageNet': 'ImageNet-16-120'}

mean_rank_dict = dict()
for i, iter in enumerate(max_iters):
    mean_rank_dict[iter] = dict()
    for criteria in label_dic.keys():
        mean_rank_dict[iter][criteria] = dict()
        dir = os.path.join(dataset_path, f'Max_iter_{iter}_eta_{eta}', 
                           "config_rsts", criteria)
        mean_ranks_avg = None
        min_regrets_avg = None
        mean_accs_avg = None
        for rid, r in enumerate(range(rounds)):
            file = os.path.join(dir, f'record{r}.csv')
            df = pd.read_csv(file)

            # round the iteration number
            df['n_iteration'] = df['n_iteration'].apply(lambda x: math.floor(x))

            # the last one is the finally obtained one
            final_obtain_rst = df.iloc[-1]
            df = df.iloc[:-1]

            # Get the overall best final accuracy
            df['test_accuracy'] = pd.to_numeric(df['test_accuracy'], downcast='float')
            best_acc = df['test_accuracy'].max()

            # Get all configuration samples
            samples = df[df['n_iteration'] == df.iloc[0]['n_iteration']].copy()
            for s in df['s'].unique()[1:]:
                s_slice = df[df['s'] == s]
                samples = pd.concat([samples, s_slice[s_slice['n_iteration'] == s_slice.iloc[0]['n_iteration']].copy()])
            samples['Rank'] = samples['test_accuracy'].rank(ascending=False, method='first') 
            
            mean_ranks = np.zeros(df['n_iteration'].unique().shape[0])
            min_regrets = np.zeros(df['n_iteration'].unique().shape[0])
            mean_accs = np.zeros(df['n_iteration'].unique().shape[0])
            for id, sh in enumerate(df['n_iteration'].unique()[1:]):
                remain = df[df['n_iteration'] >= sh]
                ranks = 0
                for idx, row in remain.iterrows():
                    ranks += samples[samples['test_accuracy'] == row['test_accuracy']]['Rank'].values[0]
                mean_ranks[id] = ranks / remain.shape[0]
                min_regrets[id] = best_acc - remain['test_accuracy'].max()
                mean_accs[id] = remain['test_accuracy'].mean()
            mean_ranks[-1] = samples[samples['test_accuracy'] == final_obtain_rst['test_accuracy']]['Rank'].values[0]
            min_regrets[-1] = best_acc - final_obtain_rst['test_accuracy']
            mean_accs[-1] = final_obtain_rst['test_accuracy']
            # print(mean_ranks)

            if rid == 0:
                mean_ranks_avg = mean_ranks
                min_regrets_avg = min_regrets
                mean_accs_avg = mean_accs
            else:
                mean_ranks_avg += mean_ranks
                min_regrets_avg += min_regrets
                mean_accs_avg += mean_accs
            # break
        mean_ranks_avg = mean_ranks_avg / rounds
        min_regrets_avg = min_regrets_avg / rounds
        mean_accs_avg = mean_accs_avg / rounds
        mean_rank_dict[iter][criteria]['Mean-Rank'] = mean_ranks_avg.tolist()
        mean_rank_dict[iter][criteria]['Min-Regret (acc)'] = min_regrets_avg.tolist()
        mean_rank_dict[iter][criteria]['Mean-Acc'] = mean_accs_avg.tolist()

with open(f'avg_rank_{dataset}.json', 'w') as file:
    json.dump(mean_rank_dict, file)

fig, axs = plt.subplots(1, 5, figsize=(10, 2))

with open(f'avg_rank_{dataset}.json', 'r') as file:
    mean_rank_dict = json.load(file)

for i, iter in enumerate(max_iters):
    for criteria in label_dic.keys():
        axs[i].plot(mean_rank_dict[str(iter)][criteria]['Min-Regret (acc)'], label=label_dic[criteria])

axs[-1].legend( bbox_to_anchor=(-4,1), loc=3, ncol = 4, framealpha = 0)
plt.subplots_adjust(left=0.05, right=0.99, top=0.84, bottom=0.125, wspace =0.22, hspace=0.45)

plt.savefig(f'motiv_{dataset}.png')
plt.savefig(f'motiv_{dataset}.pdf')


# fig, axs = plt.subplots(1, 4, figsize=(10, 2), gridspec_kw={'width_ratios':[1.5,1.5,1.5,1]})
fig, axs = plt.subplots(1, 5, figsize=(10, 1.7))

for d, dataset in enumerate(['ImageNet', 'Cifar10', 'Cifar100', 'Fashion-MNIST', 'adult']):
    with open(f'avg_rank_{dataset}.json', 'r') as file:
        mean_rank_dict = json.load(file)
    for criteria in label_dic.keys():
        if dataset in title_dic.keys():
            axs[d].plot(mean_rank_dict["150"][criteria]['Min-Regret (acc)'], label=label_dic[criteria])
        else:
            axs[d].plot(mean_rank_dict["30"][criteria]['Min-Regret (acc)'], label=label_dic[criteria])
    if dataset in title_dic.keys():
        axs[d].set_title(title_dic[dataset], fontsize=11)
        axs[d].set_xticks([0, 2, 4, 6])
        axs[d].set_xticklabels([r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^0$'])
    else:
        axs[d].set_title(dataset, fontsize=11)
        axs[d].set_yticks([0.2, 0.5])
        axs[d].set_xticks([0, 1, 2, 3])
        axs[d].set_xticklabels([r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^0$'])
    axs[d].set_xlabel("Fraction of budget")
    axs[d].grid("x", lw=0.6)
    axs[d].grid("y", lw=0.6)
axs[0].set_yticks([0, 0.5, 1, 1.5])
axs[1].set_yticks([0, 0.5, 1])
axs[2].set_yticks([0, 1, 2])
axs[0].set_ylabel("Min regret")
axs[-1].legend( bbox_to_anchor=(-4,1.16), loc=3, ncol = 4, framealpha = 0)
plt.subplots_adjust(left=0.055, right=0.99, top=0.75, bottom=0.285, wspace =0.22, hspace=0.45)

plt.savefig(f'motiv1.png')
plt.savefig(f'motiv1.pdf')
