% \documentclass{uai2024} % for initial submission
\documentclass[accepted]{uai2024} % after acceptance, for a revised version; 
% also before submission to see how the non-anonymous paper would look like 
                        
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2024} % ptmx math instead of Computer
                                         % Modern (has noticeable issues)
% \documentclass[mathfont=newtx]{uai2024} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

\usepackage{hyperref}
\usepackage{graphicx}
% \usepackage{algorithm2e}
\usepackage{algorithm}
\usepackage{algorithmic}
\usepackage{subfigure}
\DeclareMathOperator*{\argmin}{\arg\min}
\usepackage{amsmath}

\newtheorem{prop}{Proposition}
\newtheorem{theorem}{Proposition}
\newtheorem{proof}{Proof}
\usepackage{amssymb}


%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example

\title{Functional Wasserstein Bridge Inference for Bayesian Deep Learning}

% The standard author block has changed for UAI 2024 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
% \author[1]{\href{mailto:<jj@example.edu>?Subject=Your UAI 2024 paper}{Jane~J.~von~O'L\'opez}{}}
% \author[1]{Harry~Q.~Bovik}
\author[1]{\href{mailto:<mengjing.wu@student.uts.edu.au>?Subject=Your UAI 2024 paper}{Mengjing~Wu}{}}
\author[1]{Junyu~Xuan}
\author[1]{Jie~Lu}

% Add affiliations after the authors
% \affil[1]{%
%     Computer Science Dept.\\
%     Cranberry University\\
%     Pittsburgh, Pennsylvania, USA
% }

\affil[1]{%
    Australian Artificial Intelligence Institute\\
    University of Technology Sydney\\
    Ultimo NSW 2007\\
    Australia
}

% \affil[2]{%
%     Second Affiliation\\
%     Address\\
%     …
% }
% \affil[3]{%
%     Another Affiliation\\
%     Address\\
%     …
%   }
  
  \begin{document}
\maketitle

\begin{abstract}
Bayesian deep learning (BDL) is an emerging field that combines the strong function approximation power of deep learning with the uncertainty modeling capabilities of Bayesian methods. In addition to those virtues, however, there are accompanying issues brought by such a combination to the classical parameter-space variational inference, such as the nonmeaningful priors, intricate posteriors, and possible pathologies.
% \textcolor{red}{the non-meaningful and pathological prior and complex and uncontrollable posterior}
In this paper, we propose a new function-space variational inference solution called Functional Wasserstein Bridge Inference (FWBI), which can assign meaningful functional priors and obtain well-behaved posterior. Specifically, we develop a Wasserstein distance-based bridge to avoid the potential pathological behaviors of Kullback–Leibler (KL) divergence between stochastic processes that arise in most existing functional variational inference approaches. The derived functional variational objective is well-defined and proved to be a lower bound of the model evidence. We demonstrate the improved predictive performance and better uncertainty quantification of our FWBI on several tasks compared with various parameter-space and function-space variational methods.


\end{abstract}

\section{Introduction}
\label{sec:intro}

In past decades, Bayesian deep learning (BDL) approaches \citep{blundell2015weight, gal2016uncertainty, wilson2020bayesian} have shown success in combining the strong predictive performance of deep learning models with the principled uncertainty estimation of Bayesian inference. They have been recognized as an effective and irreplaceable tool for a wide range of tasks, such as the uncertainty formulation in per-pixel semantic segmentation \citep{kendall2017uncertainties}, risk-sensitive reinforcement learning \citep{depeweg2018decomposition}, and safety-critical medical diagnosis and diabetic detection \citep{filos2019systematic, band2022benchmarking}.
% \citep{ blundell2015weight, gal2016theoretically, gal2016uncertainty, khan2018fast, maddox2019simple, wilson2020bayesian}. 


Even though impressive progress has been made, the application of Bayesian deep learning has not achieved outstanding performance in some tasks compared with their non-Bayesian counterparts \citep{ovadia2019can, foong2019between, farquhar2020radial}.
This phenomenon can probably be attributed to at least two unresolved issues in their parameter-space inference. Firstly, it is difficult to incorporate meaningful prior information about the unknown function into the inference procedure. The widely used independent and identically distributed Gaussian priors for model parameters are not always applicable for that, because the samples of such priors over parameters tend to be horizontally linear and lead to pathologies for deep models \citep{duvenaud2014avoiding, matthews2018gaussian, tran2020functional} (we visualize this problem in Appendix \ref{apd:A} for the self-contained purpose). Moreover, the effects of the given priors on posterior inference and further on the resulting distributions over functions are unclear and hard to control owing to the complex architecture and non-linearity of the models \citep{ma2021functional, fortuin2021bayesian, wild2022generalized}. 
% Secondly, due to the over-parameterized nature of deep models, the high-dimensional inference is computationally expensive \citep{neal2012bayesian} and the variational posterior may suffer from multi-modes and other pathologies \citep{foong2020expressiveness}.

To avoid these issues, there has been increasing attention to performing Bayesian inference in function space instead of parameter space \citep{ma2019variational, rudner2020rethinking, rudner2022tractable, pielok2023approximate}. In such an inference framework, the distributions of function mappings defined by models are treated as probability measures in function space induced by the distributions over model parameters, and then the variational objective is defined in terms of the distributions over functions directly. In this situation, one can take advantage of more informative stochastic process priors, such as the classic \textit{Gaussian Processes} (GPs), which can easily encode prior knowledge about function properties (e.g., periodicity and smoothness) through corresponding kernel functions. In order to approximate posterior distributions over functions, existing function-space inference methods \citep{sun2019functional} explicitly build and minimize the divergence or distance between the true posterior and the variational posterior processes and develop a tractable estimate procedure for the functional variational objective. 

Like parameter-space variational methods, function-space inference methods mostly use the Kullback–Leibler (KL) divergence as the measure of dissimilarity. However, such KL divergence for distributions over infinite-dimensional functions may be infinite \citep{burt2020understanding}, leading to the ill-defined variational objective. Specifically, as a key to the definition of KL divergence between probability measures, the existence of Radon-Nikodym derivatives between the prior and the variational approximate posterior must satisfy that the latter is absolutely continuous with respect to the former \citep{matthews2016sparse, burt2020understanding}, which may not be satisfied in some situations. For example, the KL divergence between distributions over functions generated from two Bayesian neural networks with different network structures can be infinite \citep{ma2021functional}. 

In this work, we investigate a new functional variational inference method using a  Wasserstein bridge as a dissimilarity measure for distributions called \textit{Functional Wasserstein Bridge Inference} (FWBI), which can avoid the limitations of KL divergence for distributions over functions. Our main contributions are as follows:
\begin{itemize}
    \item We propose a new Bayesian inference framework in function space to avoid the limitations of parameter-space mean-field variational inference, such as the difficulties of defining meaningful priors and uncontrolled pathologies from over-parametrization. 
    \item We propose a variational objective in terms of distributions over functions based on the Wasserstein bridge as the alternative for KL divergence between probability measures. We prove that our objective function is a lower bound of the model evidence and therefore is a well-defined objective for Bayesian inference.
    \item We evaluate the proposed method by comparing it against competing parameter-space and function-space inference approaches on several tasks to demonstrate its highly predictive performance and reliable uncertainty estimation.
\end{itemize}


\section{Preliminaries}
\label{sec:prelimi}

Consider a supervised learning task with dataset $\mathcal{D}=\left\{\left(x_i, y_i\right)\right\}_{i=1}^n = \left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\}$, where $x_i\in \mathcal{X} \subseteq \mathbb{R}^d$ are the training inputs and $y_i\in \mathcal{Y} \subseteq \mathbb{R}^c$ denote the corresponding targets. Let $f(\cdot; \mathbf{w}): \mathcal{X} \rightarrow \mathcal{Y}$ be a function mapping defined by an arbitrary machine learning model with model parameters $\mathbf{w}$. For example, $f$ can be the function mapping given by a Bayesian neural network (BNN), which is one of the most representative BDL models. BNNs are stochastic neural networks, and their parameters (weights) are multivariate random variables resulting in a random function $f(\cdot; \mathbf{w})$, that is, a stochastic process. When evaluated at finite marginal points $\mathbf{X}$ in the input domain,  $f(\mathbf{X}; \mathbf{w})$ turns into a multivariate random variable.

\paragraph{Parameter-space variational inference}
BDL models are usually trained with Bayesian inference by placing prior distributions over model parameters, such as $p_0(\mathbf{w})$ is the prior distribution for random network weights in BNNs defined on a probability space $(\Omega, \mathcal{A}, P)$. Given the training data $\left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\}$ and a proper likelihood $p(\mathbf{Y}_\mathcal{D}|\mathbf{X}_\mathcal{D}, \mathbf{w})$ evaluated with the training set, the posterior of weights then can be inferred as $p(\mathbf{w}|\mathcal{D})\propto p_0(\mathbf{w})p(\mathbf{Y}_\mathcal{D}|\mathbf{X}_\mathcal{D}, \mathbf{w})$. However, due to the non-linear nature of the function mapping $f$ in terms of random $\mathbf{w}$, the marginal integration required in solving the posterior over weights is intractable for any practical dimension. Variational inference \citep{wainwright2008graphical} is one of the most popular approximation approaches to convert the problem of estimating the posterior distribution into a tractable optimization problem. The goal of variational inference is to fit an approximate posterior distribution $q(\mathbf{w}; \boldsymbol{\theta_q})$ parametrized by $\boldsymbol{\theta_q}$ from a tractable variational family by minimizing the KL divergence between it and the true posterior as $\text{min}_{q(\mathbf{w}; \boldsymbol{\theta_q})}\mathrm{KL}[q(\mathbf{w}; \boldsymbol{\theta_q})\| p(\mathbf{w}|\mathcal{D})]$, which is equivalent to maximizing the evidence lower bound (ELBO) as follows: 
\begin{equation}
\begin{aligned}
\mathcal{L}_{q(\mathbf{w};\boldsymbol{\theta_q})}:=&\mathbb{E}_{q(\mathbf{w};\boldsymbol{\theta_q})}\left[\log p(\mathbf{Y}_\mathcal{D} \mid \mathbf{X}_\mathcal{D}, \mathbf{w}; \boldsymbol{\theta_q})\right]\\
&-\mathrm{KL}[q(\mathbf{w}; \boldsymbol{\theta_q}) \| p_0(\mathbf{w})].
\end{aligned}
\label{eq:vip}
\end{equation}
% This parameter-space variational objective can be optimized by stochastic gradient techniques when the prior $p_0(\mathbf{w})$ is chosen to be simple i.i.d. Gaussian and the variational $q(\mathbf{w};\boldsymbol{\theta_q})$ is assumed to be a fully-factorized Gaussian distribution. Bayes by Backprop (BBB) \citep{blundell2015weight} is one of the most commonly used mean-field variational inference algorithms in parameter space.

\paragraph{Function-space variational inference} 
The core idea of function-space variational inference is to view a Bayesian deep learning model as a distribution of functions. The random function mapping (product measurable) defined on the function space $\mathcal{H}$ (Polish space) via a BDL model is given by $f(\cdot; \mathbf{w}): \mathcal{X} \times \Omega \rightarrow \mathcal{Y}$, which is $\mathcal{A}$ measurable for every $x \in \mathcal{X}$. Let $p_0(f)$ be the prior distribution over the stochastic functions. Like Bayesian inference in parameter space, the main goal is to infer the posterior over functions $p(f|\mathcal{D})$ combined with the likelihood $p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}))$ evaluated at the training data $\mathcal{D}=\left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\}$. However, it would be intractable for most stochastic processes. For example, as for BNNs, $p_0(f)$ is the prior distribution over functions induced by the prior distribution over random network weights $p_0(\mathbf{w})$, and there is no explicit probability form for it. Similar to parameter-space variational inference, the variational objective in function space can be denoted as $\text{min}_{q(f; \boldsymbol{\theta_q})}\mathrm{KL}[q(f; \boldsymbol{\theta_q})\| p(f|\mathcal{D})]$, where $q(f; \boldsymbol{\theta_q})$ is the variational posterior over functions induced by the variational posterior over parameters. The functional ELBO, as the variational objective function in function space to be maximized, is
\begin{equation}
\begin{aligned}
\mathcal{L}_{q(f;\boldsymbol{\theta_q})}^{KL}:= &\mathbb{E}_{q(f;\boldsymbol{\theta_q})}\left[\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))\right]\\
&-\mathrm{KL}[q(f; \boldsymbol{\theta_q}) \| p_0(f)]\label{eq:vif}, 
\end{aligned}
\end{equation}
where one can effectively incorporate prior information about the task via the $p_0(f)$ in the KL term during the optimization. For estimation of the above functional ELBO, the following three issues need to be carefully considered:
\begin{itemize}
    \item The first one concerns the validity of the definition of the objective function. In order to guarantee the existence of the Radon-Nikodym derivative in the KL divergence between two distributions over functions, it is necessary to satisfy that $q(f; \boldsymbol{\theta_q})$ is absolutely continuous with respect to $p_0(f)$. Specifically, for $q(f; \boldsymbol{\theta_q})$ and $p_0(f)$ generated from the same function mapping with different parameter distributions, such as the same neural network structure for BNNs, it should satisfy $\mathrm{KL}[q(\mathbf{w}; \boldsymbol{\theta_q})\| p_0(\mathbf{w})] \leq \infty$ to guarantee that $\mathrm{KL}[q(f; \boldsymbol{\theta_q})\| p_0(f)] \leq \infty$ according to the strong data processing inequality\citep{polyanskiy2017strong} as $\mathrm{KL}[q(f; \boldsymbol{\theta_q})\| p_0(f)] \leq  \mathrm{KL}[q(\mathbf{w}; \boldsymbol{\theta_q})\| p_0(\mathbf{w})]$.
    
    \item The second issue concerns the implicit probability density functions for $q(f; \boldsymbol{\theta_q})$ and $p_0(f)$. Note that since $\mathcal{H}$ is an infinite-dimensional function space, $p_0(f)$ and $q(f;\boldsymbol{\theta_q})$ are not actually the probability density functions with respect to the Lebesgue measure, but the probability measures over $\mathcal{H}$. Even for the marginal multivariate random vector $p_0(f(\mathbf{X}))$ and $q(f(\mathbf{X}; \boldsymbol{\theta_q}))$ at finite input points $\mathbf{X}$, the explicit probability density functions are intractable for some stochastic processes, e.g., BNNs and other non-linear models.

    \item The third problem is the effective and efficient estimation of the KL divergence between two stochastic processes. To solve the intractable infinite-dimensional KL divergence between distributions over functions, \citet{sun2019functional} proved that 
    \begin{equation}
    \begin{aligned}
         \mathrm{KL}&[q(f; \boldsymbol{\theta_q}) \| p_0(f)] \\
       & = \sup_{n\in\mathbb{N}, \mathbf{X}\in\mathcal{X}^n}\mathrm{KL}[q(f(\mathbf{X};\boldsymbol{\theta_q})) \| p_0(f(\mathbf{X}))],
    \end{aligned}
    \end{equation}
    where $\mathcal{X}^n = \cup \{\mathbf{X} \in \mathcal{X}^n |\mathcal{X}^n \in \mathbb{R}^n\}$. In other words, the functional KL divergence is equivalent to the supremum of all KL divergence over marginal finite measurement points. Unfortunately, there is no analytical way to obtain such supremum in practical optimization.
