diffIRM: A Diffusion-Augmented Invariant Risk Minimization Framework for Spatiotemporal Prediction over Graphs
Abstract: Spatiotemporal prediction over graphs (STPG) is challenging because real-world data suffer from the out-of-distribution (OOD) generalization problem, where test data follow different distributions from training ones. To address this issue, invariant risk minimization (IRM) has emerged as a promising approach for learning invariant representations across different environments. However, IRM and its variants are originally designed for Euclidean data, such as images, and may not generalize well to graph-structure data, such as spatiotemporal graphs, because of spatial correlations in graphs. To overcome the challenge posed by graph-structure data, the existing graph OOD methods adhere to the principles of invariance existence (i.e., there exist invariant features that consistently relate to the label across various environments) or environment diversity (i.e., diversifying training environments increases the likelihood that test environments align with training ones). However, there is little research that combines both principles in the STPG problem. A combination of the two is crucial for efficiently distinguishing between invariant features and spurious ones. In this study, we fill in this research gap and propose a diffusion-augmented invariant risk minimization (diffIRM) framework that combines these two principles for the STPG problem. Our diffIRM contains two processes: (1) data augmentation, and (2) invariant learning. In the data augmentation process, a causal mask generator identifies causal features, and a graph-based diffusion model acts as an environment augmentor to generate augmented spatiotemporal graph data. In the invariant learning process, an invariance penalty is designed using the augmented data and then serves as a regularizer for training the spatiotemporal prediction model. We provide theoretical evidence supporting diffIRM’s ability to identify invariant features. The effectiveness of diffIRM is further demonstrated through experiments on both numerical and real-world data. The numerical data are generated from a known structural causal model (SCM), and our proposed diffIRM successfully identifies the true invariant features. The real-world experiment uses three human mobility data sets, that is, SafeGraph, PeMS04, and PeMS08. Our proposed diffIRM outperforms baselines. Furthermore, our model demonstrates interpretability by discerning invariant features while making predictions. History: This paper has been accepted for the Transportation Science Special Issue on Machine Learning Methods for Urban Passenger Mobility. Funding: This work was supported by the National Science Foundation [Grant 2218809].
External IDs:doi:10.1287/trsc.2024.0562
Loading