import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import json
import pandas as pd
import argparse
import os
import time

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default=".")

def get_method_names():  
    methods = ["DiffPool", "MinCut", "NMF", "TopK", "SAGPool", "NDP", "Graclus", "Flat"]

    distances_training = ["diffusion_distance"]

    approach = ["EDGE"]

    methods_mag = []
    for mag_method in ["MAG", "SPREAD"]:
        for dist in distances_training:
            for app in approach:
                methods_mag.append(f"{mag_method}_{app}_{dist}")

    methods += methods_mag
    return methods

def json_to_df(json_file, key="results"):
    """
    Convert a JSON file to a DataFrame.
    :param json_file: path to the JSON file
    :return: DataFrame
    """
    if os.path.exists(json_file) == False:
        return None
    else:
        with open(json_file, "r") as f:
            data = json.load(f)
        data = data[key]
        if key == "results":
            #data = data["accuracy"]
            df = pd.DataFrame(data)
            return df
        else:
            return data

if __name__ == "__main__":
    path = parser.parse_args().path
    methods = get_method_names()
    datasets = ["MUTAG", "COX2", "ENZYMES", "PROTEINS", "Mutagenicity", "AIDS", "IMDB-BINARY", "IMDB-MULTI", "NCI1", "NCI109", 
                "BZR", "DHFR", "ogbg-molhiv", 'BZR_MD', 'COX2_MD', 'DHFR_MD', 'ER_MD', "PROTEINS_full"]
    
    for model in ["GIN", "GNN"]:
        for dataset in datasets:
            dataset_path = f"{path}/{dataset}"
            #print(dataset_path)
            if not os.path.exists(dataset_path):
                continue
            else:
                methods_found = []
                dfs_means = []
                dfs_stds = []
                all_acc = {}
                acc_json = f"{path}/{dataset}/{dataset}_{model}_accuracies.json"

                all_runtimes = {}

                for method in methods:
                    json_path = f"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json"
                    try:
                        df = json_to_df(json_path)
                        #print(df)
                    except json.decoder.JSONDecodeError:
                        print(f"Error decoding JSON for {json_path}.")
                        continue
                    if (df is None):
                        continue
                    elif (df.empty):
                        continue
                    elif (df.shape[0] == 0):
                        continue
                    else:
                        try: 
                            timet = json_to_df(json_path, key = "time_per_run")
                            mod_time = os.path.getmtime(json_path)
                            mod_time_st = time.ctime(mod_time)
                            all_runtimes[method] = [timet, mod_time, mod_time_st]
                        except:
                            continue
                        df_mean = df.mean(axis=0)
                        df_std = df.std(axis=0)
                        dfs_means.append(df_mean)
                        dfs_stds.append(df_std)
                        methods_found.append(method)
                        all_acc[method] = [v for v in df["accuracy"].values]
                if len(dfs_means) == 0:
                    print(f"No dataframes found for {dataset}.")
                    continue
                else:
                    df_means = pd.concat(dfs_means, axis=1)
                    df_stds = pd.concat(dfs_stds, axis=1)
                    df_means.columns = methods_found
                    df_stds.columns = methods_found
                    df_means = df_means.T
                    df_stds = df_stds.T
                    df_means.to_csv(f"{path}/{dataset}/{dataset}_{model}_means.csv")
                    df_stds.to_csv(f"{path}/{dataset}/{dataset}_{model}_stds.csv")

                    runtimes = pd.DataFrame(all_runtimes).T
                    runtimes.columns = ["time_per_run", "mod_time", "mod_time_human"]
                    runtimes.to_csv(f"{path}/{dataset}/{dataset}_{model}_runtimes.csv")

                    # Save accuracies to JSON
                    with open(acc_json, "w") as f:
                        json.dump(all_acc, f)
                