
import os
import json
import warnings
import argparse
import pandas as pd
from scipy import stats
import scikit_posthocs as sp
import matplotlib.pyplot as plt
import numpy as np

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--dataset_path', type=str, default='../../Records/Hyperband/Fashion-MNIST', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=3, help='Fraction of saving in hyperband')
args = parser.parse_args()


max_iters = args.max_iters
ult_objs = args.ult_objs
dataset_path = args.dataset_path
eta = args.eta

if 'Fashion-MNIST' in dataset_path:
    max_iters = [10,15,30]

file_name1 = 'win_5_valid'
cta_base = np.array(['train_losses', 'valid_losses'])
for iter in max_iters:
    print()
    print(f"******** max_iter = {iter} ********")
    for obj in [ult_objs[0]]:
        # Load data
        dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_{obj}")
        if not os.path.exists(dir):
            warnings.warn(f"Directory {dir} doesn't exist.")
            continue
        
        file1 = os.path.join(dir, f"{file_name1}.csv")
        df = pd.read_csv(file1)
        base_idx = -1
        for i, cta in enumerate(df.columns[1:]):
            if cta in cta_base:
                data_base = df[cta]
                for j, cta1 in enumerate(df.columns[i+2:i+5]):
                    data = df[cta1]
                    d = data_base - data
                    if not d.any():
                        print(f"cta1 = {cta}, cta2 = {cta1}, ALL zero")
                        continue
                    _, p = stats.wilcoxon(d)
                    if p < 0.01:
                        print(f"cta1 = {cta}, cta2 = {cta1}, res.pvalue = {p}, Cohen's d = {d.mean()/np.std(d)}")