\end{itemize}
 % there is no easy way to obtain such a supremum in practical optimization because of the infinite number of possible combinatorial subsets even for a medium sample size. Moreover, even for finite $\mathbf{X}$, it is challenging to solve the marginal KL divergence due to the implicit probability density functions.

% As discussed by \citet{knoblauch2022generalized}, in a more general variational optimization framework for Bayesian inference, the KL divergence can be replaced by any distance measure that satisfies basic properties of the similarity measure, such as the well-defined Wasserstein distance. 
% % See \appendixref{apd:first} for further details.

% \paragraph{Wasserstein distance} 
% The \textit{Wasserstein distance} \citep{kantorovich1960mathematical, villani2003topics} is a rigorously defined distance metric on probability measures satisfying non-negativity, symmetry and triangular inequality \citep{panaretos2019statistical} that was originally proposed for the optimal transport problem and has become popular in the machine learning community in recent years \citep{arjovsky2017wasserstein}. Suppose $(\mathcal{P}, \|\cdot\|)$ is a Polish space, the p-Wasserstein distance between probability measures $\mu$, $\nu\in (\mathcal{P}, \|\cdot\|)$ is defined as
% \begin{equation}
%     W_p(\mu, \nu)=\left(\inf _{\gamma \in \Gamma(\mu, \nu)} \int_{\mathcal{P} \times \mathcal{P}}\|x-y\|^p \mathrm{~d} \gamma(x, y)\right)^{1 / p},
% \end{equation}
% where $\Gamma(\mu, \nu)$ is the set of joint measures or coupling  $\gamma$ with marginals $\mu$ and $\nu$ on $\mathcal{P} \times \mathcal{P}$. 

% Distance measure $\|\cdot\|^p$ quantifies the effort of transporting one unit mass from measure $\mu$ to $\nu$, and p-Wasserstein distance interprets the minimal cost of reconfiguring mass distribution of one probability measure to another. 


\section{Our method}

The main obstacle in existing function-space variational inference is the definition and estimation issues regarding the KL divergence between distributions over functions. In this section, we propose a novel Wasserstein distance-based variational objective to avoid the limitations of KL divergence and improve approximation inference in function space. We first propose a two-step variational method via a functional prior and a bridging distribution to approximate the posterior indirectly. In the first step, we distill a functional prior by fitting a bridging distribution over functions. In the second step, we form a new ELBO by matching the variational posterior and the bridging distribution in parameter space using the 2-Wasserstein distance as a surrogate for the KL divergence. Then, we further propose an integrated variational objective to jointly optimize the bridging distribution and the variational posterior.
% Considering the isotropy of the 1-Wasserstein distance utilized in the first step


\subsection{Functional Prior Induced Variational Inference}
\label{sec:gpibnn}

Suppose a random function mapping $f(\cdot; \mathbf{w}): \mathcal{X} \times \mathbb{R}^k \rightarrow \mathcal{Y}$ parametrized by random $\mathbf{w} \in \mathbb{R}^k$ is defined by a BDL model. The main variational objective is to obtain the approximate posterior $q(f; \boldsymbol{\theta_q})$ induced by the variational posterior over parameters $q(\mathbf{w}; \boldsymbol{\theta_q})$. Let $g(\cdot; \mathbf{w}_{\boldsymbol{b}})$ be a latent random function with parameters $\mathbf{w}_{\boldsymbol{b}}\in \mathbb{R}^k$ and $p(g; \boldsymbol{\theta_b})$ denotes a bridging distribution over functions induced by the distribution over model parameters $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})$. Assume that $p(g; \boldsymbol{\theta_b})$ and $q(f; \boldsymbol{\theta_q})$ are generated from the same function structure with different parametric distributions (e.g., same BNNs structure with different distributions over weights). 


\textbf{Distilling a functional prior using a bridging distribution over functions}.
Considering that GPs are well-developed priors in function space that are known to be interpretable and are able to incorporate prior knowledge about the prediction task in hand, we can assign a GP prior denoted by $p_0(f)\sim \mathcal{GP}(\mathbf{m}, \mathbf{K})$ for random $f$. Due to the intractable KL divergence between the GP prior and the non-GP variational posterior in the functional ELBO \citep{rudner2022tractable}, we firstly distill the GP prior to the bridging distribution over functions by minimizing the 1-Wasserstein distance between $p(g; \boldsymbol{\theta_b})$ and  $p_0(f)$ \citep{tran2022all} with the dual form as follows:
\begin{equation}
    W_1(p(g; \boldsymbol{\theta_b}), p_0(f)) = \sup_{\|\phi\| \leq1} \mathbb{E}_{\mathbf{x}\sim p(g; \boldsymbol{\theta_b})}\phi(\mathbf{x})-\mathbb{E}_{\mathbf{y}\sim p_0}\phi(\mathbf{y}).
\end{equation}
Specifically, we solve the above 1-Wasserstein distance on finite randomly sampled measurement points $\mathbf{X}_{\mathcal{M}} \stackrel{\text { det }}{=}\left[\mathbf{x}_1, \ldots, \mathbf{x}_M\right]^{\mathrm{T}}$ as $ W_1(p(g(\mathbf{X}_{\mathcal{M}}; \boldsymbol{\theta_b})), p_0(f(\mathbf{X}_{\mathcal{M}})))$ in practice due to the infinite-dimensional nature of random functions. The 1-Wasserstein distance between distributions over functions is now reduced to that over multivariate random variables. The specific form is as follows:
\begin{equation}
    \begin{aligned}
        W_1(&p(g(\mathbf{X}_{\mathcal{M}}; \boldsymbol{\theta_b})), p_0(f(\mathbf{X}_{\mathcal{M}}))) =\\
       & \sup \mathbb{E}_{\mathbf{X}_{\mathcal{M}}}\left[\mathbb{E}_{p(g; \boldsymbol{\theta_b})}\phi(g(\mathbf{X}_{\mathcal{M}}))-\mathbb{E}_{p_{0}}\phi(f(\mathbf{X}_{\mathcal{M}}))\right],
    \end{aligned}
    \label{eq:w1}
\end{equation}
where $\phi$ is a 1-Lipschitz continuous function. $g(\mathbf{X}_{\mathcal{M}})$ and  $f(\mathbf{X}_{\mathcal{M}})$ are corresponding function values evaluated at $\mathbf{X}_{\mathcal{M}}$, respectively. 
 % \textcolor{red}{Not that there is another interesting work \citep{liu2023simple} uses the bi-Lipschitz condition to improve the uncertainty quality for single networks.}
It can be seen that this approximated computation procedure is based entirely on sampling, so it can still be performed smoothly even without the closed form of $p(g; \boldsymbol{\theta_b})$. At the same time,  the functional prior-induced bridging distribution over parameters is denoted as $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$, where $\boldsymbol{\theta_b}^* = \argmin_{\boldsymbol{\theta_b}} W_1(p(g; \boldsymbol{\theta_b}), p_0(f))$. It is worth noting that this distillation procedure could be applied to any prior distributions over functions as long as their random function samples are available.

\textbf{Matching the variational posterior and the bridging distribution in parameter space}.
The main idea of this inference method is to force distribution over functions $q(f; \boldsymbol{\theta_q})$ and $p(g; \boldsymbol{\theta_b})$ to share the same function structures. That is, once the optimal functional prior-induced bridging distribution over parameters $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$ is obtained from the above distilling process by fitting $p_0(f)$ to $p(g; \boldsymbol{\theta_b})$, $\boldsymbol{\theta_b}^*$ is frozen and $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$ is used in the regularization term of ELBO to the variational $q(\mathbf{w}; \boldsymbol{\theta_q})$ as
\begin{equation}
\begin{aligned}
    % \mathcal{L}_{q(\mathbf{w}; \boldsymbol{\theta_q})} = &\mathbb{E}_{q(\mathbf{w}; \boldsymbol{\theta_q})}\left[\log p(\mathcal{D} \mid \mathbf{w}; \boldsymbol{\theta_q})\right]\\
    \mathcal{L}_{q(\mathbf{w}; \boldsymbol{\theta_q})} =& \mathbb{E}_{q(\mathbf{w};\boldsymbol{\theta_q})}\left[\log p(\mathbf{Y}_\mathcal{D} \mid \mathbf{X}_\mathcal{D}, \mathbf{w}; \boldsymbol{\theta_q})\right] \\
   &- \lambda W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)),
\end{aligned}
\label{eq:gpibnn}
\end{equation}
where $W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*))$ is the 2-Wasserstein distance between the approximate posterior $q(\mathbf{w}; \boldsymbol{\theta_q})$ and the bridging distribution over parameters $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$, $\lambda$ is a hyperparameter. Suppose $q(\mathbf{w}; \boldsymbol{\theta_q})$ and $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})$ are two Gaussian distributions, then $W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*))$ has an analytical solution as 
\begin{equation}
   \begin{aligned}
       & W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)) = \\
        & \left\|\mathbf{\mu}_q-\mathbf{\mu}_b^*\right\|_2^2+\operatorname{trace}\left(\mathbf{\Sigma}_q+\mathbf{\Sigma}_b^*-2\left(\mathbf{\Sigma}_q^{1 / 2} \mathbf{\Sigma}_b^* \mathbf{\Sigma}_q^{1 / 2}\right)^{1 / 2}\right),
   \end{aligned}
   \label{eq:w2}
\end{equation}
 where $\boldsymbol{\theta_q} := \{\mathbf{\mu}_q, \mathbf{\Sigma}_q\}, \boldsymbol{\theta_b}^* := \{\mathbf{\mu}_b^*, \mathbf{\Sigma}^*_b\}$ are respective mean and covariance matrices. 
We call this improved variational inference approach based on the functional prior the Functional Prior-induced Variational Inference (FPi-VI). Note that although Gaussian Wasserstein Inference (GWI) proposed by \citet{wild2022generalized} also adopted the 2-Wasserstein distance in its functional variational inference, ours is different because Equation \ref{eq:w2} is actually in the parameter-space since we have already distilled a parameter-space counterpart of the functional prior but the one in GWI is in function space. On the surface, the function-space 2-Wasserstein distance used in GWI is more straightforward and reasonable, but it has a restriction since they have to use a GP posterior and its mean function is parameterized by a deterministic neural network to approximate the BNN posterior, and such restriction would lose the strong capability of BNN on uncertainty modeling. Our method does not have such restriction and our optimization target is still the BNN posterior. 
A pseudocode of FPi-VI is presented in Algorithm \ref{alg:fpi_vi} in Appendix \ref{apd:B}.

\subsection{Functional Wasserstein Bridge Inference}
\label{sec:wassbri}

Due to the isotropy of the 1-Wasserstein distance, there will be an infinite number of candidate $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$ with the exactly same distance to a given functional prior, and FPi-VI just randomly picks one from all candidates in the first distilling step. Such randomness brings large fluctuations to the following inference performance (see Appendix \ref{apd:e1}). Therefore, we further treat parameters of bridging distributions $\boldsymbol{\theta_b}$ as parameters that need to be optimized together with variational posterior parameters to obtain a more robust solution. 
Since there is no analytical solution to estimate the functional distance (e.g., KL divergence and Wasserstein distance) directly between the functional prior and variational posterior over functions for all but extremely simple distributions such as GPs, our key ingredient is to build a \textit{Wasserstein bridge} to decompose such functional distance into a parameter-space 2-Wasserstein distance and a function-space 1-Wasserstein distance at the same time in optimization as
\begin{equation}
    \begin{aligned}
        W_B(q(f; \boldsymbol{\theta_q}), p_0(f)) =& \lambda_1W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})) \\
       &+ \lambda_2W_1\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right),
    \end{aligned}
\end{equation}
where $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})$ is the bridging distribution over parameters,  $p(g; \boldsymbol{\theta_b})$ is the bridging distribution over functions induced by $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})$, $\boldsymbol{\theta_q}$ and $\boldsymbol{\theta_b}$ are the respective stochastic parameters of the approximate variational posterior and bridging distribution which would be optimized jointly, and $\lambda_1$, $\lambda_2$ are two hyperparameters. 

% W_1\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right) + M_2\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right)

