%%%%%%%% ICML 2025 EXAMPLE LATEX SUBMISSION FILE %%%%%%%%%%%%%%%%%

\documentclass{article}

% Recommended, but optional, packages for figures and better typesetting:
\usepackage{microtype}
\usepackage{graphicx}
% \usepackage{subfigure}
\usepackage{booktabs} % for professional tables
 \usepackage{caption}
 \usepackage{subcaption}

% hyperref makes hyperlinks in the resulting PDF.
% If your build breaks (sometimes temporarily if a hyperlink spans a page)
% please comment out the following usepackage line and replace
% \usepackage{icml2025} with \usepackage[nohyperref]{icml2025} above.
\usepackage{hyperref}


% Attempt to make hyperref and algorithmic work together better:
\newcommand{\theHalgorithm}{\arabic{algorithm}}
\usepackage[noend]{algorithmic}

% Use the following line for the initial blind version submitted for review:
\usepackage{icml2025}

% If accepted, instead use the following line for the camera-ready submission:
% \usepackage[accepted]{icml2025}

% For theorems and such
\usepackage{amsmath}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{amsthm}

% if you use cleveref..
\usepackage[capitalize,noabbrev]{cleveref}

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% THEOREMS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\theoremstyle{plain}
\newtheorem{theorem}{Theorem}[section]
\newtheorem{proposition}[theorem]{Proposition}
\newtheorem{lemma}[theorem]{Lemma}
\newtheorem{corollary}[theorem]{Corollary}
\theoremstyle{definition}
\newtheorem{definition}[theorem]{Definition}
\newtheorem{assumption}[theorem]{Assumption}
\theoremstyle{remark}
\newtheorem{remark}[theorem]{Remark}

% Todonotes is useful during development; simply uncomment the next line
%    and comment out the line below the next line to turn off comments
%\usepackage[disable,textsize=tiny]{todonotes}
\usepackage[textsize=tiny]{todonotes}
\usepackage{longtable} % For multi-page tables, if needed



