Staleness-based subgraph sampling for large-scale GNNs training

20 Sept 2023 (modified: 11 Feb 2024)Submitted to ICLR 2024EveryoneRevisionsBibTeX
Supplementary Material: pdf
Primary Area: learning on graphs and other geometries & topologies
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: subgraph sampling, large-scale GNNs training, historical embeddings, staleness
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
Abstract: Training Graph Neural Networks (GNNs) on large-scale graphs is challenging. The main difficulty is to obtain accurate node embeddings while avoiding the neighbor explosion problem. Many of the existing solutions use historical embeddings to tackle this challenge. Specifically, by using historical embeddings for the out-of-batch nodes, these methods can approximate full-batch training without dropping any input data while keeping constant GPU memory consumption. However, it still remains nascent to specifically design a subgraph sampling method that can benefit these historical embedding-based methods. In this paper, we first analyze the approximation error of node embeddings caused by using historical embeddings for out-of-batch neighbors and prove that this approximation error can be minimized by minimizing the staleness of historical embeddings of out-of-batch nodes. Based on the theoretical analysis, we design a simple yet effective \underline{S}taleness score-based \underline{S}ubgraph \underline{S}ampling method (S3) to benefit these historical embedding-based methods. The key idea is to first define the edge weight as the sum of the staleness scores of the source and target nodes and then apply graph partitioning to minimize edge cuts, with each resulting partition as a mini-batch during training. In this way, we can explicitly minimize the approximation error of node embeddings. Furthermore, to deal with the dynamic changes of staleness scores during training and improve the efficiency of graph partitioning, we design a fast algorithm to generate mini-batches via a local refinement heuristic. Experimental results show that (1) our S3 sampling method can further improve historical embedding-based methods and set the new state-of-the-art, and (2) our fast algorithm is 3x faster than re-partitioning graph from scratch on the large-scale ogbn-products dataset with 2M nodes. In addition, the consistent improvements on all three historical embedding-based methods (GAS, GraphFM, and LMC) also show the generalizability of our subgraph sampling method.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 2199
Loading