Based on the above Wasserstein bridge, we propose a variational objective in function space called Functional Wasserstein Bridge Inference (FWBI) and derive a practical algorithm to obtain the optimal $\{\boldsymbol{\theta_q}^*,  \boldsymbol{\theta_b}^*\}$ as:
\begin{equation}
    \begin{aligned}
     &\argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} -\mathbb{E}_{q(f; \boldsymbol{\theta_q})}\left[\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))\right]
   \\& \quad\quad\quad~~ + W_B(q(f; \boldsymbol{\theta_q}), p_0(f)) \\
   % =& \argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} -\frac{1}{M} \sum_{j=1}^M \left[\log p(\mathbf{Y}_\mathcal{B} \mid f(\mathbf{X}_\mathcal{B};\mathbf{\mu}_q, \mathbf{\Sigma}_q)\right] \\
   = & \argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} -\frac{1}{M} \sum_{(x, y) \in \mathcal{B}} \mathbb{E}_{q(f; \boldsymbol{\theta_q})} \left[\log p(y \mid f(x;\mathbf{\mu}_q, \mathbf{\Sigma}_q))\right] \\
  &\quad\quad\quad~ +  \lambda_1W_2(q(\mathbf{w}; \mathbf{\mu}_q, \mathbf{\Sigma}_q), p(\mathbf{w}_{\boldsymbol{b}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)) \\
 &\quad\quad\quad~~ + \lambda_2W_1\left(p(g(\mathbf{X}_{\mathcal{M}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)), p_0(f(\mathbf{X}_{\mathcal{M}}))\right),
    \label{eq:ifbnn}
    \end{aligned}
\end{equation}
% \begin{equation}
%     \begin{aligned}
%    & \{\boldsymbol{\theta_q}^*,  \boldsymbol{\theta_b}^*\} \\
%    & = \argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} -\mathbb{E}\left[\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})\right])+ W_B(q(f; \boldsymbol{\theta_q}), p_0(f)) \\
%   & = \argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} -\frac{1}{M} \sum_{j=1}^M \left[\log p(\mathbf{Y}_\mathcal{B} \mid f(\mathbf{X}_\mathcal{B};\mathbf{\mu}_q, \mathbf{\Sigma}_q)\right] \\
%   & +  W_B(q(f; \boldsymbol{\theta_q}), p_0(f))
%     \label{eq:ifbnn}.
%     \end{aligned}
% \end{equation}

% \begin{equation}
%     \begin{aligned}
%     \{\boldsymbol{\theta_q}^*,  \boldsymbol{\theta_b}^*\} = \argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} &-\mathbb{E}\left[\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})\right])+ W_B(q(f; \boldsymbol{\theta_q}), p_0(f)) \\
%    = \argmin_{\boldsymbol{\theta_q},  \boldsymbol{\theta_b}} &-\frac{1}{M} \sum_{j=1}^M \left[\log p(\mathbf{Y}_\mathcal{B} \mid f(\mathbf{X}_\mathcal{B};\mathbf{\mu}_q, \mathbf{\Sigma}_q)\right] + \lambda_1W_2(q(\mathbf{w}; \mathbf{\mu}_q, \mathbf{\Sigma}_q)), p(\mathbf{w}_{\boldsymbol{b}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)))\\
%     & + \lambda_2W_1\left(p(g(\mathbf{X}_{\mathcal{M}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b))), p_0(f(\mathbf{X}_{\mathcal{M}}))\right)\label{eq:ifbnn}.
%     \end{aligned}
% \end{equation}
where $\mathcal{B} = \{x_j, y_j\}_{j=1}^{M}$ is the mini-batch of the training data $\left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\}$ applied to the likelihood term, $\mathbf{w}$ and $\mathbf{w}_{\boldsymbol{b}}$ can be reparameterized as $\mathbf{w} = \mathbf{\mu}_q + \mathbf{\Sigma}_q \odot \mathbf{\epsilon}$, $\mathbf{w}_{\boldsymbol{b}} = \mathbf{\mu}_b + \mathbf{\Sigma}_b \odot \mathbf{\epsilon}$ respectively with random noisy $\mathbf{\epsilon}$ under the Gaussian assumption, and $\mathbf{X}_{\mathcal{M}}$ denotes the finite measurement set from the input space for the estimation of 1-Wasserstein distance in function space. 

\begin{prop}
The functional variational objective derived from FWBI based on the Wasserstein bridge is a lower bound of the model evidence.
\label{prop:1}
\end{prop}

The above proposition shows that our variational objective of FWBI is a lower bound of the model evidence, which indicates it is a well-defined objective for Bayesian inference. The proof is based on the law of cosines for the KL divergence and Talagrand inequality of probability measures (please see Appendix \ref{apd:C} for more details).


Although the 1-Wasserstein distance used in Equation (\ref{eq:ifbnn}) could link GP prior with functional bridging distribution, we found that it is weak on the higher-order moments matching in practice. To preserve the uncertainty knowledge encoded in the functional prior, we propose the following enhanced version of the 1-Wasserstein distance with an additional second-order moment matching term as
\begin{equation}
    \begin{aligned}
        \Tilde{W}_1\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right) =& 
        W_1\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right) \\
        & + M_2\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right),
    \end{aligned}
\end{equation}
where $M_2\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right)$ is a second-order moment matching term as $|var(p(g; \boldsymbol{\theta_b})) -var(p_0(f))|$. We use this enhanced 1-Wasserstein distance in the Wasserstein bridge of FWBI in practice. The pseudocode for FWBI is shown in Algorithm \ref{alg:fwbi} in Appendix \ref{apd:B}.

Different from most functional variational inference approaches \citep{sun2019functional, ma2021functional, rudner2022tractable} that perform variational optimization directly on the approximate posterior over functions, the proposed FWBI treats the bridging distribution over functions induced by distributions over parameters as an intermediate variable to link the variational posterior and functional prior. Specifically, FWBI distills a functional prior by the bridging distribution over functions $p(g; \boldsymbol{\theta_b})$ induced by $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})$ via the enhanced 1-Wasserstein distance with an additional second-order matching term and matches variational posterior $q(\mathbf{w}; \boldsymbol{\theta_q})$ and $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b})$ by minimizing the 2-Wasserstein distance simultaneously. 

The main advantages of FWBI are as follows: i) in general parameter-space mean-field variational inference, it is common to use an i.i.d Gaussian prior assumption for distributions over parameters \citep{blundell2015weight}, while FWBI can assign a more interpretable functional prior and incorporate meaningful information about the task into inference process; ii) FWBI utilizes the well-defined Wasserstein distance-based Wasserstein bridge to regularize parameters of variational posterior with a functional prior, which can circumvent the limitation of functional KL divergence used in most function-space approximate inference methods \citep{sun2019functional} and can quantify uncertainty more accurately with a second-order moment matching term in the enhanced 1-Wasserstein distance. Once the optimal $\boldsymbol{\theta_q}^*$ is obtained, the posterior predictive distribution is obtained by the following integration process and can be estimated through Monte Carlo sampling:
\begin{equation}
\begin{aligned}
    q(\mathbf{y}^*|\mathbf{x}^*) &= \int p(\mathbf{y}^*|f(\mathbf{x}^*; \boldsymbol{\theta_q}^*))q(f(\mathbf{x}^*; \boldsymbol{\theta_q}^*)) df(\mathbf{x}^*;\boldsymbol{\theta_q}^*)\\
    &\approx \frac{1}{S} \sum_{j=1}^{S} p(\mathbf{y}^*|f(\mathbf{x}^*; \mathbf{w}^{(j)})),
\end{aligned}
\end{equation}
where
$\mathbf{w}^{(j)} \sim \mathcal{N}(\mathbf{\mu}_q^*, \mathbf{\Sigma}_q^*), \boldsymbol{\theta_q}^* := \{\mathbf{\mu}_q^*, \mathbf{\Sigma}_q^*\}$.
% Note that \cite{wild2022generalized2} also proposed a generalized variational inference in function space based on the 2-Wasserstein distance. In order to obtain the analytical solution, they only assigned the variational posterior and prior as Gaussian measures, which eases the computation but limits its ability due to the very specific types of distribution. In contrast, our FWBI is not restricted to any distributional assumptions: the variational posterior and prior over functions can be any reasonable stochastic process for the specific tasks.

% Moreover, we proved that the variational objective of FWBI is a lower bound of the model evidence, which indicates it is a well-defined objective for Bayesian inference. Based on the law of cosines for the KL divergence and Talagrand inequality of probability measures (Appendix \ref{apd:A}), we derive the following Proposition 1 (see Appendix \ref{apd:C} for more details and proof):

% \begin{prop}
% For the ELBO $\mathcal{L}^{W}$ derived from FWBI based on the Wasserstein bridge and the functional ELBO $\mathcal{L}^{kl}$ defined in Equation \ref{eq:vif} based on the KL divergence, we have
% % \[ \log p(\mathcal{D}) \geq \mathcal{L}^{W} \geq \mathcal{L}^{KL} \]
% \[ \log p(\mathcal{D}) \geq \mathcal{L}^{KL} \geq \mathcal{L}^{W} \]
% where $\log p(\mathcal{D})$ is the log model evidence.
% \end{prop}



% To analyze the theoretical properties of FWBI for posterior variational inference, we further derive a new corresponding functional ELBO as follows:
% \begin{equation}
% \mathcal{L}^{W}:= \mathbb{E}\left[\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})\right] - \lambda_1W_2(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b})) - \lambda_2W_1\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right)\label{eq:ifbnn_elbo},
% \end{equation}
% where $q(f; \boldsymbol{\theta_q})$ is the variational distribution over functions induced by approximate posterior $q(\mathbf{w}; \boldsymbol{\theta_q})$. $W_2(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b}))$ is calculated by corresponding $W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), $ $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}))$. As a variational Bayesian objective, it is worthwhile to explore whether this new ELBO based on Wasserstein bridge is still a lower bound of the log marginal likelihood. Based on the law of cosines for the KL divergence and Talagrand inequality of probability measures (Appendix \ref{apd:A}), we derive the following proposition:
% \begin{prop}
% For the ELBO $\mathcal{L}^{W}$ derived from FWBI based on the Wasserstein bridge and the functional ELBO $\mathcal{L}^{kl}$ defined in \Eqref{eq:vif} based on the KL divergence, we have
% % \[ \log p(\mathcal{D}) \geq \mathcal{L}^{W} \geq \mathcal{L}^{KL} \]
% \[ \log p(\mathcal{D}) \geq \mathcal{L}^{KL} \geq \mathcal{L}^{W} \]
% where $\log p(\mathcal{D})$ is the log model evidence.
% \end{prop}
% Proposition 1 shows that $\mathcal{L}^{W}$ is a valid variational objective function since it is a lower bound of the model evidence (see Appendix \ref{apd:P} for the proof).


% For the second inequality, according to the Talagrand inequality, we can derive:
% \begin{equation}
% \begin{aligned}
%      \mathcal{L}^{W} &:= \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))] - \lambda_1 W_2^2(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b})) - \lambda_2 W_1^2(p(g; \boldsymbol{\theta_b}), p_0(f))\\
%      & \geq \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))] - \lambda_1 \frac{2 \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})]}{\rho_1} - \lambda_2 \frac{2 \mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)]}{\rho_2}\\
%      & \geq \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))] - ( \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})] + \mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)])\\
%      & = \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)] - \int \log \frac{dp_0(f)}{dp(g; \boldsymbol{\theta_b})}[dq(f; \boldsymbol{\theta_q}) - dp(g; \boldsymbol{\theta_b})]\\
%      & =  \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)] -( \int \log \frac{dp_0(f)}{dp(g; \boldsymbol{\theta_b})} dq(f; \boldsymbol{\theta_q}) + \mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)])\\
%      & = \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)] -(\mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f) +\\ & \mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)])
% \end{aligned}
% \end{equation}
% To prove the second inequality, we only need to prove that
% \[\mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})] + \mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)] \leq \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)]\]
% Notice that
% \begin{equation}
%     \begin{aligned}
%          &\mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})] + \mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)]\\
%          &= \mathbb{E}_{q(f; \boldsymbol{\theta_q})} \log q(f; \boldsymbol{\theta_q}) - \mathbb{E}_{q(f; \boldsymbol{\theta_q})} \log p(g; \boldsymbol{\theta_b}) + \mathbb{E}_{p(g; \boldsymbol{\theta_b})} \log p(g; \boldsymbol{\theta_b}) - \mathbb{E}_{p(g; \boldsymbol{\theta_b})} \log p_0(f)\\
%         & = \underbrace{( \mathbb{E}_{q(f; \boldsymbol{\theta_q})} \log q(f; \boldsymbol{\theta_q}) - \mathbb{E}_{p(g; \boldsymbol{\theta_b})} \log p_0(f))}_{\phi_1 = \log q, \phi_2 = \log p_{gp}} + \underbrace{(\mathbb{E}_{p(g; \boldsymbol{\theta_b})} \log p(g; \boldsymbol{\theta_b}) - \mathbb{E}_{q(f; \boldsymbol{\theta_q})} \log p(g; \boldsymbol{\theta_b}))}_{\phi = \log p}\\
%         & \leq W_1(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b})) + W_1(p(g; \boldsymbol{\theta_b}), q(f; \boldsymbol{\theta_q}))
%     \end{aligned}
% \end{equation}
% if $\log q(f; \boldsymbol{\theta_q}), \log p_0(f)$, and $\log p(g; \boldsymbol{\theta_b})$ are 1-Lipschitz functions. So, we only need to satisfy the additional condition:
% \begin{equation}
%     \underbrace{2W_1(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b})) \leq W_1^2(q(f; \boldsymbol{\theta_q}), p_0(f))}_{\textbf{additional optimize condition}} \leq \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)]
% \end{equation}

\section{Related Works}
Based on the variational inference methods in parameter spaces, there is an increasing number of works focusing on the function-space variational approaches for various BDL models such as BNNs, and their applications on a range of machine learning tasks where predictive uncertainty quantification is crucial \citep{Benjamin2019measuring, titsias2019functional, pan2020continual, rudner2022continual}.

\textbf{Variational inference in parameter spaces}. 
Parameter-space variational methods are widely used for approximating posterior of weights in BNNs. It is \citet{hinton1993keeping} first used variational inference in BNNs. \citet{barber1998ensemble} replaced a fully factorized Gaussian assumption for variational posterior with a full rank Gaussian to model correlations between weights. Then \citet{graves2011practical} proposed a sub-sampling technique to approximate the expected log-likelihood by Monte Carlo integration. To improve this work, \citet{blundell2015weight} developed an algorithm for variational inference called Bayes by Backprop (BBB) based on the reparameterization trick, which could yield an unbiased gradient estimator of ELBO w.r.t model parameters. 
% In the same period, \citet{gal2016uncertainty} proved that dropout neural networks are equivalent to performing variational inference for approximating posteriors over weights. 

\textbf{Variational inference in function spaces}. 
Due to the limitations of parameter-space variational inference, such as the intractability of specifying meaningful priors, \citet{sun2019functional} proposed a kind of functional ELBO to match a GP prior and the variational posterior over functions for BNNs via a spectral Stein gradient estimator designed for implicit distributions \citep{shi2018spectral}. However, the KL divergence between stochastic processes involved in the functional ELBO may be ill-defined for a wide class of distributions and further leads to an invalid variational objective \citep{burt2020understanding}. At the same time, \citet{wang2018function} proposed a particle optimization
variational inference method in function spaces for posterior approximation in BNNs. \citet{rudner2020rethinking, rudner2022tractable} pointed out that the supremum of marginal KL divergence over finite measurement sets cannot be solved analytically for the estimation of functional KL divergence. They proposed to approximate the distributions over functions as Gaussian via the linearization of their mean parameters and derived a tractable and well-defined variational objective since the functional prior and variational posterior are two BNNs that share the same network structures. \citet{ma2021functional} randomized the number of finite measurement points to derive an alternative grid-functional KL divergence, which can avoid some limitations of KL divergence between stochastic processes. However, all these methods are based on KL divergence. Considering the potential weaknesses of KL divergence, there are some recent works trying to use Wasserstein distance \citep{kantorovich1960mathematical, villani2003topics} to replace KL divergence. \citet{tran2020functional} proposed to match a BNN prior to a GP prior by minimizing the 1-Wasserstein distance to obtain more interpretable functional priors in BNNs. However, they used stochastic gradient Hamiltonian Monte Carlo (SGHMC) rather than variational inference to approximate the posterior. In contrast, our work uses an enhanced 1-Wasserstein distance with an additional second-order moment matching term to preserve the uncertainty better in the distilling between the functional prior and the bridging distribution, then develops a functional variational objective for posterior approximation based on the proposed Wasserstein bridge. \citet{wild2022generalized} built a functional variational objective called GWI where the functional prior and posterior are both Gaussian measures, and the dissimilarity measure was chosen to be the 2-Wasserstein distance. However, the critical GP components make it less applicable to general non-GP scenarios. In contrast, our FWBI is not restricted to any distributional assumptions: the variational posterior and prior over functions can be any reasonable stochastic process for the specific tasks.

