\documentclass[accepted]{uai2023} % for initial submission
% \documentclass[preprint]{uai2023} % after acceptance, for a revised
% version; also before submission to
% see how the non-anonymous paper
% would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2023} % ptmx math instead of Computer
% Modern (has noticeable issues)
% \documentclass[mathfont=newtx]{uai2023} % newtx fonts (improves upon
% ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

% Recommended, but optional, packages for figures and better typesetting:
\usepackage{microtype}
\usepackage{bmpsize}
\usepackage{booktabs}
\hypersetup{
    colorlinks=true,
    linkcolor=blue,
    urlcolor=blue,
    citecolor=blue,
    anchorcolor=blue}

% Attempt to make hyperref and algorithmic work together better:
\newcommand{\theHalgorithm}{\arabic{algorithm}}

% For theorems and such
\usepackage{amsmath}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{amsthm}
\usepackage{algorithmic,algorithm}

% if you use cleveref..
%\usepackage[capitalize,noabbrev]{cleveref}
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}

\usepackage[inline]{enumitem}
\usepackage[table, dvipsnames]{xcolor}
\usepackage{color}

\usepackage{amsmath}
\usepackage{float}
\usepackage{adjustbox}
\usepackage{caption} 
% \usepackage{subcaption}
% \usepackage{subfig}
\usepackage{mathtools, nccmath}
\usepackage{tikz}
\usepackage[algo2e, ruled,vlined,boxed,linesnumbered]{algorithm2e}
%\usepackage[noend]{algorithmic}
\SetArgSty{textnormal}
\usepackage{listings}
\usepackage{multicol}
\usepackage{wrapfig}
\usepackage{enumitem}
\usepackage{makecell}
\usepackage{upgreek}

\usepackage[toc,page]{appendix}      % for appendix

\usepackage{amsmath}		% for AMS macros
\usepackage{amssymb}		% for AMS symbols
\usepackage{amsfonts}		% for AMS fonts
\usepackage{amsthm}		% for theorems

\usepackage{mathtools}		% for advanced math
\usepackage{dsfont}		% for blackboard bold font

\usepackage{acronym}		% for acronyms