\newcommand{\indep}{\perp \!\!\! \perp}
\newcommand{\blue}{\textcolor{blue}}
\newcommand{\red}{\textcolor{red}}
\newcommand{\orng}{\textcolor{orange}}
\newcommand{\mk}[1]{{\color{red} MK: \{#1\}}}
\newcommand{\mr}{\textcolor{orange}}
\newcommand{\gray}{\textcolor{gray}}
\newcommand{\lgray}{\textcolor{lightgray}}

\newcommand{\h}{\mathcal{H}}
\newcommand{\x}{\mathbf{x}}
\newcommand{\y}{\mathbf{y}}

\newcommand{\Oc}{\mathcal{O}}
\newcommand{\Xc}{\mathcal{X}}
\newcommand{\Y}{\mathbf{Y}}
\newcommand{\V}{\mathbf{V}}

\newcommand{\vb}{\mathbf{v}}

\newcommand{\mbf}{\mathbf}
\newcommand{\data}{\mathcal{D}}
\newcommand{\M}{\mathcal{M}}
\newcommand{\F}{\mathcal{F}}
\newcommand{\Vc}{\mathcal{V}'}
\newcommand{\X}{\mathcal{X}}
\newcommand{\T}{\mathcal{T}}

\newcommand{\Do}{\text{do}}
\newcommand{\fedcm}{\text{FeDCM}\xspace}
\newcommand{\ydox}{P_{\mathbf{x}}(\mathbf{y})\xspace}
\newcommand{\mP}{\hat{P}_{\theta}}
\newcommand{\iP}{\hat{Q}}



\usepackage{tikz}
\usetikzlibrary{positioning, calc, shapes.geometric, shapes, shapes.multipart, arrows.meta, arrows, decorations.markings, external, trees}
\usetikzlibrary{backgrounds,automata}
\usetikzlibrary{backgrounds}
\usepackage{scalefnt}
\usetikzlibrary{shapes.misc}
\usetikzlibrary{positioning, calc, shapes.geometric, shapes, shapes.multipart, arrows.meta, arrows, decorations.markings, external, trees, fit}
\tikzset{
	-Latex,auto,node distance =1 cm and 1 cm,semithick,
	state/.style ={ellipse, draw, minimum width = 0.7 cm},
	point/.style = {circle, draw, inner sep=0.04cm,fill,node contents={}},
  	nnh/.style={
		 rectangle, draw,thick,minimum width=1.5cm,minimum height=1.0cm
	},
	 nnv/.style={
  % circle,
    rectangle, draw, very thick, fill=gray!28, inner sep=0.04cm, minimum width=1.2cm, minimum height=1.2cm, rounded corners=0.05cm
  },
   nnvsm/.style={
    rectangle, draw, very thick, fill=gray!28, inner sep=0.0cm, minimum width=1.0cm, minimum height=1.0cm, rounded corners=0.05cm
  },
   outer/.style={ inner sep=3pt, fill=blue!15
  },
  outer1/.style={ inner sep=3pt, fill=green!15
  },
  outer1t/.style={ inner sep=0pt, fill=green!15
  },
  louter/.style={ inner sep=5pt, fill=blue!15
  },
	XOR/.style={draw,circle,append after command={
			[shorten >=\pgflinewidth, shorten <=\pgflinewidth,]
			(\tikzlastnode.north) edge (\tikzlastnode.south)
			(\tikzlastnode.east) edge (\tikzlastnode.west)
		}
	},
	bidirected/.style={Latex-Latex,dashed},
	el/.style = {inner sep=2pt, align=left, sloped},
	cross/.style={cross out, draw=black, minimum size=2*(#1-\pgflinewidth), inner sep=0pt, outer sep=0pt},
	%default radius will be 1pt. 
	cross/.default={1pt}
}





% The \icmltitle you define below is probably too long as a header.
% Therefore, a short form for the running title is supplied here:
\icmltitlerunning{Submission and Formatting Instructions for ICML 2025}

\begin{document}

\twocolumn[
\icmltitle{Federated DCM}

% It is OKAY to include author information, even for blind
% submissions: the style file will automatically remove it for you
% unless you've provided the [accepted] option to the icml2025
% package.

% List of affiliations: The first argument should be a (short)
% identifier you will use later to specify author affiliations
% Academic affiliations should list Department, University, City, Region, Country
% Industry affiliations should list Company, City, Region, Country

% You can specify symbols, otherwise they are numbered in order.
% Ideally, you should not use this facility. Affiliations will be numbered
% in order of appearance and this is the preferred way.
\icmlsetsymbol{equal}{*}

\begin{icmlauthorlist}
\icmlauthor{Firstname1 Lastname1}{equal,yyy}
\icmlauthor{Firstname2 Lastname2}{equal,yyy,comp}
\icmlauthor{Firstname3 Lastname3}{comp}
\icmlauthor{Firstname4 Lastname4}{sch}
\icmlauthor{Firstname5 Lastname5}{yyy}
\icmlauthor{Firstname6 Lastname6}{sch,yyy,comp}
\icmlauthor{Firstname7 Lastname7}{comp}
%\icmlauthor{}{sch}
\icmlauthor{Firstname8 Lastname8}{sch}
\icmlauthor{Firstname8 Lastname8}{yyy,comp}
%\icmlauthor{}{sch}
%\icmlauthor{}{sch}
\end{icmlauthorlist}

\icmlaffiliation{yyy}{Department of XXX, University of YYY, Location, Country}
\icmlaffiliation{comp}{Company Name, Location, Country}
\icmlaffiliation{sch}{School of ZZZ, Institute of WWW, Location, Country}

\icmlcorrespondingauthor{Firstname1 Lastname1}{first1.last1@xxx.edu}
\icmlcorrespondingauthor{Firstname2 Lastname2}{first2.last2@www.uk}

% You may provide any keywords that you
% find helpful for describing your paper; these are used to populate
% the "keywords" metadata in the PDF but will not be shown in the document
\icmlkeywords{Machine Learning, ICML}

\vskip 0.3in
]

% this must go after the closing bracket ] following \twocolumn[ ...

% This command actually creates the footnote in the first column
% listing the affiliations and the copyright notice.
% The command takes one argument, which is text to display at the start of the footnote.
% The \icmlEqualContribution command is standard text for equal contribution.
% Remove it (just {}) if you do not need this facility.

%\printAffiliationsAndNotice{}  % leave blank if no need to mention equal contribution
\printAffiliationsAndNotice{\icmlEqualContribution} % otherwise use the standard text.

\begin{abstract}
Causal inference in a federated learning setup is a largely unexplored research area. Existing work focus on learning a particular model from the conditional distributions available in observational training data. Such predictions remain vulnerable to spurious correlation and perform worse in new domain. Since causal mechanisms stay invariant across domains, in this paper, we aim to approximate the structural causal model with deep generative models utilizing decentralized observational data sources. 
To specify the heterogeneous mechanisms across clients, we represent the non-iid data setup in FL as the selection bias problem in causal inference. Next, we define a neighborhood around it based on the causal graph and train the neighborhood mechanisms globally in a federated fashion.
For the rest of the SCM mechanisms, we train them in individual clients using only local data.
We perform extensive experiments on synthetic and real-world setups to illustrate the utility and performance of our approach. Finally, we map an existing few-shot federated learning algorithms to a causal problem and improve their performance by re-designing their architecture.
\end{abstract}



\section{Introduction}



% \blue{There is pain:}

% What happens if we dont have causality 

% What happens if we dont have deep causal generative model.

% What happens if we dont have efficient training of deep-scm.

% \blue{Why do we need causal model in federated learning?:}
% \red{Do we need FL for DCM training or need DCM for federated training? Or why do we need DCM in federated learning.}
\par 
Federate learning (FL) is an important approach to learn mechanisms in different distributed clients that can not be learned in isolation due to data scarcity. Even though many FL algorithms only learn a single mechanism: $f: \mathbf{X} \rightarrow \mathbf{y}$ (ex: image classification) as a conditional distribution $P(\mathbf{y}| \mathbf{X})$, such prediction might be susceptible to domain specific bias/spurious correlation. As a result, when the clients are deployed in a test domain, their performance might deteriorate.

In many real-world scenarios, the domain-invariant relation between the features and the target variable is not mere conditional distribution; rather they are parts of a causal system where variables connected through multiple causal mechanisms. Such system can be represented as a structural causal model (SCM). Suppose, we want to predict age/attractiveness features (classification) present in generated images of specific sex class (image generation) for fairness evaluation. Client's local data might contain spurious correlation between sex and age. Causal mechanisms in such system are [sex$\rightarrow$ images $\rightarrow$ predicted age; sex $\leftrightarrow$ predicted age]. Learning the causal mechanisms allows us to utilize the causal attributes to obtain invariant prediction while ignoring irrelevant non-causal correlation. These properties become useful for federated learning setup particularly with feature distribution shift and data scarcity. 
%
% \red{everyone learns a conditional distribution in FL which is domain specific. Need to learn the causal model.}
%
% \item What does it mean to have Trivial FL for causal models? Discuss its computation complexity and challenges for that.

% --------------
Learning the structural causal model from data will offer us the invariant predictions.  Researchers~\cite{kocaoglu2018causalgan, pawlowski2020deep,xia2021causal, zhang2021treatment, rahman2024modular} have employed deep generative models to learn structural causal models with or without unobserved confounders, with arbitrary causal graphs and for low or high-dimensional datasets. Such neural networks based architecture to learn the SCM is known as deep causal generative models (DCM). The core idea is to arrange neural network architectures mimicking the causal structure and perform  adversarial training to match the joint distribution implied by DCM with the real distribution of the system. After training, these method can be used to obtain samples from implicitly modeled identifiable interventional distributions. 


% --------------


In a federated setup, the trivial approach to learn the DCM would be executing FedAVG~\cite{mcmahan2017communication} for all mechanisms of the causal model~\cite{vo2022adaptive}. However, a unified causal model across all clients is infeasible for two major reasons.
%
 % \blue{Why do we need different causal models at different clients} 
 %
 Firstly, a single global causal model might not perform well due to differences in clients~\cite{9766407}. Each client might have partially similar but mostly different concepts in their causal models. Similar to personalized federated learning (PFL) methods, we can  train personalized causal models according to each client's 
individual needs. Also, in federated learning, we might have local clients interested in different tasks such as classification, segmentation, pose estimation etc for the same image distribution~\cite{zhuang2024coala}, i.e., some causal mechanisms stay same across clients and some mechanisms are client specific. Performance across various tasks in different clients is preferred.
%
% \blue{Why full dcm is not feasible and modular-dcm required?}
%
Secondly,
FedAVG would suggest sending gradients of all neural networks that were used to learn the client's causal model, to the global server. Such communication overhead is infeasible for clients such as edge devices with limited compute and memory.  Thus, we do not want to transfer all causal mechanisms between server and clients.

In this paper, we aim to resolve this issue, by training the computationally expensive or client-heterogeneous mechanisms (ex: image generation/next token prediction) collaboratively and learning the rest of the mechanisms locally at each client. Since these mechanisms are computationally less expensive and specific to an individual client, local data is sufficient to learn them. 
However, it is not clear from existing work if we can learn any arbitrary set of mechanisms of the causal model with FedAVG and piece them together with rest of the locally trained mechanisms for performing causal inference. 

% \blue{When federated learning requires two model training and when one model is sufficient.} 
Besides learning a single mechanism such as an image classifier, many approaches learn multiple mechanisms where they manually design their architectures. 
For example, if we wish to utilize a foundation model's generalization for our private dataset, Low Rank Adaptation based methods (LoRA) offers an approach of fine-tuning by training two modules (low-rank trainable matrices) with local data. For federated learning setup with clients having different data heterogeneity, \cite{  wang2024flora, bai2024federated, yang2024dual} suggests aggregation of the two LoRA modules at the server that were trained at local clients. Such fine-tuned models are later utilized to perform down-stream tasks such as dealing with client-specific personalization and test-time distribution shifts. 
\cite{hu2024fissionvae} train encoder and decoder and \cite{cao2022perfed} propose training generator-discriminator jointly with all clients for image generation and classification given non-IID data environments.
\cite{jothimurugesan2023federated} 
proposes a method for adapting to distribution drift of multiple concepts by employing one model per concept where clients collaboratively trains these models. These existing works justify our proposed method where we train a set of models globally and rest of the models locally.

% \red{discussion on confounders and neural causal models is absent}

\red{I am losing causality people here. Also, Addressed problem is not very clear. Why non-causal people should not care about this? This granular approach is better for something. What does it buy me to average ML person.  Target problem:  Causal effect estimation. For that purpose, we need to learn some dcm.}
These above works raise two questions: i) given (possibly different) causal models of all clients, which mechanisms are feasible to be trained collaboratively while learning rest of the mechanisms locally. ii) given the opportunity to train a specific set of models collaboratively (puzzle blocks), how can we fit them in rest of the client specific causal model (jigsaw puzzle). 
In this paper, we answer these questions based on possible modularization of the causal model. Instead of trying to match the whole joint distribution $P(V)$ with DCM, we utilize graph \red{c-components} to factorize the joint into multiple c-factors. Next, we train the DCM to learn each c-factor (\red{what is DCM,c-component- why does these matter}). This c-component based modularity provide us with the flexibility to train heterogeneous mechanisms globally using FL and rest of the mechanisms locally reducing the communication cost. To our knowledge, we are the first to propose an effecient approach to learn client-specific structural causal model in a federated learning setup.
Precisely, our contributions are:
\begin{itemize}
    \item Contribution 1
    \item Contribution 2
    \item Experimental contribution
\end{itemize}


% No cure       

% \blue{What is neural/deep causal models? What is modular-dcm? How does that become useful here?}


% \blue{What problem are you trying to solve? What is your proposed solution?}



\section{Background and Problem Description}

\begin{definition}[Structural causal model (SCM)~\citep{pearl2009causality}]
% \textbf{Definition 1} (Structural causal model, (SCM)).
An SCM $\mathcal{M}$ is a $5$-tuple 
$ \mathcal{M}=(\mathcal{V}, \mathcal{N}, \mathcal{U}, \mathcal{F}, P(.) )$, where each observed variable $V_i\in\mathcal{V}$ is realized as an evaluation of the function $f_i\in\mathcal{F}$ which looks at a subset of the remaining observed variables $Pa_i\subset \mathcal{V}$, an unobserved exogenous noise variable $E_i\in \mathcal{N}$, and an unobserved confounding (latent) variable $U_i\in\mathcal{U}$. 
This refers to the \textbf{semi-Markovian causal model}.
$P(.)$ is a product joint distribution over all unobserved variables $\mathcal{N}\cup\mathcal{U}$. 
\end{definition}


\begin{definition}[Acyclic Directed Mixed Graph (ADMG)]
Each SCM induces a directed graph called the \emph{causal graph},
or acyclic directed mixed graph (ADMG)
with $\mathcal{V}$ as the vertex set. The directed edges are determined by which variables directly affect which other variable by appearing explicitly in that variable's function. Thus the causal graph is $G=(V,E)$ where $V_i\rightarrow V_j$ iff $V_i\in Pa_j$. The set $Pa_j$ is called the parent set of $V_j$. We assume this directed graph is acyclic (DAG). Under the semi-Markovian assumption, each unobserved confounder can appear in the equation of exactly two observed variables. We represent the existence of an unobserved confounder between $X,Y$ in the SCM by adding a bidirected edge $X\leftrightarrow Y$ to the causal graph. These graphs are no longer DAGs although still acyclic. $V_i$ is called an ancestor for $V_j$ if there is a directed path from $V_i$ to $V_j$. Then $V_j$ is said to be a descendant of $V_i$. The set of ancestors of $V_i$ in graph $G$ is shown by $An_G(V_i)$.

Given an ADMG $G$, a maximal subset of nodes where any two nodes are connected by  bidirected paths is called a \textbf{c-component} $C(G)$. For any $S\in C(G)$, $P(S|\Do(\V\setminus S))$ is called a c-factor. We assume that we have access to the ADMG through some causal structure learning algorithm and expert knowledge.
\end{definition}

\begin{definition}[Causal effect and do-intervention]
% \textbf{Causal effect, Layer 1, Layer 2:}
A do-intervention $do(v_i)$ replaces the functional equation of $V_i$ with $V_i=v_i$ without affecting other equations. The distribution induced on the observed variables after such an intervention is called an interventional distribution, shown by $P_{v_i}(\mathcal{V})$. $P_{\emptyset}(\mathcal{V})=P(\mathcal{V})$ is called the observational distribution. 
\end{definition}

\begin{definition}[Deep causal generative models (DCM)~\cite{rahman2024modular}]
\label{def:scm}
	A neural net architecture $\mathbb{G}$ is called a deep causal generative model (DCM) for an ADMG $G=(\mathcal{V},\mathcal{E})$ if it is composed of a collection of neural nets, one  $\mathbb{G}_i$ for each $V_i\in\mathcal{V}$ such that 
		i) \emph{each $\mathbb{G}_i$ accepts a sufficiently high-dimensional noise vector $N_i$,} 
		ii) \emph{the output of $\mathbb{G}_j$ is input to $\mathbb{G}_i$ iff $V_j\in Pa_G(V_i)$,}
		iii) \emph{$N_i=N_j$ iff $V_i\leftrightarrow V_j$. }
\end{definition}
We define $\mP(.)$ as the distribution induced by the DCM. Noise vectors $N_i$ replace both the exogenous noises and the unobserved confounders in the true SCM. They are of sufficiently high dimension to induce the observed distribution. We say that a DCM is \emph{representative enough for an SCM} if the neural networks have sufficiently many parameters to induce the observed distribution induced by the SCM. 
For the neural architectures of variables in the same c-component, we can consider conditional GANs~\citep{mirza2014conditional}, as they are effective in matching the joint distribution by feeding the same prior noise $N_i=N_j$ (as confounders) into multiple generators. For variables that are not confounded ($N_i\neq N_j$), we can use conditional models such as diffusion models~\cite{ho2022classifier}.
With Defintion\ref{def:scm}, we have the following, similar to \cite{xia2021causal}:
%%%%%%%%%%%%% DCM theoretical guarnatee .%%%%%%%%%%%%%
\begin{theorem}
\cite{xia2021causal, rahman2024modular}
	\label{th:identifiability}
	Consider any SCM $\mathcal{M}=(G, \mathcal{N}, \mathcal{U}, \mathcal{F}, P(.) )$.  A DCM $\mathbb{G}$ for $G$ entails the same identifiable interventional distributions as the SCM $\mathcal{M}$ if it entails the same observational distribution.  
	\end{theorem}

Thus, even with high-dimensional variables in the true SCM, given a causal graph, in principle, any identifiable interventional query can be sampled from, with a DCM that fits the observational distribution.


\begin{definition}[Interventional sampling with DCM]
\label{def:dcm-sample}
Given that variables {$\V$} are connected as a directed acyclic graph
and we have diffusion models trained to learn the distributions $P(v_i|pa(v_i))$, we can perform \textbf{ancestral sampling} from the joint distribution, $P(\vb) = \prod_{V_i\in \V} P(v_i|pa(v_i))$ by making one pass through each 
model in the topological order while sampling from the conditional distributions~\citep{bishop2006pattern}.
\end{definition}

\textbf{Federated Learning (FL):}
We consider a federated learning setting where $C$ clients participate at each round of a training process coordinated at a central server. Training data is decentralized where for each client dataset $D^c$ is sampled from a joint distribution $P^c(.)$. For a two variable case, each sample in $D^c$ is denoted as $(x,y)\in \mathcal{X}_1 \times \mathcal{X}_2$ with $\mathcal{X}_1, \mathcal{X}_2$ being support of input $X$ and output $Y$.
Clients collaboratively train a mechanism: $F(\theta, x ): \mathcal{X}_1 \rightarrow \mathcal{X}_2$ to learn the conditional distribution $P(y|x)$. The global optimization problem is designed as~\cite{mcmahan2017communication, tang2024fusefl}: 
\begin{equation}
\label{eq:fed1mech}
\min_{\theta} L(\theta) = \sum_{c=1}^{C}  L_c(\theta) 
= \sum_{c=1}^{C} \mathbb{E}_{(x, y) \sim \mathcal{D}_c} l(F(\theta,x), y)
\end{equation}
In the classic FedAvg~\cite{mcmahan2017communication} algorithm, the central server samples a subset of $C$ clients and broadcast the mechanism parameters $\theta^t$ to those clients during round $t$. After performing local gradient updates, these clients return optimized mechanism $\theta^t_c$ to the server. The central server aggregates \& averages the local model to obtain a global model. In this paper, we consider a more general case where each client has $V^c$ variables.


% Given a learned SCM and its ADMG $G$, what is the most efficient way to transfer to a new SCM where the new SCM might have i) change in its conditional distributions. ii) change in the edges iii) change in the number of variables.

