import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import json
import pandas as pd
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default=".")

def get_method_names():  
    methods = ["DiffPool", "MinCut", "NMF", "TopK", "SAGPool", "NDP", "Graclus"]

    distances_training = ["diffusion_distance"]

    approach = [ "EDGE"]

    methods_mag = []
    for mag_method in ["MAG", "SPREAD"]:
        for dist in distances_training:
            for app in approach:
                methods_mag.append(f"{mag_method}_{app}_{dist}")

    methods += methods_mag
    return methods

def get_method_names_clean():  
    methods = ["DiffPool", "MinCut", "NMF", "TopK", "SAGPool", "NDP", "Graclus"]

    distances_training = ["diffusion_distance"]

    approach = ["EdgePool"]

    methods_mag = []
    for mag_method in ["Mag", "Spread"]:
        for dist in distances_training:
            for app in approach:
                methods_mag.append(f"{mag_method}{app}")

    methods += methods_mag
    return methods

def json_to_df(json_file):
    """
    Convert a JSON file to a DataFrame.
    :param json_file: path to the JSON file
    :return: DataFrame
    """
    with open(json_file, "r") as f:
        data = json.load(f)
    data = data["results"]
    df = pd.DataFrame(data)
    return df

if __name__ == "__main__":
    path = parser.parse_args().path
    methods = get_method_names()
    datasets = ["barbell", "torus", "erdosrenyi", "davidsensornet", "barabasialbert", "community", "Ring", "Sensor", "Grid2d"]
    rename = True

    for dataset in datasets:
        methods = get_method_names()
        dfs_means = []
        dfs_stds = []
        for method in methods:
            json_path = f"{path}/{dataset}/{method}_{dataset}_experiment.json"
            df = json_to_df(json_path)
            df_mean = df.mean(axis=0)
            df_std = df.std(axis=0)
            dfs_means.append(df_mean)
            dfs_stds.append(df_std)
        # Concatenate all dataframes
        df_means = pd.concat(dfs_means, axis=1)
        df_stds = pd.concat(dfs_stds, axis=1)
        if rename:
            methods = get_method_names_clean()
        df_means.columns = methods
        df_stds.columns = methods
        df_means = df_means.T
        df_stds = df_stds.T
        df_means.to_csv(f"{path}/{dataset}/{dataset}_means.csv")
        df_stds.to_csv(f"{path}/{dataset}/{dataset}_stds.csv")