\renewcommand*{\aclabelfont}[1]{\acsfont{#1}}		% for acronym label font
\newcommand{\acli}[1]{\textit{\acl{#1}}}		% for italicized acro
\newcommand{\aclip}[1]{\textit{\aclp{#1}}}		% for italicized acro (plural)
\newcommand{\acdef}[1]{\textit{\acl{#1}} \textup{(\acs{#1})}\acused{#1}}		% for acro def
\newcommand{\acdefp}[1]{\textit{\aclp{#1}} \textup{(\acsp{#1})}\acused{#1}}	% for acro def (plural)

\newcommand{\afterhead}{.}
\newcommand{\ackperiod}{}		% for period bug in acknowledgments
\usepackage{titlesec}
\newcommand{\para}[1]{\smallskip\paragraph{\textbf{#1\afterhead}}}


\usepackage{quoting}			% for managing spaces with quotations
\quotingsetup{vskip=\medskipamount}

\usepackage{stmaryrd}		% for extra symbols
\usepackage{wasysym}		% for extra symbols

\usepackage{booktabs}		% for better tables

\usepackage[sort&compress,capitalize,nameinlink]{cleveref}		% for cleveref formatting
\crefname{assumption}{Assumption}{Assumptions}
%\crefname{algorithm}{Alg.}{Algs.}
\crefname{algo}{Algorithm}{Algorithms}
\crefname{example}{Example}{Examples}
\crefname{method}{Method}{Methods}
\newcommand{\crefrangeconjunction}{\textendash}		% for cleveref conjunctions
%\crefrangeformat{equation}{\upshape(#3#1#4)\textendash(#5#2#6)}
\creflabelformat{assumption}{\upshape(#2#1#3\upshape)}

\crefname{assumptionenum}{Assumption}{Assumptions}
\creflabelformat{assumptionenum}{#2#1#3}

\crefname{item}{}{}
\creflabelformat{item}{#2#1#3}

\crefname{eq}{}{}
\creflabelformat{eq}{\upshape(#2#1#3\upshape)}


% \def\endenv{\hfill{\small$\blacktriangle$}}
\usepackage{thmtools}		% for theorem tools
\usepackage{thm-restate}		% for restating theorems

\newtheorem{theorem}{Theorem}		% for theorems
\newtheorem{corollary}{Corollary}		% for corollaries
\newtheorem{lemma}{Lemma}		% for lemmas
\newtheorem{proposition}{Proposition}		% for propositions

\newtheorem{conjecture}{Conjecture}		% for conjectures
\newtheorem{claim}{Claim}		% for claims

% \newtheorem{example}{{\small$\blacktriangledown$} Example}		% for examples
% \newtheorem{algo}{{\small$\blacktriangledown$} Algorithm}	
\newcommand{\needref}{{\color{red}\upshape\textbf{[??]}}\xspace}	% for missing refs
\newcommand{\attn}{{\color{red}\upshape\textbf{[!!]}}\xspace}		% for attention

%\newcommand{\debug}[1]{{\color{MyRed}#1}}		% for macro coloring
\newcommand{\debug}[1]{#1}		% for removing macro coloring

% \newcommand{\commtag}[1]{\tag*{\small\{#1\}}}

% new commands
\newcommand{\colcircle}[1]{\tikz\draw[#1, fill=#1] (0,0) circle (.5ex);}
\definecolor{darkblue}{HTML}{1A254B}
\definecolor{lightblue}{HTML}{A7BED3}
\definecolor{blue}{HTML}{114083}
\definecolor{green}{HTML}{81B5AE}
\definecolor{pink}{HTML}{F2545B}
\definecolor{red}{HTML}{A4243B}
\definecolor{airforceblue}{rgb}{0.36, 0.54, 0.66}
\definecolor{thistle}{rgb}{0.85, 0.75, 0.85}
\definecolor{ticklemepink}{rgb}{0.99, 0.54, 0.67}
\definecolor{thulianpink}{rgb}{0.67, 0.24, 0.43}
\definecolor{tealblue}{rgb}{0.11, 0.36, 0.43}
\newcommand{\bl}[1]{\textcolor{tealblue}{#1}}
\newcommand{\rl}[1]{\textcolor{thulianpink}{#1}}


% math definitions
\newcommand{\defeq}{\vcentcolon=}
%\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\def\Pcal{\mathcal{P}}
\def\Ncal{\mathcal{N}}
\def\bP{\mathbf{P}}
\def\bC{\mathbf{C}}
\def\RR{\mathbb{R}}
\def\cW{\overline{W}_{\varepsilon}}
\def\We{W_{\varepsilon}}
\def\ones{\mathbf{1}}
\DeclarePairedDelimiterX{\dotp}[2]{\langle}{\rangle}{#1, #2}
\DeclareMathOperator*{\argminB}{argmin}
\DeclareMathOperator*{\argmaxB}{argmax} 
\DeclareMathOperator*{\argmin}{argmin} 
\def\rset{\mathbb{R}}
\def\rmd{\mathrm{d}}
\def\bfX{\mathbf{X}}
\def\Leb{\mathrm{Leb}}
\newcommand{\expe}[1]{\mathbb{E}[#1]}


\newcommand{\fdrift}{b_t}
\newcommand{\doob}{h_t}
\newcommand{\doobs}{h_{t,\reg}}
\newcommand{\Loss}{L}
\newcommand{\reg}{\tau}
\newcommand{\cvolatbase}{\beta}
\newcommand{\cvolat}[1][\ctime]{\cvolatbase_{#1}}


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% MACROS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%----------------------------------------------------------------------
%%% MACROS
%----------------------------------------------------------------------
% !TEX root = ./Main.tex


%**********************************************************************
%***    MACROS: GENERAL
%**********************************************************************
\newcommand{\newmacro}[2]{\newcommand{#1}{\debug{#2}}}		% for shorthand definitions
\newcommand{\newop}[2]{\DeclareMathOperator{#1}{\debug{#2}}}		% for shorthand definitions

%----------------------------------------------------------------------
%% Delimiters
%----------------------------------------------------------------------
\DeclarePairedDelimiter{\braces}{\{}{\}}		% for braces
\DeclarePairedDelimiter{\bracks}{[}{]}		% for brackets
\DeclarePairedDelimiter{\parens}{(}{)}		% for parentheses

\DeclarePairedDelimiter{\abs}{\lvert}{\rvert}		% for absolute value
\DeclarePairedDelimiter{\ceil}{\lceil}{\rceil}		% for ceiling
\DeclarePairedDelimiter{\floor}{\lfloor}{\rfloor}		% for floor
\DeclarePairedDelimiter{\clip}{[}{]}		% for clipping
\DeclarePairedDelimiter{\negpart}{[}{]_{-}}		% for negative part
\DeclarePairedDelimiter{\pospart}{[}{]_{+}}		% for positive part

\DeclarePairedDelimiterX{\inner}[2]{\langle}{\rangle}{#1, #2}		% for scalar product
%\DeclarePairedDelimiterX{\inner}[2]{\langle}{\rangle}{#1,#2}		% for scalar product

\DeclarePairedDelimiter{\norm}{\lVert}{ \rVert}		% for norm
\DeclarePairedDelimiterXPP{\twonorm}[1]{}{\lVert}{\rVert}{}{#1}		% for L2 norm
\DeclarePairedDelimiterXPP{\dnorm}[1]{}{\lVert}{\rVert}{_{\ast}}{#1}		% for dual norm
%\newcommand{\dnorm}[1]{\norm{#1}_{\ast}}		% for dual norm

\DeclarePairedDelimiter{\bra}{\langle}{\rvert}		% for bras
\DeclarePairedDelimiter{\ket}{\lvert}{\rangle}		% for kets
\DeclarePairedDelimiterX{\braket}[2]{\langle}{\rangle}{#1,#2}		% for brakets
%\DeclarePairedDelimiterX{\braket}[2]{\langle}{\rangle}{#1\mathopen{}\delimsize\vert\mathopen{}#2}

\DeclarePairedDelimiterX{\setdef}[2]{\{}{\}}{#1:#2}		% for set builder notation
\DeclarePairedDelimiterXPP{\exclude}[1]{\mathopen{}\setminus}{\{}{\}}{}{#1}


%----------------------------------------------------------------------
%% Modifiers
%----------------------------------------------------------------------
\newcommand{\alt}[1]{#1'}		% for alternates


%----------------------------------------------------------------------
%% Number fields
%----------------------------------------------------------------------
\newcommand{\F}{\mathbb{F}}		% generic field
\newcommand{\N}{\mathbb{N}}		% for naturals
\newcommand{\Z}{\mathbb{Z}}		% for integers
\newcommand{\Q}{\mathbb{Q}}		% for rationals
\newcommand{\R}{\mathbb{R}}		% for reals
\newcommand{\C}{\mathbb{C}}		% for complex numbers (may clash)

%----------------------------------------------------------------------
%% Operators
%----------------------------------------------------------------------
\DeclareMathOperator*{\argmax}{arg\,max}		% for argmax
%\DeclareMathOperator*{\argmin}{arg\,min}		% for argmin
\DeclareMathOperator*{\intersect}{\bigcap}		% for intersections
\DeclareMathOperator*{\union}{\bigcup}		% for unions

\DeclareMathOperator{\aff}{aff}		% for affine hull
\DeclareMathOperator{\bd}{bd}		% for boundary
\DeclareMathOperator{\bigoh}{\mathcal{O}}		% for Landau O
\DeclareMathOperator{\card}{card}		% for cardinality
\DeclareMathOperator{\cl}{cl}		% for closure
\DeclareMathOperator{\conv}{conv}		% for convex hull (but see also \simplex)
\DeclareMathOperator{\crit}{crit}		% for gap function
\DeclareMathOperator{\diag}{diag}		% for diagonal matrices
\DeclareMathOperator{\diam}{diam}		% for diameter
\DeclareMathOperator{\dist}{dist}		% for distance
\DeclareMathOperator{\dom}{dom}		% for domain
\DeclareMathOperator{\eig}{eig}		% for eigenvalues
\DeclareMathOperator{\ess}{ess}		% for essential
\DeclareMathOperator{\grad}{\nabla}		% for gradient
\DeclareMathOperator{\Hess}{Hess}		% for Hessian
\DeclareMathOperator{\ind}{ind}		% for index
\DeclareMathOperator{\im}{im}		% for image
\DeclareMathOperator{\intr}{int}		% for interior
\DeclareMathOperator{\Jac}{D}		% for Jacobian
\DeclareMathOperator{\one}{\mathds{1}}		% for indicator
\DeclareMathOperator{\proj}{pr}		% for projection
\DeclareMathOperator{\prox}{prox}		% for prox
\DeclareMathOperator{\rank}{rank}		% for rank
\DeclareMathOperator{\relint}{ri}		% for relative interior
\DeclareMathOperator{\sign}{sgn}		% for sign
\DeclareMathOperator{\supp}{supp}		% for support
\DeclareMathOperator{\Sym}{Sym}		% for symmetric
\DeclareMathOperator{\tr}{tr}		% for trace
\DeclareMathOperator{\unif}{unif}		% for uniform distribution
\DeclareMathOperator{\vol}{vol}		% for volume



%----------------------------------------------------------------------
%% Sundries
%----------------------------------------------------------------------
\newmacro{\coef}{\lambda}		% for coefficient
\newmacro{\dd}{\:\mathrm{d}}		% for integrators
\newmacro{\intR}{\int_{\R^{\vdim}}}		% for integration over full domains
\newmacro{\intRR}{\int_{\R^{\vdim}  \times \R^{\vdim}  }}		% for integration over double full domains
\newmacro{\nn}{\nonumber}		% for equations

\newcommand{\subs}{\leftarrow}      % for substitution


%\newcommand{\ddt}[1]{\frac{d#1}{dt}}		% for Leibniz
\newcommand{\ddt}{\frac{\mathrm{d}}{\mathrm{d}t}}		% for Leibniz
\newcommand{\ddc}{\frac{\partial}{\partial \point_{\coord} }}		% for Leibniz

\newcommand{\del}{\partial}		% for derivatives
\newcommand{\eps}{\varepsilon}		% for better epsilon
\newcommand{\pd}{\partial}		% for derivatives
\newcommand{\wilde}{\widetilde}		% for wide tildes

\newcommand{\insum}{\sum\nolimits}		% for compact sums
\newcommand{\inprod}{\prod\nolimits}		% for compact products

\newmacro{\pexp}{p}		% for first exponent
\newmacro{\qexp}{q}		% for second exponent
\newmacro{\rexp}{r}		% for third exponent

\newcommand{\dsum}{\oplus}		% for direct sums


\newcommand{\const}{ \mathrm{const.} }

%----------------------------------------------------------------------
%% Text and formatting
%----------------------------------------------------------------------
\newcommand{\cf}{cf.\xspace}		% for consistency
\newcommand{\eg}{e.g.,\xspace}		% for consistency
\newcommand{\ie}{i.e.,\xspace}		% for consistency
\newcommand{\vs}{vs.\xspace}		% for consistency

\newcommand{\textbrac}[1]{\textup[#1\textup]}		% for upshape brackets
\newcommand{\textpar}[1]{\textup(#1\textup)}		% for upshape parentheses

\newcommand{\dis}{\displaystyle}		% for forcing display style
\newcommand{\txs}{\textstyle}		% for forcing inline style



%----------------------------------------------------------------------
%% Riemannian Manifolds
%----------------------------------------------------------------------
\newmacro{\mfd}{\mathcal{M}}		% for metric tensor
\newmacro{\curve}{\gamma}          % for curves
\newcommand{\ptrans}[3]{ \Gamma_{#1 \rightarrow #2} \left( #3 \right)}      % for parallel trasnport
\newmacro{\sect}{\mathcal{K}}    % for sectional curvatures

%**********************************************************************
%***    MACROS: SET THEORY
%**********************************************************************

%----------------------------------------------------------------------
%% Points and sets
%----------------------------------------------------------------------
\newcommand{\from}{\colon}		% for function definition
\newcommand{\too}{\rightrightarrows}		% for correspondences
\newcommand{\injects}{\hookrightarrow}		% for injections
\newcommand{\surjects}{\twoheadrightarrow}		% for surjections

%\newcommand{\defeq}{\coloneqq}		% for direct definition
\newcommand{\eqdef}{\eqqcolon}		% for reverse definition

\newmacro{\sset}{\mathcal{S}}		% for generic set

%\newmacro{\points}{\mathcal{Z}}		% for point set
\newmacro{\points}{\mfd}		% for Riemannian RM
\newmacro{\intpoints}{\points^{\circ}}		%for point set interior
\newmacro{\point}{x}		% for generic point
\newmacro{\pointalt}{\alt\point}		% for alternate point

\newmacro{\dpoints}{\mathcal{W}}		% for second point set (duals, etc.)
\newmacro{\dpoint}{w}		% for second generic point
\newmacro{\dpointalt}{\alt\dpoint}		% for second alternate variable

\newmacro{\base}{p}		% for reference point
\newmacro{\basealt}{q}		% for alternate reference point

\newcommand{\test}[1][\point]{\hat{#1}}		% for test point (\point by default)
\newcommand{\tests}{\test[\points]}		% for set of test points

\newmacro{\open}{\mathcal{U}}		% for open sets
\newmacro{\closed}{\mathcal{C}}		% for closed sets
\newmacro{\cpt}{\mathcal{K}}		% for compact sets
\newmacro{\nbhd}{\mathcal{U}}		% for neighborhoods


%**********************************************************************
%*****	MACROS: SEQUENCES AND TIME SERIES
%**********************************************************************

%----------------------------------------------------------------------
%% Basic indices
%----------------------------------------------------------------------
\newmacro{\start}{1}		% for start index
\newmacro{\halfafterstart}{3/2}		% for second index
\newmacro{\afterstart}{2}		% for second index
\newmacro{\running}{\start,\afterstart,\dotsc}		% for running index
\newmacro{\halfrunning}{\start,\halfafterstart,\dotsc}

\newmacro{\runalt}{k}		% for running sequence index
\newmacro{\run}{n}		% for main sequence index
\newmacro{\nRuns}{T}		% for total number of runs
\newmacro{\runs}{\mathcal{\nRuns}}		% for set of runs


%----------------------------------------------------------------------
%% Sequences and recursions
%----------------------------------------------------------------------
\newmacro{\state}{Z}		% for main iterate
\newmacro{\dstate}{Y}		% for other iterate

\newcommand{\avg}[1][\state]{\bar{#1}}		% for averaging (X by default)
%\newcommand{\avg}[1][\state]{\debug{\bar#1}_{\nRuns}}		% for last ergodic state (X by default)
\newcommand{\new}[1][\point]{#1^{+}}		% for new iterate (x by default)

\newcommand{\init}[1][\state]{\debug{#1}_{\start}}		% for initial value (X by default)
\newcommand{\afterinit}[1][\state]{\debug{#1}_{\afterstart}}		% for second value (X by default)
\newcommand{\preiter}[1][\state]{\debug{#1}_{\runalt-1}}		% for iterated value (X by default)
\newcommand{\iter}[1][\state]{\debug{#1}_{\runalt}}		% for iterated value (X by default)
\newcommand{\afteriter}[1][\state]{\debug{#1}_{\runalt+1}}		% for iterated value (X by default)
\newcommand{\preprev}[1][\state]{\debug{#1}_{\run-2}}		% for previous value (X by default)
\newcommand{\prev}[1][\state]{\debug{#1}_{\run-1}}		% for previous value (X by default)
\newcommand{\curr}[1][\state]{\debug{#1}_{\run}}		% for current value (X by default)
\newcommand{\prelead}[1][\state]{\debug{#1}_{\run-1}^{+}}		% for current value (X by default)
\newcommand{\lead}[1][\state]{\debug{#1}_{\run}^{+}}		% for current value (X by default)
%\renewcommand{\next}[1][\state]{\debug{#1}_{\run+1}}		% for current value (X by default)
%\newcommand{\next}[1][\state]{\debug{#1}_{\run+1}}



%**********************************************************************
%***    MACROS: LINEAR ALGEBRA
%**********************************************************************

%----------------------------------------------------------------------
%% Vector spaces
%----------------------------------------------------------------------
\newmacro{\vecspace}{\R^{\vdim}}		% for generic vector space

\newmacro{\coord}{i}		% for index
\newmacro{\vdim}{d}		% for dimension
\newmacro{\vvec}{v}		% for generic vector
\newmacro{\bvec}{e}		% for basis vector
\newmacro{\bvecs}{\mathcal{E}}		% for basis vectors

\newmacro{\subspace}{\mathcal{W}}		% for subspace
\newmacro{\wvec}{w}		% for generic subspace vector
\newmacro{\subdim}{m}		% for subspace dimension

\newmacro{\tanhull}{\mathcal{Z}}		% for tangent hull
\newmacro{\tanvec}{z}		% for tangent vectors


%----------------------------------------------------------------------
%% Duality
%----------------------------------------------------------------------
\newcommand{\dual}[1]{#1^{\ast}}		% for dual variables
\newmacro{\dspace}{\dual\vecspace}		% for dual space
\newmacro{\dvec}{v}		% for dual vector
\newmacro{\dbvec}{\eps}		% for dual basis vectors


%----------------------------------------------------------------------
%% Matrices and vectors
%----------------------------------------------------------------------
%\newmacro{\ones}{\mathbf{1}}		% for vector of ones
\newmacro{\mat}{M}		% for generic matrix
\newmacro{\eye}{I}		% for identity matrix

\newcommand{\mg}{\succ}		% for positive-definite
\newcommand{\mgeq}{\succcurlyeq}		% for positive-semidefinite
\newcommand{\ml}{\prec}		% for negative-definite
\newcommand{\mleq}{\preccurlyeq}		% for negative-semidefinite






%**********************************************************************
%***    MACROS: PROBABILITY AND STATISTICS
%**********************************************************************

%----------------------------------------------------------------------
%% Probability
%----------------------------------------------------------------------
\DeclareMathOperator{\ex}{\mathbb{E}}		% for expectations
\DeclareMathOperator{\prob}{\mathbb{P}}		% for probability
\DeclareMathOperator{\Var}{Var}		% for variance
\DeclareMathOperator{\simplex}{\Delta}		% for simplices

\newmacro{\seed}{\omega}		% for seed
\newmacro{\seeds}{\Omega}		% for seed space
\newmacro{\history}{\mathcal{H}}		% for filtrations

\newmacro{\sample}{\omega}		% for samples
\newmacro{\samples}{\Omega}		% for sample space
\newmacro{\filter}{\mathcal{F}}		% for filtrations
\newmacro{\probspace}{(\samples,\filter,\prob)}		% for probability space

\newmacro{\event}{\mathcal{E}}       % for event
\newmacro{\eventalt}{\mathcal{H}}       % for alternate event
\newcommand{\comp}[1]{#1^{\mathtt{c}}}		% for complement

\newmacro{\mean}{\mu}		% for mean of distribution
\newmacro{\sdev}{\sigma}		% for mean of distribution
\newmacro{\variance}{\sdev^{2}}		% for mean of distribution

\newmacro{\dkl}{D_{\mathrm{KL}}}		% for Kullback Leibler
%\newcommand{\as}{\debug{\textpar{a.s.}}\xspace}		% for almost surely
\newcommand{\as}{{{a.s.}}}		% for almost surely


\providecommand\given{}		% empty command for conditionals

\DeclarePairedDelimiterXPP{\exof}[1]{\ex}{[}{]}{}{%		% for conditional expectations
\renewcommand\given{\nonscript\,\delimsize\vert\nonscript\,\mathopen{}} #1}

\DeclarePairedDelimiterXPP{\probof}[1]{\prob}{(}{)}{}{%		% for conditional probabilities
\renewcommand\given{\nonscript\:\delimsize\vert\nonscript\:\mathopen{}} #1}

\newcommand{\oneof}[1]{\one_{\{#1\}}}
%\DeclarePairedDelimiterXPP{\oneof}[1]{\one}{\{}{\}}{}{%		% for conditional expectations
%\renewcommand\given{\nonscript\,\delimsize\vert\nonscript\,\mathopen{}} #1}




%----------------------------------------------------------------------
%% Geometry
%----------------------------------------------------------------------
\newmacro{\gmat}{g}		% for metric tensor
\newmacro{\gdist}{\dist_{\gmat}}
\newmacro{\ball}{\mathbb{B}}		% for balls
\newmacro{\sphere}{\mathbb{S}}		% for spheres




%**********************************************************************
%***    MACROS: SCHROEDINGER BRIDGES
%**********************************************************************

\newmacro{\mbase}{\mu}
\newmacro{\m}{\mbase_0}     % for mean of the first gaussian
\newmacro{\malt}{\mbase_\horizon}     % for mean of the second gaussian

\newmacro{\covarbase}{\Sigma}
\newmacro{\covar}{\covarbase_0}     % for covariance of the first gaussian
\newmacro{\covaralt}{\covarbase_\horizon}     % for covariance of the first gaussian

\newmacro{\ctime}{t}
\newmacro{\ctimealt}{s}
\newmacro{\horizon}{1}

\newmacro{\ratiosym}{r}    % for symbol of the the ratio between \ctime and \horizon
\newcommand{\ratio}[1][\ctime]{\ratiosym_{#1}}     % for ratio between \ctime and \horizon
\newcommand{\ratioc}[1][\ctime]{\bar{\ratiosym}_{#1}}     % for complementary ratio between \ctime and \horizon




\newmacro{\scalingbase}{\mathrm{QV}}     % for quadratic variation
\newcommand{\scaling}[1][\ctime]{ \scalingbase\parens*{ #1 } } 
\newcommand{\scalingsq}[1][\ctime]{ \scalingbase^2\parens*{ #1 } }

\newmacro{\KLbase}{D_{\mathrm{KL}}}
\newcommand{\KL}[2]{ \KLbase\parens*{ #1 \Vert #2 } }


\newmacro{\sdebase}{ X }
\newcommand{\sde}[1][\ctime]{ \sdebase_{ #1 } }     % for marginal variables
\newcommand{\dsde}[1][\ctime]{ \dd\sdebase_{ #1 } }     % for marginal variables
%\newmacro{\dXt}{ \dd X_t }     % for SDEs
\newmacro{\tinv}{\tau}

\newmacro{\testfbase}{u}
\newmacro{\generator}{\mathcal{L}_{\ctime}}

%\newcommand{\testf}[2][\ctime]{\testfbase\parens*{#1, #2} }
%\newcommand*\Laplace{\mathop{}\!\mathbin\bigtriangleup}
\newcommand{\Laplace}{\Delta}

\usepackage{xparse}
\NewDocumentCommand{\testf}{ O{\ctime} O{\point} }{ \testfbase\parens*{#1,#2} }


\newmacro{\dconst}{\lambda}     % for constant drift
\newmacro{\sconst}{\mathbf{v}}     % for constant shift
\newmacro{\vconst}{\omega}     % for constant volatility


\renewcommand\theadgape{\Gape[4pt]}
\renewcommand\cellgape{\Gape[4pt]}

\newmacro{\qvbase}{\mathrm{q}}

\newcommand{\qv}[1][\ctime]{\qvbase\parens*{ #1 } }
\newcommand{\qve}[1][\horizon]{\qvbase\parens*{ #1 } }
\newcommand{\dqv}[1][\ctime]{ \dot{\qvbase}\parens*{ #1 } }

\newmacro{\subVPfbase}{ \beta }
\newcommand{\subVPf}[1][\ctime]{ {\subVPfbase}\parens*{ #1 } }

\newcommand{\aggtimeqr}[1][\ctime]{ \aggtimebase^4_{ #1 } }


\newmacro{\refsdebase}{ Y }     % for ref SDEs
\newcommand{\refsde}[1][\ctime]{ \refsdebase_{ #1 } }
\newcommand{\drefsde}[1][\ctime]{ \dd \refsdebase_{ #1 } }

\newmacro{\wiescalebase}{ g }
\newcommand{\wiescale}[1][\ctime]{ \wiescalebase\parens*{ #1 } }
\newcommand{\wiescalesq}[1][\ctime]{ \wiescalebase^2\parens*{ #1 } }

\newmacro{\QVbase}{ \mathrm{qv} }
\newcommand{\QV}[1][\ctime]{ \QVbase\parens*{ #1 }  }
\newcommand{\dQV}[1][\ctime]{ \dot{\QVbase}\patens*{ #1 }  }

\newmacro{\driftbase}{  c  }
\newcommand{\drift}[1][\ctime]{ \driftbase\parens*{ #1 }  }

\newmacro{\shiftbase}{  \alpha  }
\newcommand{\shift}[1][\ctime]{ \shiftbase\parens*{ #1 }  }

\newmacro{\volatbase}{  g  }
\newcommand{\volat}[1][\ctime]{ \volatbase_{ #1 }  }
\newcommand{\volatsq}[1][\ctimealt]{ \volatbase^2_{ #1 }  }

\newmacro{\refprobase}{\mathbb{Q}}     % for alternative general stochastic processes alphabet
\newcommand{\refpro}[1][\ctime]{\refprobase_{ #1 }}
\newmacro{\refjoint}{\refprobase_{\mathrm{0\horizon}} }    % for alternative general stochastic processes alphabet
\newcommand{\refapprox}[1][\ctime]{ {\refprobase}^{\solbase}_{ #1 }}

%----------------------------------------------------------------------
%%% OBSOLETE
%%----------------------------------------------------------------------
%\newmacro{\wienerbase}{\mathbb{W}}     % for reversible Wiener processes alphabet
%\newmacro{\Wt}{\wienerbase_{\ctime}}     % for reversible Wiener processes
%\newmacro{\dWt}{ \dd \Wt }     % for Wiener processes increments
%----------------------------------------------------------------------
%%% LINEAR SDES
%----------------------------------------------------------------------



\newmacro{\Wienerbase}{\mathbb{W}}     % for reversible Wiener processes alphabet
\newcommand{\Wiener}[1][\ctime]{ \Wienerbase_{#1} }     % for reversible Wiener processes
\newcommand{\dWiener}[1][\ctime]{ \dd \Wienerbase_{#1} }     % for Wiener processes increments

\newmacro{\aggtimebase}{  \tau  }
%\newcommand{\aggtime}[1][\ctime]{ \aggtimebase\parens*{ #1 }  }
%\newcommand{\aggtimeinv}[1][\ctimealt]{ \aggtimebase^{-1}\parens*{ #1 }  }
%\newcommand{\aggtimesq}[1][\ctimealt]{ \aggtimebase^2\parens*{ #1 }  }
%\newcommand{\aggtimesqinv}[1][\ctimealt]{ \aggtimebase^{-2}\parens*{ #1 }  }
%\newcommand{\daggtime}[1][\ctime]{ \dot{\aggtimebase}\parens*{ #1 }  }

\newcommand{\aggtime}[1][\ctime]{ \aggtimebase_{ #1 }  }
\newcommand{\aggtimeinv}[1][\ctimealt]{ \aggtimebase^{\ssstyle -1}_{ #1 }  }
\newcommand{\aggtimesq}[1][\ctimealt]{ \aggtimebase^2_{ #1 }  }
\newcommand{\aggtimesqinv}[1][\ctimealt]{ \aggtimebase^{-2}_{ #1 }  }
\newcommand{\daggtime}[1][\ctime]{ \dot{\aggtimebase}_{ #1 }  }
%\newmacro{\mYbase}{ Y }

\newmacro{\mrsdebase}{ \eta}
\NewDocumentCommand{\mYcinit}{ O{\ctime}  }{  \mrsdebase\parens*{#1 } } %  \middle| \refsde[0] } }

\newmacro{\kernelbase}{ \kappa}
\NewDocumentCommand{\kernel}{ O{\ctime} O{\ctime'}  }{  \kernelbase\parens*{#1, #2} }% \middle| \refsde[0] } }


\newmacro{\intdasq}{ \int_0^\ctime {\aggtimesqinv[\ctimealt]}{\volatsq} \dd \ctimealt }
\newmacro{\intdasqT}{ \int_0^\horizon {\aggtimesqinv[\ctimealt]}{\volatsq} \dd \ctimealt }


\newcommand{\law}[1]{ \mathrm{law}\parens*{ #1 } }





\newmacro{\ssstyle}{\scriptscriptstyle}
\newmacro{\sssNcal}{\ssstyle\Ncal}

\newmacro{\solbase}{\star}

\newmacro{\Cstar}{C_{\sdev_{\solbase}}}

\newmacro{\pbase}{\mathbb{P}}     % for general stochastic processes alphabet
%\newmacro{\Pt}{\pbase_{\ctime}}     % for general stochastic processes
\newmacro{\Pinit}{\pbase_{{0}}}     % for initial marginal
\newmacro{\Pend}{\pbase_{{\horizon}}}     % for end marginal
\newcommand{\Pmargin}[1][\ctime]{ \pbase_{{#1}} }     % for general marginals
%\newcommand{\Pro}[1][\ctime]{ \pbase_{#1} }     % for general marginals
\newcommand{\Psol}[1][\ctime]{ \pbase^{\solbase}_{ #1 } }
\newmacro{\Pjoint}{ \pbase_{ \mathrm{0\horizon}} }

\newmacro{\distbase}{ \pbase }
\newmacro{\ini}{ {0} }
\newmacro{\distinit}{ \hat{\distbase}_{ \ini } }
\newmacro{\en}{ {\horizon} }
\newmacro{\distend}{ \hat{\distbase}_{ \en} }


\newcommand{\Xsol}[1][\ctime]{ X^{\solbase}_{ #1 } }

\newcommand{\dXsol}[1][\ctime]{ \dd X^{\solbase}_{ #1 } }

\newcommand{\meansol}[1][\ctime]{ \mu^{\solbase}_{ #1 } }
\newcommand{\dmeansol}[1][\ctime]{ \dot{\mu}^{\solbase}_{ #1 } }
\newcommand{\Sigmasol}[1][\ctime]{ \Sigma^{\solbase}_{ #1 } }
\newcommand{\Sigmasolinv}[1][\ctime]{ \Sigma^{\solbase-1}_{ #1 } }
\newcommand{\dSigmasol}[1][\ctime]{ \dot{\Sigma}^{\solbase}_{ #1 } }

\newcommand{\dratio}[1][\ctime]{\dot{\ratiosym}_{#1}}     % for ratio between \ctime and \horizon
\newcommand{\dratioc}[1][\ctime]{\dot{\bar{\ratiosym}}_{#1}}     % for complementary ratio between \ctime and \horizon


\newcommand{\cmeansol}[2]{ \mu^{\solbase}_{ #1 \vert #2 } }
\newcommand{\cSigmasol}[2]{ \Sigma^{\solbase}_{ #1 \vert #2 } }

%\newmacro{\effsc}{\rho}    % for effectively scaling



\newmacro{\efftrbase}{\rho}    % for effectively scaling
\newcommand{\efftr}[1][\ctime]{  \efftrbase_{ #1 }  }
%\newcommand{\deffsc}[1][\ctime]{\dot{\effscbase}\parens*{#1}}
%\newcommand{\scalingsol}[1][\ctime]{ \scalingbase^\star\parens*{ #1 } }
%\newcommand{\scalingsolsq}[1][\ctime]{ \scalingbase^{\star2}\parens*{ #1 } }

\newmacro{\paramf}{ \theta }
\newmacro{\SBfbase}{ Z }
%\newmacro{\SBbbase}{ \SBfbase^{\scriptscriptstyle\textup{rev}} }
\newmacro{\paramb}{ \phi }
\newmacro{\SBbbase}{ \hat{\SBfbase} }
\NewDocumentCommand{\SBf}{ O{\ctime} O{\point} O{\paramf} }{ \SBfbase_{#1}^{#3}\parens{#2} }
\NewDocumentCommand{\SBb}{ O{\ctime} O{\point} O{\paramb} }{ \SBbbase_{#1}^{#3}\parens{#2} }

\newmacro{\GSBfbase}{ f_{\sssNcal} } %f_{\scriptscriptstyle\textup{GSB}} }
\newmacro{\GSBbbase}{ \hat{\GSBfbase}}%^{\scriptscriptstyle\textup{rev}} }
\NewDocumentCommand{\GSBf}{ O{\ctime} O{\point} }{ \GSBfbase\parens*{#1,#2} }

%\newcommand{\dscaling}[1][\ctime]{ \dot{\scalingbase}\parens*{ #1 } } 

\newmacro{\tshiftbase}{ \zeta }%[1][\ctime]{ \dot{\scalingbase}\parens*{ #1 } } 
\newcommand{\tshift}[1][\ctime]{ \tshiftbase\parens*{ #1 } } 


\newcommand{\Div}[1][\point]{ \nabla_{ {#1} } \cdot }
\newmacro{\dt}{ \dd \ctime}


\newmacro{\loss}{\ell}
%\newcommand{\lossf}{ \loss\parens*{\point_{\horizon};\paramf} }
\NewDocumentCommand{\lossf}{ O{\point_{\horizon}} O{\paramf} }{  \loss\parens*{ #1; #2 }}

%\newcommand{\lossb}{ \loss\parens*{\point_{0};\paramb} }
\NewDocumentCommand{\lossb}{ O{\point_{0}} O{\paramb} }{  \loss\parens*{ #1; #2 }}


\newmacro{\caching}{M}
\newmacro{\outeriter}{K_\textup{out}}
\newmacro{\inneriter}{K_\textup{in}}

\newmacro{\pretriterf}{K_{\paramf}}
\newmacro{\pretriterb}{K_{\paramb}}

\newmacro{\lrbase}{\gamma}
\newmacro{\lrf}{\lrbase_{\paramf}}
\newmacro{\lrb}{\lrbase_{\paramb}}


\newmacro{\Ninit}{\Ncal_{\ini}}
\newmacro{\Nend}{\Ncal_{\en}}

\newmacro{\tdriftbase}{ f }
%\newcommand{\tdrift}[1][\ctime]{ \tdriftbase\parens*{ #1 } }

\NewDocumentCommand{\tdrift}{ O{\ctime} O{\refsde} }{ \tdriftbase\parens*{#1,#2} }

\newmacro{\SB}{ {\scriptscriptstyle\textup{SB}} } %{\scriptscriptstyle\mathrm{SB}} }

\newcommand{\pSB}[1]{  p^{\SB}_{#1} }


%Problem setup:
\newcommand{\T}{\mathbb{T}}
\newcommand{\Sp}{\mathbb{S}}
\newcommand{\x}{\mathbf{x}}
\newcommand{\Pssol}{\pi^\star}

%----------------------------------------------------------------------
%%% ACRONYMS
%----------------------------------------------------------------------
\newacro{LHS}{left-hand side}
\newacro{RHS}{right-hand side}
\newacro{iid}[i.i.d.]{independent and identically distributed}
\newacro{lsc}[l.s.c.]{lower semi-continuous}


\newacro{GAN}{generative adversarial network}
\newacro{NN}{neural network}
\newacro{FTRL}{``follow the regularized leader''}
\newacro{wp1}[w.p.$1$]{with probability $1$}


\newacro{SDE}{stochastic differential equation}
\newacro{SB}{Schr\"odinger bridge}
\newacro{GSB}[GSB]{Gaussian Schr\"odinger bridge}


\newacro{SGM}{score-based generative model}

\newacro{SMLD}{score matching with Langevin dynamics}
\newacro{DDPM}{denoising diffusion probabilistic model}

\newacro{OU}{Ornstein\textendash Uhlenbeck}
\newacro{BM}{Brownian motion}
\newacro{BDT}{Black–Derman–Toy}



\newacro{VESDE}[VE SDE]{variance exploding \ac{SDE}}
\newacro{VPSDE}[VP SDE]{variance preserving \ac{SDE}}
\newacro{DSB}{diffusion Schr\"odinger bridge}
\newacro{IPF}{iterative proportional fitting}

\newmacro{\acroalg}{\textsc{GSBflow}}   % acronym for our overall algorithm
%\newacro{PSD}[p.s.d.]{positive definite}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% THEOREMS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\theoremstyle{plain}
%\newtheorem{theorem}{Theorem}[section]
%\newtheorem{proposition}[theorem]{Proposition}
%\newtheorem{lemma}[theorem]{Lemma}
%\newtheorem{corollary}[theorem]{Corollary}
%\theoremstyle{definition}
%\newtheorem{definition}[theorem]{Definition}
%\newtheorem{assumption}[theorem]{Assumption}
%\theoremstyle{remark}
%\newtheorem{remark}[theorem]{Remark}

% Todonotes is useful during development; simply uncomment the next line
%    and comment out the line below the next line to turn off comments
%\usepackage[disable,textsize=tiny]{todonotes}
% \usepackage[textsize=tiny]{todonotes}

% The \icmltitle you define below is probably too long as a header.
% Therefore, a short form for the running title is supplied here:
\newcommand{\swap}[3][-]{#3#1#2} % just an example

\renewcommand{\thefootnote}{\fnsymbol{footnote}}

\title{Aligned Diffusion Schr\"odinger Bridges (Supplementary Material)}

% The standard author block has changed for UAI 2023 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1,2]{Vignesh Ram Somnath$^*$}
\author[1,3]{Matteo Pariset$^*$}
\author[1]{Ya-Ping Hsieh}
\author[2]{\\Maria Rodriguez Martinez}
\author[1]{Andreas Krause}
\author[1]{Charlotte Bunne}
% Add affiliations after the authors
\affil[1]{%
    Department of Computer Science\\
    ETH Z\"urich
}
\affil[2]{%
    IBM Research Z\"urich
}
\affil[3]{%
    Department of Computer Science\\
    EPFL
  }

\usepackage{xr}
\externaldocument{somnath_658}

  
\begin{document}
\renewcommand\ttdefault{lmtt}

\onecolumn %% Turn this off if single column is desired for the supplement
\maketitle

\footnotetext[1]{Equal contribution.}

\section{Additional Results} 
\label{app:more_results}

% \subsection{Loss Function Design}
% \label{app:sec:loss_design}

\subsection{Variance Reduction}
\label{app:sec:var_reduction}

In this paragraph, we elaborate on the need to parametrize also Doob's $\doob$ function, along with the drift $\fdrift$.
Introducing $m^{\phi}$ removes the need to evaluate \eqref{eq:softdoob} which is difficult to approximate in practice on high-dimensional spaces. This equation amounts, in fact, to a Gaussian Kernel Density Estimation of the conditional probability $\prob(X_1 = \x_1\vert X_t = \x)$ along (unconditional) paths obtained from \eqref{eq:SB-SDE}.
Faithful approximations of \eqref{eq:softdoob} would, therefore, require:
\begin{itemize}
    \item good-quality paths, which are scarce at the beginning of training when the drift $\fdrift^\theta$ has not yet been learned;
    \item exponentially many trajectories (in the dimension of the state space);
    \item that points $x_1$ (obtained from conditional trajectories, Eq. \ref{eq:SB-SD-conditioned}) be reasonably close to $x_1$ (obtained from unconditional trajectories, Eq. \ref{eq:SB-SDE});
\end{itemize}
Even if all the above conditions were satisfied, the quantity $\doob(x) = \prob(X_1 = \x_1\vert X_t = \x)$ would still be challenging to directly manipulate. 
It is, in fact, much smaller at earlier times $t$ (see Table \ref{tab:doob_magnitude_evo}), since knowledge of the far past has a weaker influence on the location $X_1$ of particles at time $t=1$. 
Precision errors at $t \approx 0$ would then be amplified when computing the score of $\doob$ (i.e, $\nabla \log \doob$) --which appears in the loss \eqref{eq:loss_modified}-- and accumulate over timesteps, eventually leading trajectories astray.
By directly parameterizing the score, we instead sidestep this problem. The magnitude of $m^\phi_t \approx \nabla \log \doob$ can, in fact, be more easily controlled and regularized.
\begin{table}[h]
    \centering
    \begin{tabular}{l|ccccccc}
    \toprule
          & \multicolumn{7}{c}{\textbf{Time $t$}} \\
          & \textbf{0} & \textbf{0.15} & \textbf{0.30} & \textbf{0.45} & \textbf{0.60} & \textbf{0.75} & \textbf{0.90} \\
    \midrule
         \textbf{Mean $\doobs$ value} &  2.92e-14	& 4.03e-13 & 2.54e-11 & 1.72e-09 & 1.47e-07 & 2.66e-05 & 8.53e-3 \\
    \bottomrule
    \end{tabular}
    \caption{Average $\doobs$ values along paths, at different timesteps. $\prob(X_1 = \x_1\vert X_t = \x)$ ranges over 11 orders of magnitude across the time interval and is smallest when $t \approx 0$.}
    \label{tab:doob_magnitude_evo}
\end{table}


\subsection{Rigid Protein Docking}
\label{app:sec:rigid_docking}

\vspace{-2pt}
\para{Baselines} We compare our method to \textsc{EquiDock} as well as traditional docking software including \textsc{Attract}~\citep{attract2017,de2015web}, \textsc{HDock}~\citep{yan2020hdock}, \textsc{ClusPro}~\citep{desta2020performance,kozakov2017cluspro}, and \textsc{PatchDock}~\citep{mashiach2010integrated,schneidman2005patchdock}.
% ~\footnote{\textsc{ClusPro}: \url{https://cluspro.bu.edu/}, \textsc{Attract}: \url{www.attract.ph.tum.de/services/ATTRACT/ATTRACT.vdi.gz}, \textsc{PatchDock}: \url{https://bioinfo3d.cs.tau.ac.il/PatchDock/}, \textsc{HDock}: \url{http://huanglab.phys.hust.edu.cn/software/HDOCK/}}
As mentioned in the paragraph above, for ligands in the test set, we generate the corresponding unbound versions by applying the rotation and translation sampled during training. We evaluate the trained model from \textsc{EquiDock} and \textsc{SBalign} on these unbound structures and report corresponding evaluation metrics. For the remaining baselines, we include the numbers from \citep{ganea2022independent}. These baselines typically sample several candidate complexes by considering small increments of rotation angles. We expect this makes them somewhat invariant to arbitrary initialization, and the corresponding docking scores to not be severely impacted.

\begin{table}
    \caption{\textbf{Rigid docking results.} Complex and interface RMSD between predicted and true bound structures (after Kabsch alignment). $^*$ denotes methods for which we use values directly from \citep{ganea2022independent}. All other results show the performance on our test set. \vspace{-5pt}
    %As stated in the main text, the proprietary baselines might internally use parts of the test sets (e.g., to extract templates or features), thus their numbers might be optimistic. 
    }
    \label{app:tab:results_docking}
    \centering
    \adjustbox{max width=\linewidth}{%
    \begin{tabular}{lcccccc}
    \toprule
     & \multicolumn{6}{c}{\textbf{DB5.5 Test Set}} \\
     & \multicolumn{3}{c}{Complex RMSD} & \multicolumn{3}{c}{Interface RMSD}  \\
    \cmidrule(lr){2-7}
    \textbf{Methods} & Median & Mean & Std & Median & Mean & Std\\
    \midrule
     \textsc{Attract}$^*$ & 9.55 & 10.09 & 9.88 & 7.48 & 10.69 & 10.90 \\
     \textsc{HDock}$^*$ & 0.30 & 5.34 & 12.04 & 0.24 & 4.76 & 10.83 \\
     \textsc{ClusPro}$^*$& 3.38 & 8.25 & 7.92 & 2.31 & 8.71 & 9.89 \\
     \textsc{PatchDock}$^*$ &  18.26 & 18.00 & 10.12 & 18.88 & 18.75 &  10.06 \\
     \textsc{\textsc{EquiDock}}$^*$&  14.13 & 14.72  & 5.31 &  11.97 & 13.23 & 4.93  \\
     \cmidrule(lr){1-7}
     \textsc{\textsc{EquiDock}}& 14.12 & 14.73 & 5.31 & 11.97 & 13.23 & 4.93 \\
     \textsc{\bf{\textsc{SBalign}}}& 6.59 & 6.69 & 2.04 & 7.69 & 8.11 & 2.39 \\
     \bottomrule \vspace{-15pt}
    \end{tabular}
}
\end{table}


\begin{figure}[H]
    \centering
    \includegraphics[width=.9\textwidth]{figures/fig_pred_1NW9.pdf}
    \caption{Ground truth and predicted bound structures for the complex with PDB ID: 1NW9. \textsc{SBalign} is able to identify the true binding pocket.}
    \label{fig:results_docked_1NW9}
\end{figure}

\para{Results} The model performance is summarized in Table~\ref{app:tab:results_docking}. Our method \textsc{SBalign} considerably outperforms \textsc{EquiDock} across all metrics. \textsc{SBalign} also achieves comparable or better performance than traditional docking software without relying on extensive candidate sampling and re-ranking or learning surface templates from parts of the current test set. An example of docked structures, in direct comparison with \textsc{EquiDock} is displayed Fig.~\ref{fig:results_docked_1QA9}. Beyond the results in Table~\ref{app:tab:results_docking}, we display the ground truth and docked complexes in Figs.~\ref{fig:results_docked_1QA9},~\ref{fig:results_docked_1NW9}, and~\ref{fig:results_docked_1JIW}.

\begin{figure}[H]
    \centering
    \includegraphics[width=.9\textwidth]{figures/fig_pred_1JIW.pdf}
    \caption{Ground truth and predicted bound structures for the complex with PDB ID: 1JIW. \textsc{SBalign} is able to find the true binding interface compared to \textsc{EquiDock}.}
    \label{fig:results_docked_1JIW}
\end{figure}

\section{Datasets} 
\label{app:datasets}
\subsection{Synthetic Datasets}
In the following, we provide further insights and experimental results in order to access the performance of \textsc{SBalign} in comparison with different baselines and across tasks of various nature.
For each dataset, we describe in detail its origin as well as preprocessing and featurization steps.

\begin{figure}[ht]
    \centering
    \includegraphics[width=0.45\textwidth]{figures/fig_synthetic_datasets.pdf}
    \caption{Initial (\textit{blue}) and final (\textit{red}) marginals for the two toy datasets \textbf{(a)} moon and \textbf{(b)} T, together with arrows indicating a few alignments}
    \label{fig:synthetic_datasets}
\end{figure}

\para{Moon dataset}
The \texttt{moon} toy dataset (Fig.~\ref{fig:synthetic_datasets}a) is generated by first sampling $\distend$ and then applying a clockwise rotation of $233^\circ$ around the origin to obtain $\distinit$.
The points on the two semi-circumferences supporting $\distend$ are initially placed equally-spaced along each semi-circumference and then moved by applying additive Gaussian noise to both coordinates. While classic generative models will choose the shortest path and connect ends of both moons closest in Euclidean distance, only methods equipped with additional knowledge or insight on the intended alignment will be able to solve this task.

\para{T dataset}
This toy dataset (Fig.~\ref{fig:synthetic_datasets}b) is generated by placing an equal amount of samples at each of the four extremes of a T-shaped area having ratio between \textit{x} and \textit{y} dimensions equal to 51/55.
If run with a Brownian prior, classical \acp{SB} also fail on this dataset because they produce swapped pairings: i.e., they match the left (\textit{resp.} top) point cloud with the bottom (\textit{resp.} right) one.
At the same time, though, this dataset prevents reference drifts with simple analytical forms (such as spatially-symmetric or time-constant functions) from fixing classical \acp{SB} runs. It therefore illustrates the need for general, plug-and-play methods capable of generating approximate reference drifts to use in the computation of classical \acp{SB}.

\subsection{Cell Differentiation Datasets}
\label{app:datasets-cell_differentiation}
\begin{figure}[ht]
    \centering
    \includegraphics[width=0.65\linewidth]{figures/fig_overview_cells.pdf}
    \caption{Overview of \textsc{SBalign} in the setting of cell differentiation with the goal of learning the evolutionary process that morphs a population from its stat at $t$ to $t+1$. Through genetic tagging (i.e., barcodes) we are able to trace progenitor cells at time point $t$ into their descendants $t+1$. This provides us with an alignment between populations at consecutive time steps. Our goal is then to recover a stochastic trajectory from $\x_0$ to $\x_1$. To achieve this, we connect the characterization of a SDE conditioned on $\x_0$ and $\x_1$ (utilizing the Doob's \emph{$h$-transform}) with that of a Brownian bridge between $\x_0$ and $\x_1$ (classical Schr\"odinger bridge theory), leading to a simpler training procedure with lower variance and strong empirical results.}
    \label{fig:overview_cells}
\end{figure}

\para{Dataset description}
We obtain the datapoints used in our cell differentiation task from the dataset generated by \cite{weinreb2020lineage}, which contains 130861 observations/cells.
We follow the preprocessing steps in \citet{bunne2021learning} and use the Python package \texttt{scanpy} \citep{wolf2018scanpy}. After processing, each observation records the level of expression of 1622 different highly-variable genes as well as the following meta information per cell:
\begin{itemize}
    \item a \texttt{timestamp}, expressed in days and taking values in \{2, 4, 6\};
    \item a \texttt{barcode}, which is a short DNA sequence that allows tracing the identity of cells and their lineage by means of single-cell sequencing readouts;
    \item an additional \texttt{annotation}, which describes the current differentiation fate of the cell.
\end{itemize}
\begin{figure}[ht]
    \centering
    \includegraphics[width=.7\textwidth]{figures/fig_cell_marginals.pdf}
    \caption{Distribution of the cell population (i.e., marginals) at time $t=t_0$ and $t=t_1$ for (\textbf{a}) the ground truth, and (\textbf{b}) \textsc{SBalign}, after projection along their first two principal components.}
    \label{fig:results_cell_marginals}
\end{figure}


\para{Dataset preparation}
We only retain cells with barcodes that appear both on days 2 and 4, taking care of excluding cells that are already differentiated on day 2.
We construct matchings by pairing cells measured at two different times but which share the barcode. 
Additionally, we filter cells to make sure that no one appears in more than one pair.
To reduce the very high dimensionality of these datapoints, we perform a PCA projection down to 50 components.

We end up with a total of 4702 pairs of cells, which we partition into train, validation, and test sets according to the split 80\%/10\%/10\%.

\subsection{Protein Docking}

\subsubsection{Conformational Changes}
\label{app:datasets-protein_conf}

\para{Dataset description} For the task of predicting protein conformational changes, we utilize the D3PM dataset. The dataset consists of both unbound and bound structures for 4330 proteins, under different types of protein motions. The PDB IDs were downloaded from \url{https://www.d3pharma.com/D3PM/}. For the PDB IDs making up the dataset, we download the corresponding (.cif) files from the Protein Data Bank. 

\para{Dataset preparation} For the scope of this work, we only focus on protein structure pairs, where the provided RMSD between the C$\alpha$ carbon atoms is $>3$\r{A}, amounting to 2370 examples in the D3PM dataset. For each pair of structures, we first identify common residues, and compute the RMSD between C$\alpha$ carbon atoms of the common residues after superimposing them using the Kabsch \citep{kabsch1976solution} algorithm, and only accept the structure if the computed C$\alpha$ RMSD is within a certain margin of the provided C$\alpha$ RMSD. The rationale behind this step was to only retain examples where we could reconstruct the RMSD values provided with the dataset. Common residues are identified through a combination of residue position and name. This step is however prone to experimental errors, and we leave it to future steps to improve the common residue identification step (using potentially, a combination of common subsequences and/or residue positions). 

After applying the above preprocessing steps, we obtain a dataset with 1591 examples, which is then split into a train/valid/test split of 1291/150/150 examples respectively. The structures used in training and inference are the Kabsch superimposed versions, therefore ensuring that the Brownian bridges are sampled between the unbound and bound states of the proteins, and no artifacts are introduced by 3D rotations and translations, which do not contribute to conformational changes.

\para{Featurization} Following standard practice and for memory and computational efficiency, we only use the C$\alpha$ coordinates of the residues to represent our protein structures instead of the full-atom structures. For each amino acid residue, we compute the following features: a one hot encoding of the amino acid identity $f_e$ of size $23$, hydrophobicity $f_h \in [-4.5, 4.5]$, volume $f_v \in [60.1, 227.8]$, the charge $f_c \in \{-1, 0, 1\}$, polarity $f_p \in \{0, 1\}$, and whether the amino acid residue is a hydrogen bond donor $f_d \in \{0, 1\}$ or acceptor $f_a \in \{0, 1\}$. The hydropathy and volume features are expanded into a radial basis with interval sizes $0.1$ and $10$ respectively. To equip the model with a notion of time, we use a sinusoidal embedding of time $\phi(t)$ of embedding dimensionality $32$. These are concatenated to the amino acid features to form our input features for the amino acid residues. The edge features consist of a radial basis expansion of the distances between the residues. We also compute the spherical harmonics of the edge vectors between the residues, which is used in the tensor product message passing layers.

\para{Position at t} For any time $t$, we sample the positions of the C$\alpha$ atoms using the Brownian Bridge - given the coordinates $\x_0$ at $t=0$ and the coordinates $\x_1$ at $t=1$ with a Brownian bridge between $\x_0$ and $\x_1$, we know that $x_t \sim \mathcal{N}\left(x_t; (1-t)\x_0 + t\x_1, t(1-t)\right)$.

\subsubsection{Rigid Protein Docking}
\label{app:datasets-rigid_docking}

\para{Dataset description} We use the DB5.5 dataset \citep{vreven2015updates} for our empirical evaluation. The DB5.5 dataset is a standard dataset used in protein-protein docking, however, it only has 253 complexes. The dataset was downloaded from \url{https://zlab.umassmed.edu/benchmark/}. The dataset consists of both unbound and bound structures, but the structures are largely rigid, with an average complex RMSD of 0.96 between the bound and unbound structures. We utilize the same splits as EquiDock \citep{ganea2022independent}, with 203 complexes in the training set, 25 complexes in the validation set, and 25 complexes in the test set. For the evaluation in Table~\ref{app:tab:results_docking}, we use the full DB5.5 test set. For ligands in the test set, we generate the corresponding unbound versions by applying the rotation and translation sampled during training.

\para{Dataset preparation} Following similar convention as \textsc{EquiDock} \citep{ganea2022independent}, we treat the receptor as fixed. We use the same splits as EquiDock \citep{ganea2022independent}, with 203 complexes in the training set, 25 complexes each in the validation and test sets. For each ligand, the final 3D structure corresponds to its bound version, and the unbound version is generated by applying a random rotation $R$ and translation $b$ to the bound version. However, applying a different rotation and translation to each ligand would result in a different Brownian bridge, thus providing limited learning signal for the drift $\fdrift^{\theta}$. To avoid this, we create a rotation matrix $R$ during training by sampling a random angle between $30-45^{\circ}$ along each axis, and a translation $b$ with a maximum magnitude between $5.0-10.0$. The same $R$ and $b$ are also applied to the validation and test sets. We leave it to future work to extend the algorithm to work for arbitrary rotations and translations.  

\para{Featurization} Following standard practice and for memory and computational efficiency, we only use the C$\alpha$ coordinates of the residues to represent our protein structures instead of the full-atom structures. For each amino acid residue, we compute the following features: a one hot encoding of the amino acid identity $f_e$ of size $23$, hydrophobicity $f_h \in [-4.5, 4.5]$, volume $f_v \in [60.1, 227.8]$, the charge $f_c \in \{-1, 0, 1\}$, polarity $f_p \in \{0, 1\}$, and whether the amino acid residue is a hydrogen bond donor $f_d \in \{0, 1\}$ or acceptor $f_a \in \{0, 1\}$. The hydropathy and volume features are expanded into a radial basis with interval size $0.1$ and $10$ respectively. To equip the model with a notion of time, we use a sinusoidal embedding of time $\phi(t)$ of embedding dimensionality $32$. These are concatenated to the amino acid features to form our input features for the amino acid residues.

\para{Position at t} For any time $t$, we sample the positions of the C$\alpha$ atoms using the Brownian Bridge - given the coordinates $\x_0$ at $t=0$ and the coordinates $\x_1$ at $t=1$ with a Brownian bridge between $\x_0$ and $\x_1$, we know that $x_t \sim \mathcal{N}\left(x_t; (1-t)\x_0 + t\x_1, t(1-t)\right)$.

\section{Experimental Details}

In the following, we provide further experimental details on the chosen evaluation metrics, network architectures, and hyperparameters.

\subsection{Evaluation Metrics} 
\label{app:metrics}

\subsubsection{Cell Differentiation}
\label{app:metrics-cell_differentation}
For fairness of comparison between our method and the baseline (\textsc{fbSB}) ---which only works at the level of distribution of cells--- we also consider three evaluation metrics (i.e., $\We$, MMD and $\ell_2$) that capture the similarity between the end marginal $\distend$ and our prediction $\pi^\star_1$, irrespective of matchings.

In what follows, we denote with $\hat{\nu}$ the predicted end marginal $\pi^\star_1$ ---i.e., the predicted status of cells at day 4--- and with $\nu$ the distribution of observed transcriptomes.

\para{Wasserstein-2 distance} We measure accuracy of the predicted target population $\hat{\nu}$ to the observed target population $\nu$ using the entropy-regularized Wasserstein distance \citep{cuturi2013sinkhorn} provided in the \texttt{OTT} library \citep{jax2018github,cuturi2022optimal} defined as
\begin{equation}\label{eq:reg-ot}
\end{equation}
where $H(\bP) \defeq -\sum_{ij} \bP_{ij} (\log \bP_{ij} - 1)$ and the polytope $U(\hat{\nu},\nu)$ is the set of $n\times m$ matrices $\{\bP\in\mathbb{R}^{n \times m}_+, \bP\mathbf{1}_m = \hat{\nu}, \bP^\top\mathbf{1}_n=\nu\}$.

\para{Maximum mean discrepancy} Kernel maximum mean discrepancy~\citep{gretton2012kernel} is another metric to measure distances between distributions, i.e., in our case between predicted population $\hat{\nu}$ and observed one $\nu$.
Given two random variables $x$ and $y$ with distributions $\hat{\nu}$ and $\nu$, and a kernel function $\omega$, \citet{gretton2012kernel} define the squared MMD as:
\begin{equation*}
    \text{MMD}(\hat{\nu},\nu; \omega) = \mathbb{E}_{x,x^\prime}[\omega(x, x^\prime)] + \mathbb{E}_{y,y^\prime}[\omega(y, y^\prime)] - 2\mathbb{E}_{x,y}[\omega(x, y)].
\end{equation*}
We report an unbiased estimate of $\text{MMD}(\hat{\nu},\nu)$, in which the expectations are evaluated by averages over the population particles in each set. We utilize the RBF kernel, and as is usually done, report the MMD as an average over the length scales: $2, 1, 0.5, 0.1, 0.01, 0.005$.

\para{Perturbation signature $\ell_2$}
A common method to quantify the effect of a perturbation on a population is to compute its perturbation signature \citep[(PS)]{stathias2018drug}, computed via the difference in means between the distribution of perturbed states and control states of each feature, e.g., here individual genes. $\ell_2$(PS) then refers to the $\ell_2$-distance between the perturbation signatures computed on the observed and predicted distributions, $\nu$ and $\hat{\nu}$. The $\ell_2$(PS) is defined as
\begin{equation*}
    \text{PS}(\nu, \mu) = \frac{1}{m}\sum_{y_i \in \nu}{y_i} - \frac{1}{n}\sum_{x_i \in \mu}{x_i},
\end{equation*}
where $n$ is the size of the unperturbed and $m$ of the perturbed population.
We report the $\ell_2$ distance between the observed signature $\text{PS}(\nu, \mu)$ and the predicted signature $\text{PS}(\hat{\nu}, \mu)$, which is equivalent to simply computing the difference in the means between the observed and predicted distributions.

\para{RMSD}
To measure the quality of matchings sampled from \textsc{SBalign} $(\hat{x}^i_0, \hat{x}^i_1)$ ---compared to the observed ones $(x^i_0, x^i_1)$--- we compute:
\begin{equation}
    \text{RMSD}(\{x^i_1\}^n,\{\hat{x}^i_1\}^n) = \sqrt{\frac{1}{n}\sum^n_{i=1} \lVert x^i_1 - \hat{x}^i_1\rVert^2}
\end{equation}
which, when squared, represents the mean of the square norm of the differences between predicted and observed statuses of the cells on day 4.

\para{Cell type classification accuracy}
We assess the quality of \textsc{SBalign} trajectories by trying to predict the differentiation fate of cells, starting from (our compressed representation of) their transcriptome.
For this, we train a simple MLP-based classifier on observed cells and use it on the last time-frame of trajectories sampled from \textsc{SBalign} to infer the differentiation of cells on day 4.
We use the classifier \texttt{MLPClassifier} offered by the library \texttt{scikit-learn} with the following parameters:
\begin{itemize}
    \item 2 hidden layers, each with a hidden dimension of 50,
    \item the \textit{}{logistic} function as non-linearity
    \item $\ell_2$ norm, regularization with coefficient $0.1$.
\end{itemize}

We report the subset accuracy of the predictions on the \textit{test} set, measured as the number of labels (i.e., cell types) coinciding with the ground truth.

\subsubsection{Rigid Protein Docking}
\label{app:metrics-rigid_docking}

We report two metrics, Complex Root Mean Square Deviation (Complex RMSD), and Interface Root Mean Square Deviation (Interface RMSD). Following \citep{ganea2022independent}, the ground truth and predicted complex structures are first superimposed using the Kabsch algorithm \citep{kabsch1976solution}, and the Complex RMSD is then computed between the superimposed versions. A similar procedure is used for computing Interface RMSD, but only using the residues from the two proteins that are within $8\,$\r{A} of each other. Given a ligand with $m$ residues and a receptor with $n$ residues, we denote the predicted bound structures with $\mathbf{Z}' \in \R^{(n + m)}$ and the ground truth bound structure with $\mathbf{Z}^{*} \in \R^{(n + m)}$. We first superimpose the predicted and ground truth bound structures using the Kabsch algorithm and then compute the Complex RMSD as $\text{C}_{\text{rmsd}} = \sqrt{\frac{1}{n+m}\norm*{\mathbf{Z}' - \mathbf{Z}}_F}$. The Interface RMSD $\text{I}_{\text{rmsd}}$ is computed similarly, but only using the residues from the two proteins that are within 8\r{A} of each other.

\subsection{Network Architectures}
\label{app:architectures}

\subsubsection{Cell Differentiation and Synthetic Datasets}
\label{app:architectures-cell_diff_synthetic}

We parameterize both $b^\theta(t, X_t)$ and $m^\phi(t, b_t, X_t)$ using a model composed of:
\begin{enumerate}
    \item \textbf{\texttt{x\_enc}}: 3-layer MLP performing the expansion of spatial coordinates (or drift) into hidden states (of dimension 64 to 256);
    \item \textbf{\texttt{t\_enc}}: sinusoidal embedding of time (on 64 to 256 dimensions), followed by a two layer MLP;
    \item \textbf{\texttt{mlp}}: 3-layer MLP which maps the concatenation of embedded spatial and temporal information (output of modules 1 and 2 above) to drift magnitude values along each dimension.
\end{enumerate}

After every linear layer (except the last one), we apply a non-linearity and dropout (level 0.1).
In all the experiments, we set the diffusivity function $g(t)$ in \eqref{eq:SB-SDE} to a constant $g$, which is optimized (see \S~\ref{app:hyperparams}).

\subsubsection{Protein Conformational Changes}
\label{app:architectures-protein_conf}

As our architecture $b^{\theta}_t(X_t)$ suitable for approximating the true drift $b_t$, we construct a graph neural networks with tensor-product message passing layers using \texttt{e3nn} \citep{thomas2018tensor, geiger2022e3nn}. To build the graph, we consider a maximum of 40 neighbors --located within a radius of 40\r{A} for each residue. The model is SE(3) equivariant and receives node and edge features capturing relevant residue properties, and distance embeddings. For the baseline EGNN model, we consider the variant of the \textsc{EGNN} model proposed in \citet{xu2022geodiff}, owing to its strong performance on the molecule conformer generation task.

\subsubsection{Rigid Protein Docking}
\label{app:architectures-rigid_docking}

For the scope of this paper, we use a $\text{MLP}$ as $\fdrift^{\theta}$ and $m^{\phi}$. As inputs, both $\fdrift^{\theta}$ and $m^{\phi}$ receive input node features and the C$_\alpha$ coordinates at time $t$, as described in Section~\ref{app:hyperparams-rigid-docking}, with $m^{\phi}$ receiving the prediction $\fdrift^{\theta}$ as additional input. Both models have $3$ hidden layers, each with a dimension of $64$ and an output dimension of $3$, with around 50K parameters in total. Our current architectures are not equivariant to global rotations and translations, which is a desirable property in protein docking as the structures of the proteins themselves are invariant to the choice of reference coordinates frames. We leave a thorough exploration of other architectures, such as equivariant GNN architectures similar to those adopted in \citep{ganea2022independent} to future work. 


\subsection{Hyperparameters}
\label{app:hyperparams}

In the following, we will provide an overview of the selected hyperparameters as well as chosen training procedures.

\subsubsection{Synthetic Tasks}
\label{app:hyperparams-synthetic}

We perform hyper-parameter optimization using the Python package \texttt{ray.tune} \citep{liaw2018tune} on:
\begin{itemize}
    \item \textbf{activation}, chosen among \texttt{leaky\_relu}, \texttt{relu}, \texttt{selu} and \texttt{silu} as implemented in the Python library \texttt{PyTorch} \citep{pytorch2019paszke}. We find \texttt{selu} to achieve marginally better performance on toy datasets.
    \item \textbf{g}, the value of the diffusivity constant, chosen among $\{1, 2, 5, 10\}$. We find $g=1$ to yield optimal results.
\end{itemize}

\subsubsection{Cell Differentiation}
\label{app:hyperparams-cell_differentiation}

We perform hyper-parameter optimization using the Python package \texttt{ray.tune} \citep{liaw2018tune} on:
\begin{itemize}
    \item \textbf{activation}, chosen among \texttt{leaky\_relu}, \texttt{relu}, \texttt{selu}, and \texttt{silu} as implemented in the Python library \texttt{PyTorch} \citep{pytorch2019paszke}. We observe that \texttt{silu} brings noticeable performance improvements on the cell differentiation dataset.
    \item \textbf{g}, the value of the diffusivity constant, chosen among $\{0.01, 0.1, 0.8, 1, 1.2, 2, 5\}$. We find $g=1$ to yield optimal results.
\end{itemize}

\subsubsection{Protein Conformational Changes}
\label{app:hyperparams-protein_conf}

We use \texttt{AdamW} as our optimizer with a initial learning rate of $0.001$, and training batch size of $2$. For each protein pair, we sample $10$ timepoints in every epoch, so the model sees realizations from different timepoints of the corresponding Brownian Bridge. This was done to improve the training speed. We use a regularization strength of $1.0$ for $m^{\phi}$ for all $t$. Inference on the validation set using training is carried out using the exponential moving average of parameters, and the moving average is updated every optimization step with a decay rate of $0.9$. The model training is set to a maximum of $1000$ epochs but training is typically stopped after $200$ epochs beyond which no improvements in the validation metrics are observed. 

Our model has 0.54M parameters and is trained for 200 epochs. After every epoch, we simulate trajectories on the validation set using our model and compute the mean RMSD. The best model selected using this procedure is used for inference on the test set. The baseline \textsc{EGNN} model has 0.76M parameters and is trained for 1000 epochs.

\subsubsection{Rigid Protein Docking}
\label{app:hyperparams-rigid-docking}

We use \textsc{Adam} as our optimizer with a learning rate of $0.001$, and training batch size of $2$. For each ligand, we sample $5$ timepoints during every training epoch so that the model is exposed to different timepoints from the corresponding Brownian Bridge for each ligand. This number was chosen as a tradeoff between CUDA memory and coverage of timepoints between $0$ and $1$. We use a regularization strength of $1.0$ for $m^{\phi}$ for all $t$. Inference on the validation set using training is carried out using the exponential moving average of parameters, and the moving average is updated every optimization step with a decay rate of $0.999$. The model training is set to a maximum of $1000$ epochs but training is typically stopped after $100$ epochs beyond which no improvements in the validation metrics are observed. 

\section{Reproducibility}
Code utilized in this publication can be found at \url{https://github.com/vsomnath/aligned_diffusion_bridges}, with a mirror at \url{https://github.com/IBM/aligned_diffusion_bridges}.

\bibliography{somnath_658}

\end{document}