import pickle
import os
import numpy as np
import torch
import networkx as nx
from baselines.DiGress.src.analysis.spectre_utils import degree_stats, clustering_stats, orbit_stats_all


def main() -> None:
    """
    Main function.
    """
    os.chdir("./arrow_diff")

    fname_output = 'multigraph_eval_MMD_emdDistance_comm20.csv'

    with open(fname_output, 'w+') as fl:
        datasets = ["comm20"]

        for dataset in datasets:
            fl.write('Evaluation on dataset: ' + str(dataset) + '\n')
            # read real graphs for all datasets (we saved the test splits for each one)
            real_graphs_path = f'./GraphRNN_output_new/graphs/GraphRNN_RNN_{dataset}_4_64_test_0.dat'
            with open(real_graphs_path, "rb") as f:
                real_graphs_list = pickle.load(f)
            fl.write('Number of real/test graphs:' + str(dataset) + ',' + str(len(real_graphs_list)) + '\n')
            # Get statistics of real graphs
            if len(real_graphs_list) > 1000:
                real_graphs_list = real_graphs_list[:500]
            fl.write('Number of real/test graphs to be considered:' + str(dataset) + ',' + str(len(real_graphs_list)) + '\n')

            # read generated graphs of GRAN
            gran_gen_graphs_path = f'./GRAN_output_{dataset}/generated_graphs.pt'
            gran_gen_graphs_list = torch.load(gran_gen_graphs_path)
            gran_gen_graphs_list = gran_gen_graphs_list[:500]
            fl.write(
                'Number of GRAN generated graphs:' + str(dataset) + ',' + str(len(gran_gen_graphs_list)) + '\n')

            # read generated graphs of GraphRNN
            graphRNN_gen_graphs_path = f'./GraphRNN_output_new/graphs/GraphRNN_RNN_{dataset}_4_64_pred_3000_1.dat'
            with open(graphRNN_gen_graphs_path, "rb") as f:
                graphRNN_gen_graphs_list = pickle.load(f)
            graphRNN_gen_graphs_list = graphRNN_gen_graphs_list[:500]
            fl.write(
                'Number of GraphRNN generated graphs:' + str(dataset) + ',' + str(len(graphRNN_gen_graphs_list)) + '\n')

            # read generated graphs of Digress
            # For citeseer_small
            # digress_gen_graphs_path = f'./output_Digress_{dataset}_epochs100k/2023-08-26/16-25-36-graph-tf-model/generated_adjs.npz'
            # For comm20
            digress_gen_graphs_path = f'./output_Digress_{dataset}_epochs100k/2023-09-10/22-04-38-graph-tf-model/generated_adjs.npz'
            digress_gen_graphs_nmp = np.load(digress_gen_graphs_path)
            lst = digress_gen_graphs_nmp.files
            digress_gen_graphs_list = []
            for i in lst:
                adj = digress_gen_graphs_nmp[i]
                graph = nx.from_numpy_array(adj)
                digress_gen_graphs_list.append(graph)
            digress_gen_graphs_list = digress_gen_graphs_list[:500]
            fl.write(
                'Number of Digress generated graphs:' + str(dataset) + ',' + str(len(digress_gen_graphs_list)) + '\n')

            # Calculate statistics
            degree_gran = degree_stats(real_graphs_list, gran_gen_graphs_list, is_parallel=True,
                                       compute_emd=True)
            orbit_gran = orbit_stats_all(real_graphs_list, gran_gen_graphs_list, compute_emd=True)
            clustering_gran = clustering_stats(real_graphs_list, gran_gen_graphs_list, bins=100, is_parallel=True,
                                               compute_emd=True)

            degree_grnn = degree_stats(real_graphs_list, graphRNN_gen_graphs_list, is_parallel=True,
                                       compute_emd=True)
            orbit_grnn = orbit_stats_all(real_graphs_list, graphRNN_gen_graphs_list, compute_emd=True)
            clustering_grnn = clustering_stats(real_graphs_list, graphRNN_gen_graphs_list, bins=100, is_parallel=True,
                                               compute_emd=True)

            degree_digress = degree_stats(real_graphs_list, digress_gen_graphs_list, is_parallel=True,
                                       compute_emd=True)
            orbit_digress = orbit_stats_all(real_graphs_list, digress_gen_graphs_list, compute_emd=True)
            clustering_digress = clustering_stats(real_graphs_list, digress_gen_graphs_list, bins=100, is_parallel=True,
                                               compute_emd=True)

            fl.write('GRAN MMD to real graphs based on degree:' + str(degree_gran) + '\n')
            fl.write('GRAN MMD to real graphs based on orbit:' + str(orbit_gran) + '\n')
            fl.write('GRAN MMD to real graphs based on clustering coefficient:' + str(clustering_gran) + '\n')
            fl.write('\n')
            fl.write('GraphRNN MMD to real graphs based on degree:' + str(degree_grnn) + '\n')
            fl.write('GraphRNN MMD to real graphs based on orbit:' + str(orbit_grnn) + '\n')
            fl.write('GraphRNN MMD to real graphs based on clustering coefficient:' + str(clustering_grnn) + '\n')
            fl.write('\n')
            fl.write('Digress MMD to real graphs based on degree:' + str(degree_digress) + '\n')
            fl.write('Digress MMD to real graphs based on orbit:' + str(orbit_digress) + '\n')
            fl.write('Digress MMD to real graphs based on clustering coefficient:' + str(clustering_digress) + '\n')
            fl.write('\n')


if __name__ == '__main__':
    main()
