import io
import pickle

import networkx as nx
import torch

import dgl


def _reconstruct_pickle(obj):
    f = io.BytesIO()
    pickle.dump(obj, f)
    f.seek(0)
    obj = pickle.load(f)
    f.close()
    return obj


def test_pickling_batched_graph():
    # NOTE: this is a test for a wierd bug mentioned in
    #   https://github.com/dmlc/dgl/issues/438
    glist = [nx.path_graph(i + 5) for i in range(5)]
    glist = [dgl.DGLGraph(g) for g in glist]
    bg = dgl.batch(glist)
    bg.ndata["x"] = torch.randn((35, 5))
    bg.edata["y"] = torch.randn((60, 3))
    new_bg = _reconstruct_pickle(bg)


if __name__ == "__main__":
    test_pickling_batched_graph()
