Modular Learning of Deep Causal Generative Models for High-dimensional Causal Inference

23 Sept 2023 (modified: 11 Feb 2024)Submitted to ICLR 2024EveryoneRevisionsBibTeX
Supplementary Material: zip
Primary Area: causal reasoning
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: Causal inference, Structural causal models, Causal graphs, Causal effect, Generative Adversarial Networks.
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
TL;DR: We propose a novel adversarial training algorithm to learn deep causal generative models from high-dimensional data in a modular fashion, enabling the use of pre-trained models for causal and counterfactual effect estimation.
Abstract: Pearl’s causal hierarchy establishes a clear separation between observational, interventional, and counterfactual questions. Researchers proposed sound and complete algorithms to compute identifiable causal queries at a given level of the hierarchy using the causal structure and data from the lower levels of the hierarchy. However, most of these algorithms assume that we can accurately estimate the probability distribution of the data, which is an impractical assumption for high-dimensional variables such as images. On the other hand, modern generative deep learning architectures can be trained to learn how to accurately sample from such high-dimensional distributions. Especially with the recent rise of foundation models for images, it is desirable to leverage pre-trained models to answer causal queries with such high-dimensional data. To address this, we propose a sequential training algorithm that, given the causal structure and a pre-trained conditional generative model, can train a deep causal generative model, which utilizes the pre-trained model and can provably sample from identifiable interventional and counterfactual distributions. Our algorithm, called WhatIfGAN, uses adversarial training to learn the network weights, and to the best of our knowledge, is the first algorithm that can make use of pre-trained models and provably sample from any identifiable causal query in the presence of latent confounders with high-dimensional data. We demonstrate the utility of our algorithm using semi-synthetic and real-world datasets containing images as variables in the causal structure.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 8025
Loading