\documentclass[accepted]{uai2022} 








\usepackage{zref-xr,zref-user}
\zexternaldocument*{vo_458-supp}

\usepackage{setspace}
\usepackage{amsthm}

\usepackage{dsfont}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{float}
\usepackage{upgreek}
\allowdisplaybreaks
\usepackage{bm}
\usepackage{enumitem}
\usepackage{multirow}
\usepackage{xcolor}
\usepackage{makecell}
\usepackage{graphicx}
\usepackage{booktabs}
\usepackage[linesnumbered,ruled,vlined]{algorithm2e}


\usepackage{stackengine}


\newcommand{\q}{q}
\newcommand{\qbar}{\tilde{\mathbb{Q}}}
\newcommand{\p}{p}
\newcommand{\D}{\mathbb{D}}
\newcommand{\m}{\mathbb{M}}
\newcommand{\e}{\mathbb{E}}
\newcommand{\V}{\mathbb{V}}
\newcommand{\infq}{\inf_{\q}}
\newcommand{\Ifqp}{\mathbb{I}_f(\q,\p)}
\newcommand{\eq}{\mathbb{E}_{\q}}
\newcommand{\ep}{\mathbb{E}_{\p}}
\newcommand{\ed}{\mathbb{E}_{\D}}
\newcommand{\EM}{\mathbb{E}_{\m}}
\newcommand{\Risk}{\mathbb{L}}
\newcommand{\Riskattained}{\underline{\mathbb{L}}}
\newcommand{\risk}{L}
\newcommand{\riskattained}{\underline{L}}
\newcommand{\reals}{\mathbb{R}}
\newcommand{\maps}{\rightarrow}
\newcommand{\settozero}{\overset{!}{=}0}
\newcommand{\ind}{\perp \!\!\! \perp }

\newcommand{\argmax}{\operatornamewithlimits{arg\max}}
\newcommand{\argmin}{\operatornamewithlimits{arg\min}}
\newcommand{\doo}{\textnormal{do}}
\allowdisplaybreaks



\newtheorem{theorem}{Theorem}
\newtheorem{definition}{Definition}
\newtheorem{remark}{Remark}
\newtheorem{example}{Example}
\newtheorem{proposition}{Proposition}
\newtheorem{corollary}{Corollary}
\newtheorem{lemma}{Lemma}

\newtheoremstyle{exampstyle}
  {3pt} {0} {} {} {\bfseries} {.} {.5em} {} \theoremstyle{exampstyle} \newtheorem{assumption}{Assumption}



\newcommand{\ymis}{Y^{\textrm{mis}}_i}
\newcommand{\yobs}{Y^{\textrm{obs}}_i}
\newcommand{\Ymis}{\textbf{Y}_{\textrm{mis}}}
\newcommand{\Yobs}{\textbf{Y}_{\textrm{obs}}}
\newcommand{\W}{\textbf{W}}
\newcommand{\X}{\textbf{X}}
\newcommand{\Yzero}{\textbf{Y}(0)}
\newcommand{\Yone}{\textbf{Y}(1)}
\newcommand{\yzero}{Y_i(0)}
\newcommand{\yone}{Y_i(1)}
\newcommand{\varbeta}{\sigma^2_{\beta}}
\newcommand{\betac}{\beta^{{[c]}}}
\newcommand{\betat}{\beta^{{[t]}}}
\newcommand{\x}{\textbf{x}_i}
\newcommand{\xic}{\x^\top\betac}
\newcommand{\xit}{\x^\top\betat}
\newcommand{\betaall}{\boldsymbol{\beta}}
\newcommand{\et}{{\epsilon_i^{[t]}}}
\newcommand{\ec}{{\epsilon_i^{[c]}}}
\newcommand{\epstilde}{\tilde{\epsilon}}
\newcommand{\mumis}{\mu_i^{\textrm{mis}}}
\newcommand{\muobs}{\mu_i^{\textrm{obs}}}
\newcommand{\lambdaobs}{\lambda_i^{\textrm{obs}}}
\newcommand{\lambdamis}{\lambda_i^{\textrm{mis}}}
\newcommand{\qmis}{{Q}_i^{\textrm{mis}}}
\newcommand{\Qmis}{{\textbf{Q}}^{{\textrm{mis}}}}
\newcommand{\qobs}{{Q}_i^{{\textrm{obs}}}}
\newcommand{\Qobs}{{\textbf{Q}}^{{\textrm{obs}}}}
\newcommand{\xiobs}{\xi_i^{\textrm{obs}}}
\newcommand{\ximis}{\xi_i^{\textrm{mis}}}
\newcommand{\xitilde}{\tilde{\textbf{x}}^{\textrm{obs}}_i}
\newcommand{\xitildemis}{\tilde{\textbf{x}}^{\textrm{mis}}_i}
\newcommand{\xtilde}{\tilde{\textbf{x}}}
\newcommand{\Xtilde}{\tilde{\textbf{X}}}
\newcommand{\xibreve}{\breve{\textbf{x}}_i}

\newcommand{\highlight}[2][yellow]{\mathchoice {\colorbox{#1}{$\displaystyle#2$}}{\colorbox{#1}{$\textstyle#2$}}{\colorbox{#1}{$\scriptstyle#2$}}{\colorbox{#1}{$\scriptscriptstyle#2$}}}



\usepackage[american]{babel}


\usepackage{natbib} \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} \usepackage{booktabs} \usepackage{tikz} 



\newcommand{\swap}[3][-]{#3#1#2} 

\title{Bayesian Federated Estimation of Causal Effects from Observational Data}

\author[1]{\href{mailto:<votv@comp.nus.edu.sg>?Subject=Your UAI 2022 paper}{Thanh~Vinh~Vo}{}}
\author[2]{Young Lee}
\author[3]{Trong Nghia Hoang}
\author[1]{\href{mailto:<leongty@comp.nus.edu.sg>?Subject=Your UAI 2022 paper}{Tze-Yun~Leong}{}}
\affil[1]{School of Computing\\
    National University of Singapore
}
\affil[2]{Harvard University
}
\affil[3]{School of Electrical Engineering and Computer Science\\
    Washington State University
  }
  
  \begin{document}
\maketitle

\begin{abstract}


We propose a Bayesian framework for estimating causal effects from federated observational data sources. Bayesian causal inference is an important approach to learning the distribution of the causal estimands and understanding the uncertainty of causal effects. Our framework estimates the posterior distributions of the causal effects to compute the higher-order statistics that capture the uncertainty. We integrate local causal effects from different data sources without centralizing them. We then estimate the treatment effects from observational data using a non-parametric reformulation of the classical potential outcomes framework. We model the potential outcomes as a random function distributed by Gaussian processes, with defining parameters that can be efficiently learned from multiple data sources. Our method avoids exchanging raw data among the sources, thus contributing towards privacy-preserving causal learning. The promise of our approach is demonstrated through a set of simulated and real-world examples.









%
 \end{abstract}


\section{Introduction} 
\label{sec:intro}

Causal effect estimation is important in many real-life situations. For example: What is the effect of war in a specific region on world food supply? How would the blood pressure of a patient change if that patient took a new drug? How does coronary heart disease affect age- and gender-specific mortality rates? These questions are common in many areas, including personalized medicine \citep{powers2018some}, digital experiments \citep{taddy2016nonparametric}, political science \citep{green2012modeling}, etc., and especially recent events in the Covid-19 pandemic and the war of Ukraine. 
In practice, the relevant data essential for accurate and meaningful causal inference may reside in multiple, decentralized data sources which cannot be shared or combined due to geographical, organizational, process, and/or privacy constraints. 
Some alternative solutions such as establishing data use agreements or creating secure data environments may not be possible and are often not easily implemented. 
In addition, it is important to know whether the causal estimands are reliable. Thus, estimating a confidence interval of the relevant causal effect together with its point estimates would give helpful insights into the uncertainty of the causal estimand. For example, a narrow confidence interval for individual treatment effect of smoking on lung cancer, where zero falls outside the confidence interval, means that the patient is at a higher risk of getting cancer.

Most of the recent causal effect estimators, e.g., \citet{louizos2017causal,shalit2017estimating,madras2019fairness}, are point estimates without considering the uncertainty of the causal estimands. Bayesian approaches, e.g., \citet{imbens1997bayesian,daniels2012bayesian,talbot2015bayesian,gutman2018bayesian,ning2019bayesian}, on the other hand, aim to learn the posterior distributions of the causal estimands to obtain higher-order statistics that capture the uncertainty.  To derive these posterior distributions of the causal estimands, however, most, if not all of the existing efforts involve pooling the distributed data from multiple sources centrally to compute the model \emph{marginal} likelihood, thus violating the privacy constraints mentioned above.

We propose a Bayesian framework that can learn the causal effects of interest without combining data sources to a central site, and, at the same time, learn higher-order statistics of the causal effects to understand their uncertainty. This federated learning  approach 
\citep{mcmahan2017communication} has not been well studied for causal inference. Our contributions are summarized as follows:

\begin{itemize}[leftmargin=*,noitemsep]

\item We propose the Federated Causal Inference (FedCI) \footnote{Source code: \url{https://github.com/vothanhvinh/FedCI}.} framework that fuses federated learning and causal inference to incorporate multiple data sources while maintaining the sources at their local sites. 

\begin{itemize}[noitemsep]
\item FedCI generalizes the Bayesian imputation approach \citep[][]{imbens2015causal} to a more generic model based on Gaussian processes (GPs); the resulting model is decomposed into multiple components, each of which handles a distinct data source. 

\item FedCI minimizes information transmitted among the sources, thus enabling privacy-preserving causal inference. The framework could support multiparty computation and differential privacy in future.
\end{itemize}
\item We propose a variational approximation scheme for the proposed model, whose evidence lower bound can be decomposed additively across different data sources. This allows the parameters to be optimized via federated gradient averaging. We then leverage the computed predictive distribution to efficiently estimate the desired treatment effect quantities.

\item We empirically evaluate the proposed framework on benchmark datasets, and show its competitive performance as compared to the recent baseline approaches trained on the combined datasets.

\end{itemize}



































    
























    

































































































































  

  

\section{Related Work}
\label{sec:potential-outcomes}


\textbf{Causal inference.} In  most causal inference literature,  the estimation of causal effects is performed directly on accessible local data sources. \citet{hill2011bayesian,Alaa:2017,Alaa:2018} proposed nonparametric approaches to estimate causal effects. 
A growing literature, including \cite{shalit2017estimating,Yoon:2018ganite,yao2018representation,kunzel2019metalearners,nie2021quasi}, used parametric methods to model the potential outcomes.  \citet{louizos2017causal,madras2019fairness} used the formulation of \cite{pearl1995causal} to estimate causal effects under the existence of latent confounding variables. \citet{bica2020Estimating,bica2020time} formalized potential outcomes for temporal data with observed and unobserved confounding variables to estimate counterfactual outcomes for treatment plans. \citet{imbens1997bayesian,daniels2012bayesian,talbot2015bayesian,gutman2018bayesian,ning2019bayesian} are typical Bayesian methods that learn posterior distributions of the causal estimands. 
All these works were not proposed for the context of multi-source data which cannot be shared and combined as a unified dataset. 
Our model, in contrast, learns treatment effects while preserving the source data at their local sites. It is different from the problem of transportability of causal relations \citep[e.g.,][]{pearl2011transportability,bareinboim2013meta,bareinboim2013causal,bareinboim2016causal,lee2020generalized}, where theoretical tools were developed to transport causal effects from a source population to a target population, which does not take into account the above data privacy constraint. 

\textbf{Federated learning.} Federated learning aims to train an algorithm across multiple decentralized clients, thus respect the privacy information of the data 
\citep{mcmahan2017communication}. Federated stochastic gradient descent \citep{shokri2015privacy} and federated averaging \citep{mcmahan2017communication} are two variations of federated learning. Recent developments of federated learning, e.g., \citet{alvarez2019non,zhe2019scalable,de2020mogptk,joukov2020fast,sattler2019robust,mohri2019agnostic}  
are formalized for a typical classification or regression problem. Also, it has recently been applied in facilitating multi-institutional collaborations without sharing patient data \citep{rieke2020future,sheller2020federated} and healthcare informatics \citep{lee2020federated,xu2021federated}. Other biomedical applications of federated learning include predicting adverse drug reactions \citep{choudhury2019predicting}, 
stroke prevention \citep{ju2020privacy}, 
mortality prediction \citep{vaid52federated}, predicting outcomes in SARS-COV-2 patients \citep{flores2021federated}, etc. \citet{ng2022towards,gao2021federated} are noticeable works that estimate causal graphs in federated setting, which is different from our work in estimating casual effects.

Following some recent works \citep[e.g.,][]{shalit2017estimating,yao2018representation,oprescu2019orthogonal,kunzel2019metalearners,nie2021quasi}, we develop a federated causal inference algorithm based on the potential outcomes framework. 
We summarize the related models in the subsequent sections.

\subsection{Potential Outcomes and the Bayesian Imputation Model}
\label{sec:rubin-imputation}
The concept of potential outcomes was proposed in \citet{neyman1923application} for randomized trial experiments. \citet{Rubin:1975,rubin:1976a,rubin1977assignment,Rubin:1978} re-formalized the framework for observational studies. We consider the causal effects of a binary treatment $w$, with $w=1$ indicating assignment to `treatment' and $w=0$ indicating assignment to `control'. Following the literature, the causal effect for individual $i$ is defined as a comparison of the two potential outcomes, $y_i(0)$ and $y_i(1)$, where these are the outcomes that would be observed under $w_i=0$ and $w_i=1$, respectively. We can never observe both $y_i(0)$ and $y_i(1)$ for any individual $i$, because it is not possible to go back in time and expose the $i$--th  individual to the other treatment. In this work, we generalize the Bayesian imputation model of \citet[][]{imbens2015causal} since it captures uncertainty of the causal estimands in a Bayesian setting: \begin{align}
    y_i(0) &= \bm{\upbeta}_0^\top\mathbf{x}_i + \epsilon_{0i}, &y_i(1) &= \bm{\upbeta}_1^\top\mathbf{x}_i + \epsilon_{1i},\label{eq:rubin-model}
\end{align}
where $\epsilon_{0i}$ and $\epsilon_{1i}$  are the Gaussian noises. The key to compute treatment effects is $y_i(0)$ and $y_i(1)$. So we need to impute one of the two outcomes. Let $y_{i,\textrm{obs}}$, $y_{i,\textrm{mis}}$  be the observed and unobserved (or missing) outcome. The idea is to find the marginal distribution $\p(y_{i,\textrm{mis}}|\mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w})$. Once the missing outcomes are imputed, the treatment effects can be estimated. To proceed, \citet[][]{imbens2015causal} suggested four steps based on the following equation $
    \p(y_{i,\textrm{mis}}| \mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w}) = \int \p(y_{i,\textrm{mis}}| \mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w},\theta)\p(\theta|\mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w})d\theta$, where $\theta = \{\bm{\upbeta}_0, \bm{\upbeta}_1\}$. The aim is to find $\p(y_{i,\textrm{mis}}| \mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w},\theta)$ and $\p(\theta|\mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w})$, and then compute the integral to obtain  $\p(y_{i,\textrm{mis}}| \mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w})$, which is a non-parametric prediction. 
    
