Disentangled Causal Transformer: Counterfactual Prediction under Time-Varying Treatments

18 Sept 2025 (modified: 11 Feb 2026)Submitted to ICLR 2026EveryoneRevisionsBibTeXCC BY 4.0
Keywords: Causal Inference, Disentangled Representation Learning, Time-series Modeling, Transformer, Time-Varying Confounding Bias
Abstract: Estimating longitudinal counterfactual outcomes from observational data is pivotal to personalized medicine and other domains. However, prevailing approaches for mitigating time-varying confounding bias typically balance all covariates indiscriminately, conflating confounders with instrumental variables and thus unnecessarily discarding valuable outcome-relevant information. While causal disentangled representation learning has proven effective in static settings, extending it to the longitudinal setting—where representation disentanglement and time-series modeling must be performed jointly over time—remains a key challenge. To address this, we introduce the $\textbf{Disentangled Causal Transformer (DCT)}$, a Transformer-based architecture designed to integrate causal representation disentanglement seamlessly within the sequence modeling process for robust longitudinal causal inference. DCT features a novel $\textbf{disentangled multi-head attention}$ mechanism that decomposes a patient’s history into instrumental, outcome, and confounder components. This design enables unbiased causal estimates while preserving the full predictive signal, thus mitigating the traditional trade-off between factual and counterfactual prediction accuracy. Extensive experiments on fully synthetic and semi-synthetic datasets derived from real electronic health records show that DCT consistently outperforms state-of-the-art baselines by a large margin in counterfactual outcome prediction. To the best of our knowledge, DCT pioneers the integration of causal representation disentanglement within a Transformer-based model for robust longitudinal causal inference.
Primary Area: causal reasoning
Submission Number: 11440
Loading