Abstract: Causal inference seeks to estimate the effect given a treatment such as a medicine or the dosage of a medication. To reduce the confounding bias caused by the non-randomized treatment assignment, most existing methods reduce the shift between subpopulations receiving different treatments. However, these methods split limited training samples into smaller groups, which cuts down the number of samples in each group, while precise distribution estimation and alignment highly rely on a sufficient number of training samples. In this paper, we propose a distribution alignment paradigm without data splitting, which can be naturally applied in the settings of binary and continuous treatments. To this end, we characterize the confounding bias by considering different probability measures of the same set including all the training samples, and exploit the optimal transport theory to analyze the confounding bias and outcome estimation error. Based on this, we propose to learn balanced representations by reducing the bias between the marginal distribution and the conditional distribution of a treatment. As a result, data reduction caused by splitting is avoided, and the outcome prediction model trained on one treatment group can be generalized to the entire population. The experiments on both binary and continuous treatment settings demonstrate the effectiveness of our method.
Lay Summary: Causal inference aims to evaluate the effect of a treatment, such as a medicine or the dosage of a medication.
Ideally, the treatment effect can be estimated by randomized controlled trials,
in which individuals are assigned to the treated group and the control group randomly,
and the effect can be calculated by comparing the results of the two groups.
However, the individuals do not receive treatments randomly.
For example, sicker patients might be more likely to receive a particular medication,
making it difficult to estimate the treatment effect accurately since the treated and control groups follow different distributions.
Existing methods usually reduce the distribution discrepancy between the groups,
in which training samples are split into different subsets,
which cuts down the number of samples in each group and hampers distribution estimation and alignment.
To address this, we propose a distribution alignment paradigm without data splitting.
Specifically, we analyze the distribution discrepancy and the effect estimation error,
and align different distributions based on optimal transport between probability measures,
which is built on all the samples instead of subsets.
As a result, data splitting is avoided, and data efficiency is improved for accurate distribution modeling.
We conduct experiments in different settings, including the binary treatment setting (receiving a treatment or not) and the continuous treatment setting (receiving a dosage of the treatment).
The experimental results demonstrate that our method can estimate the treatment effect accurately.
Existing methods often mitigate this bias by splitting data into subgroups, but this reduces the number of training samples per group, which undermines reliable distribution estimation and outcome prediction.
We tackled the problem by proposing a novel distribution alignment framework that avoids data splitting.
They model confounding bias using different probability measures over the same dataset and employ optimal transport theory to analyze and reduce both confounding bias and prediction error.
By learning balanced representations that align marginal and conditional treatment distributions, the method generalizes across the population without reducing the sample size.
This research provides a more data-efficient and theoretically grounded approach to causal inference for both binary and continuous treatments.
By avoiding data splitting and leveraging optimal transport,
our method works for both simple yes-or-no treatments and more complex cases like varying dosage levels. It improves accuracy and helps build more reliable models to guide medical decisions, benefiting both researchers and patients.
Primary Area: General Machine Learning->Causality
Keywords: Causal inference, continuous treatment, optimal transport
Submission Number: 15581
Loading