# Generate graphs for demonstration purposes
import random

import utils.graph_generators as gen
import pickle
import torch
import dgl
# import tensorflow as tf
import numpy as np
import os
import networkx as nx
import matplotlib.pyplot as plt

def preprocess_graphs(list_of_NXGraph):
    # remove self loops and isolated graphs
    for G in list_of_NXGraph:
        G.remove_edges_from(nx.selfloop_edges(G))
        G.remove_nodes_from(list(nx.isolates(G)))

    list_of_NXGraph = [nx.Graph(G.subgraph(max(nx.connected_components(G), key=len))) for G in
                         list_of_NXGraph if not nx.is_empty(G)]
    list_of_NXGraph = selfLoop(list_of_NXGraph)
    return list_of_NXGraph

def selfLoop(graph_list):
    # graph_list = [graph.add_edges_from([(i,i) for i in graph.number_of_nodes]) for graph in graph_list]
    for graph in graph_list:
        graph.add_edges_from([(i,i) for i in range(graph.number_of_nodes())])
    return  graph_list



def load_graphs(graph_pkl):
    import pickle5 as cp
    graphs = []
    with open(graph_pkl, 'rb') as f:
        while True:
            try:
                g = cp.load(f)
            except:
                break
            graphs.append(g)

    return graphs

def substructure_features(NXGraph):
    return nx.clustering(NXGraph)

def toDGL(NXgraph, Structuralinf=True):
    Graph = dgl.DGLGraph(NXgraph)
    if Structuralinf:
        cluster_attri =nx.clustering(NXgraph)
        degree_attri =dict(NXgraph.degree())
        orbits = nx.square_clustering(NXgraph)
        attri = {}
        for key, value in degree_attri.items():
            attri[key] = np.array([value+0.0,cluster_attri[key]+0.0, orbits[key]+0.0])

        nx.set_node_attributes(NXgraph, attri, "attr")
        Graph = dgl.from_networkx(NXgraph, node_attrs=["attr"])
    return Graph

def load_graph_list(fname,remove_self=True):

    if fname[-3:]=="pkl":
        glist = load_graphs(fname)
    else:
        with open(fname, "rb") as f:
            glist = np.load(f, allow_pickle=True)

    graph_list =[]
    for G in glist:
        try:
            if type(G)==list:
                if len(G[0])==0:
                    continue
                graph = nx.Graph()
                graph.add_nodes_from(G[0])
                graph.add_edges_from(G[1])
            elif type(G)==nx.classes.graph.Graph:
                graph = G
            elif type(G).format=="csr":
                graph = nx.from_scipy_sparse_matrix(G)
            else:
                graph = nx.from_numpy_matrix(G)
            if remove_self:
                graph.remove_edges_from(nx.selfloop_edges(graph))
            graph.remove_nodes_from(list(nx.isolates(graph)))
            Gcc = sorted(nx.connected_components(graph), key=len, reverse=True)
            graph = graph.subgraph(Gcc[0])
            graph = nx.Graph(graph)
            graph_list.append(graph)
        except:
            print("cpould not read a graph")
    return graph_list

def read_MMD(dir, line=-1):
    MMD_on_test = "Could not read"
    try:

        with open(dir) as f:
            lines = f.readlines()
        MMD_on_test =   lines[line]
    except Exception:
        pass
    return MMD_on_test

def graphPloter(listOfNXgraphs, dir):
    for i, G in enumerate(listOfNXgraphs):
        plotG(G, file_name=dir+"_"+str(i))

def plotG(G, type="", graph_name = "Generated Graph", file_name=None):
    plt.close(graph_name)

    pos = nx.spring_layout(G, iterations=1000)
    f = plt.figure(graph_name)
    nx.draw(G, pos, node_size=20, width=1, edge_color="black")

    # plt.draw()

    # plt.show()
    # plt.pause(0.5)
    f.savefig(type+ "_graph.png" if file_name == None else file_name)
    # plt.show()



