
import argparse
import os
import numpy as np
from castle.metrics import MetricsDAG
import pandas as pd
import matplotlib.pyplot as plt
from cdt.metrics import SHD, SID
from scipy.special import comb
import networkx as nx


SEEDS =  ["1234" , "5678" , "9012" , "3456" , "7890" , "2346" , "26498" , "82763" , "4567" , "8901" , "2345" , "6789" , "0123" , "82764" , "8264" , "9274" , "8562" , "02784" , "7455" , "8472"]

metrics = ['fdr', 'tpr', 'fpr', 'shd', 'nnz', 'precision', 'recall', 'F1', 'gscore', 'cdt-shd', 'cdt-sid', 'cdt-shd-norm', 'cdt-sid-norm'] # 

valid_methods = {"cascade":"cascade", "cascade-i":"cascade_insstant", "pc":"gCastle_pc", "ges":"gCastle_ges", "notears":"gCastle_notears", "ttpm":"gCastle_ttpm", "mdlh":"mdlh", "mdlh-sparse":"mdlh_sparse", "cause":"CAUSE_ERPP", "nphc":"CAUSE_NPHC", "nphca":"CAUSE_NPHCA", "causea":"CAUSE_ERPPA"}

experiments = {"inc_event_types", "inc_skip_rate", "inc_mean_delay", "inc_collition", "inc_chain", "instant_effect", "sanity_check", "instant_effect_all"}

experiment_iterations = {"inc_event_types": [5, 10, 15, 20, 30, 40],
                          "inc_skip_rate": ["0.9", "0.8" , "0.7", "0.6", "0.5", "0.1"],
                          "inc_mean_delay": [5, 10, 15, 20, 30, 40, 50, 70, 90, 110],
                          "inc_collition": [50,60,70,80,90,100,150,200], # [2, 4, 6, 8, 10, 15, 20, 30, 40],
                          "inc_chain": [2, 4, 6, 8, 10, 15, 20, 30, 40],
                          "instant_effect": [""],
                          "instant_effect_all": [""],
                          "sanity_check": [""]
                          }

skip_SID = False

def eval(est_graph, true_graph):
    if np.sum(true_graph) == 0:
        results = {metric:0 for metric in metrics}
        np.fill_diagonal(est_graph, 0)
        results["fdr"] = np.sum(est_graph)
        return results

    pc_metrics = MetricsDAG(est_graph, true_graph)
    if np.sum(est_graph) == 0 and np.sum(true_graph) != 0:
        pc_metrics.metrics['precision'] = 1
        pc_metrics.metrics['F1'] = 0 # recall will be 0, hence F1 will be 0
    np.fill_diagonal(est_graph, 0)
    pc_metrics.metrics['cdt-shd'] = SHD(true_graph, est_graph)
    pc_metrics.metrics['cdt-shd-norm'] = pc_metrics.metrics['cdt-shd'] / (2 * comb(true_graph.shape[0], 2))
    G = nx.from_numpy_array(est_graph, create_using=nx.DiGraph)
    if nx.is_directed_acyclic_graph(G) and not skip_SID:
        pc_metrics.metrics['cdt-sid'] = SID(true_graph, est_graph)
        pc_metrics.metrics['cdt-sid-norm'] = pc_metrics.metrics['cdt-sid'] / (2 * comb(true_graph.shape[0], 2))
    else:
        pc_metrics.metrics['cdt-sid'] = np.nan
        pc_metrics.metrics['cdt-sid-norm'] = np.nan
    return pc_metrics.metrics

def gen_plots(df, col, group_by):
    fig, axs = plt.subplots(len(metrics), 1, figsize=(col, 3*len(metrics)))

    for i, metric in enumerate(metrics):
        axs[i] = df.boxplot(column=[metric], by=group_by, ax=axs[i])
        axs[i].set_ylim(bottom=-0.1)  # Set the y-axis lower limit to 0
        axs[i].set_title(metric)  # Set the title for each subplot
        axs[i].set_ylabel(metric)  # Set the y-axis label for each subplot
        axs[i].tick_params(axis='both', labelsize='x-small', rotation=90)  # Rotate the ticks 90 degrees
        
    plt.tight_layout()  # Adjust the spacing between subplots
    return fig

def load_and_eval(estimate_path, true_graph_path):   
    # print(estimate_path)
    # print(true_graph_path)     
    true_graph = np.load(true_graph_path)
    estimated_graph = np.load(estimate_path)
    return  eval(estimated_graph, true_graph)




def eval_method(method, experiment_path, iterations, idf=False): 
    print("evaluating method: ", method)

    

    result_df = pd.DataFrame(columns=['iter', 'seed'] + metrics)
    
    for seed in SEEDS:
        print("seed: ", seed)
        for iter in iterations:
            print("iter: ", iter)
            true_graph_path = os.path.join(experiment_path,"{}_{}_true_graph.npy".format(iter, seed))
            if idf and method in [valid_methods["cascade"], valid_methods["cascade-i"]]:
                est_str = "{}_{}_{}_estimate.npy.instant-idf.npy".format(iter,seed,method)
            else:
                est_str = "{}_{}_{}_estimate.npy".format(iter,seed,method)
            estimate_path = os.path.join(experiment_path, est_str)
            try:
                result = load_and_eval(estimate_path, true_graph_path)
                result["seed"] = seed
                result["iter"] = iter
                result_df = pd.concat([result_df, pd.DataFrame(result, index=[0])], ignore_index=True)
            except FileNotFoundError:
                print("file not found: ", estimate_path)
            except AssertionError as e:
                print("assertion error: ", estimate_path)
                print(e)
            except RuntimeError as e:
                print("runtime error: ", estimate_path)
                print(e)

    return result_df

if __name__ == "__main__":
    
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-methods", type=str, nargs='+', required=True, help="methods to be evaluated")
    argparser.add_argument("-experiment", type=str, required=True, help="experiment name")
    argparser.add_argument("-iterations", type=str, nargs='+', required=False, help="iterations to be evaluated")
    argparser.add_argument("--skip-SID", action='store_true', help="skip SID metric")
    argparser.add_argument("--no-idf", action='store_false', help="don't use  idf", default=True)
    args = argparser.parse_args()

    skip_SID = args.skip_SID

    for method in args.methods:
        if method not in valid_methods.keys():
            raise ValueError("{} method not recognized".format(method))
        
    if args.experiment not in experiments:
        raise ValueError("experiment not recognized")
    
    experiment_path = "synthetic_data_server/{}/".format(args.experiment)


    iterations = experiment_iterations[args.experiment] if args.iterations is None else args.iterations
    
    results = []
    for method in args.methods:
        result_df = eval_method(valid_methods[method], experiment_path, iterations, idf=args.idf)
        result_df['method'] = method
        results.append(result_df)
    
        

    all_results =  pd.concat(results)
    all_results.to_csv(os.path.join(experiment_path, "all_results.csv"), index=False, na_rep='nan')
    
    fig = gen_plots(all_results,len(args.methods)*len(iterations),['iter','method'])
    fig.savefig(os.path.join(experiment_path, "all_results.pdf"))