% As the gradient estimator used in functional ELBO is less efficient for high-dimensional input, Moreover, they proposed an expressive variational family as the non-Gaussian generalization of variation implicit processes \citep{ma2019variational} and developed a new functional variational inference method. 


\section{Experimental Evaluation}
In this section, we evaluate the predictive performance and uncertainty quantification of FWBI on several tasks, including 1-D extrapolation toy examples, multivariate regression on UCI datasets, contextual bandits, and image classification tasks. We compare FWBI to several well-established parameter-space and function-space variational inference approaches. 

\subsection{Extrapolation Illustrative Examples}

\paragraph{Learning polynomial curves}
Consider an 1-D oscillation curve from the polynomial function:  $y = \sin(3\pi x) + 0.3 \cos(9\pi x) + 0.5 \sin(7\pi x) + \epsilon$ with noise $\epsilon \sim \mathcal{N}(0,0.5^2)$. There are 20 randomly sampled observation points, half of which are sampled from the interval $[-0.75,-0.25]$, and the other half are from $[0.25,0.75]$. For parameter-space variational inference comparison, we choose BBB \citep{blundell2015weight} using KL divergence for distributions over parameters, denoted by KLBBB, and a 2-Wasserstein distance alternative version called WBBB. For functional methods, we compare with the benchmark functional BNNs (FBNN) proposed by \citet{sun2019functional}. For FWBI and FBNN, we use the same GP prior with three different kinds of kernels: RBF kernel, Matern kernel, and Linear kernel (not suitable for modelling polynomial oscillatory curves). Results are shown in Figure \ref{fig:g3}, the leftmost column shows that the two parametric inference methods, KLBBB and WBBB, fail to fit the target function, while the two function-space approaches exhibit better predictive performance. For FWBI and FBNN, we first pre-train the GP prior to obtaining a more informative functional prior. Figures \ref{fig:ifbnn_rbf_g3} and \ref{fig:ifbnn_m_g3} show that FWBI is able to recover the key polynomial characteristic of the curve in observation range and provide strong uncertainty in the unseen region of input space with appropriate RBF kernel and Matern kernel. On the other hand, the mismatched Linear kernel in Figure \ref{fig:ifbnn_l_g3} expresses a certain trend of error linearity, which indicates FWBI can effectively utilize functional prior information in the inference process. In contrast, FBNN under-fit the curve severely in both observations and non-observations with RBF kernel and Matern kernel, while results from the inappropriate Linear kernel are a little better. FBNN is less responsive to different kernel information and performs poorly in uncertainty estimation. See Appendix \ref{apd:e2} for more baseline results. Appendix \ref{apd:e3} shows detailed comparisons of posteriors of GPs and FWBI. The calibration curves for all methods are shown in Appendix \ref{apd:e4}.

\begin{figure*}[!t] %htbp
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
% \floatconts
\centering
  {
    \subfigure[KLBBB]{\label{fig:klbnn_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/klbnn}}%
    % \qquad
    \subfigure[FBNN-RBF]{\label{fig:fbnn_rbf_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/fbnn_g3_rbf}}%
    % \qquad
    \subfigure[FBNN-Matern]{\label{fig:fbnn_m_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/fbnn_g3_matern}}
    % \qquad
    \subfigure[FBNN-Linear]{\label{fig:fbnn_l_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/fbnn_g3_linear.pdf}} 
    % \quad
    \subfigure[WBBB]{\label{fig:wbnn_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/wbnn}}
    % \quad
    \subfigure[FWBI-RBF]{\label{fig:ifbnn_rbf_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/rbf_g3}}%
    \subfigure[FWBI-Matern]{\label{fig:ifbnn_m_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/matern_g3}}
    \subfigure[FWBI-Linear]{\label{fig:ifbnn_l_g3}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g3_toy/linear_g3}} 
  }
\caption{Learning polynomial curves. The green line is the ground true function, and the blue lines correspond to mean approximate posterior predictions. Black dots denote 20 training points; shadow areas represent the predictive standard deviations. The leftmost column shows two parameter-space methods, and the other three columns are the results of functional approaches based on GP priors with three different kernels. For more details, see Appendix \ref{apd:D}.}
\label{fig:g3}
\end{figure*}



% \paragraph{Interpolation with non-GP priors} 
% One of the main advantages of FWBI is that it is not constrained by specific prior and variational posterior distribution families in function spaces, whether explicit or implicit, GP or non-GP. In this experiment, we consider using a BNN as the functional prior to fit a 1-D toy example: $y(x)=\sin (x)+0.1 x +\epsilon, \epsilon \sim \mathcal{N}(0,0.5)$. 30 observations are randomly sampled from $[-7.5, -5] \cup[-2.5, 2.5] \cup[5, 7.5]$. For comparison, we also obtained the results for GP priors. As shown in Figure \ref{fig:g2}, it is obvious that FWBI has a stronger capacity to recover the key characteristics of the true function than the other inference methods with all four different priors. The BNN prior shows a competitive performance with the GP priors, and even converges faster (see Appendix \ref{apd:e5} for details). 

\paragraph{Learning periodic curves}
One of the main advantages of FWBI is the ability to encode a variety of prior knowledge into the posterior inference process by utilizing rich functional priors (e.g., through the kernel functions of GP priors). In this experiment, we consider using GPs as the functional prior to fit a periodic curve: $y(x)=\sin (x)+0.1 x +\epsilon, \epsilon \sim \mathcal{N}(0,0.5)$. 30 observations are randomly sampled from $[-7.5, -5] \cup[-2.5, 2.5] \cup[5, 7.5]$. To demonstrate the effects of different priors on posterior inference, we use three different kernel functions: Periodic kernel, Matern kernel, and Linear kernel (not suitable for modelling periodic curves) in corresponding three different GP priors, respectively. For comparison, we also consider the parameter-space variational inference method KLBBB \citep{blundell2015weight} using i.i.d. Gaussian prior for model parameters, which is intractable to incorporate periodic prior knowledge. The results are shown in Figure \ref{fig:g2}. As shown in Figure \ref{fig:g2_fwbi_p} and \ref{fig:g2_fwbi_m}, it is obvious that FWBI has a strong capacity to recover the key periodic characteristics of the true function and provide accurate uncertainty estimation utilizing appropriate GP priors. GP prior with periodic kernel generates a little bit better and smoother posterior results than that with Matern kernel. On the other hand, Figure \ref{fig:g2_fwbi_l} illustrates that mismatched Linear kernel incorporated in the prior knowledge could destroy posterior inference. In contrast, parameter-space KLBBB in the rightmost Figure \ref{fig:klbbb_g2} fails to fit the periodic trend, which is closely related to the inability to encode the correct prior knowledge in the parameter space. For more analysis of the impact of functional properties of priors (smoothness and noise) on FWBI, see Appendix \ref{apd:e6}.


\begin{figure*}[!t] %htbp
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
% \floatconts
\centering
  {
    \subfigure[FWBI-Periodic]{\label{fig:g2_fwbi_p}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g2_toy/g2_periodic}}%
    \subfigure[FWBI-Matern]{\label{fig:g2_fwbi_m}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g2_toy/g2_matern}}
    \subfigure[FWBI-Linear]{\label{fig:g2_fwbi_l}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g2_toy/g2_linear}}
    \subfigure[KLBBB]{\label{fig:klbbb_g2}%
      \includegraphics[width=0.24\linewidth]{UAI2024/g2_toy/klbbb_g2}} 
  }
\caption{Learning periodic curve with different GP priors. We consider three different kernel functions: the Periodic kernel, the Matern kernel, and the Linear kernel. See Appendix \ref{apd:e5} for more details about convergence of FWBI.}
\label{fig:g2}
\end{figure*}

\subsection{Multivariate Regression on UCI Datasets}

In this experiment, we evaluate our method for multivariate regression tasks on benchmark UCI datasets to demonstrate the predictive performance of FWBI. Table \ref{tab:uci} shows the average results of root mean square error (RMSE). All three functional inference methods consistently provide better results than parameter-space approaches, which could reflect the advantages of function-space variational inference. Furthermore, our FWBI significantly outperforms all other functional models and shows efficient running performance (see Appendix \ref{apd:e8}).


\begin{table*}[!t]
% \setlength{\abovecaptionskip}{0.5cm}
% \setlength{\belowcaptionskip}{0.5cm}
\small
\centering
    \caption{The table shows the results of average RMSE for multivariate regression on UCI datasets. We split each dataset randomly into 90\% training data and 10\% test data, and this process is repeated 10 times to ensure validity. We perform the paired-sample t-test for the results from FWBI and the results from other methods and get $p<.001$. See Appendix \ref{apd:e7} for results of the test negative log-likelihood (NLL).}
    \label{tab:uci}
    \begin{tabular}{lccccc}
    % \toprule
    \hline
    \textbf{Dataset} & \text{FWBI} & \text{GWI} & \text{FBNN} & \text{WBBB} & \text{KLBBB}\\
    % \midrule
    \hline 
    \text{Yacht} & \textbf{1.303}$\pm$\textbf{0.112} & 2.198 $\pm$ 0.083 & 1.523 $\pm$ 0.075 & 2.328 $\pm$ 0.091 & 2.131 $\pm$ 0.085\\
    \text{Boston} & \textbf{1.531}$\pm$\textbf{0.055} & 1.742 $\pm$ 0.046 & 1.683 $\pm$ 0.122 & 2.306 $\pm$ 0.102 & 1.919 $\pm$ 0.074\\
    \text{Concrete} & \textbf{1.144}$\pm$\textbf{0.057} & 1.297 $\pm$ 0.053 & 1.274 $\pm$ 0.049 & 2.131 $\pm$ 0.068 & 1.784 $\pm$ 0.063\\
    \text{Wine} & \textbf{1.242}$\pm$\textbf{0.056} & 1.680 $\pm$ 0.064 & 1.528 $\pm$ 0.053 & 2.253 $\pm$ 0.071 & 1.857 $\pm$ 0.069\\
    \text{Kin8nm} & \textbf{1.093}$\pm$\textbf{0.014} & 1.188 $\pm$ 0.015 & 1.447 $\pm$ 0.069 & 2.134 $\pm$ 0.029 & 1.787 $\pm$ 0.027\\
     \text{Protein} & \textbf{1.195}$\pm$\textbf{0.006} & 1.333 $\pm$ 0.007 & 1.503 $\pm$ 0.025 & 2.188 $\pm$ 0.012 & 1.795 $\pm$ 0.010\\
    % \bottomrule
    \hline 
    \end{tabular}
\end{table*}

\begin{figure*}[!t]
    \centering
    {
    \subfigure[p=0.4]{\label{0.4}
    \includegraphics[width=0.32\linewidth]{UAI2024/context_bandits/0.4.pdf}}
    \subfigure[p=0.5]{\label{0.5}
    \includegraphics[width=0.32\linewidth]{UAI2024/context_bandits/0.5.pdf}}
    \subfigure[p=0.6]{\label{0.6}
    \includegraphics[width=0.32\linewidth]{UAI2024/context_bandits/0.6.pdf}}
    }
    \caption{Comparisons of cumulative regrets for FWBI, KLBBB, WBBB, FBNN, GWI on the Mushroom contextual bandit task. Lower represents better performance.}
    \label{fig:bandits}
\end{figure*}

\begin{table*}[!t]
% \setlength{\abovecaptionskip}{0.5cm}
% \setlength{\belowcaptionskip}{0.5cm}
\small
\centering
    \caption{Image classification and OOD detection performance.}
    \label{tab:classification}
    \begin{tabular}{lcccccc}
    % \toprule
    \hline & \multicolumn{2}{c}{\text { MNIST }} & \multicolumn{2}{c}{\text { FMNIST }}& \multicolumn{2}{c}{\text { CIFAR10 }} \\
    \cline { 2 - 3 } \cline { 4 - 5 } \cline { 6 - 7 } \text { Model } & \text { Accuracy } & \text { OOD-AUC } & \text { Accuracy } & \text { OOD-AUC } & \text { Accuracy } & \text { OOD-AUC } \\
    % \midrule
    \hline 
    \text{FWBI} & \textbf{96.51} $\pm$ \textbf{0.00} & \textbf{0.962} $\pm$ \textbf{0.01} & \textbf{86.01} $\pm$ \textbf{0.00} & \textbf{0.838} $\pm$ \textbf{0.01} & \textbf{47.93} $\pm$ \textbf{0.01} & 0.618 $\pm$ 0.03 \\
    \text{GWI} & 95.40 $\pm$ 0.00 & 0.858 $\pm$ 0.05 & 85.43 $\pm$ 0.00 & 0.394 $\pm$ 0.04 & 44.78 $\pm$ 0.01 & \textbf{0.635} $\pm$ \textbf{0.02}\\
    \text{FBNN} & 96.09 $\pm$ 0.00 & 0.801 $\pm$ 0.07 & 85.64 $\pm$ 0.00 & 0.814 $\pm$ 0.02 & 46.29 $\pm$ 0.01 & 0.612 $\pm$ 0.03\\
    \text{WBBB} & 96.16 $\pm$ 0.00 & 0.869 $\pm$ 0.03 & 85.57 $\pm$ 0.00 & 0.819 $\pm$ 0.01 & 45.76 $\pm$ 0.01 & 0.606 $\pm$ 0.03\\
    \text{KLBBB} & 96.26 $\pm$ 0.00 & 0.868 $\pm$ 0.03 & 85.71 $\pm $ 0.00 & 0.829 $\pm$ 0.02 & 46.20 $\pm$ 0.00 & 0.606 $\pm$ 0.02\\
    % \bottomrule
    \hline 
    \end{tabular}
\end{table*}

\subsection{Contextual Bandits}

Reliable uncertainty estimation is crucial for downstream tasks such as contextual bandit problems, where the agent gradually learns the model by observing a context repeatedly and choosing the optimal action in dynamic environments. In these scenarios, it is important to balance the exploration and exploitation during the optimization.
Thompson sampling \citep{thompson1933likelihood, russo2016information} is a widely used algorithm for strategy exploration in contextual bandits.
In this section, we evaluate the ability of FWBI to guide exploration on the UCI Mushroom dataset, which includes 8124 instances, and each mushroom has 22 features and is identified as edible or poisonous. The agent can observe these mushroom features as the context and choose either to eat or reject a mushroom to maximize the reward. We consider three different reward patterns: for the action of eating a mushroom, if the mushroom is edible, the agent will receive a reward of 5. Conversely, if the mushroom is poisonous, the agent will receive a reward of -35 with probabilities 0.4, 0.5, and 0.6, respectively, for three different patterns; otherwise, a reward of 5. On the other hand, if the agent decides to take the action of rejecting a mushroom, it will receive a reward of 0.

