
import os
import json
import warnings
import argparse
import pandas as pd
# from scipy import stats
# import scikit_posthocs as sp
import matplotlib.pyplot as plt
# import seaborn as sns
import numpy as np

# from cliffs_delta import cliffs_delta

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--task', type=str, default='acc_loss', help='Description of task')
parser.add_argument('--dataset_path', type=str, default='../../Records/Hyperband/ImageNet', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=3, help='Fraction of saving in hyperband')
args = parser.parse_args()

max_iters = args.max_iters
ult_objs = args.ult_objs
dataset_path = args.dataset_path

file_name = args.task
eta = args.eta


import matplotlib.patches as mpatches



label_dic = {'train_accuracy': 'Training accuracy',
             'train_losses': 'Training loss',
             'valid_accuracy': 'Validation accuracy',
             'valid_losses': 'Validation loss'}

title_dic = {'Cifar10': 'CIFAR-10',
             'Cifar100': 'CIFAR-100',
             'ImageNet': 'ImageNet-16-120'}

fig, axs = plt.subplots(2, 5, figsize=(10, 3))

data_dir = '../../Records/Hyperband/'
for idx, dataset in enumerate(['Cifar10', 'Cifar100', 'ImageNet', 'Fashion-MNIST', 'jasmine']):
    dataset_path = data_dir + dataset
    if 'Fashion-MNIST' in dataset_path or 'higgs' in dataset_path or 'adult' in dataset_path or 'jasmine' in dataset_path or 'vehicle' in dataset_path or 'volkert' in dataset_path:
        max_iters = [3, 6, 10, 15, 30] 
    else:
        max_iters = [20, 35, 50, 65, 81, 95, 110, 120, 135, 150]

    mean_values = dict()
    # std_values = dict()
    for col in label_dic.keys():
        mean_values[col] = np.zeros(len(max_iters))
    for i, iter in enumerate(max_iters):
        dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_test_accuracy")
        file = os.path.join(dir, f"acc_loss.csv")
        df = pd.read_csv(file)
        for col in label_dic.keys():
            mean_values[col][i] = df[col].mean()
    for col in label_dic.keys():
        if 'train_accuracy' in col:
            axs[0,idx].plot(max_iters, mean_values[col], label=label_dic[col], lw=1, marker='+')
        else:
            axs[0,idx].plot(max_iters, mean_values[col], label=label_dic[col], lw=1)
            
    axs[0,idx].grid("x", lw=0.6)
    axs[0,idx].grid("y", lw=0.6)
    axs[0,idx].set_xlabel("Budget")
    if dataset in title_dic.keys():
        axs[0,idx].set_title(title_dic[dataset], fontsize=11)
    else:
        axs[0,idx].set_title(dataset, fontsize=11)

axs[0,0].set_ylabel("Test accuracy")
axs[0,2].set_yticks([42, 45])
axs[0,-1].legend( bbox_to_anchor=(-4,1.15), loc=3, ncol = 4, framealpha = 0)


for d, dataset in enumerate(['ImageNet', 'Cifar10', 'Cifar100', 'Fashion-MNIST', 'adult']):
    with open(f'avg_rank_{dataset}.json', 'r') as file:
        mean_rank_dict = json.load(file)
    for criteria in label_dic.keys():
        if dataset in title_dic.keys():
            if 'train_accuracy' in criteria:
                axs[1,d].plot(mean_rank_dict["150"][criteria]['Min-Regret (acc)'], label=label_dic[criteria], marker='+')
            else:
                axs[1,d].plot(mean_rank_dict["150"][criteria]['Min-Regret (acc)'], label=label_dic[criteria])
        else:
            if 'train_accuracy' in criteria:
                axs[1,d].plot(mean_rank_dict["30"][criteria]['Min-Regret (acc)'], label=label_dic[criteria], marker='+')
            else:
                axs[1,d].plot(mean_rank_dict["30"][criteria]['Min-Regret (acc)'], label=label_dic[criteria])
    if dataset in title_dic.keys():
        axs[1,d].set_xticks([0, 2, 4, 6])
        axs[1,d].set_xticklabels([r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^0$'])
    else:
        axs[1,d].set_yticks([0.2, 0.5])
        axs[1,d].set_xticks([0, 1, 2, 3])
        axs[1,d].set_xticklabels([r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^0$'])
    axs[1,d].set_xlabel("Fraction of budget")
    axs[1,d].grid("x", lw=0.6)
    axs[1,d].grid("y", lw=0.6)

axs[1,0].set_yticks([0, 0.5, 1, 1.5])
axs[1,1].set_yticks([0, 0.5, 1])
axs[1,2].set_yticks([0, 1, 2])
axs[1,0].set_ylabel("Min regret")
plt.subplots_adjust(left=0.065, right=0.99, top=0.85, bottom=0.17, wspace =0.22, hspace=0.7)
plt.savefig('train_val.png')
plt.savefig('train_val.pdf')
