
import os
import json
import warnings
import argparse
import pandas as pd
from scipy import stats
import scikit_posthocs as sp
import matplotlib.pyplot as plt
import seaborn as sns

# from cliffs_delta import cliffs_delta

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--task', type=str, default='acc_loss', help='Description of task')
parser.add_argument('--dataset_path', type=str, default='../Records/Hyperband/ImageNet', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=3, help='Fraction of saving in hyperband')
args = parser.parse_args()

max_iters = args.max_iters
ult_objs = args.ult_objs
dataset_path = args.dataset_path
if 'Fashion-MNIST' in dataset_path or 'higgs' in dataset_path or 'adult' in dataset_path or 'jasmine' in dataset_path or 'vehicle' in dataset_path or 'volkert' in dataset_path:
    max_iters = [3, 6, 10, 15, 30] 
file_name = args.task
eta = args.eta

import matplotlib.patches as mpatches

labels = []
def add_label(violin, label):
    color = violin["bodies"][0].get_facecolor().flatten()
    labels.append((mpatches.Patch(color=color), label))


def draw_significance(ax, pvalues, cta1_pos):
    # Get the y-axis limits
    bottom, top = ax.get_ylim()
    y_range = top - bottom
    if y_range < 7:
        param = 0.15
    else:
        param = 0.09
    max_bar_height = 0
    for i, (p, pos) in enumerate(pvalues):

        if p < 0.001:
            sig_symbol = '***'
        elif p < 0.01:
            sig_symbol = '**'
        elif p < 0.05:
            sig_symbol = '*'
        else:
            continue

        if pos < cta1_pos:
            pos1 = pos
            pos2 = cta1_pos - 0.05
        else:
            pos1 = cta1_pos + 0.05
            pos2 = pos
        # What level is this bar among the bars above the plot?
        level = round((pos2 - pos1) / interval)
        # Plot the bar
        bar_height = (y_range * param * level) + top * 0.991
        bar_tips = bar_height - (y_range * 0.02)
        ax.plot(
            [pos1, pos1, pos2, pos2],
            [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k'
        )

        text_height = bar_height - 0.05 * y_range
        ax.text((pos1 + pos2) * 0.5, text_height, sig_symbol, ha='center', va='bottom', c='k')
        
        if bar_height > max_bar_height:
            max_bar_height = bar_height
    ax.set_ylim([bottom, max_bar_height + param * y_range])


label_dic = {'train_accuracy': 'Training accuracy',
             'valid_accuracy': 'Validation accuracy',
             'train_losses': 'Training loss',
             'valid_losses': 'Validation loss'}



fig, axs = plt.subplots(6, 5, figsize=(10, 9))
interval=0.6

cta1 = 'valid_losses'
pathes = ['../Records/Hyperband/Fashion-MNIST', '../Records/Hyperband/adult', '../Records/Hyperband/higgs', 
          '../Records/Hyperband/jasmine', '../Records/Hyperband/vehicle', '../Records/Hyperband/volkert']
eta = 3
max_iters = [3, 6, 10, 15, 30] 
for pt, dataset_path in enumerate(pathes):
    for m, iter in enumerate(max_iters):
        dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_test_accuracy")
        file = os.path.join(dir, f"{file_name}.csv")
        df = pd.read_csv(file)

        pvalues = []
        cta1_pos = 0

        for i, col in enumerate(df.columns[1:]):
            data = df[col]
            pos = [0.5+interval*i]
            vio = axs[pt, m].violinplot(data, positions=pos,
                            showmeans=False,
                            showmedians=True)
            if pt == 0 and m == 0:
                add_label(vio, label_dic[col])
            # Pair-wise wilcoxon test
            if col != cta1:
                d = df[cta1] - df[col]
                _, p = stats.wilcoxon(d)
                pvalues.append([p, pos[0]])
                print(f'cta1 = {cta1}, cta2 = {col}, p = {p}')
            else:
                cta1_pos = pos[0]
        name = dataset_path.split('/')[1]
        axs[pt, m].set_xlabel(rf'{name} ($R$={iter})')
        draw_significance(axs[pt, m], pvalues, cta1_pos)

        axs[pt, m].yaxis.grid(True)
        axs[pt, m].set_xticks([])

    axs[pt, 0].set_ylabel('Test accuracy %')
    

plt.legend(*zip(*labels), bbox_to_anchor=(-4.5,7.5), loc=3, ncol = 4, framealpha = 0)
plt.subplots_adjust(left=0.062, right=0.99, top=0.97, bottom=0.03, wspace =0.3, hspace=0.3)
plt.savefig(f'app_hpo_train_val_eta_{eta}.png')
plt.savefig(f'app_hpo_train_val_eta_{eta}.pdf')