\emph{The above procedure shows that learning the distribution $\p(y_{i,\emph{\textrm{mis}}}| \mathbf{y}_{\emph{\textrm{obs}}},\mathbf{X},\mathbf{w})$ would require data from all sources since it is conditional on $\mathbf{y}_{\emph{\textrm{obs}}}$, $\mathbf{X}$, and $\mathbf{w}$. Thus, it violates the data privacy constraint.} 
In Sections~\ref{sec:rubin-to-gp}, \ref{sec:model} and \ref{sec:inference}, we generalize this model with Gaussian processes and decompose it into multiple components to perform federated inference of the causal effects, which minimizes the risk of privacy leak of the data.
    
    %
 

\section{Our Approach}

We generalize the Bayesian imputation model presented in Section~\ref{sec:rubin-imputation} to a generic model based on Gaussian Processes (GPs). 
We introduce the Federated Causal Inference (FedCI) method to decompose the model into multiple components, each associated with a data source,  to estimate causal effects under a federated setting.


\subsection{Problem Formulation}
\label{sec:prob-formu}
\textbf{Problem setting \& notations.} Suppose we have $m$ data sources that are organized and curated at their local sites. Each source is denoted by $\mathsf{D}^\mathsf{s} = \{( w_i^\mathsf{s}, y_{i,\textrm{obs}}^\mathsf{s}, \mathbf{x}_i^\mathsf{s})\}_{i=1}^{n_\mathsf{s}}$, where $\mathsf{s}=1,2,\dots,m$, and the quantities $w_i^\mathsf{s}$, $y_{i,\textrm{obs}}^\mathsf{s}$ and  $\mathbf{x}_i^\mathsf{s}$ are the treatment assignment, observed outcome associated with the treatment, and covariates of individual $i$ in source $\mathsf{s}$, respectively. In this work, we focus on binary treatment $w_i^\mathsf{s} \in \{0,1\}$, thus $y_{i,\textrm{obs}}^\mathsf{s}$ can be either of the potential outcomes $y_i^\mathsf{s}(0)$ or $y_i^\mathsf{s}(1)$, i.e., for each individual $i$, we can only observe either $y_i^\mathsf{s}(0)$ or  $y_i^\mathsf{s}(1)$, but not both of them. We further denote the unobserved or missing outcome as $y_{i,\textrm{mis}}^\mathsf{s}$. These variables are related to each other through the following equations:
\begin{align}
y_i^\mathsf{s}(1) &= w_i^\mathsf{s} y^\mathsf{s}_{i,\textrm{obs}} + (1-w_i^\mathsf{s})y^\mathsf{s}_{i,\textrm{mis}}, \label{eq:y1-ymis-yobs}\\
y_i^\mathsf{s}(0) &= (1-w_i^\mathsf{s}) y^\mathsf{s}_{i,\textrm{obs}} + w_i^\mathsf{s}y^\mathsf{s}_{i,\textrm{mis}}\label{eq:y0-ymis-yobs}.
\end{align}
Thus, $y_i^\mathsf{s}(1) = y^\mathsf{s}_{i,\textrm{obs}}$ when $w_i^\mathsf{s}=1$ and $y_i^\mathsf{s}(1) = y^\mathsf{s}_{i,\textrm{mis}}$ when $w_i^\mathsf{s}=0$, and similarly for $y_i^\mathsf{s}(0)$. 
For notational convenience, we further denote $    \mathbf{y}^\mathsf{s}(0) = [y_1^\mathsf{s}(0),\!...,y_{n_\mathsf{s}}^\mathsf{s}(0)]^\top$,  $\mathbf{y}^\mathsf{s}_{\textrm{obs}} = [y^\mathsf{s}_{1,\textrm{obs}},\!...,y^\mathsf{s}_{n_\mathsf{s},\textrm{obs}}]^\top$, 
and similarly for $\mathbf{y}^\mathsf{s}(1)$, $\mathbf{y}^\mathsf{s}_{\textrm{mis}}$, $\mathbf{X}^\mathsf{s}$ and $\mathbf{w}^\mathsf{s}$.




\textbf{Causal effects of interest.} We estimate the individual treatment effect (\textrm{ITE})\footnote{Also known as conditional average treatment effect (CATE). } and the average treatment effect (\textrm{ATE}) defined as follows:
\begin{align}
\uptau_i^\mathsf{s} &\vcentcolon= y_i^\mathsf{s}(1) - y_i^\mathsf{s}(0), &\uptau &\textstyle\vcentcolon= \sum_{\mathsf{s}=1}^m\sum_{i=1}^{n_\mathsf{s}}\uptau_i^\mathsf{s}/n,\label{eq:ate}
\end{align}
where $y_i^\mathsf{s}(1)$, $y_i^\mathsf{s}(0)$ are realization outcomes of the corresponding random variables, and $n = \sum_{\mathsf{s}=1}^m n_\mathsf{s}$. 


\textbf{The causal estimands.} Inserting Eq.~(\ref{eq:y1-ymis-yobs})~and~(\ref{eq:y0-ymis-yobs}) into (\ref{eq:ate}), we obtain the estimate of ITE:
\begin{align}
&\e[\uptau^\mathsf{s}_i]
=(2w_i^\mathsf{s}-1) (y^\mathsf{s}_{i,\textrm{obs}} - \e\big[y^\mathsf{s}_{i,\textrm{mis}}\big| \mathbf{y}_{\textrm{obs}}, \mathbf{X}, \mathbf{w}\big]),\label{eq:ite-hat-expectation}\\ 
&\V\text{ar}[\uptau^\mathsf{s}_i] =(2w_i^\mathsf{s}-1)^2\mathbb{V} \text{ar}\left[y^\mathsf{s}_{i,\textrm{mis}}\big| \mathbf{y}_{\textrm{obs}}, \mathbf{X}, \mathbf{w}\right],
\end{align}
where $\mathbf{y}_{\textrm{obs}}$, $\mathbf{X}$, $\mathbf{w}$ denote the vectors/matrices of the observed outcomes, covariates and treatments concatenated from all the sources. The estimate of ATE is as follows:
\begin{align}
&\!\!\!\!\e[\uptau] \!=\!(2\mathbf{w}-\mathbf{1})^\top\!(\mathbf{y}_{\textrm{obs}} \!-\! \e[\mathbf{y}_{\text{mis}}| \mathbf{y}_{\textrm{obs}}, \mathbf{X}, \mathbf{w}])/n,\\
&\!\!\!\!\V\text{ar}[\uptau] \!=\!(2\mathbf{w}\!-\!\mathbf{1})^\top\!\mathbb{C} \text{ov}[\mathbf{y}_{\textrm{mis}}| \mathbf{y}_{\textrm{obs}}, \mathbf{X}, \mathbf{w}](2\mathbf{w}\!-\!\mathbf{1})/n^2,\label{eq:tau-hat-variance}
\end{align}
where $\mathbf{1}$ is a vector of ones. 

Hence, the remaining task is to learn the posterior $\p(\mathbf{y}_{\textrm{mis}}\big| \mathbf{y}_{\textrm{obs}}, \mathbf{X}, \mathbf{w})$, which is the predictive distribution of $\mathbf{y}_{\textrm{mis}}$ given \emph{all} the covariates, treatments and observed outcomes from \emph{all} sources. 



