
\section{Introduction}





Federated learning (FL) allows multiple parties to jointly train a consensus model without sharing user data. 
Compared to the classical centralized learning regime, federated learning keeps training data on local clients, such as mobile devices or hospitals, where data privacy, security, and access rights are a matter of vital interest.
This aggregation of various data resources heeding privacy concerns yields promising potential in areas of internet of things~\cite{cys+20}, healthcare~\cite{lgd+20,lmx+19}, text data~\cite{hsc+20}, and fraud detection~\cite{zygw20}. 

A standard formulation of federated learning is a distributed optimization framework that tackles communication costs, client robustness, and data heterogeneity across different clients \cite{lsts20}. Central to the formulation is the efficiency of the communication, which directly motivates the communication-efficient federated averaging (FedAvg)~\cite{mmr+17}. FedAvg introduces a global model to synchronously aggregate multi-step local updates on the available clients and yields distinctive properties in communication. However, FedAvg often stagnates at inferior local modes empirically due to the data heterogeneity across the different clients \cite{Charles20, Woodworth20}. To tackle this issue, \cite{kkm+20, FedSplit} proposed stateful clients to avoid the unstable convergence, which are, however, not scalable with respect to the number of clients in applications with mobile devices \cite{agxr21}. In addition, the optimization framework often fails to quantify the uncertainty accurately for the parameters of interest, which are crucial for building estimators, hypothesis tests, and credible intervals. Such a problem leads to unreliable statistical inference and casts doubts on the credibility of the prediction tasks or  diagnoses in medical applications.   


To unify optimization and uncertainty quantification in federated learning, we resort to a \emph{Bayesian treatment by sampling from a global posterior distribution}, where the latter is aggregated by infrequent communications from local posterior distributions. We adopt a popular approach for inferring posterior distributions for large datasets, the stochastic gradient Markov chain Monte Carlo (SG-MCMC) method~\cite{Welling11,VollmerZW2016,Teh16,Chen14,yian2015}, which enjoys theoretical guarantees beyond convex scenarios \cite{Maxim17, Yuchen17, Mangoubi18,ma19}. 
In particular, we examine in the federated learning setting the efficacy of the stochastic gradient Langevin dynamics (SGLD) algorithm, which differs from stochastic gradient descent (SGD) in an additionally injected noise for exploring the posterior. 
The close resemblance naturally inspires us to adapt the optimization-based FedAvg to a distributed sampling framework. 
Similar ideas have been proposed in federated posterior averaging~\cite{agxr21}, where empirical study and analyses on Gaussian posteriors have shown promising potential of this approach. 
Compared to the appealing theoretical guarantees of optimization-based algorithms in federated learning~\cite{FedSplit,agxr21}, the convergence properties of approximate sampling algorithms in federated learning is far less understood. To fill this gap, we proceed by asking the following question:
\begin{center}
    {\it Can we build a unified algorithm with convergence guarantees for sampling in FL?}
\end{center}
In this paper, we make a first step in answering this question in the affirmative. We propose the federated averaging Langevin dynamics for posterior inference beyond the Gaussian distribution. We list our contributions as follows:
\begin{itemize}
    \item We present a novel non-asymptotic convergence analysis for FA-LD from simulating strongly log-concave distributions on non-i.i.d data when the learning rate is fixed. The frequently used bounded gradient assumption of $\ell_2$ norm in FedAvg optimization is not required.
    \item The convergence analysis indicates that injected noise, data heterogeneity, and stochastic-gradient noise are all driving factors that affect the convergence. Such an analysis provides a concrete guidance on the optimal number of local updates to minimize communications. 
    % \item We present a convergence result for FA-LD with decaying learning rates. This strategy accelerates the computation by a logarithmic factor to achieve the precision $\epsilon$.
    % \item We present a convergence result for FA-LD with decaying learning rates, where the computation is accelerated by a logarithmic factor to achieve the precision $\epsilon$.
    \item  We can activate partial device updates to avoid straggler’s effects in practical applications and tune the correlation of injected noises to protect privacy. 
    \item We also provide differential privacy guarantees, which shed light on the trade-off between data privacy and accuracy given a limited budget.
\end{itemize}



\vspace{-2mm}
\paragraph{Roadmap.}
In Section~\ref{sec:related}, we discuss related work. In Section~\ref{sec:preli_fa}, we present the preliminary knowledge. In Section~\ref{sec:posterior_inference}, we propose federated averaging Langevin dynamics for posterior inference. In Section~\ref{sec:convergence}, we lay out the required assumptions, sketch the proof, and show the theoretical convergence results. In Section~\ref{DP_section}, we prove the differential privacy (DP) guarantees. In Section~\ref{simulation_local_step}, we conduct experiments to demonstrate the necessity of local steps. In Section~\ref{sec:concl}, we conclude our work. 


\paragraph{Notation}

For any positive integer $n$, we use $[n]$ to denote the set $\{1,2,\cdots,n\}$. Let $N$ denote the number of clients.  For each $c \in [N] $, we use $f^c$ and $\nabla f^c$ as the loss function and gradient of the function $f^c$ in client $c$. $\nabla \tilde f^c(\cdot)$ is the \emph{unbiased} stochastic gradient of $\nabla f^c$. In addition, we denote $p_c$ as the weight of the $c$-th client such that $p_c=\frac{n_c}{\sum_{i=1}^{N} n_i}\in(0, 1)$, where $n_c>0$ is the number of data points in the $c$-th client. Let $T_{\epsilon}$ denote the number of global steps to achieve the precision $\epsilon$. Let $K$ denote the number of local steps and hence ${T_{\epsilon}}/{K}$ denotes the number of communications.