\textbf{Problem setup:}
\red{Addressed problem is not very clear.}
For simplicity, we assume all clients share the same underlying SCM containing same set of variables and mechanisms. We relax this assumption later. 

Dataset $\{D^c\}_{i=1}^{n_c}$ is collected from its environment with joint distribution $P^c(v)$ which is assumed to be generated from an unknown SCM $\M^{*}$. Due to data heterogeneity, we assume that there exists a heterogenous variable $\Vc$ such that for any client $i$, client $j$: $P^{i}(v') \neq P^{j}(v')$ We relax for arbitrary mechanism set later.

 % \cite{vo2022adaptive}
Our task is to i) Federated challenge: learn a proxy of the true SCM $\M^{*}$ with a globally learned $\hat \M$ without exchanging any client data \red{st joint matches,} ii) Causal challenge: estimate any causal effect or sample from any interventional distribution $P(y|do(x))$ where $X,Y \in V$. Formally, in any $c\in C$, for arbitrary $X,Y \in V$,
\begin{equation}
    \min_{\hat \M}  d(P(y|do(x)), \hat{P}(y|do(x))) 
\end{equation}



\textbf{Challenges of applying DCM in FL:}
\begin{figure}[t!]
    \centering
%
\begin{subfigure}[c]{0.45\linewidth}
  \centering
\begin{tikzpicture}[scale=0.7, transform shape]
    %\tikzstyle{every node}=[font=\tiny]
    \tikzstyle{every node}=[]
    \node   [ ] (x) {$W$};
    \node [below =0.8cm of x] (w1) {$Z$};
    \node [below left =0.3cm and 0.4 of x] (y1) {$Y_1$};
    \node [right =1cm of x] (w2) {$X$};
    \node [below =0.8cm of w2] (y) {$Y_2$};
    \draw[ thick] (x) to  (w1); 
    \draw[ thick] (x) to  (y1); 
    \draw[ thick] (w1) to  (y1); 
    \draw[ thick] (w1) to  (w2); 
    \draw[ thick] (w2) to  (y); 
    \path[bidirected] (x) edge[bend left=35] (w2);
    \path[bidirected] (w1) edge[bend right=35] (y);
\end{tikzpicture}
\caption{Graph example}
\label{fig:fail-ex}
\end{subfigure}
\caption{Causal graphs}
\end{figure}

Given the causal graph $G(V)$, we have $\{f_i\}_{i=1}^{|V|}$ mechanisms in the causal model. 
Our task requires us to learn the true joint distribution $P(v)$ by training local models on client data. 

If the heterogeneous variable $\Vc$ is not caused by any unobserved confounder, such as $\Vc= {Y_1}$ in Figure~\ref{fig:fail-ex}, then we can use equation~\ref{eq:fed1mech} to learn $f_{Y_1}$ globally and train $\{f_i\}_{v_i\in V\setminus \{\Vc\}}$ locally. This way we can match $P(V)$ and learn the full SCM.

However, we are dealing with a more general case: Non-Markovian causal model where an unobserved shared parent $U$ might cause both $\Vc$ and some other $W\in V$, i.e., $\Vc \leftarrow U \rightarrow W$ (also represented as $\Vc \leftrightarrow W$). In such case, two trivial solutions might come to our mind, and here we show how they fail.


