import os
import pickle
import random
from collections import defaultdict

import networkx as nx
dir_path = os.path.dirname(os.path.realpath(__file__))
root_path = os.path.abspath(os.path.join(dir_path, os.pardir))

def read_graph():
    data_path = dir_path + '/dataset/IMDBMULTI/IMDBMULTI.pkl'
    with open(data_path, 'rb') as file:
        data_list = pickle.load(file)
    return data_list

def save_graph(data_list):
    save_path = dir_path + '/dataset/Resampled_IMDBMULTI/Resampled_IMDBMULTI.pkl'
    with open(save_path, 'wb') as file:
        pickle.dump(data_list, file)
    print("finish dump.")


if __name__ == '__main__':
    graphs_dict = defaultdict(list)
    data_list = read_graph()
    for graph in data_list:
        graphs_dict[graph.graph['label']].append(graph)


    num_per_class = 200
    sampled_graph_list = [] # random.sample(data_list, num)
    for label, g_list in graphs_dict.items():
        sampled_graph_list.extend(random.sample(g_list, num_per_class))
    save_graph(sampled_graph_list)

    label_dict = defaultdict(int)
    for graph in sampled_graph_list:
        label_dict[graph.graph['label']] += 1
    print(label_dict)