
import argparse
import os
import numpy as np
from castle.metrics import MetricsDAG
import pandas as pd
import matplotlib.pyplot as plt
from cdt.metrics import SHD, SID
from scipy.special import comb
import networkx as nx


SEEDS =  ["1234" , "5678" , "9012" , "3456" , "7890" , "2346" , "26498" , "82763" , "4567" , "8901" , "2345" , "6789" , "0123" , "82764" , "8264" , "9274" , "8562" , "02784" , "7455" , "8472"]

metrics = ['runtime'] # 

valid_methods = {"cascade":"cascade", "cascade-i":"cascade_insstant", "cascade-idf" "pc":"gCastle_pc", "ges":"gCastle_ges", "notears":"gCastle_notears", "ttpm":"gCastle_ttpm", "mdlh":"mdlh", "mdlh-sparse":"mdlh_sparse", "cause":"CAUSE_ERPP", "nphc":"CAUSE_NPHC", "nphca":"CAUSE_NPHCA", "causea":"CAUSE_ERPPA"}

experiments = {"inc_event_types", "inc_skip_rate", "inc_mean_delay", "inc_collition", "inc_chain", "instant_effect", "sanity_check", "instant_effect_all"}

experiment_iterations = {"inc_event_types": [5, 10, 15, 20, 30, 40],
                          "inc_skip_rate": ["0.9", "0.8" , "0.7", "0.6", "0.5", "0.1"],
                          "inc_mean_delay": [5, 10, 15, 20, 30, 40, 50, 70, 90, 110],
                          "inc_collition": [50,60,70,80,90,100,150,200], # [2, 4, 6, 8, 10, 15, 20, 30, 40],
                          "inc_chain": [2, 4, 6, 8, 10, 15, 20, 30, 40],
                          "instant_effect": [""],
                          "instant_effect_all": [""],
                          "sanity_check": [""]
                          }

skip_SID = False

def gen_plots(df, col, group_by):
    fig, axs = plt.subplots(2, 1, figsize=(col, 3*len(metrics)))

    for i, metric in enumerate(metrics):
        axs[i] = df.boxplot(column=[metric], by=group_by, ax=axs[i])
        axs[i].set_ylim(bottom=-0.1)  # Set the y-axis lower limit to 0
        axs[i].set_title(metric)  # Set the title for each subplot
        axs[i].set_ylabel(metric)  # Set the y-axis label for each subplot
        axs[i].tick_params(axis='both', labelsize='x-small', rotation=90)  # Rotate the ticks 90 degrees
      
    plt.tight_layout()  # Adjust the spacing between subplots
    return fig



def eval_method(method, experiment_path, iterations, idf=False): 
    print("evaluating method: ", method)

    

    result_df = pd.DataFrame(columns=['iter', 'seed'] + metrics)
    
    for seed in SEEDS:
        print("seed: ", seed)
        for iter in iterations:
            print("iter: ", iter)
            #true_graph_path = os.path.join(experiment_path,"{}_{}_true_graph.npy".format(iter, seed))
            if idf and method in [valid_methods["cascade"], valid_methods["cascade-i"]]:
                est_str = "{}_{}_{}_estimate.npy.instant-idf.npy".format(iter,seed,method)
            else:
                est_str = "{}_{}_{}_estimate.npy".format(iter,seed,method)
            estimate_path = os.path.join(experiment_path, est_str)+"_elapsed_time.txt"
            try:
                file = open(estimate_path, "r")
                content = file.read()
                file.close()
                result = {"runtime":int(content)} #load_and_eval(estimate_path, true_graph_path)
                result["seed"] = seed
                result["iter"] = iter
                result_df = pd.concat([result_df, pd.DataFrame(result, index=[0])], ignore_index=True)
            except FileNotFoundError:
                print("file not found: ", estimate_path)
            except AssertionError as e:
                print("assertion error: ", estimate_path)
                print(e)
            except RuntimeError as e:
                print("runtime error: ", estimate_path)
                print(e)

    return result_df

if __name__ == "__main__":
    
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-methods", type=str, nargs='+', required=True, help="methods to be evaluated")
    argparser.add_argument("-experiment", type=str, required=True, help="experiment name")
    argparser.add_argument("-iterations", type=str, nargs='+', required=False, help="iterations to be evaluated")
    argparser.add_argument("--skip-SID", action='store_true', help="skip SID metric")
    argparser.add_argument("--idf", action='store_true', help="use idf", default=False)
    args = argparser.parse_args()

    skip_SID = args.skip_SID

    for method in args.methods:
        if method not in valid_methods.keys():
            raise ValueError("method not recognized")
        
    if args.experiment not in experiments:
        raise ValueError("experiment not recognized")
    
    experiment_path = "synthetic_data_server/{}/".format(args.experiment)


    iterations = experiment_iterations[args.experiment] if args.iterations is None else args.iterations
    
    results = []
    for method in args.methods:
        result_df = eval_method(valid_methods[method], experiment_path, iterations, idf=args.idf)
        result_df['method'] = method
        results.append(result_df)
    
        

    all_results =  pd.concat(results)
    all_results.to_csv(os.path.join(experiment_path, "all_runtimes.csv"), index=False, na_rep='nan')
    
    fig = gen_plots(all_results,len(args.methods)*len(iterations),['iter','method'])
    fig.savefig(os.path.join(experiment_path, "all_runtimes.pdf"))