import pickle
from collections import defaultdict


def read_graph(data_path):
    with open(data_path, 'rb') as tf:
        graph_set = pickle.load(tf)
    return graph_set

# get total node, total edge, avg node per graph, avg edge per graph
def data_stat(graph_set):
    graph_num = 0
    total_node = 0
    total_edge = 0
    for graph in graph_set:
        total_node += graph.number_of_nodes()
        total_edge += graph.number_of_edges()
        graph_num += 1
    print('Total graph Number:{}, total node:{}, total edge:{}, avg node:{}, avg edge:{}'.
          format(graph_num, total_node, total_edge, (total_node/graph_num), (total_edge/graph_num)))




if __name__ == '__main__':
    data_path = '/Users/haonan/Downloads/github/DP-generative-model/DP-graphSAGE-VAE/data/new_IMDB_MULTI.pickle'
    graph_set = read_graph(data_path)
    data_stat(graph_set)
    label_dict = defaultdict(int)
    label_list = []
    for graph in graph_set:
        print(graph.number_of_nodes())
        print(graph.graph['label'])
        label_dict[graph.graph['label']] += 1
        label_list.append(graph.graph['label'])
    print(label_dict)

    # generate_graph_path = '/Users/haonan/Downloads/github/DP-generative-model/DPgraphGen_cleaned/data/imdb_new_generated_netgan_nx2.pickle'
    # generate_graph_set = read_graph(generate_graph_path)
    # for id, graph in enumerate(generate_graph_set):
    #     graph.graph['label'] = int(label_list[id])
    #
    # with open(generate_graph_path, 'wb') as tf:
    #     pickle.dump(generate_graph_set, tf)
    #
    #
    # generate_graph_path = '/Users/haonan/Downloads/github/DP-generative-model/DPgraphGen_cleaned/data/imdb_new_generated_graphrnn_nx2.pickle'
    # generate_graph_set = read_graph(generate_graph_path)
    # for id, graph in enumerate(generate_graph_set):
    #     graph.graph['label'] = int(label_list[id])
    #
    # with open(generate_graph_path, 'wb') as tf:
    #     pickle.dump(generate_graph_set, tf)