# Generate graphs for demonstration purposes
import utils.graph_generators as gen
import pickle
import torch
import dgl
import tensorflow as tf
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

def selfLoop(graph_list):
    # graph_list = [graph.add_edges_from([(i,i) for i in graph.number_of_nodes]) for graph in graph_list]
    for graph in graph_list:
        graph.add_edges_from([(i,i) for i in range(graph.number_of_nodes())])
    return  graph_list

def load_graphs(graph_pkl):
    import pickle5 as cp
    graphs = []
    with open(graph_pkl, 'rb') as f:
        while True:
            try:
                g = cp.load(f)
            except:
                break
            graphs.append(g)

    return graphs


def load_graph_list(fname,remove_self=True):

    if fname[-3:]=="pkl":
        glist = load_graphs(fname)
    else:
        with open(fname, "rb") as f:
            glist = np.load(f, allow_pickle=True)

    graph_list =[]
    for G in glist:
        try:
            if type(G)==list:
                if len(G[0])==0:
                    continue
                graph = nx.Graph()
                graph.add_nodes_from(G[0])
                graph.add_edges_from(G[1])
            elif type(G)==nx.classes.graph.Graph:
                graph = G
            else:
                graph = nx.from_numpy_matrix(G)
            if remove_self:
                graph.remove_edges_from(nx.selfloop_edges(graph))
            graph.remove_nodes_from(list(nx.isolates(graph)))
            Gcc = sorted(nx.connected_components(graph), key=len, reverse=True)
            graph = graph.subgraph(Gcc[0])
            graph = nx.Graph(graph)
            graph_list.append(graph)
        except:
            print("cpould not read a graph")
    return graph_list


dir = "/local-scratch/kiarash/AAAI/TrainningCurveEvalof_LDVAE/Mutag/mutah_200/MMD_AvePool_FC_MUTAG_graphGeneration_PMIBFSTrue400001675894280.7566695/"
ref_dir = [dir+x for x in  ["_Target_Graphs_adj_val_2499.npy","_Target_Graphs_adj_train_17499.npy","_Target_Graphs_adj_test_2499.npy"]]
generated_pattern = ["Max_Con_comp_generatedGraphs_adj_val_*","Max_Con_comp_generatedGraphs_adj_train_*","Max_Con_comp_generatedGraphs_adj_test_*"]
import glob

res = []
for fold_num in range(len(generated_pattern)):
    saved_adjs_paths =glob.glob(dir + generated_pattern[fold_num])
    # print(saved_adjs_paths)
    for file in saved_adjs_paths:
        refrence = load_graph_list(ref_dir[fold_num])[:1000]
        generated = load_graph_list(file)[:1000]
        if len(generated)<2:
            res.append([file, "Error"])
            continue
        print(file)
        # generated = generated[:len(refrence)]
        generated = selfLoop(generated)
        refrence = selfLoop(refrence)

        device =  torch.device('cpu')

        generated = [dgl.DGLGraph(g).to(device) for g in generated] # Convert graphs to DGL from NetworkX
        refrence = [dgl.DGLGraph(g).to(device) for g in refrence] # Convert graphs to DGL from NetworkX
        # Compute all GNN-based metrics at once
        from evaluation.evaluator import Evaluator

        f1 = []
        mmd_rbf = []
        for i in range(10):
            try:
                evaluator = Evaluator(device=device)
                result = evaluator.evaluate_all(generated, refrence)
                f1.append(result["f1_pr"])
                mmd_rbf.append(result["mmd_rbf"])
            except Exception:
                print(Exception)
                continue

        if len(f1)>1:
            f1_std= np.array(f1).std()
            f1_mean= np.array(f1).mean()

            mmd_rbf_std= np.array(mmd_rbf).std()
            mmd_rbf_mean= np.array(mmd_rbf).mean()
            res.append([file, f1_std, f1_mean, mmd_rbf_std, mmd_rbf_mean])

            print([file, f1_std, f1_mean, mmd_rbf_std, mmd_rbf_mean])
        else:
            res.append([file, "Error"])
    res.append(["file","f1_std", "f1_mean", "mmd_rbf_std", "mmd_rbf_mean"])
    f1_std
    #wite down the result
    import csv

    # Write CSV file
    with open(ref_dir[fold_num]+"__Rand_GNN_Eval_all.csv", "wt") as fp:
        writer = csv.writer(fp, delimiter=",")
        # writer.writerow(["your", "header", "foo"])  # write header
        writer.writerows(res)