
import os
import json
import warnings
import argparse
import pandas as pd
from scipy import stats
import scikit_posthocs as sp
import matplotlib.pyplot as plt
import numpy as np

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--dataset_path', type=str, default='../../Records/Hyperband/ImageNet', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=4, help='Fraction of saving in hyperband')
args = parser.parse_args()


max_iters = args.max_iters
ult_objs = args.ult_objs
dataset_path = args.dataset_path
file_name1 = "acc_loss"
file_name2 = "dyn_win_size"
eta = args.eta


import matplotlib.patches as mpatches

labels = []
def add_label(violin, label):
    color = violin["bodies"][0].get_facecolor().flatten()
    labels.append((mpatches.Patch(color=color), label))
    
label_dic = {'train_accuracy': 'Train. acc.',
             'valid_accuracy': 'Valid. acc.',
             'train_losses': 'Train. loss',
             'valid_losses': 'Valid. loss',
             'train_accuracy_seed': 'Train. avg. acc.',
             'valid_accuracy_seed': 'Valid. avg. acc.',
             'train_losses_seed': 'Train. avg. loss',
             'valid_losses_seed': 'Valid. avg. loss'}


file_name = 'acc_loss'
file_name1 = 'dyn_win_size'


import matplotlib.cm as cm  # Import the color map module

interval = 0.6
def draw_significance(ax, pvalues):
    # Get the y-axis limits
    bottom, top = ax.get_ylim()
    y_range = top - bottom
    max_bar_height = 0
    for i, (p, pos1, pos2) in enumerate(pvalues):

        if p < 0.001:
            sig_symbol = '***'
        elif p < 0.01:
            sig_symbol = '**'
        elif p < 0.05:
            sig_symbol = '*'
        else:
            continue

        # What level is this bar among the bars above the plot?
        level = round((pos2 - pos1) / interval)
        # Plot the bar
        bar_height = (y_range * 0.1 * level) + top * 0.993
        bar_tips = bar_height - (y_range * 0.02)
        ax.plot(
            [pos1, pos1, pos2, pos2],
            [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k'
        )

        text_height = bar_height - 0.04 * y_range
        ax.text((pos1 + pos2) * 0.5, text_height, sig_symbol, ha='center', va='bottom', c='k')
        
        if bar_height > max_bar_height:
            max_bar_height = bar_height
    ax.set_ylim([bottom, max(top+0.3, max_bar_height)])


num_lines = 20
color_palette = cm.get_cmap('tab20', num_lines)
line_colors = color_palette(range(num_lines))

dataset_path = '../../Records/Hyperband/Cifar10'
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5))
legend = False
for mi, iter in enumerate(max_iters):
    dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_test_accuracy")
    
    ##### Smooth window ######
    file = os.path.join(dir, f"{file_name}.csv")
    file_smooth = os.path.join(dir, f"{file_name1}.csv")
    df = pd.read_csv(file)
    df_smooth = pd.read_csv(file_smooth)

    for a, baseline in enumerate(['valid_losses', 'train_losses']):
    # baseline = 'valid_losses'
        i = 0       
        j = 0
        scat_xs1 = []
        scat_xs2 = []
        scat_xs3 = []
        scat_xs = []
        scat_ys1 = []
        scat_ys2 = []
        scat_ys3 = []
        scat_ys = []
        sigs = []
        for cta2 in df_smooth.columns[1:]:
            if ('loss' in baseline and not 'loss' in cta2) or \
                ('accuracy' in baseline and not 'accuracy' in cta2) \
                 or ('sig' in cta2) or \
                ('valid' in baseline and not 'valid' in cta2) \
                or ('train' in baseline and not 'train' in cta2):
                continue
            d = df[baseline] - df_smooth[cta2]
            if not d.any():
                # print(f"cta1 = {cta1}, cta2 = {cta2}, ALL zero")
                continue
            res = stats.wilcoxon(d)
            if res.pvalue < 0.01:
                data = df_smooth[cta2]
                pos = [0.4+interval*i]
                i += 1
                scat_xs.append(0.4+interval*j)
                scat_ys.append(-d.mean()/np.std(d))
                if res.pvalue < 0.001:
                    scat_xs3.append(0.4+interval*j)
                    scat_ys3.append(-d.mean()/np.std(d))
                elif res.pvalue < 0.05:
                    scat_xs2.append(0.4+interval*j)
                    scat_ys2.append(-d.mean()/np.std(d))
                elif res.pvalue < 0.01:
                    scat_xs1.append(0.4+interval*j)
                    scat_ys1.append(-d.mean()/np.std(d))
                j += 1
        if a > 0:
            axs[a].plot(scat_xs, scat_ys, zorder=0, color=line_colors[mi*2], label=rf'$R={iter}$')
        else:
            axs[a].plot(scat_xs, scat_ys, zorder=0, color=line_colors[mi*2])
        if not legend:
            axs[a].scatter(scat_xs1, scat_ys1, marker='o', s=30, zorder=2, color=line_colors[mi*2], label=r'$p<0.01$')
            axs[a].scatter(scat_xs2, scat_ys2, marker='v', s=30, zorder=2, color=line_colors[mi*2], label=r'$p<0.05$')
            axs[a].scatter(scat_xs3, scat_ys3, marker='*', s=30, zorder=2, color=line_colors[mi*2], label=r'$p<0.001$')
            legend = True
        else:
            axs[a].scatter(scat_xs1, scat_ys1, marker='o', s=30, zorder=2, color=line_colors[mi*2])
            axs[a].scatter(scat_xs2, scat_ys2, marker='v', s=30, zorder=2, color=line_colors[mi*2])
            axs[a].scatter(scat_xs3, scat_ys3, marker='*', s=30, zorder=2, color=line_colors[mi*2])
        if len(scat_xs) == 5:
            axs[a].set_xticks(scat_xs)
            axs[a].set_xticklabels([str(i) for i in np.arange(2, 7)])
        axs[a].grid(axis="x", linewidth=0.6)
        axs[a].grid(axis="y", linewidth=0.6)

