Abstract: Graph neural networks (GNNs) have demonstrated remarkable success in graph representation learning and various sampling approaches have been proposed to scale GNNs to applications with large-scale graphs. A class of promising GNN training algorithms take advantage of historical embeddings to reduce the computation and memory cost while maintaining the model expressiveness of GNNs. However, they incur significant computation bias due to the stale feature history. In this paper, we provide a comprehensive analysis of their staleness and inferior performance on large-scale problems. Motivated by our discoveries, we propose a simple yet highly effective training algorithm (REST) to effectively reduce feature staleness, which leads to significantly improved performance and convergence across varying batch sizes, especially when staleness is predominant. The proposed algorithm seamlessly integrates with existing solutions, boasting easy implementation, while comprehensive experiments underscore its superior performance and efficiency on large-scale benchmarks. Specifically, our improvements to state-of-the-art historical embedding methods result in a 2.7\% and 3.6\% performance enhancement on the ogbn-papers100M and ogbn-products dataset respectively, accompanied by notably accelerated convergence. The code can be found at https://github.com/RXPHD/REST.
Lay Summary: Contemporary datasets —from social-media interactions to molecular structures and power-grid topologies—are most naturally modeled as graphs, collections of nodes connected by edges. Graph Neural Networks (GNNs) provide a powerful framework for analyzing such data, yet their training is computationally intensive. A common remedy is to sample smaller subgraphs, but this inevitably discards information from un-sampled nodes. To avoid this loss, the widely adopted historical-embedding approach caches intermediate node representations for reuse. Unfortunately, these cached embeddings quickly become stale, forcing the model to rely on outdated signals and thereby undermining both accuracy and convergence speed.
To address this staleness problem, our work first offers a comprehensive analysis that pinpoints its root causes and then introduces REST—a simple yet effective training framework that mitigates the issue. REST decouples forward and backward passes and executes them at different frequencies, allowing historical embeddings to be refreshed more frequently. Experiments on multiple large-scale benchmarks show that REST improves prediction accuracy, accelerates convergence, and preserves computational efficiency, delivering a promising solution for diverse real-world applications.
Link To Code: https://github.com/RXPHD/REST
Primary Area: Deep Learning->Graph Neural Networks
Keywords: Graph Neural Networks, Large Scale Machine Learning, Historical Embeddings
Submission Number: 8553
Loading