\subsection{Assumptions}
\label{sec:assumptions}
The following assumptions are made to enable federated causal estimations: \begin{assumption}
[Strong Ignorability] \label{assumption:ignorability}  \emph{($i$)} The potential outcomes are independent of the treatment assignment conditional on the covariates \emph{(}unconfoundedness\emph{)}, i.e., $ y_i^\mathsf{s}(1),y_i^\mathsf{s}(0)\!\ind\! w_i^\mathsf{s}|\mathbf{x}_i^\mathsf{s}$, and \emph{($ii$)} every individual has some positive probability to be assigned to every treatment \emph{(}positivity\emph{)}, i.e., $0 < \p(w_i^\mathsf{s}\!=\!1|\mathbf{x}_i^\mathsf{s}) < 1$. \citep{rosenbaum1983central}
\end{assumption}
\begin{assumption} [Stable Unit Treatment Value Assumption or SUTVA]
\label{assumption:sutva}
\emph{($i$)} The potential outcomes for any individual do not vary with the treatments assigned to other individuals, and \emph{($ii$)} there are no different forms or versions of each treatment level, which would lead to different potential outcomes. 
\citep{imbens2015causal}
\end{assumption}
\begin{assumption} \label{assumption:share-covariates}
The individuals from all sources share the same set of common covariates.
\end{assumption}
\begin{assumption}\label{assumption:unique-ident} There exists a set of features such that any individual is uniquely identified across different sources. We refer to this set as `primary key'.
\end{assumption}
\begin{assumption}
\label{assumption:homogeneous-heterogeneous}
Data in different sources are drawn from parts of the population. The multi-source data, which may be homogeneous or heterogeneous in nature, together reflect the characteristics of the population.
\end{assumption}
Assumption~\ref{assumption:ignorability} and~\ref{assumption:sutva} are standards in causal inference, as discussed in, e.g., \citet{imbens2015causal,shalit2017estimating}. Assumption~\ref{assumption:share-covariates} is reasonable, e.g., decentralized data in \citet{choudhury2019predicting,vaid52federated,flores2021federated} (to name a few) satisfy this assumption for federated learning. 
In Assumption~\ref{assumption:unique-ident}, a `primary key' is not limited to the observed data used for inference as described in Section~\ref{sec:prob-formu}, but it can include any features to uniquely identify an individual, such as $\{\text{nationality, national id}\}$ of a patient.  
Assumption~\ref{assumption:unique-ident} allows a secure preprocessing procedure to remove repeated individual records in different sources, if necessary, without sharing raw data among the sources (see Appendix~\zref{sec:appendix-preprocessing} for details). Assumption~\ref{assumption:homogeneous-heterogeneous} ensures that there is no imbalanced data bias across the sources. In the subsequent sections, we assume that all of the above assumptions 
are satisfied, and the preprocessing procedure is already performed if necessary.

\subsection{GP-based Imputation}
\label{sec:rubin-to-gp}
The model presented in Eq.~(\ref{eq:rubin-model}) is a simple Bayesian linear model. In this section, we present a more general nonlinear Bayesian model. In particular, since $\bm{\upbeta}_0$ and $\bm{\upbeta}_1$ in Eq.~(\ref{eq:rubin-model}) follow multivariate normal distributions, the two components $\bm{\upbeta}_0^\top\mathbf{x}_i$ and $\bm{\upbeta}_1^\top\mathbf{x}_i$ also follow multivariate normal distributions. The generalisation of these two components are $f_0(\mathbf{x}_i) = \bm{\upbeta}_0^\top\omega(\mathbf{x}_i)$ and $f_1(\mathbf{x}_i) = \bm{\upbeta}_1^\top\omega(\mathbf{x}_i)$, where $\omega(\mathbf{x}_i)$ is a vector of basis functions with input $\mathbf{x}_i$. This formulation would lead to the fact that the marginal of $f_0(\mathbf{x})$ and $f_1(\mathbf{x})$ are Gaussian processes (GPs). Thus, we propose:\begin{align}
    \!\!\!\!\!y_i(0) &= f_0(\mathbf{x}_i) + \epsilon_{0i}, &y_i(1) &= f_1(\mathbf{x}_i) + \epsilon_{1i},\label{eq:gp-model}
