Keywords: Curvature, Causal Representation Learning, Geometric Deep Learning, Graph Neural Networks
TL;DR: This paper explores the relationship between graph curvature and causal representation learning in network data, showing that positive Ricci curvature in graphs corresponds to more accurate estimation of causal parameters
Abstract: Learning causal mechanisms involving networked units of data is a notoriously challenging task with various applications. Graph Neural Networks (GNNs) have proven to be effective for learning representations that capture complex dependencies between data units. This effectiveness is largely due to the conduciveness of GNNs to tools that characterize the geometry of graphs. The potential of geometric deep learning for GNN-based causal representation learning, however, remains underexplored. This work makes three key contributions to bridge this gap. First, we establish a theoretical connection between graph curvature and causal inference, showing that negative curvatures pose challenges to learning the causal mechanisms underlying network data. Second, based on this theoretical insight, we present empirical results using the Ricci curvature to gauge the error in treatment effect estimates made from representations learned by GNNs. This empirically demonstrates that positive curvature regions yield more accurate results. Lastly, as an example of the potentials unleashed by this newfound connection between geometry and causal inference, we propose a method using Ricci flow to improve the treatment effect estimation on networked data. Our experiments confirm that this method reduces the error in treatment effect estimates by flattening the network, showcasing the utility of geometric methods for enhancing causal representation learning. Our findings open new avenues for leveraging discrete geometry in causal representation learning, offering insights and tools that enhance the performance of GNNs in learning robust structural relationships.
Submission Number: 3
Loading