
import os
import json
import warnings
import argparse
import pandas as pd
from scipy import stats
import scikit_posthocs as sp
import matplotlib.pyplot as plt
import numpy as np

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--dataset_path', type=str, default='../../Records/Hyperband/ImageNet', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=4, help='Fraction of saving in hyperband')
args = parser.parse_args()


max_iters = args.max_iters
ult_objs = args.ult_objs
dataset_path = args.dataset_path
file_name1 = "acc_loss"
file_name2 = "dyn_win_size"
eta = args.eta



import matplotlib.cm as cm  # Import the color map module

interval = 0.6
def draw_significance(ax, pvalues):
    # Get the y-axis limits
    bottom, top = ax.get_ylim()
    y_range = top - bottom
    max_bar_height = 0
    for i, (p, pos1, pos2) in enumerate(pvalues):

        if p < 0.001:
            sig_symbol = '***'
        elif p < 0.01:
            sig_symbol = '**'
        elif p < 0.05:
            sig_symbol = '*'
        else:
            continue

        # What level is this bar among the bars above the plot?
        level = round((pos2 - pos1) / interval)
        # Plot the bar
        bar_height = (y_range * 0.1 * level) + top * 0.993
        bar_tips = bar_height - (y_range * 0.02)
        ax.plot(
            [pos1, pos1, pos2, pos2],
            [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k'
        )

        text_height = bar_height - 0.04 * y_range
        ax.text((pos1 + pos2) * 0.5, text_height, sig_symbol, ha='center', va='bottom', c='k')
        
        if bar_height > max_bar_height:
            max_bar_height = bar_height
    ax.set_ylim([bottom, max(top+0.3, max_bar_height)])


num_lines = 20
color_palette = cm.get_cmap('tab20', num_lines)
line_colors = color_palette(range(num_lines))

file_name = 'acc_loss'
fig, axs = plt.subplots(1, 2, figsize=(6.3, 2.5))
dataset_path = '../../Records/Hyperband/Cifar10'
for a, iter in enumerate([50, 81]):
    dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_test_accuracy")
    file = os.path.join(dir, f"{file_name}.csv")
    df = pd.read_csv(file)
    ##### Random seeds ######
    file_seed = os.path.join(dir, f"acc_loss_seed.csv")
    df_seed = pd.read_csv(file_seed)
    pvalues = []
    for i, col in enumerate(df.columns[1:]):
        col_seed = col + '_seed'
        data = df[col]
        data_seed = df_seed[col_seed]
        pos = [0.4+interval*i*2]
        pos_seed = [0.4+interval*(i*2+1)]

        # Outlier detection !!!!!!
        z_scores = np.abs(stats.zscore(data))
        threshold = 3  # Adjust this threshold as needed
        outliers = data[(z_scores > threshold)]
        vio1 = axs[a].violinplot(data.drop(outliers.index), positions=pos, showmeans=False, showmedians=True)   # remove outlier!
        z_scores = np.abs(stats.zscore(data_seed))
        threshold = 3  # Adjust this threshold as needed
        outliers = data_seed[(z_scores > threshold)]
        vio2 = axs[a].violinplot(data_seed.drop(outliers.index), positions=pos_seed, showmeans=False, showmedians=True)   # remove outlier!
        
        # Customize violin plots
        for pc in vio1['bodies']:
            pc.set_facecolor(line_colors[i*2])
            pc.set_edgecolor(line_colors[i*2])
            pc.set_alpha(0.4)
        for pc in vio2['bodies']:
            pc.set_facecolor(line_colors[i*2+1])
            pc.set_edgecolor(line_colors[i*2+1])
            # pc.set_alpha(0.4)
        
        vio1['cmaxes'].set_color(line_colors[i*2])
        vio1['cmins'].set_color(line_colors[i*2])
        vio1['cmedians'].set_color(line_colors[i*2])
        vio1['cbars'].set_color(line_colors[i*2])
        vio2['cmaxes'].set_color(line_colors[i*2])
        vio2['cmins'].set_color(line_colors[i*2])
        vio2['cmedians'].set_color(line_colors[i*2])
        vio2['cbars'].set_color(line_colors[i*2])
        # axs[a].vlines(pos, min(data), max(data), color=line_colors[i*2], linestyle='-', lw=1)
        # axs[a].vlines(pos_seed, min(data_seed), max(data_seed), color=line_colors[i*2+1], linestyle='-', lw=1)
        if a == 0:
            add_label(vio1, label_dic[col])
            add_label(vio2, label_dic[col_seed])
        # Pair-wise wilcoxon test
        d = data - data_seed
        _, p = stats.wilcoxon(d)
        pvalues.append([p, pos[0], pos_seed[0]])
    draw_significance(axs[a], pvalues)
    
axs[0].set_ylabel("Test accuracy %", fontsize=14)
# axs[1].set_ylabel("Test accuracy %")
axs[0].set_xticks([])
axs[1].set_xticks([])
axs[0].set_xlabel(r"CIFAR-10 ($R=50$)", fontsize=14)
axs[1].set_xlabel(r"CIFAR-10 ($R=81$)", fontsize=14)

plt.legend(*zip(*labels), bbox_to_anchor=(-1.45,1), loc=3, ncol = 4, fontsize=12,
           framealpha = 0, handlelength=0.8, labelspacing=0.1, columnspacing=0.4)
plt.subplots_adjust(left=0.09, right=0.98, top=0.75, bottom=0.12, wspace =0.18, hspace=0.45)

plt.savefig('random_seed.pdf')
plt.savefig('random_seed.png')
# axs[0].set_xticklabels(np.arange(2, 7))
# axs[1].set_xticklabels(np.arange(2, 7))
# legend = axs[2].legend(bbox_to_anchor=(3,1.3,0,0), ncol=5, loc='upper right', framealpha=0.8, handlelength=3)

# axs[1].legend()
# dataset_path = '../../Records/Hyperband/Fashion-MNIST'
# file_name2 = 'dyn_win_size_'
# for iter in [10, 15]:
#     print(f"iter = {iter}")
#     dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_test_accuracy")
#     for file_name in ['valid_loss', 'train_loss']:
#         file = os.path.join(dir, f"dyn_win_size_{file_name}.csv")
#         df = pd.read_csv(file)
#         if file_name == 'valid_loss':
#             baseline = 'valid_losses'
#         elif file_name == 'train_loss':
#             baseline = 'train_losses'

#         i = 0
#         j = 0
#         a = 2
#         scat_xs = []
#         scat_ys = []
#         sigs = []
#         for cta2 in df.columns[2:]:
#             if 'sig' in cta2:
#                 continue
#             d = df[baseline] - df[cta2]
#             if not d.any():
#                 print(f"cta1 = {baseline}, cta2 = {cta2}, ALL zero")
#                 continue
#             res = stats.wilcoxon(d)
#             if res.pvalue < 0.01:
#                 data = df[cta2]
#                 pos = [0.4+interval*i]
#                 i += 1
#                 # vio = axs[0].violinplot(data, positions=pos,
#                 #                     showmeans=False,
#                 #                     showmedians=True)
#                 # if cta2 in label_dic:
#                 #     add_label(vio, label_dic[cta2])
#                 # else:
#                 #     add_label(vio, cta2)
#                 scat_xs.append(0.4+interval*j)
#                 scat_ys.append(-d.mean()/np.std(d))
#                 j += 1
#                 if res.pvalue < 0.001:
#                     sigs.append("***")
#                 elif res.pvalue < 0.05:
#                     sigs.append("**")
#                 elif res.pvalue < 0.01:
#                     sigs.append("*")
#         axs[a].scatter(scat_xs, scat_ys, marker='v', s=10, label=rf'R={iter}')
#         a += 1

    # break

# plt.legend(*zip(*labels), bbox_to_anchor=(-1,-1), loc=3, ncol = 4, framealpha = 0)