\end{align}
where $f_0(\mathbf{x}_i)$ and $f_1(\mathbf{x}_i)$ are two random functions evaluated at $\mathbf{x}_i$, i.e., $f_0(\mathbf{x}_i) \sim \mathsf{GP}(\mu_0(\mathbf{X}), \mathbf{K})$ and $f_1(\mathbf{x}_i) \sim \mathsf{GP}(\mu_1(\mathbf{X}), \mathbf{K})$, where $\mathbf{K}$ denotes the covariance matrix computed with a kernel function $\mathsf{k}(\mathbf{x},\mathbf{x}')$. Similar to the imputation model of \citet{imbens2015causal}, 
this model also requires finding the marginal distribution $\p(y_{i,\textrm{mis}}\,|\,\mathbf{y}_{\textrm{obs}},\mathbf{X},\mathbf{w})$, \emph{through accessing the observed data from all the sources.} 

\emph{Similarly, although this model is generic, it requires access to all the observed data to compute $\mathbf{K}$, which is impossible without violating the privacy constraints mentioned above.} 
In the subsequent sections, we propose a federated model to address this problem. 

\subsection{The Proposed Model}
\label{sec:model}
Recall that the aim is to find $\p(\mathbf{y}_{\textrm{mis}}\,|\, \mathbf{y}_{\textrm{obs}}, \mathbf{X}, \mathbf{w})$ so that we may in turn compute Eqs.~(\ref{eq:ite-hat-expectation})-(\ref{eq:tau-hat-variance}) to arrive at the quantities of interest. To that end, we propose to model the joint distribution of the potential outcomes as follows:
\begin{align}
\begin{bmatrix}
	y_i^\mathsf{s}(0)\\
	y_i^\mathsf{s}(1)
	\end{bmatrix} = \Psi^{\frac{1}{2}}\left(\begin{bmatrix}
	f_0^\mathsf{s}(\mathbf{x}_i)\\
	f_1^\mathsf{s}(\mathbf{x}_i)
	\end{bmatrix} + \begin{bmatrix}
	g_0^\mathsf{s}\\
	g_1^\mathsf{s}
	\end{bmatrix}\right) + \Sigma^{\frac{1}{2}}\bm{\upvarepsilon}_i^\mathsf{s},
\label{eq:the-model}
\end{align}
where $\bm{\upvarepsilon}_i^\mathsf{s} \sim \mathsf{N}(\mathbf{0}, \mathbf{I}_2)$  is to handle the noise of the outcomes. 

As mentioned in Section~\ref{sec:rubin-imputation}~and~\ref{sec:rubin-to-gp}, all the outcomes from all sources are \emph{interdependent} in the Bayesian imputation approach, which is problematic for federated learning. This dependency is handled via $f_j^\mathsf{s}(\mathbf{x}_i)$ and $g_j^\mathsf{s}$ ($j\in\{0,1\}$), which enable federated learning for the proposed model. We refer to the dependency handled by $f_j^\mathsf{s}(\mathbf{x}_i)$ as intra-dependency and the one captured by $g_j^\mathsf{s}$ as inter-dependency. 

\textbf{Intra-dependency.} 
$f_0^\mathsf{s}(\mathbf{x}_i)$ and $f_1^\mathsf{s}(\mathbf{x}_i)$ are GP-distributed functions, which allows us to model each source dataset simultaneously along with its heterogeneous correlations. Specifically, we model $f_0^\mathsf{s}(\mathbf{x}_i) \sim \mathsf{GP}(\mu_0(\mathbf{X}^\mathsf{s}), \mathbf{K}^\mathsf{s})$ and $f_1^\mathsf{s}(\mathbf{x}_i) \sim \mathsf{GP}(\mu_1(\mathbf{X}^\mathsf{s}), \mathbf{K}^\mathsf{s})$, where $\mathbf{K}^\mathsf{s}$ is a covariance matrix computed by a kernel function $\mathsf{k}(\mathbf{x}_i^\mathsf{s}, \mathbf{x}_j^\mathsf{s})$, and $\mu_0(\cdot)$, $\mu_1(\cdot)$ are functions modelling the mean of these GPs. Parameters of these functions and hyperparameters in the kernel function are shared across multiple sources. 
The above GPs handle the correlations within one source only. 

\textbf{Inter-dependency.} 
To capture \textit{dependency} among the sources, we introduce variable $\mathbf{g} = [\mathbf{g}_0, \mathbf{g}_1]$, where $\mathbf{g}_0 = [g_0^{1},\!...,g_0^{m}]^\top \sim \mathsf{N}(\bm{r}_0, \mathbf{M})$ and $\mathbf{g}_1 = [g_1^{1},\!...,g_1^{m}]^\top \sim \mathsf{N}(\bm{r}_1, \mathbf{M})$. 
Both $g_0^\mathsf{s}$ and $g_1^\mathsf{s}$ are shared within the source $\mathsf{s}$, and they are correlated across multiple sources $\mathsf{s} \in \{1,\!...,m\}$. The correlations among the sources are modelled via the covariance matrix $\mathbf{M}$ 
which is computed with a kernel function.  The inputs to the kernel function are the sufficient statistics (we used mean, variance, skewness, and kurtosis) of each covariate $\mathbf{x}^\mathsf{s}$ within the source $\mathsf{s}$. We denote the first four moments of covariates as $\mathbf{\widetilde{x}}^\mathsf{s} \in \mathbb{R}^{4 d_x \times 1}$ and the kernel function as $\gamma(\mathbf{\widetilde{x}}^\mathsf{s}, \mathbf{\widetilde{x}}^{\mathsf{s}'})$, which evaluates the correlation of two sources $\mathsf{s}$ and $\mathsf{s}'$. This formulation implies that $\mathbf{g}_0$ and $\mathbf{g}_1$ are GPs. The elements of $\bm{r}_0$ and $\bm{r}_1$ are computed with the mean functions $r_0(\mathbf{\widetilde{x}}^\mathsf{s})$ and $r_1(\mathbf{\widetilde{x}}^\mathsf{s})$, respectively. Herein, we only share the sufficient statistics of covariates, but not covariates of a specific individual. 

\textbf{The two variables $\Psi$ and $\Sigma$.} 
These are positive semi-definite matrices capturing the correlations between the two possible outcomes $y_i^\mathsf{s}(0)$ and $y_i^\mathsf{s}(1)$, $\Psi^{\frac{1}{2}}$ and $\Sigma^{\frac{1}{2}}$ are their Cholesky decomposition matrices. Note that $\Psi$ and $\Sigma$ are also random variables. Since these are positive semi-definite matrices, we model their priors using Wishart distribution $\Psi  \sim \mathsf{Wishart}(\mathbf{V}_0, d_0)$, $\Sigma \sim \mathsf{Wishart}(\mathbf{S}_0, n_0)$, where $\mathbf{V}_0, \mathbf{S}_0 \in \mathbb{R}^{2\times 2}$ are predefined positive semi-definite matrices and $d_0, n_0 \ge 2$ are predefined degrees of freedom.















\subsection{The Proposed Algorithm}
\label{sec:inference}
Based on some results on the joint distribution of potential outcomes, we construct a federated objective function for the proposed federated causal inference algorithm (FedCI).

\subsubsection{Joint Distribution of the Outcomes}

We first present some results that are helpful in constructing the federated objective function in Section~\ref{sec:objective-function}. 
The proofs of these results are in the appendices. To simplify the exposition, we denote $\mathbf{g}^\mathsf{s} = [\mathbf{g}_0^\mathsf{s}, \mathbf{g}_1^\mathsf{s}]$, where $\mathbf{g}_0^\mathsf{s} = [g_0^\mathsf{s},\!..., g_0^\mathsf{s}]^\top$ and $\mathbf{g}_1^\mathsf{s} = [g_1^\mathsf{s},\!..., g_1^\mathsf{s}]^\top$. 

\begin{lemma}
\label{lem:joint-prob-s}
Let $\Psi$, $\Sigma$, $\mathbf{K}$, $\mu_0(\mathbf{X}^\mathsf{s})$, $\mu_1(\mathbf{X}^\mathsf{s})$, and $\mathbf{g}^\mathsf{s}$  satisfy the model in Eq.~\emph{(\ref{eq:the-model})}. Then,
\begin{align*}
&\begin{bmatrix}
\mathbf{y}^\mathsf{s}(0)\\
\mathbf{y}^\mathsf{s}(1)
\end{bmatrix}\Big|\Psi, \Sigma, \mathbf{X}^\mathsf{s}, \mathbf{w}^\mathsf{s}, \mathbf{g}^\mathsf{s} \\[-0.1cm]
&\sim \mathsf{N}\!\left( \left(\Psi^{\frac{1}{2}} \otimes \mathbf{I}_{n_\mathsf{s}}\right)\begin{bmatrix}\mu_0(\mathbf{X}^\mathsf{s}) + \mathbf{g}_0^\mathsf{s}\\\mu_1(\mathbf{X}^\mathsf{s}) + \mathbf{g}_1^\mathsf{s}
\end{bmatrix} \!\!, \Psi \otimes \mathbf{K}^\mathsf{s} + \Sigma \otimes \mathbf{I}_{n_\mathsf{s}}\right)\!,
\end{align*}
where $\otimes$ is the Kronecker product.
\end{lemma}
The proof of Lemma~\ref{lem:joint-prob-s} is presented in Appendix~\zref{sec:appendix-proof-lem-1}. From Lemma~\ref{lem:joint-prob-s}, we observe that $\Psi$, $\mathbf{K}^\mathsf{s}$, $\Sigma$, and $\mathbf{I}_{n_\mathsf{s}}$ are positive semi-definite, thus the covariance matrix $\Psi \otimes \mathbf{K}^\mathsf{s} + \Sigma \otimes \mathbf{I}_{n_\mathsf{s}}$ is positive semi-definite due to the fundamental property of Kronecker product. This is why we choose $\Psi$ and $\Sigma$ to be positive semi-definite in our model; otherwise, the covariance matrix is invalid. From Lemma~\ref{lem:joint-prob-s}, we can obtain the result in Lemma~\ref{lem:3} as follows:
\begin{lemma}
\label{lem:3}
Let $\Psi$, $\Sigma$, $\mathbf{K}$, $\mu_0(\mathbf{X}^\mathsf{s})$, $\mu_1(\mathbf{X}^\mathsf{s})$, and $\mathbf{g}^\mathsf{s}$  satisfy the model in Eq.~\emph{(\ref{eq:the-model})}. Then,
\begin{align*}
&\!\!\!\begin{bmatrix}
\mathbf{y}^\mathsf{s}_{\emph{\textrm{obs}}}\\
\mathbf{y}^\mathsf{s}_{\emph{\textrm{mis}}}
\end{bmatrix}\!\!\Big|\Psi,\! \Sigma,\! \mathbf{X}^\mathsf{s}\!, \mathbf{w}^\mathsf{s}\!, \mathbf{g}^\mathsf{s} \!\sim\! \mathsf{N}\!\left(\! \begin{bmatrix}\mu_{\emph{\textrm{obs}}}(\mathbf{X}^\mathsf{s})\\\mu_{\emph{\textrm{mis}}}(\mathbf{X}^\mathsf{s})
\end{bmatrix} \!\!,\! \begin{bmatrix}
\mathbf{K}_{\emph{\textrm{obs}}}^\mathsf{s}&\mathbf{K}_{\emph{\textrm{om}}}^\mathsf{s}\\
(\mathbf{K}_{\emph{\textrm{om}}}^\mathsf{s})^\top&\mathbf{K}_{\emph{\textrm{mis}}}^\mathsf{s}\end{bmatrix}\!\right)\!\!.
\end{align*}
The mean functions $\mu_{\emph{\textrm{obs}}}(\mathbf{X}^\mathsf{s})$ and $\mu_{\emph{\textrm{mis}}}(\mathbf{X}^\mathsf{s})$ are:
\begin{align*}
\mu_{\emph{\textrm{obs}}}(\mathbf{X}^\mathsf{s}) &= (\mathbf{1} - \mathbf{w}^\mathsf{s})\odot\mathbf{m}_0 + \mathbf{w}^\mathsf{s} \odot\mathbf{m}_1,\\[-0.1cm]
\mu_{\emph{\textrm{mis}}}(\mathbf{X}^\mathsf{s}) &= \mathbf{w}^\mathsf{s} \odot\mathbf{m}_0 + (\mathbf{1} - \mathbf{w}^\mathsf{s}) \odot\mathbf{m}_1,
\end{align*}
where we denote $\mathbf{m}_0 = \psi_{11}^\ast(\mu_0(\mathbf{X}^\mathsf{s}) + \mathbf{g}_0^\mathsf{s})$ and $\mathbf{m}_1 = \psi_{21}^\ast(\mu_0(\mathbf{X}^\mathsf{s}) + \mathbf{g}_0^\mathsf{s}) + \psi_{22}^\ast(\mu_1(\mathbf{X}^\mathsf{s}) + \mathbf{g}_1^\mathsf{s})$ with $\psi^\ast_{ij}$ is the $(i,j)$--th element of Cholesky decomposition matrix of $\Psi$, $\mathbf{1}$ is a vector ones, and $\odot$ is the element-wise product. The covariance matrices $\mathbf{K}^\mathsf{s}_{\textrm{\emph{\textrm{obs}}}}$, $\mathbf{K}^\mathsf{s}_{\textrm{\emph{\textrm{mis}}}}$, and $\mathbf{K}^\mathsf{s}_{\textrm{\emph{\textrm{om}}}}$ are computed by kernel functions:
\begin{align*}
k_{\emph{\textrm{obs}}}(\mathbf{x}_i, \mathbf{x}_j) \!&=\! \big[(1-w_i)(1-w_j)\psi_{11} \!+\! w_iw_j\psi_{22} \\&\,\,\,+\! (1-w_i)w_j\psi_{12} \!+\! w_i(1-w_j)\psi_{21}\big] \mathsf{k}(\mathbf{x}_i, \mathbf{x}_j)\\&\,\,\,\,\, +\! \big[(1-w_i)\sigma_{11} \!+\! w_i\sigma_{22}\big] \mathds{1}_{i=j},\\k_{\emph{\textrm{mis}}}(\mathbf{x}_i, \mathbf{x}_j) \!&=\! \big[w_iw_j\psi_{11} \!+\! (1-w_i)(1-w_j)\psi_{22} \\&\,\,\, +\! (1-w_i)w_j\psi_{21} \!+\! w_i(1-w_j)\psi_{12}\big] \mathsf{k}(\mathbf{x}_i, \mathbf{x}_j)\\& \,\,\,\,\,+\! \big[w_i\sigma_{11} \!+\! (1-w_i)\sigma_{22}\big] \mathds{1}_{i=j},
\\k_{\emph{\textrm{om}}}(\mathbf{x}_i, \mathbf{x}_j) &\!=\! \big[(1-w_i)(1-w_j)\psi_{21} \!+\! w_iw_j\psi_{12} \\&\,\,\,+\! (1-w_i)w_j\psi_{22} \!+\! w_i(1-w_j)\psi_{11}\big] \mathsf{k}(\mathbf{x}_i, \mathbf{x}_j) \\&\,\,\,\,\,+\! \big[(1-w_i)\sigma_{21} \!+\! w_i\sigma_{12}\big] \mathds{1}_{i=j},
\end{align*}
where $\psi_{ab}$ and $\sigma_{ab}$ are the $(a,b)$--th elements of $\Psi$ and $\Sigma$, respectively.
\end{lemma}

The proof of Lemma~\ref{lem:3} is in Appendix~\zref{sec:appendix-proof-lem-2}. Lemma~\ref{lem:3} has two important roles in our work. First, we can obtain the conditional likelihood to help infer the parameters and hyperparameters of our proposed model. Second, we can also obtain the posterior of $\mathbf{y}^\mathsf{s}_\textrm{mis}$ to help us estimate ITE and ATE.



\subsubsection{Federated Objective Function}
\label{sec:objective-function}
The proposed model in Eq.~(\ref{eq:the-model}) would lead to an objective function that can be decomposed into $m$ components, each associated with a data source. Since estimating $\p(\mathbf{y}^\mathsf{s}_{\textrm{mis}}\,\big|\, \mathbf{y}^\mathsf{s}_{\textrm{obs}}, \mathbf{X}^\mathsf{s}, \mathbf{w}^\mathsf{s})$ exactly is intractable, we sidestep this intractability via a variational approximation.  To achieve this, we maximize the following evidence lower bound (ELBO) $\mathbf{L}$:
\begin{align}
    \log\p(\mathbf{y}_{\textrm{obs}}\,|\,\mathbf{X},\mathbf{w}) &= \log \!\int\!\p(\mathbf{y}_{\textrm{obs}},\mathbf{g}, \Psi, \Sigma\,|\,\mathbf{X},\mathbf{w}) d\mathbf{g}d\Psi d\Sigma\nonumber\\
    &\ge \sum_{\mathsf{s}=1}^m \mathbf{L}^\mathsf{s} =\vcentcolon\mathbf{L},\label{eq:loss}
\end{align}
where \begin{align*}
\!\mathbf{L}^\mathsf{s} &= \e_q\Big[\! \log\p(\mathbf{y}^\mathsf{s}_{\textrm{obs}} | \cdot)\Big]  \!-\!\frac{1}{m}\Big(\!\sum_{z \in \{\mathbf{g}, \Psi, \Sigma\}}\!\D_{\text{KL}}[\q(z)\|\p(z)]\Big).\end{align*}
Herein, $\D_{\text{KL}}[\cdot]$ is the Kullback–Leibler divergence. Details of the ELBO are presented in Appendix~\zref{sec:appendix-elbo}. The conditional likelihood $\p(\mathbf{y}^\mathsf{s}_{\textrm{obs}} | \cdot)$ is obtained from Lemma~\ref{lem:3} by marginalizing out $\mathbf{y}^\mathsf{s}_{\textrm{mis}}$, i.e.,
\begin{align}
    \!\p(\mathbf{y}^\mathsf{s}_{\textrm{obs}} | \mathbf{X}^\mathsf{s}, \mathbf{w}^\mathsf{s}, \Psi, \Sigma, \mathbf{g}^\mathsf{s}) = \mathsf{N}(\mathbf{y}^\mathsf{s}_{\textrm{obs}};\mu_{\textrm{obs}}(\mathbf{X}^\mathsf{s}), \mathbf{K}_{\textrm{obs}}^\mathsf{s}).
\end{align}
The above conditional likelihood is free of $\sigma_{21}$ and $\sigma_{12}$, which capture the correlation of two potential outcomes. Thus the posteriors of these variables would coincide with their priors, i.e., the correlation cannot be learned but set as a prior. This is well-known as one of the potential outcome cannot be observed \citep{imbens2015causal}. 
In Eq.~(\ref{eq:loss}), the ELBO $\mathbf{L}$ is derived from the of joint marginal likelihood of all $m$ sources, and it is factorized into $m$ components $\mathbf{L}^\mathsf{s}$, each component corresponds to a source. This enables federated optimization of $\mathbf{L}$. The first term of $\mathbf{L}^\mathsf{s}$ is expectation of the conditional likelihood with respect to the variational posterior $q(\mathbf{g}, \Psi, \Sigma)$, thus this distribution is learned from data of all the sources. In the following, we present its factorization. 

\textbf{Variational posterior distributions.} 
We use the typical mean-field approximation to factorize among the variational posteriors $
\q(\Psi, \Sigma, \mathbf{g}) =\q(\Psi)\,\q(\Sigma)\,\q(\mathbf{g})$. Let $\mathbf{\widetilde{y}}_{\textrm{obs}}^\mathsf{s}(0)$, $\mathbf{\widetilde{y}}_{\textrm{obs}}^\mathsf{s}(1)$, $\mathbf{\widetilde{x}}^{\mathsf{s}}$, and $\mathbf{\widetilde{w}}^\mathsf{s}$ ($\mathsf{s} = 1,2,\!...,m$) be the first four moments of the observed outcomes, covariates, and treatment of the $\mathsf{s}$--th source. Let $\mathbf{\widetilde{X}} = [\mathbf{\widetilde{x}}^{1},\!...,\mathbf{\widetilde{x}}^{m}]^\top$, $\mathbf{\widetilde{y}}_{\textrm{obs}}(0) = [\mathbf{\widetilde{y}}_{\textrm{obs}}^{1}(0),\!...,\mathbf{\widetilde{y}}_{\textrm{obs}}^{m}(0)]^\top$, $\mathbf{\widetilde{y}}_{\textrm{obs}}(1) = [\mathbf{\widetilde{y}}_{\textrm{obs}}^{1}(1),\!...,\mathbf{\widetilde{y}}_{\textrm{obs}}^{m}(1)]^\top$, and $\mathbf{\widetilde{w}} = [\mathbf{\widetilde{w}}^{1},\!...,\mathbf{\widetilde{w}}^{m}]^\top$. We parameterize 
\begin{align*}
    \q(\mathbf{g}) =\textstyle \prod_{j\in\{0,1\}}\mathsf{N}(\mathbf{g}_j;h_j(\mathbf{\widetilde{y}}_{\textrm{obs}}(0), \mathbf{\widetilde{y}}_{\textrm{obs}}(1), \mathbf{\widetilde{X}}, \mathbf{\widetilde{w}}), \mathbf{U}),
\end{align*}
where $h_0(\cdot)$ and $h_1(\cdot)$ are the mean functions, $\mathbf{U}$ is the covariance matrix computed with a kernel function $\kappa(u^\mathsf{s}, u^{\mathsf{s}'})$, where $u^\mathsf{s}\vcentcolon= [\mathbf{\widetilde{y}}^\mathsf{s}_{\textrm{obs}}(0), \mathbf{\widetilde{y}}^\mathsf{s}_{\textrm{obs}}(1), \mathbf{\widetilde{x}}^\mathsf{s}, \mathbf{\widetilde{w}}^\mathsf{s}]$. 

Since $\Psi$ and $\Sigma$ are positive semi-definite matrices, we model their variational posterior as Wishart distribution: 
\begin{align*}
    	\q(\Psi) \!=\! \mathsf{Wishart}(\Psi;\mathbf{V}_q, d_q),\quad \q(\Sigma) \!=\! \mathsf{Wishart}(\Sigma;\mathbf{S}_q, n_q),
\end{align*}
where $d_q, n_q$ are degrees of freedom, $\mathbf{V}_q, \mathbf{S}_q$ are the scale matrices. We set the form of these scale matrices as follows
\begin{align*}
\mathbf{V}_q &= \begin{bmatrix}
\nu_{1}^2&\rho\nu_{1}\nu_2\\
\rho\nu_{1}\nu_2&\nu_{2}^2
\end{bmatrix}, &\mathbf{S}_q &= \begin{bmatrix}
\delta_{1}^2&\eta\delta_{1}\delta_2\\
\eta\delta_{1}\delta_2&\delta_{2}^2
\end{bmatrix},
\end{align*}
where $\nu_i, \rho, \delta_i, \eta$ are parameters to be learned and $\rho, \eta \in [0,1]$.


\textbf{Reparameterization.} To maximize the ELBO, we approximate the expectation in $\mathbf{L}^\mathsf{s}$ with Monte Carlo integration, which requires drawing samples of $\mathbf{g}$, $\Psi$ and $\Sigma$ from their variational distributions. This requires a reparameterization to allow the gradients to pass through the random variables $\mathbf{g}$, $\Psi$ and $\Sigma$. 
The reparameterization trick for $\mathbf{g}$ are:
$
\mathbf{g}_j = h_j(\mathbf{\widetilde{y}}_{\textrm{obs}}(0), \mathbf{\widetilde{y}}_{\textrm{obs}}(1), \mathbf{\widetilde{X}}, \mathbf{\widetilde{w}}) + \mathbf{U}^{\frac{1}{2}} \bm{\xi}_j, j\in\{0,1\}$, 
where $\bm{\xi}_j \sim \mathsf{N}(\bm{0},\mathbf{I}_m)$ and $\mathbf{U}^{\frac{1}{2}}$ is the Cholesky decomposition matrix of $\mathbf{U}$.
Since $\q(\Psi)$ is a Wishart distribution, we introduce the following procedure to draw $\Psi$: $\Psi = \mathbf{V}_q^{\frac{1}{2}} \bm{\zeta} ( \mathbf{V}_q^{\frac{1}{2}})^\top, \bm{\zeta} \sim \mathsf{Wishart}(\mathbf{I}_2, d_q)$, 
where $\mathbf{V}_q^{\frac{1}{2}}$ is the Choleskey decomposition matrix of $\mathbf{V}_q$. Likewise, we also apply this procedure to draw $\Sigma$.

\textbf{Federated optimization algorithm.} 
With the above model and its objective function, we can compute gradients of the learnable parameters separately in each source without sharing data to a central server. We summarize our procedure in Algorithm~\ref{algo:maximize-elbo}. 









\subsubsection{Predicting Causal Effects from Multiple Sources}
To understand why data from all the sources can help predict causal effects in a source $\mathsf{s}$, we observe that \begin{align}
    &\p(\mathbf{y}^\mathsf{s}_\textrm{mis}\,\big| \mathbf{y}_\textrm{obs}, \mathbf{X}, \mathbf{w}) \label{eq:predictive-dist}\\
    &\simeq \e_{\q} \big[p(\mathbf{y}^\mathsf{s}_\textrm{mis}\big| \mathbf{y}^\mathsf{s}_\textrm{obs},\mathbf{X}^\mathsf{s}, \mathbf{w}^\mathsf{s}, \Psi, \Sigma, \mathbf{g})\big] \nonumber\\[-0.1cm]
&= p(\mathbf{y}^\mathsf{s}_\text{mis}\,\big| \underbrace{\addstackgap[1.3pt]{$\mathbf{y}^\mathsf{s}_\text{obs}, \mathbf{X}^\mathsf{s}, \mathbf{w}^\mathsf{s}$}}_\textbf{(i)}, \underbrace{\addstackgap[4.6pt]{$\Theta $}}_\textbf{(ii)}, \underbrace{\addstackgap[2pt]{$\mathbf{\widetilde{y}}_{\textrm{obs}}(0), \mathbf{\widetilde{y}}_{\textrm{obs}}(1), \mathbf{\widetilde{X}}, \mathbf{\widetilde{w}}$}}_\textbf{(iii)}).\nonumber
\end{align}
Eq.~(\ref{eq:predictive-dist}) is an approximation of the predictive distribution of the missing outcomes $\mathbf{y}^\mathsf{s}_\textrm{mis}$ and it depends on the following three components:
\begin{enumerate}[noitemsep,topsep=0pt,leftmargin=*,label=\textbf{(\roman*).}]
    \item The observed outcomes, covariates and treatment assignments from the same source $\mathsf{s}$.
    \item The shared parameters $\Theta$ learned from data of all the sources.
    \item Sufficient statistics of the observed data from all the sources.\end{enumerate}
The two last components \textbf{(ii)} and \textbf{(iii)} indicate that the predictive distribution in source $\mathsf{s}$ utilizes knowledge from all the sources through $\Theta$ and the sufficient statistics $[\mathbf{\widetilde{y}}_{\textrm{obs}}(0), \mathbf{\widetilde{y}}_{\textrm{obs}}(1), \mathbf{\widetilde{X}}, \mathbf{\widetilde{w}}]$. 
This explain why data from all of the sources help predict missing outcomes in source $\mathsf{s}$.


{\IncMargin{1.2em}
\begin{algorithm}\setstretch{0}
\caption{Federated causal inference}
\footnotesize
\label{algo:maximize-elbo}
\SetKwInOut{Parameter}{Parameters}
\Parameter{Let $\Theta$ be set of parameters}
\Begin{
 		Initialize $\Theta$ and send to all source machines\;
	    \Repeat{\emph{stopping condition}}{
    		\For{\emph{ source machine} $\mathsf{s} \in \{1,2,\dots,m\}$}{
    			Compute $\nabla_\Theta \mathbf{L}^\mathsf{s}$ and send to server\;
    		}
    		In the central server, do the following steps:\\
    		\Begin{
    		Collect gradients from all sources\;
    		Compute $\nabla_\Theta\mathbf{L} = \sum_{\mathsf{s}=1}^m \nabla_\Theta\mathbf{L}^\mathsf{s}$\;
    		Update $\Theta \!\leftarrow\! \Theta + \mathsf{learning\_rate} \times\nabla_\Theta\mathbf{L} $\;
    		Broadcast the new $\Theta$ to all sources\;
    		}
		}
}
\end{algorithm}
\DecMargin{1.2em}
}








 

\section{Experiments}
\label{sec:experiment}

\textbf{The baselines and experimental objectives.} 
We first examine the performance of FedCI. We then compare the performance of FedCI against recent causal inference methods, such as BART \citep{hill2011bayesian}, 
TARNet, CFR Wass (CFRNet with Wasserstein distance), CFR MMD (CFRNet with maximum mean discrepancy distance) \citep{shalit2017estimating}, 
CEVAE \citep{louizos2017causal}, 
OrthoRF \citep{oprescu2019orthogonal}, 
X-learner \citep{kunzel2019metalearners}, and R-learner \citep{nie2021quasi}. All these methods do not consider learning causal effects in a federated setting. This analysis aims to show the efficacy of FedCI as compared with the baselines trained in three different cases: (\textbf{1}) training a local model on each source data, (\textbf{2}) training a global model with the combined data of all sources, (\textbf{3}) using bootstrap aggregating (also known as bagging, which is an ensemble learning method) of \citet{breiman1996bagging} where $m$ models are locally trained on each source data; then taking average of the predicted treatment effects of each model. Although case (\textbf{2})  \textit{violates} the privacy constraint of federated data, we use it  for comparison purposes. In general, we would like 
to assess the performance of the federated causal inference approach against the baselines using combined data in case~(\textbf{2}). 


We use publicly available libraries and source codes to implement the baseline methods. In particular, CEVAE, TARNet, CFR Wass, and CFR MMD are readily available on github. We use the online packages \texttt{BartPy} for BART,  \texttt{causalml} \citep{chen2020causalml} for X-learner  and R-learner, and 
\texttt{econml} \citep*{econml} for OrthoRF. 
For all the methods, we fine-tune the learning rate in $\{10^{-1}, 10^{-2}, 10^{-3}, 10^{-4}\}$ and regularizers in $\{10^{1}, 10^{0}, 10^{-1}, 10^{-2}, 10^{-3}\}$. 

\textbf{Evaluation metrics.} We report two evaluation metrics: (i) precision in estimation of heterogeneous effects (PEHE) \citep{hill2011bayesian} for evaluating ITE, and (ii) absolute error for evaluating ATE. Details are presented in Appendix~\zref{sec:eval-metrics}.  These metrics are for point estimates, which are the mean of ITE and ATE in their estimated distributions. 
We also report the estimated distribution of ATE in our model.



\subsection{Synthetic Data}
\label{sec:synthetic-data}
We analyses FedCI in terms of three types of outcomes: (1) real-value, (2)  binary, and (3) count. While (1) is examined in a well-specified case for the outcomes, (2) and (3) are studied in misspecified cases.
\subsubsection{Real-value Outcomes}
\label{sec:synthetic-data-real-value}

\textbf{Data.} Obtaining ground truth for evaluating causal inference algorithm is challenging. Thus, most methods are evaluated using synthetic or semi-synthetic datasets. In this experiment, we simulate the data with the following distributions:
\begin{align*}
    x_{ij} &\!\sim\! \mathsf{U}[-1,1], &\!\!y_i(0) &\!\sim\! \mathsf{N}(\lambda(b_0 \!+\! \mathbf{x}_i^\top \mathbf{b}_1), \sigma_0^2), \\[-0.1cm]
    w_i &\!\sim\! \mathsf{Bern}(\varphi(a_0 \!+\! \mathbf{x}_i^\top \mathbf{a}_1)), &\!\!y_i(1) &\!\sim\! \mathsf{N}(\lambda(c_0 \!+\! \mathbf{x}_i^\top \mathbf{c}_1), \sigma_1^2),
\end{align*}
where $\varphi(\cdot)$ is the sigmoid function, $\lambda(\cdot)$ is the softplus function, and $\mathbf{x}_i = [x_{i1},\!...,x_{id_x}]^\top \in \mathbb{R}^{d_x}$ with $d_x=20$. We simulate two synthetic datasets: DATA-1 and DATA-2. For DATA-1, the ground truth parameters are randomly set as follows: $\sigma_0=\sigma_1=1$,  $(a_0,b_0,c_0)=(0.6,0.9,2.0)$, $\mathbf{a}_1 \sim \mathsf{N}(\mathbf{0}, 2\cdot\mathbf{I}_{d_x})$, $\mathbf{b}_1 \sim \mathsf{N}(\mathbf{0}, 2\cdot\mathbf{I}_{d_x})$, $\mathbf{c}_1 \sim \mathsf{N}(\mathbf{1}, 2\cdot\mathbf{I}_{d_x})$. For DATA-2, we set $(b_0,c_0) = (6,30)$, $\mathbf{b}_1 \sim \mathsf{N}(10\cdot\mathbf{1}, 2\cdot\mathbf{I}_{d_x})$, $\mathbf{c}_1 \sim \mathsf{N}(15\cdot\mathbf{1}, 2\cdot\mathbf{I}_{d_x})$, and the other parameters are set similar to that of DATA-1. The purpose is to make two different scales of the outcomes for the two datasets. For each dataset, we simulate $10$ replications with $n = 5000$ records. We only keep $\{(y_i, w_i, \mathbf{x}_i)\}_{i=1}^n$ as the observed data, where $y_i = y_i(0)$ if $w_i=0$ and $y_i = y_i(1)$ if $w_i=1$. We divide the data into five sources, each consists of $n_\mathsf{s}=1000$ records. In each source, we use $50$ records for training, $450$ for testing and $400$ for validation. We report the evaluation metrics and their standard errors over the 10 replications. The parameters chosen for this simulation study satisfy Assumption~\ref{assumption:ignorability} since $y_i(0)$ and $y_i(1)$ are independent of $w_i$ given $\mathbf{x}_i$. Assumption~\ref{assumption:sutva} is respected as the treatment on an individual $i$ does not effect the outcome of another individual $j$ ($i\neq j$). Since we fixed the dimension of $\mathbf{x}_i$ and draw it from the same distribution, Assumption~\ref{assumption:share-covariates} is implicitly satisfied. Assumption~\ref{assumption:unique-ident} holds true since each record drawn from the above distributions is attributed to one individual. This means that there are no duplicates of individuals in more than one source. Assumption~\ref{assumption:homogeneous-heterogeneous} is also satisfied since we have divided the data equally from one dataset. \begin{figure}
\centering
    \includegraphics[width=0.43\textwidth]{figures/FedCI-analysis.pdf}
\caption{Federated inference analysis on DATA-1.}
\label{fig:fedci-alalysis}
\end{figure}
\begin{figure}
\centering
    \includegraphics[width=0.43\textwidth]{figures/FedCI-inter-dependency-analysis.pdf}
\caption{The impact of inter-dependency on DATA-1.}
\label{fig:fedci-inter-dependency-alalysis}
\end{figure}
\begin{table}[!ht]
\centering
\caption{Out-of-sample errors on DATA-1 where top-3 performances are highlighted in bold (lower is better). The dashes (---) in `$\mathsf{loc}$' and `$\mathsf{agg}$' indicate that the numbers are the same as those of `$\mathsf{com}$'.}
\label{tab:error-synthetic}
\setlength{\tabcolsep}{2.3pt}
\scriptsize
\begin{tabular}{@{}lcccccc@{}}
\toprule
\multirow{2}{*}{Method}                                           & \multicolumn{3}{c}{The error of ITE ($\sqrt{\epsilon_\text{PEHE}}$)} & \multicolumn{3}{c}{The error of ATE ( $\epsilon_\text{ATE}$)} \\ \cmidrule(lr){2-4}\cmidrule(lr){5-7} 
                                                                  & 1 source      & 3 sources     & 5 sources     & 1 source      & 3 sources     & 5 sources    \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
BART$_\mathsf{loc}$                                                       & ---    & 6.04$\pm$.05    & 6.02$\pm$.04    & ---    & 0.59$\pm$.14    & 0.53$\pm$.10   \\
X-Learner$_\mathsf{loc}$                                                  & ---    & 5.81$\pm$.13    & 5.77$\pm$.09    & ---    & 0.44$\pm$.24    & 0.51$\pm$.13   \\
R-Learner$_\mathsf{loc}$                                                  & ---    & 5.94$\pm$.05    & 5.94$\pm$.03    & ---    & 0.65$\pm$.05    & 0.66$\pm$.02   \\
OthoRF$_\mathsf{loc}$                                                     & ---   & 5.83$\pm$.12    & 6.23$\pm$.13    & ---    & \textbf{0.31$\pm$.08}    & 0.52$\pm$.10   \\
TARNet$_\mathsf{loc}$    & ---     & 4.25$\pm$.07     & 4.22$\pm$.06     & ---     & 0.85$\pm$.04     & 0.81$\pm$.02     \\ 
CFR Wass$_\mathsf{loc}$  & ---     & 4.10$\pm$.04 & 3.92$\pm$.03 & ---     & 0.81$\pm$.02     & 0.80$\pm$.02     \\ 
CFR MMD$_\mathsf{loc}$   & --- & 4.11$\pm$.06 & 3.93$\pm$.03 & ---     & 0.80$\pm$.03     & 0.79$\pm$.02 \\
CEVAE$_\mathsf{loc}$                                                     & ---   & 3.82$\pm$.09    & 3.50$\pm$.06    & ---    & 0.63$\pm$.11    & 0.52$\pm$.03   \\\cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
BART$_\mathsf{agg}$                                                 & ---             & 5.97$\pm$.05              & 5.94$\pm$.03              & ---             & 0.64$\pm$.14              & 0.47$\pm$.11             \\
X-Learner$_\mathsf{agg}$ & ---             & 5.18$\pm$.09    & 5.09$\pm$.05    & ---             & 0.46$\pm$.24    & 0.52$\pm$.13   \\
R-Learner$_\mathsf{agg}$ & ---             & 5.94$\pm$.05    & 5.93$\pm$.03    & ---             & 0.65$\pm$.05    & 0.66$\pm$.03   \\
OthoRF$_\mathsf{agg}$    & ---             & 4.19$\pm$.13    & 3.66$\pm$.08              & ---             & \textbf{0.36$\pm$.13}    & 0.48$\pm$.12             \\
TARNet$_\mathsf{agg}$    & ---     & 4.02$\pm$.04     & 4.00$\pm$.05     & ---     & 0.79$\pm$.04     & 0.77$\pm$.02     \\ 
CFR Wass$_\mathsf{agg}$  & ---     & 3.92$\pm$.03 & 3.75$\pm$.03 & ---     & 0.78$\pm$.03     & 0.76$\pm$.02     \\ 
CFR MMD$_\mathsf{agg}$   & --- & 4.01$\pm$.05 & 3.80$\pm$.02 & ---     & 0.78$\pm$.03     & 0.76$\pm$.02 \\
CEVAE$_\mathsf{agg}$                                                     & ---    & 3.65$\pm$.10    & 2.99$\pm$.06    & ---    & 0.41$\pm$.05    & 0.37$\pm$.04   \\\cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(l){5-7}
BART$_\mathsf{com}$                                                     & 5.98$\pm$.06             & 5.97$\pm$.06    & 5.93$\pm$.03    & 0.83$\pm$.11             & 0.56$\pm$.16    & 0.38$\pm$.09   \\
X-Learner$_\mathsf{com}$    & 5.48$\pm$.15             & 4.60$\pm$.09    & 4.15$\pm$.04    & 0.93$\pm$.22             & 0.60$\pm$.11    & \textbf{0.30$\pm$.07}   \\
R-Learner$_\mathsf{com}$     & 5.93$\pm$.06             & 5.73$\pm$.08    & 5.54$\pm$.06    & 0.78$\pm$.10             & 0.47$\pm$.09    & \textbf{0.30$\pm$.07}   \\
OthoRF$_\mathsf{com}$        & 5.86$\pm$.40             & \textbf{3.60$\pm$.12}    & \textbf{2.94$\pm$.05}    & \textbf{0.55$\pm$.14}             & 0.45$\pm$.14    & 0.34$\pm$.09   \\
TARNet$_\mathsf{com}$    & 3.93$\pm$.07     & 3.87$\pm$.05     & 3.80$\pm$.03     & 0.80$\pm$.04     & 0.77$\pm$.04     & 0.76$\pm$.02     \\ 
CFR Wass$_\mathsf{com}$  & \textbf{3.77$\pm$.05}     & 3.73$\pm$.04 & 3.71$\pm$.02 & 0.80$\pm$.04     & 0.75$\pm$.04     & 0.75$\pm$.02     \\ 
CFR MMD$_\mathsf{com}$   & 3.90$\pm$.06 & 3.73$\pm$.04 & 3.70$\pm$.02 & 0.82$\pm$.05     & 0.75$\pm$.04     & 0.75$\pm$    .02 \\
CEVAE$_\mathsf{com}$        & \textbf{3.79$\pm$.07}             & \textbf{2.85$\pm$.06}    & \textbf{2.72$\pm$.04}    & \textbf{0.51$\pm$.13}             & \textbf{0.23$\pm$.07}    & \textbf{0.20$\pm$.06}   \\\cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
FedCI                                                             & \textbf{3.71$\pm$.10}     & \textbf{2.35$\pm$.09}    & \textbf{1.99$\pm$.05}    & \textbf{0.69$\pm$.12}    & \textbf{0.31$\pm$.12}    & \textbf{0.29$\pm$.06}   \\ \bottomrule
\end{tabular}
\end{table}








\begin{table}\centering
\caption{Out-of-sample errors on DATA-2. Please see the full table in Appendix~\zref{sec:appendix-data-2}.
}
\label{tab:error-synthetic-2}
\setlength{\tabcolsep}{2.1pt}
\scriptsize
\begin{tabular}{@{}lcccccc@{}}
\toprule
\multirow{2}{*}{Method}                                           & \multicolumn{3}{c}{The error of ITE ($\sqrt{\epsilon_\text{PEHE}}$)} & \multicolumn{3}{c}{The error of ATE ( $\epsilon_\text{ATE}$)} \\ \cmidrule(lr){2-4}\cmidrule(lr){5-7} 
                                                                  & 1 source      & 3 sources     & 5 sources     & 1 source      & 3 sources     & 5 sources    \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
BART$_\mathsf{com}$                                                     & \textbf{18.0$\pm$0.4}             & \textbf{17.7$\pm$0.2}    & 17.4$\pm$0.1    & \textbf{3.54$\pm$1.3}             & 2.94$\pm$0.8    & \textbf{1.84$\pm$0.5}   \\
X-Learner$_\mathsf{com}$    & 21.1$\pm$0.9             & 17.9$\pm$0.4    & \textbf{16.2$\pm$0.2}    & 4.55$\pm$1.4             & 3.29$\pm$1.0    & 2.37$\pm$0.8   \\
R-Learner$_\mathsf{com}$     & 25.9$\pm$0.6             & 23.5$\pm$0.5    & 21.3$\pm$0.4    & 19.0$\pm$0.8             & 15.6$\pm$0.7    & 12.3$\pm$0.6   \\
OthoRF$_\mathsf{com}$        & 37.8$\pm$2.7             & \textbf{10.7$\pm$0.5}    & \textbf{9.83$\pm$0.5}    & 7.88$\pm$2.2             & \textbf{1.99$\pm$0.4}    & 2.36$\pm$0.6   \\
TARNet$_\mathsf{com}$    & 36.1$\pm$0.4     & 35.5$\pm$0.2     & 35.0$\pm$0.2     & 7.11$\pm$0.4     & 7.10$\pm$0.3     & 7.08$\pm$0.2     \\ 
CFR Wass$_\mathsf{com}$  & 35.1$\pm$0.4     & 34.5$\pm$0.2 & 34.1$\pm$0.2 & 7.10$\pm$0.4     & 7.01$\pm$0.3     & 6.90$\pm$0.2     \\ 
CFR MMD$_\mathsf{com}$   & 35.1$\pm$0.4 & 35.0$\pm$0.2 & 34.9$\pm$0.2 & 7.12$\pm$0.4     & 7.02$\pm$0.3     & 7.01$\pm$0.2     \\
CEVAE$_\mathsf{com}$        & \textbf{20.1$\pm$0.5}             & 18.4$\pm$0.6    & 16.6$\pm$0.6    & \textbf{1.50$\pm$0.3}             & \textbf{1.38$\pm$0.4}    & \textbf{1.89$\pm$0.2}   \\\cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
FedCI                                                             & \textbf{9.28$\pm$0.4}     & \textbf{6.34$\pm$0.2}    & \textbf{5.53$\pm$0.1}    & \textbf{2.37$\pm$0.5}    & \textbf{1.47$\pm$0.4}    & \textbf{0.74$\pm$.2}   \\ \bottomrule
\end{tabular}
\end{table}


\textbf{FedCI vs. training on combined data.} Figure~\ref{fig:fedci-alalysis} reports the three evaluation metrics of FedCI compared with two data source settings: training on combined data and training locally on each data source. As expected, the figures show that the errors of FedCI are as low as those of training on the combined data. This result verifies the efficacy of the proposed federated algorithm.


\textbf{Inter-dependency component analysis.} We study the impact of the inter-dependency component (see Section~\ref{sec:model}) by removing it from the model. Figure~\ref{fig:fedci-inter-dependency-alalysis} presents the errors of FedCI compared with `no inter-dependency' (FedCI without inter-dependency). The figures show that the errors in predicting ITE and ATE of `no inter-dependency' seem to be higher than those of FedCI. This result showcases the importance of our proposed inter-dependency component.


In Figure~\ref{fig:fedci-alalysis}, the error $\epsilon_\text{ATE}$ of FedCI increases as the number of sources increases from 1 to 2. In Figure~\ref{fig:fedci-inter-dependency-alalysis}, $\epsilon_\text{ATE}$ of FedCI is larger than that of without inter-dependency. These results might be due to the non-convex optimisation which could lead to a local minima. A potential direction to improve is to use a minibatch stochastic gradient descent for GPs \citep{chen2020stochastic}.



\textbf{Contrasting with existing baselines.} In this experiment, we compare FedCI with the existing causal inference methods. All these baseline methods do not consider estimating causal effects on multiple sources. Thus, we train them in three cases as explained earlier: \textbf{(1)} train locally ($\mathsf{loc}$), \textbf{(2)} train with combined data ($\mathsf{com}$), and \textbf{(3)} train with bootstrap aggregating ($\mathsf{agg}$). Note that case (\textbf{2}) violates constraint that data are stored at their local sites. We expect that the error of FedCI to be close to case \textbf{(2)} of the baselines. 
Table~\ref{tab:error-synthetic}~and~\ref{tab:error-synthetic-2} report the  performance of each method in estimating ATE and ITE. Regardless of different scales on the two synthetic datasets, the figure shows that FedCI achieves competitive results as compared with all the baselines. FedCI is in the top-3 performances among all the methods. Importantly, FedCI obtains lower errors than those of BART$_\mathsf{com}$, X-Learner$_\mathsf{com}$, R-Learner$_\mathsf{com}$, OthoRF$_\mathsf{com}$, TARNet$_\mathsf{com}$, CFR~Wass$_\mathsf{com}$, and CFR~MMD$_\mathsf{com}$, which were trained on combined data and thus violate constraint of federated data setting. 
Compared with CEVAE$_\mathsf{com}$, FedCI is better than this method in predicting ITE and comparable with this method in predicting ATE (slightly higher errors). However, we emphasize again that this result is expected since FedCI is a federated learning algorithm while CEVAE$_\mathsf{com}$ works directly on combined data. 


\textbf{The estimated distribution of ATE.} To analyse uncertainty, we present in  Figure~\ref{fig:fedci-uncertainty-analysis-synthetic} the estimated distribution of ATE in the first source ($\mathsf{s}=1$). The figures show that the true ATE is covered by the estimated interval and the estimated mean ATE shifts towards its true value (dotted lines) when more data sources are used. This result might provide useful information about the application in practice. 

\begin{figure}
\centering
    \includegraphics[width=0.45\textwidth]{figures/FedCI-uncertainty-analysis-synthetic.pdf}
\caption{
       Estimated distribution of ATE on source \#1 of DATA-2. The dotted black lines represent the true ATE.
    } \label{fig:fedci-uncertainty-analysis-synthetic}
\end{figure}



































\subsubsection{Misspecification Analysis: Binary and Count Outcomes}

\textbf{Data.} In this experiment, we analyse the performance when the model is misspecified. We compare FedCI with the baselines in two cases: binary outcomes and count outcomes. We reuse the ground truth distributions of $x_{ij}$ and $w_i$ as in Section~\ref{sec:synthetic-data-real-value}. For the outcomes, we simulate them with the following distributions:
\begin{align*}
    &\text{Binary outcomes:} &&y_i(0) \sim \mathsf{Bern}(\varphi(b_0 \!+\! \mathbf{x}_i^\top \mathbf{b}_1)),\\[-0.06cm]
   &&& y_i(1) \sim \mathsf{Bern}(\varphi(c_0 \!+\! \mathbf{x}_i^\top \mathbf{c}_1)).\\[-0.06cm]
    &\text{Count outcomes:} &&y_i(0) \sim \mathsf{Poisson}(\exp(b_0 \!+\! \mathbf{x}_i^\top \mathbf{b}_1)),\\[-0.06cm]
    &&&y_i(1) \sim \mathsf{Poisson}(\exp(c_0 \!+\! \mathbf{x}_i^\top \mathbf{c}_1)).
\end{align*}
\textbf{Results and discussion.} From Table~\ref{tab:error-synthetic-binary}~and~\ref{tab:error-synthetic-count}, FedCI gives competitive results compared with the baselines trained on combined data. The reason for the good performance for FedCI and some baselines in these misspecification cases is because they provide good estimates for the mean of the missing outcomes. This might in turn be due to the mean estimation of Gaussian distribution in FedCI coincides with the mean estimation of the other distributions.  
Nevertheless, since these are misspecified cases, the continuous posterior distribution is not a good estimation. To obtain better posterior distributions of the missing outcomes and the causal estimands, we would need to consider some other appropriate distributions in our model.

\begin{table}
\centering
\caption{Out-of-sample errors on binary outcomes data.}
\label{tab:error-synthetic-binary}
\setlength{\tabcolsep}{2.3pt}
\scriptsize
\begin{tabular}{@{}lllllll@{}}
\toprule
\multirow{2}{*}{Method}                                           & \multicolumn{3}{c}{The error of ITE ($\sqrt{\epsilon_\text{PEHE}}$)} & \multicolumn{3}{c}{The error of ATE ( $\epsilon_\text{ATE}$)} \\ \cmidrule(lr){2-4}\cmidrule(lr){5-7} 
           & 1 source       & 3 sources      & 5 sources      & 1 source       & 3 sources      & 5 sources      \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
BART$_\mathsf{com}$      & 0.77$\pm$.01     & 0.73$\pm$.01     & 0.70$\pm$.01     & 0.41$\pm$.01     & 0.31$\pm$.01     & 0.24$\pm$.01     \\ 
X-Learner$_\mathsf{com}$ & 0.69$\pm$.01     & 0.60$\pm$.01     & 0.56$\pm$.01     & \textbf{0.13$\pm$.03}     & 0.10$\pm$.02     & \textbf{0.09$\pm$.01}    \\ 
R-Learner$_\mathsf{com}$ & 0.65$\pm$.01     & 0.64$\pm$.01     & 0.62$\pm$.01     & \textbf{0.05$\pm$.01} & \textbf{0.03$\pm$.01} & \textbf{0.03$\pm$.01} \\ 
OthoRF$_\mathsf{com}$    & 0.94$\pm$.04     & 0.60$\pm$.01     & 0.56$\pm$.01     & 0.17$\pm$.03     & 0.18$\pm$.03     & 0.16$\pm$.03     \\ 
TARNet$_\mathsf{com}$    & 0.68$\pm$.02     & 0.68$\pm$.02     & 0.65$\pm$.01     & 0.33$\pm$.01     & 0.33$\pm$.01     & 0.32$\pm$.01     \\ 
CFR Wass$_\mathsf{com}$  & 0.61$\pm$.02     & \textbf{0.50$\pm$.01} & \textbf{0.50$\pm$.01} & 0.32$\pm$.01     & 0.30$\pm$.01     & 0.30$\pm$.01     \\ 
CFR MMD$_\mathsf{com}$   & \textbf{0.55$\pm$.01} & \textbf{0.50$\pm$.01} & \textbf{0.50$\pm$.01} & 0.32$\pm$.01     & 0.30$\pm$.01     & 0.30$\pm$.01     \\ 
CEVAE$_\mathsf{com}$     & \textbf{0.39$\pm$.01} & \textbf{0.37$\pm$.01} & \textbf{0.37$\pm$.01} & \textbf{0.08$\pm$.02} & \textbf{0.05$\pm$.01} & \textbf{0.05$\pm$.01} \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
FedCI                           & \textbf{0.41$\pm$.01}                      & \textbf{0.40$\pm$.01}                      & \textbf{0.39$\pm$.01}                      & \textbf{0.05$\pm$.01}                      & \textbf{0.04$\pm$.01}                      & \textbf{0.03$\pm$.01}                      \\ \bottomrule
\end{tabular}
\end{table}


\begin{table}
\centering
\caption{Out-of-sample errors on count outcomes data.
}
\label{tab:error-synthetic-count}
\setlength{\tabcolsep}{2.3pt}
\scriptsize
\begin{tabular}{@{}lllllll@{}}
\toprule
\multirow{2}{*}{Method}                                           & \multicolumn{3}{c}{The error of ITE ($\sqrt{\epsilon_\text{PEHE}}$)} & \multicolumn{3}{c}{The error of ATE ( $\epsilon_\text{ATE}$)} \\ \cmidrule(lr){2-4}\cmidrule(lr){5-7} 
           & 1 source       & 3 sources      & 5 sources      & 1 source       & 3 sources      & 5 sources      \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
BART$_\mathsf{com}$      & 6.30$\pm$.06     & 6.29$\pm$.04     & 6.26$\pm$.03     & 0.75$\pm$.14     & 0.59$\pm$.18     & 0.47$\pm$.13     \\ 
X-Learner$_\mathsf{com}$ & 6.10$\pm$.10     & 5.16$\pm$.06     & 4.72$\pm$.03     & 1.34$\pm$.29     & 0.63$\pm$.12     & 0.42$\pm$.08    \\ 
R-Learner$_\mathsf{com}$ & 6.27$\pm$.06     & 6.09$\pm$.05     & 5.89$\pm$.04     & 0.82$\pm$.13 & 0.66$\pm$.15 & 0.56$\pm$.10 \\ 
OthoRF$_\mathsf{com}$    & 6.02$\pm$.29     & 4.15$\pm$.06     & \textbf{3.74$\pm$.05}     & 0.75$\pm$.18     & 0.54$\pm$.17     & \textbf{0.41$\pm$.10}     \\ 
TARNet$_\mathsf{com}$    & 4.54$\pm$.14     & \textbf{3.98$\pm$.05}     & 3.80$\pm$.02     & 0.77$\pm$.10     & 0.66$\pm$.02     & 0.62$\pm$.03     \\ 
CFR Wass$_\mathsf{com}$  & \textbf{4.08$\pm$.04}     & 4.03$\pm$.03 & 3.78$\pm$.02 & 0.72$\pm$.04     & \textbf{0.51$\pm$.03}     & 0.50$\pm$.03     \\ 
CFR MMD$_\mathsf{com}$   & 4.15$\pm$.06 & 4.05$\pm$.04 & 3.77$\pm$.02 & \textbf{0.69$\pm$.07}     & 0.54$\pm$.03     & 0.50$\pm$.03     \\ 
CEVAE$_\mathsf{com}$     & \textbf{3.40$\pm$.09} & \textbf{3.31$\pm$.07} & \textbf{3.08$\pm$.05} & \textbf{0.56$\pm$.16} & \textbf{0.40$\pm$.12} & \textbf{0.35$\pm$.08} \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
FedCI                           & \textbf{4.02$\pm$.10}                      & \textbf{3.05$\pm$.08}                      & \textbf{2.66$\pm$.04}                      & \textbf{0.54$\pm$.09}                      & \textbf{0.48$\pm$.08}                      & \textbf{0.25$\pm$.05}                      \\ \bottomrule
\end{tabular}
\end{table}

\subsection{IHDP Data}
\label{sec:ihdp}
\textbf{Data.} The Infant Health and Development Program (IHDP) \citep{hill2011bayesian} is a dataset with 747 data points, each has 25 covariates. These data are obtained from a randomized study on the impact of specialist visits to children's cognitive development. Herein, specialist visit is the treatment and children's cognitive development is the outcome. We use the NPCI package \citep{dorie2016npci} to simulate two potential outcomes for the treatment (with or without specialist visit) of each child. Hence, the \textit{true} individual treatment effect can be computed for evaluation purposes. 
There are 10 replicates of the dataset, and each of them is  divided into three sources of size 249. For each source, we then split it into three equal sets for the purpose of training, testing, and validating the models. The mean and standard error of the aforementioned evaluation metrics are reported over the above 10 replicates of the data. 

\textbf{Results and discussion.} Similar to the experiment for synthetic datasets,  here we also train the baselines in three cases as explained earlier. We also expect that the errors of FedCI to be close to the baselines trained with combined data ($\mathsf{com}$). The results reported in Table~\ref{tab:error-ihdp} show that FedCI achieves competitive results compared to the baselines (we skipped the first and second cases ($\mathsf{loc}$ and $\mathsf{agg}$), please see Appendix~\zref{sec:appendix-ihdp} for the full table). Indeed, FedCI is in the top-3 performances among all the methods. This result again verifies that FedCI can be used to estimate causal effects effectively under privacy-perserving, federated data settings. The estimated distribution of ATE is presented in Appendix~\zref{sec:appendix-ihdp} due to limited space. 


\begin{table}\centering
\caption{Out-of-sample errors on IHDP dataset. Please see the full table in Appendix~\zref{sec:appendix-ihdp}.
}
\label{tab:error-ihdp}
\setlength{\tabcolsep}{2.7pt}
\scriptsize
\begin{tabular}{@{}lcccccc@{}}
\toprule
\multirow{2}{*}{Method}                                           & \multicolumn{3}{c}{The error of ITE ($\sqrt{\epsilon_\text{PEHE}}$)} & \multicolumn{3}{c}{The error of ATE ( $\epsilon_\text{ATE}$)} \\ \cmidrule(lr){2-4}\cmidrule(lr){5-7} 
                                                                  & 1 source      & 2 sources     & 3 sources     & 1 source      & 2 sources     & 3 sources    \\ \cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
BART$_\mathsf{com}$                                                 & 5.98$\pm$2.7             & 4.32$\pm$2.1              & 4.04$\pm$2.0              & 1.80$\pm$1.1             & 2.09$\pm$1.1              & 1.21$\pm$0.6             \\
X-Learner$_\mathsf{com}$ & 4.22$\pm$1.6             & 4.15$\pm$1.5    & 4.06$\pm$1.8    & 1.64$\pm$0.7             & 1.93$\pm$0.8    & 0.84$\pm$0.4   \\
R-Learner$_\mathsf{com}$ & 6.97$\pm$2.1             & 4.43$\pm$1.4    & 4.47$\pm$1.7    & 3.15$\pm$0.5             & 1.34$\pm$0.5    & 1.10$\pm$0.3   \\
OthoRF$_\mathsf{com}$    & 4.49$\pm$1.9             & 3.81$\pm$1.3    & 3.75$\pm$1.5              & 1.86$\pm$0.8             & 1.61$\pm$0.6    & 1.56$\pm$0.8\\ 
TARNet$_\mathsf{com}$    & 4.50$\pm$1.4     & 3.15$\pm$0.8     & 3.79$\pm$1.1     & \textbf{1.52$\pm$0.5}     & 1.18$\pm$0.4     & 0.91$\pm$0.3     \\ 
CFR Wass$_\mathsf{com}$  & \textbf{4.37$\pm$1.2}     & 2.93$\pm$0.6 & 2.85$\pm$0.9 & \textbf{1.18$\pm$0.7}     & \textbf{0.72$\pm$0.2}     & 0.67$\pm$0.1     \\ 
CFR MMD$_\mathsf{com}$   & 4.43$\pm$1.3 & \textbf{2.85$\pm$0.6} & \textbf{2.83$\pm$1.1} & 2.32$\pm$0.8     & \textbf{0.63$\pm$0.2}     & \textbf{0.54$\pm$0.2}     \\
CEVAE$_\mathsf{com}$        & \textbf{3.16$\pm$0.6}             & \textbf{2.34$\pm$0.6}    & \textbf{2.31$\pm$0.7}    & 2.02$\pm$0.4             & \textbf{0.53$\pm$0.1}    & \textbf{0.48$\pm$0.2}   \\\cmidrule(r){1-1}\cmidrule(lr){2-4}\cmidrule(lr){5-7}
FedCI                                                             & \textbf{2.88$\pm$0.8}     & \textbf{2.36$\pm$0.5}    & \textbf{2.35$\pm$0.6}    & \textbf{1.43$\pm$0.7}    & 1.03$\pm$0.4    & \textbf{0.51$\pm$0.2}   \\ \bottomrule
\end{tabular}
\end{table}











































 

 

\section{Conclusion}
\label{sec:conclusion}
We have introduced FedCI, a Bayesian causal inference paradigm via a reformulation of multi-output GPs to learn causal effects, while keeping data at their local sites. An inference method involving the decomposition of ELBO is presented, allowing the model to be trained in a federated setting. 

This work is an important step towards a privacy-preserving causal learning model. One interesting future research direction is to combine FedCI with differential privacy to give a stronger privacy guarantee. 
This new direction would require adding an appropriate noise component, such as Laplace noise, to the parameters while training the model. 
Our future research would also involve further exploration on combining FedCI with differential privacy for Gaussian processes \citep{smith2018differentially} and multiparty differential privacy algorithms \citep{pathak2010multiparty,rajkumar2012differentially,hamm2016learning}. 

The inherent use of GPs in our approach would  incur computational time of inverse covariance matrix in each source of cubic time complexity. Another possible future research direction is to reformulate this in terms of sparse Gaussian process models \citep{Hensman13,hoang2017generalized,hoang2020revisiting}. 

 




























































\begin{acknowledgements} 
   This research/project is supported by the National Research Foundation Singapore and DSO National Laboratories under the AI Singapore Programme (AISG Award No: AISG2-RP-2020-016).
\end{acknowledgements}

\bibliography{ref}
























































































\end{document}