\textbf{Trivial Solution 1:} Since $P(v')$ changes across clients due to data heterogeneity, we might want to train only the mechanism  $f_{\Vc}$ globally according to equation~\ref{eq:fed1mech} and train $\{f_i\}_{v_i\in V\setminus \{\Vc\}}$ locally.  If we have bi-directed edge $\Vc \leftrightarrow W$ in the causal graph, $\Vc$ and $W$ do not cause each other, but they share a joint $P(v', w)$ that must be matched to be consistent with the full joint $P(V)$. As a result, both $f_{\Vc}$ and $f_{W}$ need to be trained together with the same confounding noise $U$. Training $f_{\Vc}$ globally and $f_{W}$ locally would not allow feeding the same confounding noise and matching $P(v', w)$. In Figure~\ref{fig:fail-ex}, if $\Vc=Y_2$, we can not train $f_{Y_2}$ globally and $f_{Z}$ locally as we have to feed the same confounding noise to both models.

\textbf{Trivial Solution 2:}
We might want to train all mechanism of the SCM: $\{f_i\}_{v_i\in V}$ globally. That would ensure that $P(v)$ is matched. To obtain that, we can minimize the following loss function:
\begin{equation}
\label{eq:triv-sol2}
\begin{split}
\min_{\theta} L(\theta)  = \sum_{c=1}^{C}  L_c(\theta) 
& = \sum_{c=1}^{C} \sum_{v\in V} \mathbb{E}_{[pa(v),v] \sim \mathcal{D}_c} l( \hat{v}, v)\\
   & \hat{v} =  f_i(\theta_i, pa(v), U_i )
\end{split}
\end{equation}

However, this will be computationally expensive and unnecessary since clients only differ at $P(v')$. In this paper, we aim to learn a minimum mechanism set $\F$ globally such that 
$f_{\Vc} \subseteq \F$ allows us to match $P(v')= \hat{P}(v')$.

\textbf{Valid Solution:}
We can obtain a valid solution for Figure~\ref{fig:fail-ex} by considering the c-components. We train $\{f_Z, f_{Y_2}\}$ globally and $\{f_{W}, f_{X}, f_{Y_1}\}$ locally.  Below, we generalize the idea of federated learning based on c-components.

% containing our expensive mechanism $f_m$ (i.e.,  $f_m \in \mathcal{F}$) that we can train globally while training rest of the mechanism locally.

% Our objective is to approximate the true SCM with $\hat \M^c$ such that $P^{c}(v)= \hat{P}^{c}(v)$ with the same causal graph $G^c$.




% \begin{figure}
%     \centering
%     \includegraphics[width=1\linewidth]{Figures/Mod2/fed.pdf}
%     \caption{Enter Caption}
%     \label{fig:enter-label}
% \end{figure}

% \begin{figure}
%     \centering
%     \includegraphics[width=1.1\linewidth]{Figures/Mod2/root.pdf}
%     \caption{Enter Caption}
%     \label{fig:enter-label}
% \end{figure}

% \begin{figure}
%     \centering
%     \includegraphics[width=1\linewidth]{Figures/Mod2/architectures.pdf}
%     \caption{Applications of modularity}
%     \label{fig:enter-label}
% \end{figure}



%%%%%%% Methodology
% Minimum modularity (how modular-dcm can be improved
% Theoretical guarantee of minimum modularity and why can we not modularize further.
% Connect it with federated learning : Collaborate based on agreed upor architecture
% If pre-specified architecture, how do we adapt local training,










%%%%%%%%

\section{Methodology }
% \red{remove example to the appendix}

\subsection{Designing the FL setup as a selection bias problem \red{for detection?}}


\begin{figure}[t!]
    \centering
\begin{subfigure}[c]{0.15\linewidth}
  \centering
\begin{tikzpicture}[scale=0.7, transform shape]
    %\tikzstyle{every node}=[font=\tiny]
    \tikzstyle{every node}=[]
    \node   [] (x) {${X}$};
    \node [ right =0.8cm of x] (w2) {${Y}$};
    \node [ below =0.6cm of x] (s) {$C$};
    \draw[ thick] (x) to  (s); 
    \draw[ thick] (x) to  (w2); 
\end{tikzpicture}
\caption{Case 1}
\label{Case 1}
\end{subfigure}
\begin{subfigure}[c]{0.27\linewidth}
  \centering
\begin{tikzpicture}[scale=0.7, transform shape]
    %\tikzstyle{every node}=[font=\tiny]
    \tikzstyle{every node}=[]
    \node   [] (x) {${X}$};
    \node [ right =0.8cm of x] (w2) {${Y}$};
    \node [ left =0.8cm of x] (z) {${S}$};
    \node [ below =0.6cm of x] (s) {$C$};
    \draw[ thick] (z) to  (x); 
    \draw[ thick] (x) to  (s); 
    \draw[ thick] (x) to  (w2); 
\end{tikzpicture}
\caption{Case 2}
\label{Case 2}
\end{subfigure}
\begin{subfigure}[c]{0.27\linewidth}
  \centering
\begin{tikzpicture}[scale=0.7, transform shape]
    %\tikzstyle{every node}=[font=\tiny]
    \tikzstyle{every node}=[]
    \node   [] (x) {${X}$};
    \node [ right =0.8cm of x] (w2) {${Y}$};
    \node [ left =0.8cm of x] (z) {${S}$};
    \node [ above =0.1cm of x] (w) {$A$};
    \node [ below =0.5cm of x] (s) {$C$};
    \draw[ thick] (z) to  (x); 
    \draw[ thick] (x) to  (s); 
    \draw[ thick] (x) to  (w2);
    \draw[ thick] (w) to  (z); 
    \draw[ thick] (w) to  (w2); 
\end{tikzpicture}
\caption{Case 3}
\label{Case 3}
\end{subfigure}
\begin{subfigure}[c]{0.27\linewidth}
  \centering
\begin{tikzpicture}[scale=0.7, transform shape]
    %\tikzstyle{every node}=[font=\tiny]
    \tikzstyle{every node}=[]
    \node   [] (x) {${X}$};
    \node [ right =0.8cm of x] (w2) {${Y}$};
    \node [ left =0.8cm of x] (z) {${Z}$};
    \node [ below =0.6cm of x] (s) {$C$};
    \draw[ thick] (z) to  (x); 
    \draw[ thick] (x) to  (s); 
    \draw[ thick] (x) to  (w2); 
    \path[bidirected] (z) edge[bend left=35] (w2);

\end{tikzpicture}
\caption{Case 4}
\label{Case 4}
\end{subfigure}
%
\caption{$X:$ X-ray image, $Y$: pneumonia prediction, $S$: symptoms, $A$: patient age. Selection variable $C$ indicates different distributions in client $C=c$. }
% \label{fig:gen-form}
\end{figure}


Selection bias is a common real-world challenge created by preferential selection of data points. \red{Cite} For example, in epidemiology case-control studies, patients with unusual disease or complication (outcomes) generally report their cases ($C=1$) while non-severe outcomes remain unreported ($C=0$). This forms a selection bias in the dataset. To formalize this problem, a variable $S$ is explicitly added to represent the selection mechanism, and assumed that $C = 1$ represents presence in the dataset, and $C = 0$ if not.

\red{re-write} We can represent the federated learning problem with the same setup.  
We explicitly add a selection variable $C$ as children of the heterogeneous variable $X$.
$X\rightarrow C$ represents that based on the value of only $X=x$,
it is decided whether a sample $\mathbf{v}$ will be found in the local data pool of a client. $P(C=c|\mathbf{v}(x))$ implies the probability of sample $\mathbf{v}$ with $X=x$ being included in the local data of client $c$.
$P(c|\mathbf{v}(x)) = \frac{P(\mathbf{v}(x)|c) P(c)}{P(\mathbf{v}(x))} \propto P(\mathbf{v}(x)|c)$. As we consider $P(c)$ is uniform and $P(\mathbf{v}(x))$ same for all clients.  
\red{$P(C)$ should be propotion to its dataset size.}
% $P(c=i| \mathbf{v}(x)) \neq P(c=j| \mathbf{v}(x)) $
Thus, if for some $x\in \mathcal{X}$,
$P(\mathbf{v}(x)|c=i) << P(\mathbf{v}(x)| c=j)$ then $P(c=i| \mathbf{v}(x)) << P(c=j| \mathbf{v}(x))$. We need to perform federated learning to transfer required mechanisms learned from the local data of other clients $j: \mathcal{C\setminus\{i\}}$, to client $i$ as a global model.
%
% dependent on $X$ from client $i$ to client $j$ (and other similar clients). 
%
Since the clients disagree on the values of $X$, we define $X$ as the heterogeneous variable.

The above selection bias representation of FL allows us to understand what mechanisms should we train globally. 
Suppose for client $i$ has partial support $\mathcal{X}_i$ for variable $X$. Client $i$ need to learn how mechanism $f_X(.)$ should perform to generate $X=x$ such that $x\in \mathcal{X}\setminus \mathcal{X}_i$ and how the mechanisms $f_{V: Ch(X)} (x)$ should perform for $X=x$ with support $x\in \mathcal{X}\setminus \mathcal{X}_i$. Below, we build the formal characterization of the set of mechanisms that need to be trained globally, case by case:


Since we do not want any spurious correlation, our goal is always the interventional distribution (ex: $P(y|do(x))$ and not the conditional distribution (ex: $P(y|x)$.

Case 1 (Figure~\ref{Case 1}): Consider predicting pneumonia ($Y$) from x-ray images ($X$).
For prediction $P(y|do(x))= P(y|x)$, federated learning of $f_y(x); \forall x\in \X$ is sufficient. Thus, $\mathcal{T}= \{f_{x}\}$

Case 2 (Figure~\ref{Case 2}): 
Suppose, we have information about patient symptoms $S$ now. Symptoms affect how the x-ray would look like ($S\rightarrow X)$ but 
pneumonia prediction ($Y$) should depend on $X$ not on $S$. Thus, no $S\rightarrow Y$. 
% should be independent from $Y$ for a specific x-ray image ($Z\indep Y|X$).
For prediction $P(y|do(s))= \sum_{x}P(y,x|do(s))= \sum_x P(x|do(s)) P(y|do(s),x) = \sum_x P(x|s) P(y|x) $ = $\sum_{x\in \X_i} P(x|s) P(y|x) +  \sum_{x\in \X \setminus \X_i} P(x|s) P(y|x)$. Thus,  client $i$ need to learn how i) $x= f_X(s);\forall x\in \X \setminus \X_i$ and $\forall z\in \mathcal{S}$ are generated and ii) $y= f_Y(x);\forall x\in \X \setminus \X_i$ and $\forall y\in \mathcal{Y}$ are generated. Thus, $\T = \{f_x, f_y\}$.

Case 3 (Figure~\ref{Case 3}): Suppose, now we have access to patient ages $A$ and client 1 represents a hospital that 
mostly serves older demographic. Older patients are more likely to develop different symptoms ($A\rightarrow S$) and have a higher risk of getting pneumonia ($A\rightarrow Y$).
However, partial support does not affect $A$ or vice-versa. Thus, even though we have an additional measured confounder $A$, the FL training set $\T = \{f_x, f_y\}$ stays the same. 

Case 4 (Figure~\ref{Case 4}): As FL will train the global model locally, using age $A$ attribute might help the model overfit to local data (\red{doesnt age has same distribution across clients?}). Thus, we do not want age as a feature in our models. We consider $A$ as an unobserver confounder. This is represented as a bi-directed edge $Z\leftrightarrow Y$. $\{Z,Y\}$ is now a c-component. Similar to case 2 and 3, we have $f_x\in\T$ and $f_y\in \T$. However, note that $Z \not\perp Y|X$. If we train $f_Y$ with only $X$ as input, we have $y= f_Y(x)$. This makes $Z \perp Y|X$. Thus, we need to train both $z=f_Z(u)$ and $y= f_Y(x,u)$ together with the same confounding noise $U\sim \mathcal{N}(0, I)$. Therefore, the FL training set $\T= \{f_z, f_x, f_y\}$.

\subsection{Defining the Neighborhood Mechanisms for FL}
We need to learn the i) mechanism for generating the new $X=x$ given its parent ($Pa(X)$) values  and ii) the mechanisms that determine what values its children ($Ch(X)$ should take for the previously unseen $X=x$. For example, suppose the clients maintain a causal graph: $color (C) \rightarrow Image (X) \rightarrow digits (d)$. 
Both clients has images of colors: $R,G,B$ and digits: $0-9$ but images in their local data are collected from different environments.
Client 1 has images of $0-4$ w/ all red color while client 2 has images of $5-9$ with green color. In this setup,


% If $C$ is child of both $X$ and $Y$ them it implies that based on both $X,Y$, the sample belongs to a client. For example: a specific X-ray image without major symptoms has equal likelihood of being assingned to both emegency and regular client. Howver, if a clinician predicts somenthings  





\subsection{Modular learning of Deep causal generative models}
Given an ADMG of a semi-Markovian model,  ~\cite{tian2002general} utilizes the c-component sub-graph modules of the ADMG and factorizes the joint distribution $P(\mathcal{V})$ into c-factors: the joint distributions of each c-component $S_i$ intervened on their parents, i.e., $P(s_j|\Do(pa(s_i)))$.

\begin{equation}
\label{eq:c-fact}
	\begin{split}
		P(v)&= \prod_{s_i\in C(G)} P(s_i|\Do(pa(s_i))\\
	\end{split}
\end{equation}
%
%
%
This factorization implies that if we can enforce our approximated DCM $\hat{\mathcal{M}}$ to match each of the c-factors, the joint distribution implied by the DCM (Def~\ref{def:dcm-sample}) will also match $P(\mathcal{V})$.



% \textbf{Limitations of modular-dcm and minimum modularity}
Existing approach such as \cite{rahman2024modular} utilize the c-factorization to modularize the DCM learning. However, they point out the fact each c-factors in equation~\ref{eq:c-fact} is an interventional distribution. Since we have access to only $P(\mathcal{V})$ dataset, they suggest learning a proxy distribution of each c-factor involving more variables than the c-component. This becomes wasteful specially in our considered federated learning setup.
In this paper, we show that c-component based modularity is sufficient and necessary to learn the DCM matching the joint distribution.


Suppose, $P(V)$ represents the original data distribution and $\mP(V)$ be the distribution implied by the $\theta$ parameterized DCM. We need to minimize the following for each c-factor by training the DCM:
\begin{equation}
               L_0 = 
               d(\mP(S|do(pa(S)))  
               , P(S|do(pa(S))) )
\end{equation}



Let $\Vc$ be the heterogeneous variable and $S_{\Vc}$ be the c-component containing $\Vc$. Our main idea is to train the heterogeneous c-component $S_{\Vc}$ globally in a federated manner, and train rest of the c-components $\{{S_i}\}_i\setminus \{S_{\Vc}\}$ locally.


As an example, consider the graph in Figure~\ref{fig:gen-form}: $W \rightarrow {X} \rightarrow {Z} \rightarrow {Y}; {X} \leftrightarrow {Y}$. Here, 

\begin{equation}
\begin{split}
   % & P(W, {X}, {Y}, {Z})\\
    P(v) = P(w|\Do(\emptyset)) &P({x}, {y}|\Do({w, z}))  P({z}|\Do({x}))\\
\end{split}
\end{equation}

Suppose $P(y|z)$ exhibits heterogeneity across clients. Our algorithm would suggest training $\{f_{x}, f_{y}\}$ globally to match $P(x,y|do(w,z))$ with a FL algorithm and train $\{f_w, f_z\}$ locally with only client local data.


\begin{algorithm}[t!]
\caption{Fed-DCM Algorithm}
\begin{algorithmic}[1]
\STATE \textbf{Input:} Dataset $\mathcal{D}$, Causal graph $\mathcal{G}$, Variables $\mathbf{V} = \{V_1, V_2, \dots, V_n\}, n = |\mathbf{V}|$

\STATE \textbf{Client initialization:}
\FOR{each $ V \in \mathbf{V}$}
    \STATE Initialize weights of $f_{V}(Pa(V), U_{V} )$ as $w[V]$
\ENDFOR
\STATE $[\mathbf{S}_i, \text{Pa}(\mathbf{S}_i)] \leftarrow \text{c\_component\_partition}(\mathcal{G})$
\STATE $S_{\Vc} =$ Find c-component $S \in \{{S}\}_i $ s.t. $\Vc \in S$

\STATE \textbf{Server executes: }
\STATE Initialize model weights $w_0[V];$ for all $V\in S_{\Vc}$.
\FOR{each round $t = 1, 2, \dots$}
    % \STATE $m \gets \max(C \cdot K, 1)$
    \STATE $C_t \gets$ (random set of $max(\alpha C, 1)$ clients)
    \FOR{each client $k \in C_t$ \textbf{in parallel}}
        \STATE $w_{t+1}^{k}$ $\gets \textsc{ClientUpdate}(k, w_t)$
        % \FOR{each variable $V\in S_{\Vc}$ }
        % \STATE $w_{t+1}^{k}[V] \gets  keep[V] $
        % \ENDFOR
    \ENDFOR
    \FOR{each variable $V\in S_{\Vc}$ }
    \STATE $w_{t+1}[V] \gets \sum_{k=1}^K \frac{n_k}{n} w_{t+1}^{k}[V]$
    \ENDFOR
\ENDFOR
\vspace{1em}


\STATE \textbf{LocalTraining}($w, \mathcal{B}, S, Pa(S)$)): \textit{(Models in a cc)}
\FOR{each local epoch $i$ from $1$ to $E$}
    \FOR{each batch $b \in \mathcal{B}$}
\STATE Sample   $pa(\mathbf{S}_i)  \sim \text{Uniform}(\text{support}(\text{Pa}(\mathbf{S}_i)))$
    \STATE $D^R[\mathbf{S}] = $ getRealIntvData($b, \mathcal{G}, \mathbf{S_i}, pa(\mathbf{S_i})$)
    \STATE $D^F[\mathbf{S}] = $ getFakeIntvData($f_{V_i\in \mathbf{V}}, \mathbf{S_i}, pa(\mathbf{S_i})$))
    \STATE $\ell$ = dist($D^F, D^R$)
    \FOR{each $V\in S$}
        \STATE $w[V] \gets w[V] - \eta \nabla \ell$
    \ENDFOR
    \ENDFOR
\ENDFOR
\STATE \textbf{return} $w$

\vspace{1em}
\STATE \textbf{ClientUpdate($c, w$):} \textit{(Run on client $c$)}
\STATE $\mathcal{B} \gets$ (split $D^c$ into batches of size $B$)

\FOR{each $ {S} \in \{{S_i}\}_i\setminus \{S_{\Vc}\} $ }
\STATE $w = \textbf{LocalTraining($w, \mathcal{B}, S_{i}, Pa(S_{i})$)}$ \textbf{in parallel}
\STATE Save $\{w[V]\}_{V\in S}$ locally.
\ENDFOR
\STATE $w = \textbf{LocalTraining($w, \mathcal{B}, S_{\Vc}, Pa(S_{\Vc})$)}$ \textbf{in parallel}
\STATE \textbf{return} $\{w[V]\}_{V\in S_{\Vc}}$ to server
% \STATE \textbf{return } SCM mechanisms $f_{V}$
\end{algorithmic}
\end{algorithm}




\subsection{Connection with Federated Learning}

\textbf{Client initialization:}
We initialize the weights of each model $f_i$ in the DCM as $w[V]$. We obtain all c-components from the causal graph and address the c-component containing the heterogeneous variable $\Vc$ as $S_{\Vc}$. 

\textbf{Client update:}
In each client $c$, we train the models in c-components that are personalized and non-heterogeneous ($\{{S_i}\}_i\setminus \{S_{\Vc}\}$) using the function \textbf{LocalTraining(.)}. We save these models locally.

Calling \textbf{LocalTraining(.)} similarly, we train models in the heterogeneous c-component $S_{\Vc}$ but send those models to the server to be aggregated with the global model.  That is, each client locally takes one step of gradient descent on the current global models $\{f_{V_j}\}_{V_j \in S_{\Vc}}$  using its local data, and the server then takes a weighted average of the resulting models. 





When client $c$ calls \textbf{LocalTraining(.)} for a c-component $S$, it makes $E$ training passes over its local dataset $D^c$ of $\mathcal{B}$ batches to train models $\{f_{V}\}_{V \in S}$. Client $c$ compares the fake samples $D^F \sim P_{\theta}(s|do(pa(s))$ with real interventional data $D^R \sim P(s|do(pa(s))$ to obtain a loss function $\ell$ (discussed in Section~\ref{subsec:c-factor-train}). The model weights $w[V]$ of each function $f_{V}$ in $S$ is updated as $w[V] \gets w[V] - \eta \nabla \ell$.
This local training for each c-component is performed in parallel.


The computation complexity of a client depends on parameters such as: $\alpha:$ clients that perform computation on each round, $E:$ number of epochs on local data at each round, $B:$ local minibatch size and $|{f_{V_j}\}_{V_j \in S_{\Vc}}}|:$ number of       models contained in the heterogeneous c-component $S_{\Vc}$.

\textbf{Server executes:}
The server receives the heterogeneous c-component models from each client and takes a weighted average of the sent models. For the weights of each function $f_{V}$ in $S_{\Vc}$, the server performs $w_{t+1}[V] \gets \sum_{k=1}^K \frac{n_k}{n} w_{t+1}^{k}[V]$.


\red{First, each source computes gradients from all sources and subsequently updates the model. Next, the server broadcasts the new the local gradient,  using its own data and sends to the server. The server, then, collects these model to all the sources.}

\subsection{Efficient learning of c-factors: $P(s_i|do(pa(S_i))$ }
\label{subsec:c-factor-train}

\begin{figure}[t!]
    \centering
        \begin{subfigure}[c]{0.45\linewidth}
  \centering
\begin{tikzpicture}[scale=0.7, transform shape]
    %\tikzstyle{every node}=[font=\tiny]
    \tikzstyle{every node}=[]
    \node   [] (x) {${X}$};
   \node [ left =1cm of x] (W) {${W}$};
    \node [below right =0.8cm and 0.4cm of x] (w1) {$\mathbf{Z}$};
    \node [ right =1.4cm of x] (w2) {${Y}$};
    \draw[ thick] (x) to  (w1); 
    \draw[ thick] (W) to  (x);
    \draw[ thick] (w1) to  (w2); 
    \path[bidirected] (x) edge[bend left=35] (w2);
\end{tikzpicture}
% \caption{Generalized format}
\end{subfigure}
%
\caption{Causal graphs}
\label{fig:gen-form}
\end{figure}



% We can apply step 7 of the ID algorithm for the case in modular-DCM when rule-2 does not apply. 

Equation~\ref{eq:c-fact} suggests that we have to match each of the c-factors in the product. This can be done independently and in-parallel. 
Thus, we first focus on how a specific c-factor can be matched by training the mechanisms of the c-component. The main idea is that we train a set of models $\{M_j\}$ to learn conditional distributions and utilize them to generate samples from the c-factor $P(s_i|do(pa(s_i))$. Now, given that we have generated real interventional data for $P(s_i|do(pa(s_i))$, we can train the mechanisms in the DCM $\{f_{V_j}\}_{V_j\in S_i}$ on those. We generate fake data from the GAN architecture and perform adversarial training considering the generated interventional data as the real dataset.

    \begin{algorithm}[t!]
\caption{getRealIntvData($\mathcal{D}, \mathcal{G}, \mathbf{S}, pa(\mathbf{S})$)}
\begin{algorithmic}[1]
\STATE \textbf{Input:} Dataset $\mathcal{D}$, Causal graph $\mathcal{G}$, C-component $\mathbf{S}$, blanket $Pa(\mathbf{S})$.
\STATE $An= V_{\pi^{j-1}} \cap (S_i \cup pa(S_i))$; $\pi_{\mathcal{G}}$ be the ancestral order.
\FOR{each $V_j \in \mathbf{S}$}
\STATE Train $M_j(An)$ on $\mathcal{D}$ such that $M_j(An) \sim P(v_j|An)$
\ENDFOR
\STATE Fix $ pa(\mathbf{S}) $ in $M_{j: V_j \in \mathbf{S}}$ and ancestral sample to obtain $ D^R[\mathbf{S}] \sim P(\mathbf{S} \mid \text{do}(\text{Pa}(\mathbf{S}))) $
\STATE \textbf{Return } $D^R[\mathbf{S}]$
\label{alg:getRealIntvData}
\end{algorithmic}
\end{algorithm}


\textbf{Real interventional data generation:} To be more precise, let us consider a c-component $S_i$ and we need to all train mechanisms $f_{V_j} \in S_i$.
 Here, the c-factor $P(s_i|do(pa(s_i))$ is identifiable, i.e., we can uniquely estimate it as a function of the observational distribution $P(V)$ as the intervention set $\mathbf{X}= Pa(S_i)$ is located outside the c-component~\cite{tian2002general}. We can estimate the c-factor with the following formula:
\begin{equation}
\label{eq:step6}
    % P(s_i| do(pa(s_i))=\prod_{V_j\in S_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))
    P(s_i| do(pa(s_i))=\prod_{\{j|V_j\in S_i\}} P(v_j|v_{\pi}^{j-1} )
\end{equation}

We can train a conditional model $M_j$ for each of these conditional distributions in Equation~\ref{eq:step6} with the loss function below.
Here, $\iP_{\theta'}(V)$ be the distribution learned by the $\theta$ parameterized models trained in Algorithm~\ref{alg:getRealIntvData}. 

\begin{equation}
\begin{split}
 L_1 = &d(P(S|do(pa(S))),  {\iP}_{\theta'}(S|do(pa(S))) )  \\
= & \sum_{i:V_i\in S} d(P(v_i| v_{\pi}^{i-1}), 
 \iP_{\theta'_i}(v_i| v_{\pi}^{i-1}))    
\end{split}
\end{equation}
After convergence, we connect these trained models according to input-outputs and perform ancestral sampling to generate samples from the interventional distribution $P(s_i| do(pa(s_i))$. This is illustrated in Algorithm~\ref{alg:getRealIntvData}: getRealIntvData(.), lines 3-5.


In Figure~\ref{fig:gen-form},
\begin{equation}
    P({x}, {y}|\Do(w, {z})) =  P({x}|w) P({y}|w, {x}, {z})
\end{equation}
To generate samples from $P(x,y|\Do(w,z))$:
\begin{enumerate}
    \item First, we train a model $M_{X}$ with $W$ as input such that its output $M_{X}(w)\sim P(x|w)$. Next, we train a model $M_Y$ such that $M_{Y}(w,x,z)\sim P(y|w,x,z)$. 
    \item 
    Connect the models as $[w] \rightarrow M_{X} \rightarrow M_{Y} \leftarrow [w,z]$. Next, perform ancestral sampling: feed the intervened value $W=w$ to sample ${\hat{X}}=M_{X}(w)$ and then feed the generated ${\hat{X}}$ and intervened ${(W,Z)=(w,z)}$ to sample $\hat{Y}= M_Y(w, \hat{X}, {z})$. Here, $\{\hat{X}, \hat{Y}\}\sim P(x,y|do(w,z))$.
\end{enumerate}

This implies that we can generate interventional data from observational data only having access to models $M_{V_j}; V_j\in S_i$ and data: $D[Pa(S_i), S_i]$. This c-component based modularity was unexplored previously and the closest work~\cite{rahman2024modular} considered additional models and data of proxy variables.


\begin{algorithm}[t!]
\caption{getFakeIntvData($\mathbb{G}_{V_i\in \mathbf{V}}, \mathbf{S}, pa(\mathbf{S})$)}
\begin{algorithmic}[1]
\STATE \textbf{Input:}  DCM $\mathbb{G}_{V_i\in \mathbf{V}}$, C-component $\mathbf{S}$, blanket $Pa(\mathbf{S})$.
\STATE Fix $ pa(\mathbf{S}) $ in $\mathbb{G}_{V_i \in \mathbf{V}}$ and ancestral sample to obtain $ D^F[\mathbf{S}] \sim P(\mathbf{S} \mid \text{do}(\text{Pa}(\mathbf{S}))) $
\STATE \textbf{Return } $D^F[\mathbf{S}]$
\label{alg:getFakeIntvData}
\end{algorithmic}
\end{algorithm}


\textbf{Fake interventional data generation:}
Now that we have obtained real interventional samples from $P(s_i| do(pa(s_i))$, we can utilize those as training data to train mechanisms $f_{V_j}; V_j\in S_i$. 





Algorithm~\ref{alg:getFakeIntvData}: getFakeIntvData
\begin{enumerate}
    \item Feed the generated ${\hat{X}}$. and random ${Z}$ to generate ${Y}$ as inference.
    \item Train $M_{{Y}}(N,{Z})$ by matching with ${Y}$. This will ensure that $M_{{Y}}$ is being trained on $P({X,Y|do(Z)})$.
    \end{enumerate}


Now, we can compare the generated fake interventional dataset  as Algorithm~\ref{alg:getFakeIntvData} with generated real interventional dataset as Algorithm~\ref{alg:getRealIntvData} and train the models in $\{f_{V_j}\}_{V_j\in S_i}$ accordingly.


The the loss function for DCM is as follow:
\begin{equation}
               L_2 = 
               d(Q_{\theta'}(S|do(pa(S))), P_{\theta}(S|do(pa(S))))
\end{equation}

Now, we can back propagate on $L^{\theta}$ and $L^{\theta'}$ 
% , \forall i:V_i\in S 
which lets us update each parameter $\theta_i': \forall_i$ and $\theta$ independently. 


Finally, according to triangle inequality,
\begin{equation}
    \begin{split}
        & d(P(S|do(pa(S))), P_{\theta}(S|do(pa(S))))  \\
               &\leq 
               d(P(S|do(pa(S))), Q_{\theta'_i}(S|do(pa(S)))) \\
        & +  d(Q_{\theta'}(S|do(pa(S))), P_{\theta}(S|do(pa(S)))) \\
        & \implies L_0 \leq L_1 + L_2
    \end{split}
\end{equation}

As we are minimizing the loss functions $L_1$ and $L_2$, our target loss function $L_0$ will be minimized as well.

% Suppose the graph can be generalized in the format: $W \rightarrow {X} \rightarrow {Z} \rightarrow {Y}; {X} \leftrightarrow {Y}$ as shown in Figure~\ref{fig:gen-form}.





% ID-GEN should be able to be utilized easily
% to obtain c-component based modularity in modular-DCM.
% Previously we discussed about applications such as adaptation to distribution shift, dcm for time-series, transportability etc.

% Also \cite{jung2024estimating} show that any
% g-identifiable causal effect can be expressed as a function of generalized multi outcome sequential back-door adjustments that are amenable to estimation.


% To obtain a DCM we need to train $|V|$ models for $V$ variables and match all of the following terms.
% \begin{equation}
%     P(V) = \prod_{i\in \{n\}} P(S_i| do(pa(S_i))
% \end{equation}


\section{Experiments}
TBD
\section{Related works}
TBD
\section{Conclusion}
TBD


\clearpage

\section{Experiments}


\subsection{Synthetic data: }
Compare with FedRRF as baseline.

\subsection{Domain shift/spurious correlation}
CelebA.

\subsection{Benefits of modularity \& causal modeling}

We show the applications of our algorithms in two direction:

i) \textbf{Efficient adaptation to distribution shift:} 
Suppose we observe a distribution shift in the coming data.
This can be represented as soft intervention in the causal graph which
changed some mechanism. We can locate the c-component $S_i$ and fine-tune it to adapt to the distribution shift.

ii) \textbf{Transportability of mechansims}:
Suppose, we want to transport some part of our model to a different domain. We can compare our training domain and test domain to determine which mechanisms/blankets stay invariant. We can transport that part and train rest of the mechanisms in the causal graph.  
We can use reddit data for this purpose. We train our DCM on weight gain subreddit and transport it to weight lose subreddit.

ii) \textbf{Federated Learning}: We have some pre-trained setup. Next, we have a new output variable. We add it to the SCM and train only necessary part of the causal graph.


multimodal learning or when auxiliary information can improve the classification task.


\subsection{Real-world experiment 1}
COALA~\cite{zhuang2024coala}, at the task level, we
extend to a broader spectrum of 15 CV tasks, including classification, object detection, segmentation, pose estimation,
face recognition. At the data level, support
semi-supervised FL, unsupervised FL, and multi-domain
FL with feature distribution shifts among local training data. At the model level, clients can train multiple models with varying
parameters and architectures.


\subsection{Real-world experiment 2}
\cite{zhang2023federated} leverage Stable Diffusion to synthesize high quality training data on the server based on the text embeddings collected from clients. They generate prompts based on the characteristics of the client’s data, which are used as inputs to a specific text encoder to obtain corresponding text embeddings. Once all text embeddings are collected from the clients, the server performs embedding aggregation
and then synthesizes a high-quality substitute training dataset. This public synthetic dataset serves as a proxy for the clients’ private data and can be used to train a global model on the server.





\section{Related works}

\cite{hu2024fissionvae} uses federated setup to generate images with VAE and GANs.

\cite{song2021federated} uses cycle gan in federated setup for image translation task.


\cite{ng2022towards} proposes federated bayesian network structure learning with continuous optimization.


\cite{han2021federated} Federated adaptive causal estimation (face) of target treatment effects.

\cite{xiong2023federated} Federated causal inference in heterogeneous.

\cite{qiao2023collaborative}Collaborative causal inference with fair incentives.


% \FOR{each $ \mathbf{S} \in [\mathbf{S}_i] $}
% \STATE Connect input and output of $ G_{V_i} $ for all $ V_i $, according to directed edges $ \rightarrow $ in $ \mathcal{G} $
% % \ENDFOR
% \FOR{each bi-directed edge $ \leftrightarrow $ in $ \mathcal{G} $}
% \STATE $ U \sim \mathcal{N}(0, I) $
% \ENDFOR




\clearpage


% % Acknowledgements should only appear in the accepted version.
% \section*{Acknowledgements}

\section*{Impact Statement}

Authors are \textbf{required} to include a statement of the potential 
broader impact of their work, including its ethical aspects and future 
societal consequences. This statement should be in an unnumbered 
section at the end of the paper (co-located with Acknowledgements -- 
the two may appear in either order, but both must be before References), 
and does not count toward the paper page limit. In many cases, where 
the ethical impacts and expected societal implications are those that 
are well established when advancing the field of Machine Learning, 
substantial discussion is not required, and a simple statement such 
as the following will suffice:

``This paper presents work whose goal is to advance the field of 
Machine Learning. There are many potential societal consequences 
of our work, none which we feel must be specifically highlighted here.''

The above statement can be used verbatim in such cases, but we 
encourage authors to think about whether there is content which does 
warrant further discussion, as this statement will be apparent if the 
paper is later flagged for ethics review.


% In the unusual situation where you want a paper to appear in the
% references without citing it in the main text, use \nocite
\nocite{langley00}

\bibliography{references}
\bibliographystyle{icml2025}


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% APPENDIX
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\newpage
\appendix
\onecolumn





\section{Issues need fix}
\begin{itemize}
    \item For overlapping variables, don't we need their parents?
    \item Why is this restricted to gans only? Does that mean this is minimum component that is dependent on a gan architecture?
    \item Heterogenous and personalized.
    \item cite more recent DCM/NCM works. Replace old works with new ones.
    \item Rephrase SCM def and intro.
    \item FL def and eq.

    \item \textbf{Trivial Solution 1:} joint $P(v', w)$\red{*}; confounding noise $U$ \red{*}
    \item Equation~\ref{eq:triv-sol2} $U$.
    \item Section 3.1 taken from Modular-DCM paper.
\end{itemize}


\section*{Mathematical Notation}
The table below lists and defines the mathematical symbols used throughout this paper:

\begin{longtable}{|c|l|}
\hline
\textbf{Symbol} & \textbf{Description} \\ \hline
\endfirsthead
\hline
\textbf{Symbol} & \textbf{Description} \\ \hline
\endhead
\hline
\endfoot

% Add rows for each symbol
$X$ & The given/detected variable heterogeneous across clients. \\ \hline
$\T$ & Set of mechanism that we select for federated learning. \\ \hline
$\mathcal{C}$ & Set of all clients \\ \hline
$f(x)$ & Probability density function \\ \hline
$P(A)$ & Probability of event $A$ \\ \hline
$\mathbf{V}$ & Set of all causal variables in a structural causal model \\ \hline
$\mathbb{R}$ & Set of real numbers \\ \hline
$\nabla f(x)$ & Gradient of $f(x)$ \\ \hline
$\partial$ & Partial derivative \\ \hline
$\sum_{i=1}^n x_i$ & Summation over $n$ terms \\ \hline
$\arg\max_x f(x)$ & Argument that maximizes $f(x)$ \\ \hline
$\mathcal{N}(\mu, \sigma^2)$ & Gaussian distribution with mean $\mu$ and variance $\sigma^2$ \\ \hline
$\lambda$ & Regularization parameter \\ \hline
$\alpha, \beta$ & Model hyperparameters \\ \hline

\end{longtable}





\section{Extra }
\begin{itemize}



\item \cite{mclaughlin2024personalized}  frame representation learning as a
generative modeling task, where representations are trained with a classifier based
on the global feature distribution. Their algorithm efficiently generates personalized models by adapting global generative classifiers
to their local feature distributions. 

% \item Knowledge distillation. Teacher student federated distillation. Different nodes of the causal graph are different representation of the data/different task.
    
\end{itemize}







\blue{Is there any other method that can solve the same problem? Or you can built upon?}
\begin{itemize}
    \item \cite{tang2024fusefl} provide a causal view to understand the gap between multi-round FL and OFL, showing that
augmenting intermediate features from other clients contributes helps improve OFL. They are the first using causality to analyze the data heterogeneity of OFL. They have code as well. We might built upon their work.

% \item
% \cite{makhija2024a} propose personalized federated learning utilizing Bayesian principles for improved robustness and reliability, particularly in contexts where data is scarce. We might utilize its approach.


% \item (shifted) Even though each client share some common mechanisms that they learn through federated setup, they generally have their own complex causal system which contains the learned mechanism. There exists no work that discusses the connection or dynamics between other variables in the system and the federated mechanism and performs further prediction. For example, in Figure~\ref{?}, the learned image is an intermediate variable. The input and the output of the variables are confounded and we want a causal effect of the image input on the image classifier output.


\end{itemize}

\blue{What is the proposed solutions?}

\begin{itemize}
    \item Also, we utilize partial identification and remove the most weak edges based on the constraints.
\end{itemize}



% \blue{Additional Benefits}
% \begin{itemize}

% \item Local PLM finetuning involves other variables. And finetuning is expensive for edge devices. fine-tuning PLMs
% through FL requires the clients and server to frequently exchange model parameters or gradients, usually on a scale of
% millions or even billions of parameters. Method which can deal with these: adapter tuning , prefix tuning , LoRA and BitFit

% \item Bias Variance tradeoff with causal approach feature distribution shift and data scarcity. Causal graph will also allow us feature reduction.
% \end{itemize}





\begin{theorem}
\textbf{Sufficiency:}
Following the c-component is correct. And this is sufficient to sample from  $P(y|do(x))$.

\textbf{Necessity:}
Theoretical guarantee of c-component based minimum modularity

Theoretical Guarantees
We prove the necessary number of model updates and necessary amount of variables required for the adaptation to new SCM. This is possible due to the c-blanket of the c-component.
    
\end{theorem}


\subsection{Pre-specified mechanism utilization}




\textbf{ii) Adapting local mechanisms for special cases}
Assumption: Bi-directed neighbors without any parents or bi-directed neighbors.





\begin{figure}
    \centering
    \includegraphics[width=1.1\linewidth]{Figures/swig_type_modularization.pdf}
    \caption{Enter Caption}
    \label{fig:enter-label}
\end{figure}



Fed-DCM in Practice:
In practice, we might face some situation when only some specific variable mechanism is eligible to be learned in the federated setup. In this section, we show how we can learn the causal model in such scenario.


Figure 1 left, modularing till c-component level.
We match $P(x,y,z)$.
\begin{equation}
    \begin{split}
P(x,y,z)= P(x,y|do(z)) P(z|do(x))\\
= P(x) P(y|x,z) P(z|x)
    \end{split}
\end{equation}

Figure 1 right, 
\begin{equation}
\begin{split}
   &Let X=X';\\
    &P(X)= P(X');\\
    &P(X,X') = P(X)\\ 
\end{split}
\end{equation}

Now match $P(x, x', y, z)$.

\begin{equation}
    \begin{split}
&P(x, x', y, z) \\
&=P(x,x') P(z|do(x)) P(y|do(x',z))\\
&=P(x,x') P(z|x) P(y|x',z)\\
&=P(x) P(z|x) P(y|x,z)\\
&= P(x,y,z)
    \end{split}
\end{equation}
This implies that matching $P(x, x', y, z)$ is equivalent to matching $P(x,y,z)$.

Now, we estimate the causal effect for the new graph:

\begin{equation}
\begin{split}
&P(y| do(x))\\
&=\sum_{x',z} P(z|x) P(x') P(y|x',z)\\
&=\sum_{z} P(z|x) \sum_{x'} P(x') P(y|x',z)
\end{split}
\end{equation}
which equals the causal effect in the original causal graph (Figure 1 right).



\subsection{Fed-DCM in test domain with distribution shift}



\subsection{Further efficiency}
It appears that we need to train $|S_i|$ many models to sample from the distribution in equation~\ref{eq:step6}. However, we can reduce the training cost. Let $\mathbf{Z_i}=\{V_j: V_{\pi^{j-1}} \cap Pa(S_i) =\emptyset \}$, i.e, $Pa(S_i)$ are not ancestors of such $V_j \in \mathbf{Z_i}$.

\begin{equation}
\begin{split}
    &\prod_{V_j\in S_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))\\
    = &\prod_{V_j\in \mathbf{Z}_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))\\
    & \times \prod_{V_j\in S_i\setminus \mathbf{Z}_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))\\
    = &\prod_{V_j\in \mathbf{Z}_i} P(v_j|
    (v_{\pi^{j-1}} \cap S_i) (v_{\pi^{j-1}} \cup pa(S_i)))\\
    & \times \prod_{V_j\in S_i \setminus \mathbf{Z}_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))\\
        \end{split}
\end{equation}
\begin{equation}
\begin{split}
    = &\prod_{V_j\in \mathbf{Z}_i} P(v_j|
    (v_{\pi^{j-1}} \cap S_i))\\
    & \times \prod_{V_j\in S_i \setminus \mathbf{Z}_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))\\
    = &P(\mathbf{Z}_i)
    \prod_{V_j\in S_i \setminus \mathbf{Z}_i} P(v_j|v_{\pi^{j-1}} \cap (S_i \cup pa(S_i)))
\end{split}
\end{equation}

Thus, we would need to train $|S_i \setminus \mathbf{Z}_i|$ number of models to generate the required interventional data.

We can call it blanket which keeps the mechanisms in this blanket independent from any change occurs outside the blanket.


Let $|S_i \setminus \mathbf{Z}_i|=M$ and $|S_i|=N$. We need to train $N$ models in the DCM with data generated from $M$ models. Here we can compress even further. If the number of interventions are $|I|$, then we need to train $|I|+1$ number of models which can be represented as a chain of nodes with each intervention being parent of each node.


\subsection{Arbitrary Causal Model across different clients}
Suppose, each client $c$ has causal variables $V^c$ in its environment with an overlapping set of variables $V$ with other clients, i.e, $V=\cap_{c=1}^{C} V^{c}$.

We relax the assumption of having the same acyclic mixed directed acyclic graphs (ADMG) across all clients~\cite{vo2022adaptive} and assume access to ADMGs: $\{G_1, ..., G_C\}$ such that $\cap_{i=1}^c G_i = G(V)$.

Due to data heterogeneity, the overlapping set of variables has different distribution for any client $i$, client $j$: $P^{i}(v) \neq P^{j}(v)$.


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


\end{document}


% This document was modified from the file originally made available by
% Pat Langley and Andrea Danyluk for ICML-2K. This version was created
% by Iain Murray in 2018, and modified by Alexandre Bouchard in
% 2019 and 2021 and by Csaba Szepesvari, Gang Niu and Sivan Sabato in 2022.
% Modified again in 2023 and 2024 by Sivan Sabato and Jonathan Scarlett.
% Previous contributors include Dan Roy, Lise Getoor and Tobias
% Scheffer, which was slightly modified from the 2010 version by
% Thorsten Joachims & Johannes Fuernkranz, slightly modified from the
% 2009 version by Kiri Wagstaff and Sam Roweis's 2008 version, which is
% slightly modified from Prasad Tadepalli's 2007 version which is a
% lightly changed version of the previous year's version by Andrew
% Moore, which was in turn edited from those of Kristian Kersting and
% Codrina Lauth. Alex Smola contributed to the algorithmic style files.