# Ideal
# dir = "/local-scratch/kiarash/BiGG/google-research-0c1bbe5fc971a1de1a427debc814e66ab4f1e7fa/bigg/data/LDP/IMDBBINARY/"
dir = "/local-scratch/kiarash/AAAI/LDPVAE/data/LDP/DD/"
perturbed_file_name = "LDP_train.npy" # generated file pattern
Test_file_name = "DD-LDP_Test.npy"
# # BiGG
# dir = "/local-scratch/kiarash/BiGG/google-research-0c1bbe5fc971a1de1a427debc814e66ab4f1e7fa/bigg/data/LDP/ogbg-molbbbp/"
# # dir = "/local-scratch/kiarash/AAAI/LDPVAE/data/LDP/PTC/"
# perturbed_file_name = "generated.npy" # generated file pattern
# Test_file_name = "ogbg-molbbbp-LDP_Test.npy"

MMD_fileName = "/MMD.log"
pattern = perturbed_file_name

Structural_Feature_attri=True
import glob
sub_dirs =glob.glob(dir + '*',recursive = True)
print(sub_dirs)
reprot = [["path", "f1_mean",  "f1_std" , "mmd_rbf_mean","mmd_rbf_std", "precision_mean","precision_std","recall_mean","recall_std","Statistics Report"]]
for file in sub_dirs:
    if pattern in file:
        try:
            perturbed_train_data = load_graph_list(file)
            refrence = load_graph_list(dir + Test_file_name) # max 1000 graph

            print(file)
            random.shuffle(perturbed_train_data)
            perturbed_train_data = perturbed_train_data[:len(refrence)]
            perturbed_train_data = preprocess_graphs(perturbed_train_data)

            refrence = preprocess_graphs(refrence)

            device =  torch.device('cpu')
            # graphPloter(perturbed_train_data[:20], dir + +file+"_generated_sample")
            # toDGL(generated[0])
            if Structural_Feature_attri:
                perturbed_train_data = [toDGL(g) for g in perturbed_train_data]
                refrence = [toDGL(g) for g in refrence]
            else:
                perturbed_train_data = [dgl.DGLGraph(g).to(device) for g in perturbed_train_data] # Convert graphs to DGL from NetworkX
                refrence = [dgl.DGLGraph(g).to(device) for g in refrence] # Convert graphs to DGL from NetworkX

            # Compute all GNN-based metrics at once
            from evaluation.evaluator import Evaluator

            f1 = []
            mmd_rbf = []
            recall =[]
            precision = []
            for i in range(10):
                evaluator = Evaluator(input_dim=3, device=device)
                result = evaluator.evaluate_all(perturbed_train_data, refrence)
                f1.append(result["f1_pr"])
                recall.append(result["recall"])
                precision.append(result["precision"])
                mmd_rbf.append(result["mmd_rbf"])


            f1_std= np.array(f1).std()
            f1_mean= np.array(f1).mean()

            mmd_rbf_std= np.array(mmd_rbf).std()
            mmd_rbf_mean= np.array(mmd_rbf).mean()
            # Stats = read_MMD(path+MMD_fileName)
            Stats = ""
            recall_mean = np.array(recall).mean()
            recall_std = np.array(recall).std()
            precision_mean = np.array(precision).mean()
            precision_std = np.array(precision).std()
            res = [file, f1_mean,  f1_std , mmd_rbf_mean,mmd_rbf_std,precision_mean,precision_std,recall_mean,recall_std, Stats]
            print(res)
            reprot.append(res)
        except  Exception as e:
            reprot.append([file,"Error"])
            print(e)
import csv

if Structural_Feature_attri:
    dir = dir+"RandomGNN_"+"With_Structural_Feature_attri_orbit"+"_"
else:
    dir ="RandomGNN_"+dir

with open(dir+".csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(reprot)