Suppose an oracle will always choose to eat an edible mushroom (and receive a reward of 5) and not to eat the poisonous mushroom. We take the cumulative regrets with respect to the reward achieved by the oracle to measure the exploration-exploitation ability of an agent. We concatenate the mushroom context and the action chosen by the agent as model input, and the corresponding reward received is the model output. We follow the hyperparameter settings by \citet{blundell2015weight}. The cumulative regrets of all 5 parameter-space and function-space variational inference methods for 3 reward patterns are shown in Figure \ref{fig:bandits}. FWBI performs better than other inference methods in all three reward modes, which indicates that FWBI can provide reliable uncertainty estimation in such decision-making scenarios.


\subsection{Classification and OOD Detection}

We evaluate the scalability of FWBI via image classification tasks with high-dimensional inputs. We assess the in-distribution predictive performance and out-of-distribution (OOD) detection ability on MNIST, FashionMNIST\citep{xiao2017fashion} and CIFAR-10\citep{krizhevsky2009learning}. For all functional methods, we use the same GP prior with RBF kernel. We report the test accuracy for predictive performance and the area under the curve (AUC) of OOD detection pairs FashionMNIST/MNIST, MNIST/FashionMNIST and CIFAR10/SVNH based on predictive entropies in Table \ref{tab:classification}. Our FWBI consistently outperforms all parameter-space and function-space baselines for classification accuracy and performs competitively in OOD detection (Appendix \ref{apd:e9}).

% \begin{table*}[h]
% % \setlength{\abovecaptionskip}{0.5cm}
% % \setlength{\belowcaptionskip}{0.5cm}
% \small
% \centering
%     \caption{Image classification and OOD detection performance.}
%     \label{tab:classification}
%     \begin{tabular}{lcccccc}
%     % \toprule
%     \hline & \multicolumn{2}{c}{\text { MNIST }} & \multicolumn{2}{c}{\text { FMNIST }}& \multicolumn{2}{c}{\text { CIFAR10 }} \\
%     \cline { 2 - 3 } \cline { 4 - 5 } \cline { 6 - 7 } \text { Model } & \text { Accuracy } & \text { OOD-AUC } & \text { Accuracy } & \text { OOD-AUC } & \text { Accuracy } & \text { OOD-AUC } \\
%     % \midrule
%     \hline 
%     \text{FWBI} & \textbf{96.51} $\pm$ \textbf{0.00} & \textbf{0.962} $\pm$ \textbf{0.01} & \textbf{86.01} $\pm$ \textbf{0.00} & \textbf{0.838} $\pm$ \textbf{0.01} & \textbf{47.93} $\pm$ \textbf{0.01} & 0.618 $\pm$ 0.03 \\
%     \text{GWI} & 95.40 $\pm$ 0.00 & 0.858 $\pm$ 0.05 & 85.43 $\pm$ 0.00 & 0.394 $\pm$ 0.04 & 44.78 $\pm$ 0.01 & \textbf{0.635} $\pm$ \textbf{0.02}\\
%     \text{FBNN} & 96.09 $\pm$ 0.00 & 0.801 $\pm$ 0.07 & 85.64 $\pm$ 0.00 & 0.814 $\pm$ 0.02 & 46.29 $\pm$ 0.01 & 0.612 $\pm$ 0.03\\
%     \text{WBBB} & 96.16 $\pm$ 0.00 & 0.869 $\pm$ 0.03 & 85.57 $\pm$ 0.00 & 0.819 $\pm$ 0.01 & 45.76 $\pm$ 0.01 & 0.606 $\pm$ 0.03\\
%     \text{KLBBB} & 96.26 $\pm$ 0.00 & 0.868 $\pm$ 0.03 & 85.71 $\pm $ 0.00 & 0.829 $\pm$ 0.02 & 46.20 $\pm$ 0.00 & 0.606 $\pm$ 0.02\\
%     % \bottomrule
%     \hline 
%     \end{tabular}
% \end{table*}


\section{Conclusion and future work}
In this paper, we proposed a new function-space variational inference method termed Functional Wasserstein Bridge Inference (FWBI). It optimizes a Wasserstein bridge-based functional variational objective as the surrogate to the possible problematic KL divergence between stochastic processes involved in most existing functional variational inference. We proved the functional variational objective derived from FWBI is a lower bound of the model evidence. Empirically, we demonstrated that FWBI could leverage various functional priors to yield high predictive performance and principled uncertainty quantification. On these premises, our future work will focus on the theoretical comparison with other Bayesian approximation methods and the application on more complex BDL models, such as Bayesian deep ensembles, and on more task scenarios, such as active learning and Bayesian optimization.



\begin{acknowledgements} % will be removed in pdf for initial submission,
This work is supported by the Australian Research Council under the Discovery Early Career Researcher Award DE200100245.
\end{acknowledgements}

















% References
\bibliography{uai2024-template}

\newpage

\onecolumn

\title{Functional Wasserstein Bridge Inference for Bayesian Deep Learning\\(Supplementary Material)}
\maketitle


\appendix

\section{Further Background}\label{apd:A}
\paragraph{Pathologies for parameter-space priors} As in Figure \ref{fig:horizon_prior}, we show the function samples generated from three BNNs with Gaussian prior 
$\mathcal{N} (0, 1)$ over network weights. It is obvious that as the depth increases, the function samples tend to be more horizontal, which can lead to a problematic posterior inference.

\begin{figure}[h]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
    \subfigure{\label{fig:2_layer_prior}%
      \includegraphics[width=0.32\linewidth]{UAI2024/horizon_priors/2_50_samples}}%
    % \qquad
    \subfigure{\label{fig:4_layer_prior}%
      \includegraphics[width=0.32\linewidth]{UAI2024/horizon_priors/4_50_samples}}
    % \qquad
    \subfigure{\label{fig:8_layer_prior}%
      \includegraphics[width=0.32\linewidth]{UAI2024/horizon_priors/8_50_samples}}
 }
\caption{Function samples from three fully connected BNNs with different network architectures: there are 2, 4, and 8 hidden layers, respectively, and each layer has 50 units. The prior distribution for weights is $\mathcal{N} (0, 1)$ and the activation is tanh.}
\label{fig:horizon_prior}
\end{figure}

\paragraph{Wasserstein distance} 
The \textit{Wasserstein distance} \citep{kantorovich1960mathematical, villani2003topics} is a rigorously defined distance metric on probability measures satisfying non-negativity, symmetry and triangular inequality \citep{panaretos2019statistical} that was originally proposed for the optimal transport problem and has become popular in the machine learning community in recent years \citep{arjovsky2017wasserstein}. Suppose $(\mathcal{P}, \|\cdot\|)$ is a Polish space, the p-Wasserstein distance between probability measures $\mu$, $\nu\in (\mathcal{P}, \|\cdot\|)$ is defined as
\begin{equation}
    W_p(\mu, \nu)=\left(\inf _{\gamma \in \Gamma(\mu, \nu)} \int_{\mathcal{P} \times \mathcal{P}}\|x-y\|^p \mathrm{~d} \gamma(x, y)\right)^{1 / p},
\end{equation}
where $\Gamma(\mu, \nu)$ is the set of joint measures or coupling  $\gamma$ with marginals $\mu$ and $\nu$ on $\mathcal{P} \times \mathcal{P}$. 

% \paragraph{Gaussian Processes}
% The measurable mapping $g: \Omega\times \mathcal{T}\rightarrow\mathcal{R}$ defined on probability space $(\Omega, \mathcal{G}, P)$ with compact index set $\mathcal{T}$ is a Gaussian process (GP) if and only if random vector $g(T)=((g(t_1), g(t_2), ..., g(t_n))$ is multivariate Gaussian for marginals over any finite index sets $T=\left\{t_i\right\}_{i=1}^{n}\subset \mathcal{T}$. A GP is entirely governed by its mean function $m(t)=\mathbb{E}[g(t)]$ and covariance (kernel) function $k(t, t')=\mathbb{E}[(g(t)-m(t))(g(t')-g(t'))]$ denoted by $g \sim \mathcal{GP}(m, k)$, $t, t'\in \mathcal{T}$ \citep{rasmussen:williams:2006}.

% \paragraph{Wasserstein distance}
% The p-Wasserstein is defined as
% \begin{equation}
%     W_p(\mu, \nu)=\left(\inf _{\gamma \in \Gamma(\mu, \nu)} \int_{\mathcal{P} \times \mathcal{P}}\|x-y\|^p \mathrm{~d} \gamma(x, y)\right)^{1 / p},
% \end{equation}
% where $(\mathcal{P}, \|\cdot\|)$ is a Polish space, $\mu$, $\nu\in (\mathcal{P}, \|\cdot\|)$ 
%  are two probability measures, $\Gamma(\mu, \nu)$ is the set of joint measures or coupling  $\gamma$ with marginals $\mu$ and $\nu$ on $\mathcal{P} \times \mathcal{P}$. \citet{wild2022generalized2} proposed a functional variational inference method based on the 2-Wasserstein distance between two Gaussian measures. \citet{tran2022all} proposed an algorithm to match a functional prior with a target GP prior using the dual representation of 1-Wasserstein distance as
%  \begin{equation}
%     W_1(\mu, \nu)=\sup _{\|\phi\| \leq 1} \mathbb{E}_{x \sim \mu} \phi(x)-\mathbb{E}_{y \sim \nu} \phi(y),
% \end{equation}
% where $\phi(\cdot): \mathcal{P}\rightarrow \mathbb{R}$ is a 1-Lipschitz continuous function.\\

\paragraph{The law of cosines for the KL divergence}
For two probability measures $p$ and $q\in (\mathcal{P}, \|\cdot\|)$, the law of cosines for the KL divergence between $p$ and $q$ is defined as \citep{belavkin2013law}:
\begin{equation}
    \begin{aligned}
    \mathrm{KL}[p\| q] &= \mathrm{KL}[p\| r] + \mathrm{KL}[r\| q] - \int \log\frac{dq(x)}{dr(x)}[dp(x)-dr(x)]\\
    &=\mathrm{KL}[p\| r] -\mathrm{KL}[q\| r] - \int \log\frac{dq(x)}{dr(x)}[dp(x)-dq(x)].
    \end{aligned}
    \label{coskl}
\end{equation}
where $r$ is the reference measure. Consider the 1-Wasserstein distance between $p$ and $q$ as
\begin{equation}
    W_1(p, q) := \sup \{\mathbb{E}_p\{f\} - \mathbb{E}_q\{g\}: f(x)-g(y)\leq c(x, y)\}.
\end{equation}
Suppose real function $f(x)$ and $g(x)$ satisfying additional constraints:
\begin{equation}
    \beta f(x) = \nabla \mathrm{KL}[p\|r] = \log\frac{dp(x)}{dr(x)}, \quad \beta \geq 0
\end{equation}
\begin{equation}
    \alpha g(x) = \nabla \mathrm{KL}[q\|r] = \log\frac{dq(x)}{dr(x)}, \quad \alpha \geq 0
\end{equation}
Thus, $\beta f$ and $\alpha g$ are the gradients of divergence $\mathrm{KL}[p\|r]$, $\mathrm{KL}[q\|r]$ respectively, and this means that probability measures $p, q$ have the following exponential representations:
\begin{equation}
    dp(x) = e^{\beta f(x)-\kappa[\beta f]}d r(x)
\end{equation}
\begin{equation}
    dq(x) = e^{\alpha g(x)-\kappa[\alpha g]}d r(x)
\end{equation}
where $\kappa[(\cdot)] = \log \int e^{(\cdot)} dr(x)$ is the normalizing constant.
\begin{equation}
    \frac{d}{d \beta} \kappa[\beta f] = \mathbb{E}_{p}\{f\}, \quad \mathrm{KL}[p\|r] = \beta \mathbb{E}_{p}\{f\} - \kappa [\beta f]
\end{equation}
\begin{equation}
    \frac{d}{d \alpha} \kappa[\alpha g] = \mathbb{E}_{q}\{g\}, \quad \mathrm{KL}[q\|r] = \alpha \mathbb{E}_{q}\{g\} - \kappa [\alpha g]
\end{equation}
Substituting these formulate into \eqref{coskl} we obtain

\begin{equation}
    \mathrm{KL}[p\|q] = \beta \mathbb{E}_{p}\{f\} - \alpha \mathbb{E}_{q}\{g\} -(\kappa [\beta f] - \kappa [\alpha g]) - \alpha \int g(x) [dp(x)- dq(x)]
\end{equation}
According to the Theoreom 2 in \citet{belavkin2018relation}, assume that Lagrange multipliers $\alpha = \beta =1$, then have
\begin{equation}
\begin{aligned}
    \mathrm{KL}[p\|q] &= \mathbb{E}_{p}\{f\} - \mathbb{E}_{q}\{g\} -(\kappa [f] - \kappa [g]) - \int g(x) [dp(x)- dq(x)]\\
    & = W_1(p, q) - (\kappa [f] - \kappa [g]) - \int g(x) [dp(x)- dq(x)]
\end{aligned}
\end{equation}

\paragraph{Talagrand inequality} As proved by \citet{otto2000generalization}, the probability measure $q$ satisfies a Talagrand inequality with constant $\rho$ if for all probability measure $p$, absolutely continuous w.r.t. $q$, with finite moments of order 2,
\begin{equation}
    W_1(p, q) \leq W_2 (p, q) \leq \sqrt{\frac{2 \mathrm{KL}[p\|q]}{\rho}}
\end{equation}
where the first inequality can be proved by the Cauchy-Schwarz inequality.

% \paragraph{Periodic kernel * RBF kernel}

% \begin{equation}
% k_{\text {Periodic * RBF }}\left(\mathbf{x}, \mathbf{x}^{\prime}\right)=\exp \left(-2 \frac{\sin ^2\left(\frac{\pi}{p}\left(|\mathbf{x}-\mathbf{x}^{\prime}|\right)\right)}{l_p}\right) * \sigma_r^2 \exp \left(-\frac{\left(\mathbf{x}-\mathbf{x}^{\prime}\right)^2}{2 l_r^2}\right)
% \end{equation}
% where $p$ is the period length parameter, $\sigma_r$ is the scaling factor, $\l_p$ and $l_r$ are lengthscale parameters.

\section{Pseudocode of FPi-VI and FWBI}\label{apd:B}

\begin{algorithm} 
	\caption{ Functional Prior-induced Variational Inference (FPi-VI)} 
	\label{alg:fpi_vi} 
	\begin{algorithmic}[1]
		\REQUIRE Dataset $\mathcal{D}=\left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\}$, minibatch $\mathcal{B} = \{x_j, y_j\}_{j=1}^{M} \subset \mathcal{D}$, functional prior $p_0(f)$
       \STATE Initialise $\mathbf{w}\sim\mathcal{N}(0, 1)$, $\mathbf{w}_{\boldsymbol{b}}\sim\mathcal{N}(0, 1)$, reparameterize $\mathbf{w} = \mathbf{\mu}_q + \mathbf{\Sigma}_q \odot \mathbf{\epsilon}$, $\mathbf{w}_{\boldsymbol{b}} = \mathbf{\mu}_b + \mathbf{\Sigma}_b \odot \mathbf{\epsilon}$ with $\mathbf{\epsilon}\sim\mathcal{N}(0, 1)$, $\boldsymbol{\theta_q} := \{\mathbf{\mu}_q, \mathbf{\Sigma}_q\}, \boldsymbol{\theta_b} := \{\mathbf{\mu}_b, \mathbf{\Sigma}_b\}$
		\WHILE{$\boldsymbol{\theta_b}$ not converged} 
		\STATE draw measurement set $\mathbf{X}_{\mathcal{M}}$ randomly from input domain 
   
        \STATE draw functional prior functions $f(\mathbf{X}_{\mathcal{M}})\sim p_0(f)$ at $\mathbf{X}_{\mathcal{M}}$ 
   
        \STATE draw bridging distribution functions $g(\mathbf{X}_{\mathcal{M}})\sim p(g; \boldsymbol{\theta_b})$ at $\mathbf{X}_{\mathcal{M}}$
   
       \STATE calculate $W_1\left(p(g(\mathbf{X}_{\mathcal{M}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)), p_0(f(\mathbf{X}_{\mathcal{M}}))\right)$ using Equation \ref{eq:w1}
   
      \STATE $\boldsymbol{\theta_b}$ $\leftarrow$ Optimizer($\boldsymbol{\theta_b}$, $W_1$)
   
		\ENDWHILE

     \STATE Froze $\boldsymbol{\theta_b}^* := \{\mathbf{\mu}_b^*, \mathbf{\Sigma}^*_b\}$
     
     \WHILE{$\boldsymbol{\theta_q}$ not converged}
     
     \STATE calculate $\mathcal{L} = -\frac{1}{M} \sum_{(x, y) \in \mathcal{B}} \mathbb{E}_{q(\mathbf{w}; \boldsymbol{\theta_q})} \left[\log p(y \mid x, \mathbf{w}; \boldsymbol{\theta_q})\right] + \lambda W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*))$ using Equation \ref{eq:w2}

     \STATE $\boldsymbol{\theta_q}$ $\leftarrow$ Optimizer($\boldsymbol{\theta_q}$, $\mathcal{L}$)

     \ENDWHILE
     
	\end{algorithmic} 
