Incorporating Interventional Independence Improves Robustness against Interventional Distribution Shift
Abstract: We consider the problem of learning robust discriminative representations of latent variables that are causally related to each other via a directed graph. In addition to passively collected observational data, the training dataset also includes interventional data obtained through targeted interventions on some of these latent variables to learn representations that are robust against the resulting interventional distribution shifts. However, existing approaches treat interventional data like observational data, even when the underlying causal model is known, and ignore the independence relations that arise from these interventions. Since these approaches do not fully exploit the causal relational information resulting from interventions, they learn representations that produce large disparities in predictive performance on observational and interventional data. This performance disparity worsens when the number of interventional data samples available for training is limited. In this paper, (1) we first identify a strong correlation between this performance disparity and adherence of the representations to the statistical independence conditions induced by the underlying causal model during interventions. (2) For linear models, we derive sufficient conditions on the proportion of interventional data in the training dataset, for which enforcing statistical independence between representations corresponding to the intervened node and its non-descendants during interventions lowers the test-time error on interventional data. Combining these insights, (3) we propose RepLIn, a training algorithm to explicitly enforce this statistical independence during interventions. We demonstrate the utility of RepLIn on a synthetic dataset and on real image and text datasets on facial attribute classification and toxicity detection, respectively. Our experiments show that RepLIn is scalable with the number of nodes in the causal graph and is suitable to improve the robustness of representations against interventional distribution shifts of both continuous and discrete latent variables compared to the ERM baselines.
Submission Length: Long submission (more than 12 pages of main content)
Previous TMLR Submission Url: https://openreview.net/forum?id=pZRanZlab4
Changes Since Last Submission: - Added experimental results on the effect of loss hyperparameters $\lambda_{\text{dep}}$ and $\lambda_{\text{self}}$.
- Added new baselines to the primary experiments. Added results for toxicity detection on CivilComments dataset.
- Added subsection on limitations of the proposed method.
- Minor writing changes to further clarify our objectives and claims.
Assigned Action Editor: ~Yu_Yao3
Submission Number: 5315
Loading