
import os
import json
import warnings
import argparse
import pandas as pd
from scipy import stats
import scikit_posthocs as sp
import matplotlib.pyplot as plt
import seaborn as sns

# from cliffs_delta import cliffs_delta

parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--task', type=str, default='acc_loss', help='Description of task')
parser.add_argument('--dataset', type=str, default='ImageNet', help='Path of benchmark data')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[20, 50, 81, 120, 150], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=3, help='Fraction of saving in hyperband')
args = parser.parse_args()

label_dic = {'train_accuracy': 'Training accuracy',
             'train_losses': 'Training loss',
             'valid_accuracy': 'Validation accuracy',
             'valid_losses': 'Validation loss'}
title_dic = {'Cifar10': 'CIFAR-10',
             'Cifar100': 'CIFAR-100',
             'ImageNet': 'ImageNet-16-120'}
max_iters = args.max_iters
ult_objs = args.ult_objs
dataset_path = '../Records/Hyperband/'
dataset = args.dataset
dataset_path = os.path.join(dataset_path, dataset)

if 'Fashion-MNIST' in dataset_path or 'higgs' in dataset_path or 'adult' in dataset_path or 'jasmine' in dataset_path or 'vehicle' in dataset_path or 'volkert' in dataset_path:
    max_iters = [3, 6, 10, 15, 30] 
else:
    max_iters = [20, 50, 81, 120, 150]

file_name = args.task
eta = args.eta

for iter in max_iters:
    print()
    print(f"******** max_iter = {iter} ********")
    for obj in ult_objs:
        print(f"*********** obj = {obj} ************")
        # Load data
        dir = os.path.join(dataset_path, f"Max_iter_{iter}_eta_{eta}", "cta", f"obj_{obj}")
        if not os.path.exists(dir):
            warnings.warn(f"Directory {dir} doesn't exist.")
            continue
        
        file = os.path.join(dir, f"{file_name}.csv")
        df = pd.read_csv(file)
        json_rst = dict()
        
        # Pair-wise wilcoxon test
        if not 'train_losses' in df.columns and not 'valid_losses' in df.columns:
            file = os.path.join(dir, f"acc_loss.csv")
            df_base = pd.read_csv(file)
        else:
            df_base = df
        for i, cta1 in enumerate(df_base.columns):
            if cta1 in ('Max_test_accuracy', 'Min_test_loss'):
                continue
            for j in range(len(df.columns)):
                cta2 = df.columns[j]
                if cta2 in ('Max_test_accuracy', 'Min_test_loss'):
                    continue
                d = df_base[cta1] - df[cta2]
                if not (d != 0).any():
                    print("zero")
                    continue
                res = stats.wilcoxon(d, alternative="less")
                print(f"cta1 = {cta1}, cta2 = {cta2}, res.pvalue = {res.pvalue}, max diff = {d.max()}, min diff = {d.min()}, mean diff = {d.mean()}")