\end{algorithm}

\begin{algorithm} 
	\caption{Functional Wasserstein Bridge Inference (FWBI)} 
	\label{alg:fwbi} 
	\begin{algorithmic}[1]
		\REQUIRE Dataset $\mathcal{D}=\left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\}$, minibatch $\mathcal{B} = \{x_j, y_j\}_{j=1}^{M} \subset \mathcal{D}$, functional prior $p_0(f)$
       \STATE Initialise $\mathbf{w}\sim\mathcal{N}(0, 1)$, $\mathbf{w}_{\boldsymbol{b}}\sim\mathcal{N}(0, 1)$, reparameterize $\mathbf{w} = \mathbf{\mu}_q + \mathbf{\Sigma}_q \odot \mathbf{\epsilon}$, $\mathbf{w}_{\boldsymbol{b}} = \mathbf{\mu}_b + \mathbf{\Sigma}_b \odot \mathbf{\epsilon}$ with $\mathbf{\epsilon}\sim\mathcal{N}(0, 1)$, $\boldsymbol{\theta_q} := \{\mathbf{\mu}_q, \mathbf{\Sigma}_q\}, \boldsymbol{\theta_b} := \{\mathbf{\mu}_b, \mathbf{\Sigma}_b\}$
		\WHILE{$\boldsymbol{\theta_b}$, $\boldsymbol{\theta_q}$ not converged} 
		\STATE draw measurement set $\mathbf{X}_{\mathcal{M}}$ randomly from input domain 
   
        \STATE draw functional prior functions $f(\mathbf{X}_{\mathcal{M}})\sim p_0(f)$ at $\mathbf{X}_{\mathcal{M}}$ 
   
        \STATE draw bridging distribution functions $g(\mathbf{X}_{\mathcal{M}})\sim p(g; \boldsymbol{\theta_b})$ at $\mathbf{X}_{\mathcal{M}}$
   
       \STATE $\mathcal{L} = -\frac{1}{M} \sum_{(x, y) \in \mathcal{B}} \mathbb{E}_{q(f; \boldsymbol{\theta_q})} \left[\log p(y \mid f(x;\mathbf{\mu}_q, \mathbf{\Sigma}_q))\right] + \lambda_1 W_2(q(\mathbf{w}; \mathbf{\mu}_q, \mathbf{\Sigma}_q), p(\mathbf{w}_{\boldsymbol{b}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)) + \lambda_2 W_1\left(p(g(\mathbf{X}_{\mathcal{M}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)), p_0(f(\mathbf{X}_{\mathcal{M}}))\right) + \lambda_3M_2\left(p(g(\mathbf{X}_{\mathcal{M}}; \mathbf{\mu}_b, \mathbf{\Sigma}_b)), p_0(f(\mathbf{X}_{\mathcal{M}}))\right)$ 
   
      \STATE $\boldsymbol{\theta_b}$, $\boldsymbol{\theta_q}$ $\leftarrow$ Optimizer($\boldsymbol{\theta_b}$, $\boldsymbol{\theta_q}$, $\mathcal{L}$)
   
		\ENDWHILE 
	\end{algorithmic} 
\end{algorithm}

\section{Proof of Theoretical Results}
\label{apd:C}

To analyze the theoretical properties of FWBI for posterior variational inference, we further derive a corresponding functional variational objective as follows:
\begin{equation}
\mathcal{L}^{W}:= \mathbb{E}\left[\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q}))\right] - \lambda_1W_2(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b})) - \lambda_2W_1\left(p(g; \boldsymbol{\theta_b}), p_0(f)\right)\label{eq:ifbnn_elbo},
\end{equation}
where $q(f; \boldsymbol{\theta_q})$ is the variational distribution over functions induced by approximate posterior $q(\mathbf{w}; \boldsymbol{\theta_q})$. $W_2(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b}))$ is calculated by corresponding $W_2(q(\mathbf{w}; \boldsymbol{\theta_q}), $ $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}))$. As a variational Bayesian objective, it is worthwhile to explore whether this new variational objective based on the Wasserstein bridge is still a lower bound of the log marginal likelihood. Based on the law of cosines for the KL divergence and Talagrand inequality of probability measures (Appendix \ref{apd:A}), we derive the following Proposition \ref{thrm:1}:


\begin{theorem}
For the functional variational objective $\mathcal{L}^{W}$ derived from FWBI based on the Wasserstein bridge, we have
% \[ \log p(\mathcal{D}) \geq \mathcal{L}^{W} \geq \mathcal{L}^{KL} \]
\[ \log p(\mathcal{D}) \geq \mathcal{L}^{W} \]
where $\log p(\mathcal{D})$ is the log model evidence.
\label{thrm:1}
\end{theorem}

\textit{Proof.} Since the 1-Wasserstein distance is not greater than the 2-Wasserstein distance between two probability measures based on the Cauchy-Schwarz inequality, assume that $\lambda_1 = \lambda_2 = 1$, we first have:
\begin{equation}
\begin{aligned}
    \mathcal{L}^{W} &:= \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})] - \lambda_1 W_2(q(f; \boldsymbol{\theta_q}), p(g; \boldsymbol{\theta_b})) - \lambda_2 W_1(p(g; \boldsymbol{\theta_b}), p_0(f))\\
    & \leq \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})] - \lambda_1 (\mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})] + \int g_1(f) [dq(f; \boldsymbol{\theta_q})- dp(g; \boldsymbol{\theta_b})]) - \\ & \lambda_2 (\mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)] + 
    \int g_2(f) [dp(g; \boldsymbol{\theta_b})- dp_0(f)])\\
    & = \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})] - (\lambda_1 (\mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p(g; \boldsymbol{\theta_b})])  + \lambda_2 (\mathrm{KL}[p(g; \boldsymbol{\theta_b})\|p_0(f)])) - \\ 
    & (\lambda_1 \int g_1(f) [dq(f; \boldsymbol{\theta_q})- dp(g; \boldsymbol{\theta_b})] + \lambda_2 \int g_2(f) [dp(g; \boldsymbol{\theta_b})- dp_0(f)])\\
    & = \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)] - (\int \log \frac{dp_0(f)}{dp(g; \boldsymbol{\theta_b})}[dq(f; \boldsymbol{\theta_q}) - dp(g; \boldsymbol{\theta_b})] + \\  & \int \log \frac{dp(g; \boldsymbol{\theta_b})}{dr(f)}[dq(f; \boldsymbol{\theta_q}) - dp(g; \boldsymbol{\theta_b})] + \int \log \frac{dp_0(f)}{dr(f)}[dp(g; \boldsymbol{\theta_b}) - dp_0(f)]) \\
    & =  \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)] - (\int \log \frac{dp_0(f)}{dr(f)}[dq(f; \boldsymbol{\theta_q}) - dp(g; \boldsymbol{\theta_b})] +\\  &\int \log \frac{dp_0(f)}{dr(f)}[dp(g; \boldsymbol{\theta_b}) - dp_0(f)])\\
    & = \mathbb{E} [\log p(\mathbf{Y}_\mathcal{D} \mid f(\mathbf{X}_\mathcal{D}; \boldsymbol{\theta_q})] - \mathrm{KL}[q(f; \boldsymbol{\theta_q})\|p_0(f)] - \int \log \frac{dp_0(f)}{dr(f)}[dq(f; \boldsymbol{\theta_q}) - dp_0(f)]\\
    & = \mathcal{L}^{KL} - \int \log \frac{dp_0(f)}{dr(f)}[dq(f; \boldsymbol{\theta_q}) - dp_0(f)]\\
    & = \mathcal{L}^{KL} - \int \log \frac{dp_0(f)}{-dq(f; \boldsymbol{\theta_q})}[dq(f; \boldsymbol{\theta_q}) - dp_0(f)]\\
    & =\mathcal{L}^{KL} - (\mathrm{KL}[q(f; \boldsymbol{\theta_q})\| p_0(f)] + \mathrm{KL}[p_0(f)\| q(f; \boldsymbol{\theta_q})])\\
    & \leq \mathcal{L}^{KL}\\
    & \leq \log p(\mathcal{D})
\end{aligned}
\end{equation}
where we assume that the reference measure $dr(f) = -dq(f; \boldsymbol{\theta_q})$ in the law of cosines for the KL divergence. 

Proposition 1 shows that $\mathcal{L}^{W}$ is a valid variational objective function since it is a lower bound of the model evidence.


% For the gap toy example, we used BNNs with two hidden layers each with 100 hidden units and iterated 10000 epochs. For UCI regression, we used 2 hidden layer of 10 hidden units BNNs and trained for 2000 epochs. In all experiments, we first pre-train GP hyperparameters for 100 epochs on the uniformly sampled test data set for methods which use GP priors. Finite measurement points for functional approaches are randomly sampled from training data and inducing points sampled from the domain where one is to make predictions in toy examples. For all experiments we used tanh activation and Adam optimizer. The convergence processes of 1-Wasserstein distance and 2-Wasserstein of FWBI in two toy examples are shown in \Figref{fig:g3conv} and \Figref{fig:gapconv}.

% \begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{0cm}
% % \floatconts
% \centering
%   {
%     \subfigure[W1-RBF]{\label{fig:w1_pr}%
%       \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/ifbnn/w1_rbf}}%
%     % \qquad
%     \subfigure[W1-Matern]{\label{fig:w1_m}%
%       \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/ifbnn/w1_matern}}
%     % \qquad
%     \subfigure[W1-Linear]{\label{fig:w1_l}%
%       \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/ifbnn/w1_linear}} 
%     % \quad
%     \subfigure[W2-RBF]{\label{fig:w2_pr}%
%       \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/ifbnn/w2_rbf}}%
%     % \qquad
%     \subfigure[W2-Matern]{\label{fig:w2_m}%
%       \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/ifbnn/w2_matern}}
%     % \qquad
%     \subfigure[W2-Linear]{\label{fig:w2_l}%
%       \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/ifbnn/w2_linear}} 
%  }
% \caption{Convergence of Wasserstein distances in the training of FWBI for polynomial extrapolation.}
% \label{fig:g3conv}
% \end{figure}



\section{Experimental Setting}\label{apd:D}

\textbf{Polynomial curve extrapolation} In this experiment, we use $2\times100$ fully connected tanh BNNs as variational posteriors for all models. The functional GP priors are pre-trained on the 20 training points for 10000 epochs. We also use 40 inducing points for the sampling of marginal measurement points in FWBI, FBNN and GWI from $[-1, 1]$. All methods are trained for 10000 epochs.

\textbf{Periodic curve extrapolation}In this experiment, we use $2\times100$ fully connected tanh BNNs as variational posteriors for all models. The functional GP priors are pre-trained on the 30 training points for 10000 epochs. The marginal measurement set for FWBI is randomly sampled from all observations together with 30 inducing points randomly sampled from $[-10, 10]$. All inference methods are trained for 10000 epochs.

\textbf{Multivariate regression on UCI datasets} We choose BNNs posteriors with two hidden layers (input-10-10-output). The GP prior uses RBF kernel and is pre-trained on the test dataset for 100 epochs. The number of iterations for all models is 2000. 

\textbf{Contextual bandits} The variational posteriors are fully connected tanh BNNs with two hidden layers (input-100-100-output) and the GP prior is pre-trained on 1000 randomly sampled points from training data. ALL models are trained using the last 4096 input-output tuples in the training buffer with a batch size of 64 and training frequency 64 for each iteration. All inference methods are trained for 10000 epochs.

\textbf{Classification and OOD detection} For all models in this experiment, the variational posteriors are fully connected BNNs with 2 hidden layers, each with 800 units. The functional prior is a Dirichlet-based GP designed for classification tasks \cite{milios2018dirichlet} and is pre-trained on test dataset for 500 epochs. ALL inference methods are trained for 600 epochs and the batchsize is 125. 