axs[0].set_ylabel("Cohen's d", fontsize=14)
axs[0].set_xlabel("Window size", fontsize=12)
axs[1].set_xlabel("Window size", fontsize=12)
axs[0].set_title("(a) Over validation loss", y=-0.6, fontsize=14)
axs[1].set_title("(b) Over training loss", y=-0.6, fontsize=14)

axs[1].legend(bbox_to_anchor=(1.05,1.5,0,0), ncol=4, loc='upper right', framealpha=0,
              handlelength=1, labelspacing=0.3, borderpad=0.3, fontsize=12)

axs[0].legend(bbox_to_anchor=(2.3,1.35,0,0), ncol=5, loc='upper right', framealpha=0,
              handlelength=1, labelspacing=0.1, borderpad=0.3, columnspacing=1.2, fontsize=12)


fig, axs = plt.subplots(1, 2, figsize=(5, 2.5))
dataset_path = '../../Records/Hyperband/Fashion-MNIST'
max_iters = [3, 6, 10, 15, 30]
legend = False
for mi, iter in enumerate(max_iters):
    dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_test_accuracy")
    
    ##### Smooth window ######
    file = os.path.join(dir, f"{file_name}.csv")
    file_smooth = os.path.join(dir, f"{file_name1}.csv")
    df = pd.read_csv(file)
    df_smooth = pd.read_csv(file_smooth)

    for a, baseline in enumerate(['valid_losses', 'train_losses']):
        i = 0       
        j = 0
        scat_xs1 = []
        scat_xs2 = []
        scat_xs3 = []
        scat_xs = []
        scat_ys1 = []
        scat_ys2 = []
        scat_ys3 = []
        scat_ys = []
        sigs = []
        for cta2 in df_smooth.columns[1:]:
            if ('loss' in baseline and not 'loss' in cta2) or \
                ('accuracy' in baseline and not 'accuracy' in cta2) \
                 or ('sig' in cta2) or \
                ('valid' in baseline and not 'valid' in cta2) \
                or ('train' in baseline and not 'train' in cta2):
                continue
            d = df[baseline] - df_smooth[cta2]
            if not d.any():
                # print(f"cta1 = {cta1}, cta2 = {cta2}, ALL zero")
                continue
            res = stats.wilcoxon(d)
            if res.pvalue < 0.01:
                data = df_smooth[cta2]
                pos = [0.4+interval*i]
                i += 1
                scat_xs.append(0.4+interval*j)
                scat_ys.append(-d.mean()/np.std(d))
                if res.pvalue < 0.001:
                    scat_xs3.append(0.4+interval*j)
                    scat_ys3.append(-d.mean()/np.std(d))
                elif res.pvalue < 0.05:
                    scat_xs2.append(0.4+interval*j)
                    scat_ys2.append(-d.mean()/np.std(d))
                elif res.pvalue < 0.01:
                    scat_xs1.append(0.4+interval*j)
                    scat_ys1.append(-d.mean()/np.std(d))
                j += 1
        if a > 0:
            axs[a].plot(scat_xs, scat_ys, zorder=0, color=line_colors[mi*2], label=rf'$R={iter}$')
        else:
            axs[a].plot(scat_xs, scat_ys, zorder=0, color=line_colors[mi*2])
        if not legend:
            axs[a].scatter(scat_xs1, scat_ys1, marker='o', s=30, zorder=2, color=line_colors[mi*2], label=r'$p<0.01$')
            axs[a].scatter(scat_xs2, scat_ys2, marker='v', s=30, zorder=2, color=line_colors[mi*2], label=r'$p<0.05$')
            axs[a].scatter(scat_xs3, scat_ys3, marker='*', s=30, zorder=2, color=line_colors[mi*2], label=r'$p<0.001$')
            legend = True
        else:
            axs[a].scatter(scat_xs1, scat_ys1, marker='o', s=30, zorder=2, color=line_colors[mi*2])
            axs[a].scatter(scat_xs2, scat_ys2, marker='v', s=30, zorder=2, color=line_colors[mi*2])
            axs[a].scatter(scat_xs3, scat_ys3, marker='*', s=30, zorder=2, color=line_colors[mi*2])
        if len(scat_xs) == 5:
            axs[a].set_xticks(scat_xs)
            axs[a].set_xticklabels([str(i) for i in np.arange(2, 7)])
        axs[a].grid(axis="x", linewidth=0.6)
        axs[a].grid(axis="y", linewidth=0.6)

axs[0].set_ylabel("Cohen's d", fontsize=14)
axs[0].set_xlabel("Window size", fontsize=12)
axs[1].set_xlabel("Window size", fontsize=12)
axs[0].set_title("(a) Over validation loss", y=-0.6, fontsize=14)
axs[1].set_title("(b) Over training loss", y=-0.6, fontsize=14)

axs[1].legend(bbox_to_anchor=(0.92,1.5,0,0), ncol=4, loc='upper right', framealpha=0,
              handlelength=1, labelspacing=0.3, borderpad=0.3, fontsize=12)

axs[0].legend(bbox_to_anchor=(2.3,1.35,0,0), ncol=5, loc='upper right', framealpha=0,
              handlelength=1, labelspacing=0.1, borderpad=0.3, columnspacing=1.2, fontsize=12)

plt.subplots_adjust(left=0.145, right=0.98, top=0.8, bottom=0.3, wspace =0.24, hspace=0.45)

plt.savefig('smooth_win1.pdf')
plt.savefig('smooth_win1.png')