\section{Further Results}\label{apd:E}

\subsection{Demonstration for Potential Sub-optimal FPi-VI}\label{apd:e1}

\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
    \subfigure[]{\label{fig:fpi_(0, 1)}%
      \includegraphics[width=0.32\linewidth]{UAI2024/isotropy/fpi_(0, 1)}}
    % \qquad
    \subfigure[]{\label{fig:fpi_(1, 3)}%
      \includegraphics[width=0.32\linewidth]{UAI2024/isotropy/fpi_(1, 3)}}
    % \qquad
    \subfigure[]{\label{fig:fpi_(2, 2)}%
      \includegraphics[width=0.32\linewidth]{UAI2024/isotropy/fpi_(2, 2)}}
    % \qquad
    \subfigure[]{\label{fig:fwbi_(0, 1)}%
      \includegraphics[width=0.32\linewidth]{UAI2024/isotropy/fwbi_(0, 1)}}
    % \qquad
    % \qquad
    \subfigure[]{\label{fig:iso_imag}%
      \includegraphics[width=0.36\linewidth]{UAI2024/isotropy/iso_imag}}

 }
\caption{Explanation for potential sub-optimal solution in FPi-VI.}
\label{fig:isotropy}
\end{figure}

Due to the isotropy of the 1-Wasserstein distance, there will be an infinite number of candidate $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$ with exactly the same distance to a given functional prior, and FPi-VI just randomly picks one from all candidates in the first distilling step as shown in Figure \ref{fig:iso_imag}. Such randomness brings large fluctuations to the following inference performance. To show such a problem, we first train FWBI on a 1-D toy example (results are shown in Figure \ref{fig:fwbi_(0, 1)}), the output distance of the learned bridging distribution to the GP prior is evaluated and denoted as $d=0.0800$. Then, we use the 1-Wasserstein distance as the loss to (randomly) find another three $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$ with the same distance $d$ to the GP prior\footnote{In the implementation, we allow a small variation to it as $d+\epsilon$ where $\epsilon < 0.001$}, and we run the second step of FPi-VI using the three $p(\mathbf{w}_{\boldsymbol{b}}; \boldsymbol{\theta_b}^*)$ respectively. The results are plotted in Figures \ref{fig:fpi_(0, 1)}, \ref{fig:fpi_(1, 3)} and \ref{fig:fpi_(2, 2)}. We can observe that 1) the performances of all models are significantly different even though their bridging distributions all have the same distance to the prior, which demonstrates that the first step of FPi-VI is with large fluctuations may harm the final posterior inference; 2) FWBI is much better than the other three, which shows that FWBI could automatically learn a reasonably good bridging distribution from the infinite number of candidates.     



\subsection{More Baseline results for Toy Example}
\label{apd:e2}


\begin{figure}[t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
   \subfigure[GWI-RBF]{\label{fig:gw_rbf}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/gwi/rbf}}%
    % \qquad
   \subfigure[GWI-Matern]{\label{fig:gw_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/gwi/matern}}%
    % \qquad
    \subfigure[GWI-Linear]{\label{fig:gw_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/gwi/linear}}% 
     % \qquad
     
     \subfigure[MCMC]{\label{fig:mc_g3}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/psgld}}%
       % \qquad
     \subfigure[Laplace]{\label{fig:lap_g3}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_toy/lapalce_posterior}}%
 }
 \caption{More baseline results for polynomial extrapolation example.}
 \label{fig:more_g3}
\end{figure}


\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
   \subfigure[GP-Matern]{\label{fig:gp_m}%
      \includegraphics[width=0.45\linewidth]{UAI2024/gp_compare/gp_matern}}%
    \qquad
    \subfigure[GP-RBF]{\label{fig:gp_rbf}%
      \includegraphics[width=0.45\linewidth]{UAI2024/gp_compare/gp_rbf}}%
     % \qquad
     
   \subfigure[FWBI-Matern]{\label{fig:fwbi_m}%
      \includegraphics[width=0.45\linewidth]{UAI2024/g3_toy/matern_g3}}%
     \qquad
     \subfigure[FWBI-RBF]{\label{fig:fwbi_rbf}%
      \includegraphics[width=0.45\linewidth]{UAI2024/g3_toy/rbf_g3}}%
 }
 \caption{Comparisons of posteriors of GP and FWBI.}
 \label{fig:gp_comp}
\end{figure}



Figure \ref{fig:more_g3} shows more baseline results for the toy example. Figure \ref{fig:gw_rbf}, \ref{fig:gw_m} and \ref{fig:gw_l} are the approximate posteriors of GWI \citep{wild2022generalized} using three different kernels corresponding to which used in Figure \ref{fig:g3}. And Figure \ref{fig:mc_g3} is the mean of samples from MCMC posterior using Langevin dynamics and \ref{fig:lap_g3} is the posterior from Laplace approximation\citep{daxberger2021laplace}. Compared to these baseline results, FWBI still shows a stronger ability to recover the main trend of the target function and more accurate uncertainty estimation in the unseen region.




\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
    \subfigure[]{\label{fig:cali_1}%
      \includegraphics[width=0.45\linewidth]{UAI2024/calibration/calibration_curve_all}}%
    \qquad
    \subfigure[]{\label{fig:cali_2}%
      \includegraphics[width=0.45\linewidth]{UAI2024/calibration/calibration_curve_all2}}
    % \qquad
 }
 \caption{Calibration curves for two toy examples. The gray dashed line is the perfect calibration.}
 \label{fig:cali_curve}
\end{figure}


\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{0cm}
% \floatconts
\centering
  {
    \subfigure[W1-RBF]{\label{fig:w1_pr}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/w1_rbf}}%
    % \qquad
    \subfigure[W1-Matern]{\label{fig:w1_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/w1_matern}}
    % \qquad
    \subfigure[W1-Linear]{\label{fig:w1_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/w1_linear}} 
    % \quad
    \subfigure[W2-RBF]{\label{fig:w2_pr}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/w2_rbf}}%
    % \qquad
    \subfigure[W2-Matern]{\label{fig:w2_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/w2_matern}}
    % \qquad
    \subfigure[W2-Linear]{\label{fig:w2_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/w2_linear}} 
    % \quad
    \subfigure[Training loss-RBF]{\label{fig:train_loss_rbf}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/train_loss_rbf}}%
    % \qquad
    \subfigure[Training loss-Matern]{\label{fig:train_loss_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/train_loss_matern}}
    % \qquad
    \subfigure[Training loss-Linear]{\label{fig:train_loss_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g3_conv/train_loss_linear}} 
 }
\caption{Convergence of Wasserstein bridge in the training of FWBI for polynomial extrapolation.}
\label{fig:g3conv}
\end{figure}


\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{0cm}
% \floatconts
\centering
  {
    \subfigure[W1-Periodic]{\label{fig:g2_w1_p}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/w1_p}}%
    % \qquad
    \subfigure[W1-Matern]{\label{fig:g2_w1_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/w1_m}}
    % \qquad
    \subfigure[W1-Linear]{\label{fig:g2_w1_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/w1_linear}} 
    % \quad
    \subfigure[W2-Periodic]{\label{fig:w2_p}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/w2_p}}%
    % \qquad
    \subfigure[W2-Matern]{\label{fig:w2_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/w1_m}}
    % \qquad
    \subfigure[W2-Linear]{\label{fig:w2_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/w2_linear}} 
     % \quad
    \subfigure[Training loss-Periodic]{\label{fig:train_loss_p}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/train_loss_p}}%
    % \qquad
    \subfigure[Training loss-Matern]{\label{fig:train_loss_m}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/train_loss_m}}
    % \qquad
    \subfigure[Training loss-Linear]{\label{fig:train_loss_l}%
      \includegraphics[width=0.3\linewidth]{UAI2024/g2_toy/train_loss_linear}} 
 }
\caption{Convergence of Wasserstein bridge in the training of FWBI for periodic extrapolation.}
\label{fig:g2conv}
\end{figure}


\subsection{Detailed Comparisons with GP Posteriors for Toy Example}\label{apd:e3}



In this section, we give a detailed comparison between the FWBI posterior and the GP posterior. We first pre-trained GP priors with the Matern kernel and the RBF kernel on 20 training data points, and the results for the corresponding GP posteriors and FWBI posteriors are given in Figure \ref{fig:gp_comp}. In Figure \ref{fig:gp_m}, the GP posterior shows excessive uncertainty in the well-fitted training region, which is barely distinguishable from the uncertainty in the unseen region. In contrast, our model is able to achieve more reasonable uncertainty estimates. As shown in Figure \ref{fig:fwbi_m} for FWBI posterior, in the well-fitted intervals $[-0.75, -0.25]$ and $[0.25, 0.75]$ containing training points, the uncertainty is significantly smaller than that in the three regions without data. For the results corresponding to the RBF kernel in Figure \ref{fig:gp_rbf} and \ref{fig:fwbi_rbf}, both models show good uncertainty estimation at the same time, however, our FWBI posterior demonstrates better fitting ability, e.g., in the middle unseen region $[-0.25, 0.25]$, our model better recovers the trend of the objective function. 

\subsection{Calibration curves for Toy Examples}\label{apd:e4}

Referring to the calibration curve for regression tasks in \citet{kuleshov2018accurate}, it can measure how well the predicted probabilities match the observed frequencies. As shown in Figure \ref{fig:cali_curve}, we plot calibration curves of all methods for two toy examples, where the horizontal and vertical coordinates are the predicted cumulative distribution function (CDF) and the empirical CDF, respectively. The 45-degree diagonal line represents the perfect calibration, where we can see that our FWBI shows the most superior calibration.



\subsection{Convergences of Wasserstein Bridge in FWBI}\label{apd:e5}


The convergence processes of 1-Wasserstein distance and 2-Wasserstein of FWBI in two toy examples are shown in Figure \ref{fig:g3conv} and Figure \ref{fig:g2conv}. We can see that both the 1-Wasserstein distance and 2-Wasserstein in our Wasserstein Bridge converge very quickly.


\subsection{Impact of Functional Properties of Prior on FWBI}\label{apd:e6}



\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
    \subfigure[prior-Matern0.5]{\label{fig:prior_m0.5}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/matern/matern0.5_prior}}%
    \qquad
    \subfigure[prior-Matern2.5]{\label{fig:prior_m2.5}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/matern/matern2.5_prior}}%
    \qquad
   \subfigure[posterior-Matern0.5]{\label{fig:posterior_m0.5}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/matern/matern0.5_posterior}}%
     \qquad  
    \subfigure[posterior-Matern2.5]{\label{fig:posterior_m2.5}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/matern/matern2.5_posterior}}%
    % \qquad
 }
\caption{The effect of the prior smoothness on FWBI posterior.}
\label{fig:prior_smooth}
\end{figure}


\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
%   {
    \subfigure[Naive prior]{\label{fig:prior_p}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/prior_naive}}%
    \qquad
    \subfigure[Naive posterior]{\label{fig:posterior_p}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/posterior_naive}}%
    % \qquad
    
    \subfigure[GP(0, 0.5)]{\label{fig:prior_(0, 0.5)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/prior_noise(0, 0.5)}}%
    \qquad
    \subfigure[GP(0, 0.5)]{\label{fig:posterior_(0, 0.5)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/posterior_noise(0, 0.5)}}%
    % \qquad
    
   \subfigure[GP(0, 3)]{\label{fig:prior_(0, 5)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/prior_noise(0, 3)}}%
    \qquad
     \subfigure[GP(0, 3)]{\label{fig:posterior_(0, 5)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/posterior_noise(0, 3)}}%
 % }
\caption{The effect of the prior noise with fixed mean on FWBI posterior (subtitles represent the injected noise). The top row is the naive GP prior and the corresponding FWBI posterior. The left column is the GP priors with different injected GP noises, and the right column is the corresponding FWBI posteriors.}
\label{fig:prior_noise_fixedmean}
\end{figure}



\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{-0.5cm}
\centering
% \floatconts
  {
    \subfigure[Naive prior]{\label{fig:prior_p}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/prior_naive}}%
    \qquad
    \subfigure[Naive posterior]{\label{fig:posterior_p}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/posterior_naive}}%
    % \qquad
    
   \subfigure[GP(0.3, 0.1)]{\label{fig:prior_(0.3, 1)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/prior_noise(0.3, 0.1)}}%
    \qquad
     \subfigure[GP(0.3, 0.1)]{\label{fig:posterior_(0.3, 1)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/posterior_noise(0.3, 0.1)}}%
    % \qquad
    
    \subfigure[GP(3, 0.1)]{\label{fig:prior_(3, 1)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/prior_noise(3, 0.1)}}%
    \qquad
    \subfigure[GP(3, 0.1)]{\label{fig:posterior_(3, 1)}%
      \includegraphics[width=0.45\linewidth]{UAI2024/prior_property/periodic/posterior_noise(3, 0.1)}}%
    % \qquad
 }
\caption{The effect of the prior noise with fixed variance on FWBI posterior (subtitles represent the injected noise). The top row is the naive GP prior and the corresponding FWBI posterior. The left column is the GP priors with different injected GP noises, and the right column is the corresponding FWBI posteriors.}
\label{fig:prior_noise_fixedvariance}
\end{figure}


In this section, we analyze the impact of functional properties (smoothness and noise) of prior on FWBI posterior. Consider a 1-D periodic function: $y = 2 * \sin(4x) + \epsilon$ with noise $\epsilon \sim \mathcal{N}(0,0.01)$, and randomly sample 20 training points from this function within $[-2, -0.5] \cup[0.5, 2]$. Firstly, we fit a GP with Matern kernel as the prior using these training data. Since the Matern kernel has a parameter $\nu$ used to control the smoothness of the functions from GP, we use $\nu$ to simulate different prior smoothness. The results are shown in Figure \ref{fig:prior_smooth}, where Figure \ref{fig:prior_m0.5} and \ref{fig:prior_m2.5} are the results from the unsmoothed ($\nu=0.5$) and smoothed ($\nu=2.5$) GP priors, respectively. The corresponding FWBI posteriors are shown in Figures \ref{fig:posterior_m0.5} and \ref{fig:posterior_m2.5}, where we can see the smoothness of the prior has some effects on the resulting posteriors, e.g., the prediction curves in \ref{fig:posterior_m2.5} on the interval $[0.75, 1.5]$ are smoother than those of \ref{fig:posterior_m0.5}.

Then, we investigate the impact of prior noise on the posteriors by adding different GP noises to the pre-trained GP prior (with periodic kernel). Specifically, we consider two situations: GP noises with fixed 0 mean and varying variances, and GP noises with fixed variance and varying means, respectively. The results are shown in Figures \ref{fig:prior_noise_fixedmean} and \ref{fig:prior_noise_fixedvariance}, where the left column shows several GP priors with different injected noises, and the right column shows the corresponding FWBI posteriors. We can observe that: 1) when the mean of noises is fixed, there are significant effects from varying variances (from 0.5 to 3). The larger noise variance tends to destroy the predictive accuracy of the training data region and the predictive uncertainty also increases noticeably with the increasing noise variance; 2) when the variance of noises is fixed, varying means (from 0.3 to 3) only have little effect on the region with training data, but destroy the prediction on the non-data regions. The larger changed means would lead to worse prediction.  




\subsection{NLL Results for UCI Regressions}
\label{apd:e7}

Table \ref{tab:nll} shows the average test negative log likelihood (NLL) results on UCI regression tasks. FWBI still shows competitive performance compared to other weight-space and function-space variational methods.
\begin{table}[!t]
% \setlength{\abovecaptionskip}{0.5cm}
% \setlength{\belowcaptionskip}{0.5cm}
\small
    \centering
    \caption{The table shows the average test NLL on several UCI regression tasks. We split each dataset randomly into 90\% of training data and 10\% of test data. This process is repeated 10 times to ensure validity.}
    \label{tab:nll}
    \begin{tabular}{lccccc}
    \hline
    \textbf{Dataset} & \text{FWBI} & \text{GWI} & \text{FBNN} & \text{WBBB} & \text{KLBBB}\\
    % \midrule
    \hline 
    \text{Yacht} & \textbf{-0.994}$\pm$\textbf{0.956} & 0.112 $\pm$ 0.757 & -0.770 $\pm$ 0.869 & 2.856 $\pm$ 0.186 & 2.512 $\pm$ 0.161\\
    \text{Boston} & 0.473 $\pm$ 0.306 & -1.043 $\pm$ 0.681 & \textbf{-1.193}$\pm$\textbf{0.763} & 2.656 $\pm$ 0.179 & 2.066 $\pm$ 0.115\\
    \text{Concrete}  & -0.254 $\pm$ 0.206 & -0.684 $\pm$ 0.492 & \textbf{-1.001}$\pm$\textbf{0.520} & 2.838 $\pm$ 0.152 & 2.614 $\pm$ 0.166\\
    \text{Wine} & \textbf{0.347}$\pm$\textbf{0.114} & 0.700 $\pm$ 0.159 & 0.524 $\pm$ 0.137 & 2.843 $\pm$ 0.147 & 2.148 $\pm$ 0.125\\
    \text{Kin8nm} & -1.499 $\pm$ 0.149 & \textbf{-2.604}$\pm$\textbf{0.237} & -2.445 $\pm$ 0.622 & 2.823 $\pm$ 0.066 & 2.614 $\pm$ 0.071\\
    \text{Protein} & \textbf{-2.252}$\pm$\textbf{0.319} & -1.575 $\pm$ 0.229 & -1.486 $\pm$ 0.238 & 2.744 $\pm$ 0.026 & 2.222 $\pm$ 0.020\\
    \hline 
    \end{tabular}
\end{table}


\subsection{Running Time Comparison for UCI Regressions}\label{apd:e8}
In order to compare the efficiency between FWBI and other inference approaches, we provide the running time comparisons on the small Boston dataset and the large Protein dataset in multivariate regression tasks. The Boston dataset has 455 training points with 13-dimensional features, while there are 41157 training points with 9 input dimensions in the larger Protein dataset. The GPU running time for all 2000 training epochs of each method is shown in Table \ref{tab:run}. 

We can see that in a small Boston dataset, the running time of FWBI is similar to parameter-space WBBB and KLBBB, and FWBI is nearly 10 times faster than FBNN. And for the large Protein dataset, the running time of FBNN and GWI is 2-4$\times$ higher than FWBI, which indicates that FWBI is very efficient. Additionally, the convergence processes of training loss for all methods are shown in Figure \ref{fig:uciconv}, FWBI shows significant advantages in terms of both convergence speed and stability. 

\begin{table}[!t]
% \setlength{\abovecaptionskip}{0.5cm}
% \setlength{\belowcaptionskip}{0.5cm}
\small
\centering
    \caption{Running time comparison on Boston and Protein dataset.}
    \label{tab:run}
    \begin{tabular}{lccccc}
    % \toprule
    \hline
    \textbf{Run time(s)} & \text{FWBI} & \text{GWI} & \text{FBNN} & \text{WBBB} & \text{KLBBB}\\
    % \midrule
    \hline 
    \text{Boston} & 22.25 $\pm$ 1.5000 & 16.00 $\pm$ 1.000 & 200.67 $\pm$ 6.351 & 8.33 $\pm$ 0.577 & 8.33 $\pm$ 0.577\\
     \text{Protein} & 134.250 $\pm$ 3.5000 & 472.67 $\pm$ 0.577 & 318.33 $\pm$ 5.774 & 8.00 $\pm$ 0.000 & 8.00 $\pm$ 1.000\\
    % \bottomrule
    \hline 
    \end{tabular}
\end{table}
% We can see that in a small Boston dataset, the running time of FWBI is similar to parameter-space WBBB and KLBBB, and FWBI is nearly 10 times faster than FBNN. And for large Protein datasets, the running time of FBNN and GWI is 5-7$\times$ higher than FWBI, which indicates that FWBI is very efficient. Additionally, the convergence processes of training loss for all methods are shown in Figure \ref{fig:uciconv}, FWBI shows significant advantages in terms of both convergence speed and stability. 

\begin{figure}[!t]
% \setlength{\abovecaptionskip}{0cm}
% \setlength{\belowcaptionskip}{0cm}
% \floatconts
\centering
  {
    \subfigure[KLBBB]{\label{fig:bs_kl}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/boston/klbnn}}%
    % \qquad
    \subfigure[WBBB]{\label{fig:bn_w}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/boston/wbnn}}%
    % \qquad
    \subfigure[FWBI]{\label{fig:bs_fwbi}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/train_loss_boston}}% 
    % \quad
    \subfigure[GWI]{\label{fig:bs_gwi}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/boston/gwi}}%
    % \qquad
    \subfigure[FBNN]{\label{fig:bs_fbnn}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/boston/fbnn}}%
    % \qquad
    
    \subfigure[KLBBB]{\label{fig:bs_kl}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/protein/klbnn}}%
    % \qquad
    \subfigure[WBBB]{\label{fig:bn_w}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/protein/wbnn}}%
    % \qquad
    \subfigure[FWBI]{\label{fig:bs_fwbi}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/train_loss_protein}}% 
    % \quad
    \subfigure[GWI]{\label{fig:bs_gwi}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/protein/gwi}}%
    % \qquad
    \subfigure[FBNN]{\label{fig:bs_fbnn}%
      \includegraphics[width=0.19\linewidth]{UAI2024/uci_conv/protein/fbnn}}%
    % \qquad
 }
\caption{Convergence of training loss for multivariate regression tasks. The top row is the results for the Boston dataset, and the bottom row is the results for the Protein dataset.}
\label{fig:uciconv}
\end{figure}


\subsection{ROC for OOD Detection in Classification Tasks}\label{apd:e9}


Figure \ref{fig:roc} shows the receiver operating characteristic curve (ROC) for all methods on OOD detection in image classification tasks. The closer the curve is to the upper left corner, the stronger the OOD detection capability. Our FWBI performs competitively in all three datasets.

\begin{figure}[!t]
    \centering
    {
    \subfigure[ROC (MNIST)]{\label{roc_m}
    \includegraphics[width=0.32\linewidth]{UAI2024/roc/rocmn.pdf}}
    \subfigure[ROC (FMNIST)]{\label{roc_fm}
    \includegraphics[width=0.32\linewidth]{UAI2024/roc/rocfm.pdf}}
    \subfigure[ROC (CIFAR-10)]{\label{roc_cf}
    \includegraphics[width=0.32\linewidth]{UAI2024/roc/roccf.pdf}}
    }
    \caption{Receiver operating characteristic curve (ROC) for out-of-distribution detection.}
    \label{fig:roc}
\end{figure}

\subsection{Wasserstein distance vs. KL divergence}



\begin{figure}[!t]
\centering
  {
     \subfigure[Initialization]{\label{fig:prior}%
      \includegraphics[width=0.33\linewidth]{UAI2024/distances/prior.pdf}}%
     \subfigure[Wasserstein]{\label{fig:wass}%
      \includegraphics[width=0.33\linewidth]{UAI2024/distances/Wasserstein.pdf}}%
      \subfigure[KL]{\label{fig:kl}%
      \includegraphics[width=0.33\linewidth]{UAI2024/distances/I-projection.pdf}}%
      
      \subfigure[Initialization]{\label{fig:prior1}%
      \includegraphics[width=0.33\linewidth]{UAI2024/distances/prior1.pdf}}%
     \subfigure[Wasserstein]{\label{fig:wass1}%
      \includegraphics[width=0.33\linewidth]{UAI2024/distances/Wasserstein1.pdf}}%
      \subfigure[KL]{\label{fig:kl1}%
    \includegraphics[width=0.33\linewidth]{UAI2024/distances/I-projection1.pdf}}%
}
\caption{Approximation results from different loss. The background contour field is a Gaussian mixture with three components. }
\label{fig:klw}
\end{figure}


We first define a Gaussian mixture model (GMM) as our target distribution,
\begin{equation}
    p(x) = 0.1 * \mathcal{N}\left(x; \begin{bmatrix}
0 \\
0 
\end{bmatrix}, 
\begin{bmatrix}
2 & 0 \\
0 & 2 \\
\end{bmatrix} \right ) + 
0.2 * \mathcal{N}\left(x; \begin{bmatrix}
20 \\
20 
\end{bmatrix}, 
\begin{bmatrix}
3 & 0 \\
0 & 3 \\
\end{bmatrix} \right ) +
0.7 * \mathcal{N}\left(x; \begin{bmatrix}
-10 \\
20 
\end{bmatrix}, 
\begin{bmatrix}
1 & 0 \\
0 & 1.5 \\
\end{bmatrix} \right )
\end{equation}
where three components are included with corresponding weights. The log-likelihood contour field is plotted in Figure \ref{fig:klw}. We then use a Gaussian distribution 
\begin{equation}
    q(x) = 0.1 * \mathcal{N}\left(x; \mu, 
\begin{bmatrix}
5 & 0 \\
0 & 5 \\
\end{bmatrix} \right )
\end{equation}
to approximate the above-defined GMM distribution, where $\mu$ is the mean parameter that needs to be optimized. Finally, we use KL divergence ($\mathrm{KL}[q || p]$) and Wasserstein distance ($W_1(q, p)$) as the loss function to optimize $\mu$, respectively. All other hyperparameters are the same for all, like optimizer, steps, and learning rates. 


The results are shown in Figure \ref{fig:klw}, where we set two different initializations (Figures \ref{fig:prior} and \ref{fig:prior1}). We can see that 
\begin{itemize}
    \item KL divergence is sensitive to initialization. For different initializations, there are two different results (Figures \ref{fig:kl} and \ref{fig:kl1}) from KL divergence. In contrast, the results from Wasserstein distance (Figures \ref{fig:wass} and \ref{fig:wass1}) are the same under different initializations. 
    \item Wasserstein could jump out of the local optimum and move close to the global optimal mode (which is the up-left corner one with the darkest colour in Figure \ref{fig:klw}). 
\end{itemize}


\section{Notation Table}
Table \ref{tab:note} is the notation table to demonstrate the notation used in this paper.
\begin{table}[!t]
% \setlength{\abovecaptionskip}{0.5cm}
% \setlength{\belowcaptionskip}{0.5cm}
\small
\centering
    \caption{Notation table}
    \label{tab:note}
    \begin{tabular}{cp{10cm}}
    % \toprule
    \hline
    \textbf{Notation} & \text{Meanings}\\
    % \midrule
    \hline 
    $ \mathcal{D}= \left\{\mathbf{X}_\mathcal{D}, \mathbf{Y}_\mathcal{D}\right\} $ & Training dataset\\
    $\mathcal{X} \subseteq \mathbb{R}^d$ & ($d$-dimensional) input space\\
    $\mathcal{Y} \subseteq \mathbb{R}^c$ & ($c$-dimensional) output space\\
    $\mathbf{X}$ & Finite marginal points\\
    $\mathbf{X}_{\mathcal{M}}$ & Finite measurement points\\
    $\mathbf{w} \in \mathbb{R}^k $ & Random model parameters for a BDL model (e.g., network weights of a BNN)\\
    $\mathbf{w}_{{\boldsymbol{b}}} \in \mathbb{R}^k$ & Random model parameters for a latent function\\
    $f(\cdot; \mathbf{w})$ & Random function mapping defined by a BDL model (e.g., a BNN)  parameterized by $\mathbf{w}$\\
    $g(\cdot; \mathbf{w}_{\boldsymbol{b}})$ & Random latent function parameterized by $\mathbf{w}_{\boldsymbol{b}}$\\
    $\boldsymbol{\theta_q}=\{\mathbf{\mu}_q, \mathbf{\Sigma}_q\} $ & Parameters for variational distribution\\
    $\boldsymbol{\theta_b}=\{\mathbf{\mu}_b, \mathbf{\Sigma}_b\} $ & Parameters for bridging distribution\\
    
    % $\mathbf{w}$ & Random model parameters for $f$\\
    % $\mathbf{w}_{\boldsymbol{b}}$ & Random function parameters for $g$\\
    % $\mathbb{R}^d$ & Input dimensions\\
    % $\mathbb{R}^c$ & Output dimensions\\
    % $\mathbb{R}^k$ & Parameter dimensions in $f$\\
    % $\mathbb{R}^h$ & Parameter dimensions in $g$\\
    $p_0(\mathbf{w})$ & Prior distribution over model parameters (e.g., prior over weights in a BNN)\\
    $p(\mathbf{w}|\mathcal{D})$ & Posterior over model parameters (e.g., posterior over weights in a BNN)\\
    $q(\mathbf{w}; \boldsymbol{\theta_q})$ & Variational posterior over model parameters (e.g., variational posterior over weights in a BNN)\\
     $p(\mathbf{w}_{{\boldsymbol{b}}}; \boldsymbol{\theta_b})$ & Bridging distribution over parameters\\
    $p_0(f)$ & Prior distribution over random functions\\
    $p(f|\mathcal{D})$ & Posterior over functions\\
    $q(f; \boldsymbol{\theta_q})$ & Variational posterior over functions\\
    % $q(\mathbf{w}; \boldsymbol{\theta_q})$ & Variational posterior over model parameters\\
    $p(g; \boldsymbol{\theta_b})$ & Bridging distribution over functions\\
    
    
    
    
    
    % \bottomrule
    \hline 
    \end{tabular}
\end{table}


\end{document}
