\documentclass[accepted]{uai2024} % after acceptance, for a revised version; 

\usepackage[american]{babel}
\usepackage[utf8]{inputenc} % allow utf-8 input
\usepackage[T1]{fontenc}    % use 8-bit T1 fonts
\usepackage{hyperref}       % hyperlinks
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography
\usepackage{xcolor}         % colors
\usepackage{pdfpages}       % plots
\usepackage{graphicx}       % plots
\usepackage{bm}             % bolds in math mode
\usepackage{tikz}           % check marks
\usepackage{algorithm}      % algorithms
\usepackage{amsmath}        % add math
\usepackage{algpseudocode}   % more algorithms
\usepackage{colortbl} 
     % allow colors for tables
\usepackage{amsmath,amssymb,amsthm,amsfonts,amsbsy,latexsym,dsfont,tikz}
\usepackage{booktabs}
%\usepackage{enumerate}
\usepackage[separate-uncertainty=true]{siunitx}
%\usepackage{wrapfig,lipsum}
\usepackage{comment}
\usepackage{capt-of}% or \usepackage{caption}
\usepackage{varwidth}
\usepackage{natbib}
\usepackage{cuted}
%\usepackage{dblfloatfix}
\usepackage{enumitem}
\usepackage{float}


\newenvironment{enumeratei}{\begin{enumerate}[\upshape (i)]}{\end{enumerate}}
\newenvironment{enumeratea}{\begin{enumerate}[\upshape (a)]}{\end{enumerate}}
\newenvironment{enumeraten}{\begin{enumerate}[\upshape 1.]}{\end{enumerate}}
\newenvironment{enumerateA}{\begin{enumerate}[\upshape (A)]}{\end{enumerate}}

\newcolumntype{P}[1]{>{\centering\arraybackslash}p{#1}}


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%    
%% Local environments 
\theoremstyle{remark}
\newtheorem{thm}{Theorem}[section]
\newtheorem{lem}[thm]{Lemma}
\newtheorem{cor}[thm]{Corollary}
\newtheorem{prop}[thm]{Proposition}
\newtheorem{defn}[thm]{Definition}
\newtheorem{rem}[thm]{Remark}
\newtheorem{hyp}[thm]{Hypothesis}
\newtheorem{ass}[thm]{Assumption}
\newtheorem{exc}[section]{Exercise}
\newtheorem{ex}[thm]{Example}
\newtheorem{conj}[thm]{Conjecture}   
%% End Local environments
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Local macros
\renewcommand{\leq}{\leqslant} 
\renewcommand{\geq}{\geqslant} 
\newcommand{\eset}{\varnothing}
\newcommand{\ra}{\rangle}
\newcommand{\la}{\langle} 
\newcommand{\wt}{\widetilde}
\newcommand{\pms}{\{-1,1\}}
\newcommand{\ind}{\mathds{1}}
\newcommand{\eps}{\varepsilon}
\newcommand{\To}{\longrightarrow}
\newcommand{\norm}[1]{\left\Vert#1\right\Vert}
\newcommand{\abs}[1]{\left\vert#1\right\vert}
\newcommand{\set}[1]{\left\{#1\right\}}
\newcommand{\goesto}{\longrightarrow}
\newcommand{\Real}{\mathds{R}}
\newcommand{\Complex}{\mathds{C}}
\newcommand{\ie}{\emph{i.e.,}}
\newcommand{\as}{\emph{a.e.}}
\newcommand{\eg}{\emph{e.g.,}}
\renewcommand{\iff}{\Leftrightarrow} 
\newcommand{\equald}{\stackrel{\mathrm{d}}{=}}
\newcommand{\probc}{\stackrel{\mathrm{P}}{\longrightarrow}}
\newcommand{\weakc}{\stackrel{\mathrm{w}}{\longrightarrow}}
\newcommand{\fpar}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\spar}[2]{\frac{\partial^2 #1}{\partial #2^2}}
\newcommand{\mpar}[3]{\frac{\partial^2 #1}{\partial #2 \partial #3}}  
\def\qed{ \hfill $\blacksquare$}  
\newcommand{\rd}{\mathrm{d}}
\newcommand{\re}{\mathrm{e}}
\renewcommand{\div}{\mbox{div}}
\def\checkmark{\tikz\fill[scale=0.4](0,.35) -- (.25,0) -- (1,.7) -- (.25,.15) -- cycle;} 
%% End Local macros
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Greek symbols
\let\ga=\alpha \let\gb=\beta \let\gc=\gamma \let\gd=\delta \let\geqe=\epsilon
\let\gf=\varphi \let\gh=\eta \let\gi=\iota  \let\gk=\kappa \let\gl=\lambda \let\gm=\mu      \let\gn=\nu \let\go=\omega \let\gp=\pi \let\gr=\rho \let\gs=\sigma \let\gt=\tau \let\gth=\vartheta
\let\gx=\chi \let\gy=\upsilon \let\gz=\zeta
\let\gC=\Gamma \let\gD=\Delta \let\gF=\Phi \let\gL=\Lambda \let\gTh=\Theta
\let\gO=\Omega   \let\gP=\Pi    \let\gPs=\Psi  \let\gS=\Sigma \let\gU=\Upsilon \let\gX=\Chi
\let\gY=\Upsilon                                
%% End Greek Symbols
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% MathCal symbols
\newcommand{\mc}[1]{\mathcal{#1}}
\newcommand{\cA}{\mathcal{A}}\newcommand{\cB}{\mathcal{B}}\newcommand{\cC}{\mathcal{C}}
\newcommand{\cD}{\mathcal{D}}\newcommand{\cE}{\mathcal{E}}\newcommand{\cF}{\mathcal{F}}
\newcommand{\cG}{\mathcal{G}}\newcommand{\cH}{\mathcal{H}}\newcommand{\cI}{\mathcal{I}}
\newcommand{\cJ}{\mathcal{J}}\newcommand{\cK}{\mathcal{K}}\newcommand{\cL}{\mathcal{L}}
\newcommand{\cM}{\mathcal{M}}\newcommand{\cN}{\mathcal{N}}\newcommand{\cO}{\mathcal{O}}
\newcommand{\cP}{\mathcal{P}}\newcommand{\cQ}{\mathcal{Q}}\newcommand{\cR}{\mathcal{R}}
\newcommand{\cS}{\mathcal{S}}\newcommand{\cT}{\mathcal{T}}\newcommand{\cU}{\mathcal{U}}
\newcommand{\cV}{\mathcal{V}}\newcommand{\cW}{\mathcal{W}}\newcommand{\cX}{\mathcal{X}}
\newcommand{\cY}{\mathcal{Y}}\newcommand{\cZ}{\mathcal{Z}}  
%% End MathCal symbols
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Math Boldface Symbols
\newcommand{\vzero}{\mathbf{0}}\newcommand{\vone}{\mathbf{1}}\newcommand{\vtwo}{\mathbf{2}}
\newcommand{\vthree}{\mathbf{3}}\newcommand{\vfour}{\mathbf{4}}\newcommand{\vfive}{\mathbf{5}}
\newcommand{\vsix}{\mathbf{6}}\newcommand{\vseven}{\mathbf{7}}\newcommand{\veight}{\mathbf{8}}
\newcommand{\vnine}{\mathbf{9}}\newcommand{\vA}{\mathbf{A}}\newcommand{\vB}{\mathbf{B}}
\newcommand{\vC}{\mathbf{C}}\newcommand{\vD}{\mathbf{D}}\newcommand{\vE}{\mathbf{E}}
\newcommand{\vF}{\mathbf{F}}\newcommand{\vG}{\mathbf{G}}\newcommand{\vH}{\mathbf{H}}
\newcommand{\vI}{\mathbf{I}}\newcommand{\vJ}{\mathbf{J}}\newcommand{\vK}{\mathbf{K}}
\newcommand{\vL}{\mathbf{L}}\newcommand{\vM}{\mathbf{M}}\newcommand{\vN}{\mathbf{N}}
\newcommand{\vO}{\mathbf{O}}\newcommand{\vP}{\mathbf{P}}\newcommand{\vQ}{\mathbf{Q}}
\newcommand{\vR}{\mathbf{R}}\newcommand{\vS}{\mathbf{S}}\newcommand{\vT}{\mathbf{T}}
\newcommand{\vU}{\mathbf{U}}\newcommand{\vV}{\mathbf{V}}\newcommand{\vW}{\mathbf{W}}
\newcommand{\vX}{\mathbf{X}}\newcommand{\vY}{\mathbf{Y}}\newcommand{\vZ}{\mathbf{Z}}
\newcommand{\va}{\mathbf{a}}\newcommand{\vb}{\mathbf{b}}\newcommand{\vc}{\mathbf{c}}
\newcommand{\vd}{\mathbf{d}}\newcommand{\ve}{\mathbf{e}}\newcommand{\vf}{\mathbf{f}}
\newcommand{\vg}{\mathbf{g}}\newcommand{\vh}{\mathbf{h}}\newcommand{\vi}{\mathbf{i}}
\newcommand{\vj}{\mathbf{j}}\newcommand{\vk}{\mathbf{k}}\newcommand{\vl}{\mathbf{l}}
\newcommand{\vm}{\mathbf{m}}\newcommand{\vn}{\mathbf{n}}\newcommand{\vo}{\mathbf{o}}
\newcommand{\vp}{\mathbf{p}}\newcommand{\vq}{\mathbf{q}}\newcommand{\vr}{\mathbf{r}}
\newcommand{\vs}{\mathbf{s}}\newcommand{\vt}{\mathbf{t}}\newcommand{\vu}{\mathbf{u}}
\newcommand{\vv}{\mathbf{v}}\newcommand{\vw}{\mathbf{w}}\newcommand{\vx}{\mathbf{x}}
\newcommand{\vy}{\mathbf{y}}\newcommand{\vz}{\mathbf{z}} 
%% End Math Boldface Symbols
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
%% Math Bold Symbols commands  
\newcommand{\mv}[1]{\boldsymbol{#1}}\newcommand{\mvinfty}{\boldsymbol{\infty}}\newcommand{\mvzero}{\boldsymbol{0}}
\newcommand{\mvone}{\boldsymbol{1}}\newcommand{\mvtwo}{\boldsymbol{2}}\newcommand{\mvthree}{\boldsymbol{3}}
\newcommand{\mvfour}{\boldsymbol{4}}\newcommand{\mvfive}{\boldsymbol{5}}\newcommand{\mvsix}{\boldsymbol{6}}
\newcommand{\mvseven}{\boldsymbol{7}}\newcommand{\mveight}{\boldsymbol{8}}\newcommand{\mvnine}{\boldsymbol{9}}
\newcommand{\mvA}{\boldsymbol{A}}\newcommand{\mvB}{\boldsymbol{B}}\newcommand{\mvC}{\boldsymbol{C}}
\newcommand{\mvD}{\boldsymbol{D}}\newcommand{\mvE}{\boldsymbol{E}}\newcommand{\mvF}{\boldsymbol{F}}
\newcommand{\mvG}{\boldsymbol{G}}\newcommand{\mvH}{\boldsymbol{H}}\newcommand{\mvI}{\boldsymbol{I}}
\newcommand{\mvJ}{\boldsymbol{J}}\newcommand{\mvK}{\boldsymbol{K}}\newcommand{\mvL}{\boldsymbol{L}}
\newcommand{\mvM}{\boldsymbol{M}}\newcommand{\mvN}{\boldsymbol{N}}\newcommand{\mvO}{\boldsymbol{O}}
\newcommand{\mvP}{\boldsymbol{P}}\newcommand{\mvQ}{\boldsymbol{Q}}\newcommand{\mvR}{\boldsymbol{R}}
\newcommand{\mvS}{\boldsymbol{S}}\newcommand{\mvT}{\boldsymbol{T}}\newcommand{\mvU}{\boldsymbol{U}}
\newcommand{\mvV}{\boldsymbol{V}}\newcommand{\mvW}{\boldsymbol{W}}\newcommand{\mvX}{\boldsymbol{X}}
\newcommand{\mvY}{\boldsymbol{Y}}\newcommand{\mvZ}{\boldsymbol{Z}}\newcommand{\mva}{\boldsymbol{a}}
\newcommand{\mvb}{\boldsymbol{b}}\newcommand{\mvc}{\boldsymbol{c}}\newcommand{\mvd}{\boldsymbol{d}}
\newcommand{\mve}{\boldsymbol{e}}\newcommand{\mvf}{\boldsymbol{f}}\newcommand{\mvg}{\boldsymbol{g}}
\newcommand{\mvh}{\boldsymbol{h}}\newcommand{\mvi}{\boldsymbol{i}}\newcommand{\mvj}{\boldsymbol{j}}
\newcommand{\mvk}{\boldsymbol{k}}\newcommand{\mvl}{\boldsymbol{l}}\newcommand{\mvm}{\boldsymbol{m}}
\newcommand{\mvn}{\boldsymbol{n}}\newcommand{\mvo}{\boldsymbol{o}}\newcommand{\mvp}{\boldsymbol{p}}
\newcommand{\mvq}{\boldsymbol{q}}\newcommand{\mvr}{\boldsymbol{r}}\newcommand{\mvs}{\boldsymbol{s}}
\newcommand{\mvt}{\boldsymbol{t}}\newcommand{\mvu}{\boldsymbol{u}}\newcommand{\mvv}{\boldsymbol{v}}
\newcommand{\mvw}{\boldsymbol{w}}\newcommand{\mvx}{\boldsymbol{x}}\newcommand{\mvy}{\boldsymbol{y}}
\newcommand{\mvz}{\boldsymbol{z}}
\newcommand{\mvga}{\boldsymbol{\alpha}}\newcommand{\mvgb}{\boldsymbol{\beta}}\newcommand{\mvgc}{\boldsymbol{\gamma}}
\newcommand{\mvgC}{\boldsymbol{\Gamma}}\newcommand{\mvgd}{\boldsymbol{\delta}}\newcommand{\mvgD}{\boldsymbol{\Delta}}
\newcommand{\mvgee}{\boldsymbol{\epsilon}}\newcommand{\mveps}{\boldsymbol{\eps}}\newcommand{\mvgf}{\boldsymbol{\varphi}}
\newcommand{\mvgF}{\boldsymbol{\Phi}}\newcommand{\mvgth}{\boldsymbol{\theta}}\newcommand{\mvgTh}{\boldsymbol{\Theta}}
\newcommand{\mvgh}{\boldsymbol{\eta}}\newcommand{\mvgi}{\boldsymbol{\iota}}\newcommand{\mvgk}{\boldsymbol{\kappa}}
\newcommand{\mvgl}{\boldsymbol{\lambda}}\newcommand{\mvgL}{\boldsymbol{\Lambda}}\newcommand{\mvgm}{\boldsymbol{\mu}}
\newcommand{\mvgn}{\boldsymbol{\nu}}\newcommand{\mvgo}{\boldsymbol{\omega}}\newcommand{\mvgO}{\boldsymbol{\Omega}}
\newcommand{\mvgp}{\boldsymbol{\pi}}\newcommand{\mvgP}{\boldsymbol{\Pi}}\newcommand{\mvgr}{\boldsymbol{\rho}}
\newcommand{\mvgs}{\boldsymbol{\sigma}}\newcommand{\mvgS}{\boldsymbol{\Sigma}}\newcommand{\mvgt}{\boldsymbol{\tau}}
\newcommand{\mvgu}{\boldsymbol{\upsilon}}\newcommand{\mvgU}{\boldsymbol{\Upsilon}}\newcommand{\mvgx}{\boldsymbol{\chi}}
\newcommand{\mvgz}{\boldsymbol{\zeta}}\newcommand{\mvalpha}{\boldsymbol{\alpha}}\newcommand{\mvbeta}{\boldsymbol{\beta}}
\newcommand{\mvgamma}{\boldsymbol{\gamma}}\newcommand{\mvGamma}{\boldsymbol{\Gamma}}\newcommand{\mvdelta}{\boldsymbol{\delta}}
\newcommand{\mvDelta}{\boldsymbol{\Delta}}\newcommand{\mvepsilon}{\boldsymbol{\epsilon}}\newcommand{\mvphi}{\boldsymbol{\phi}}
\newcommand{\mvPhi}{\boldsymbol{\Phi}}\newcommand{\mvtheta}{\boldsymbol{\theta}}\newcommand{\mvTheta}{\boldsymbol{\Theta}}
\newcommand{\mveta}{\boldsymbol{\eta}}\newcommand{\mviota}{\boldsymbol{\iota}}\newcommand{\mvkappa}{\boldsymbol{\kappa}}
\newcommand{\mvlambda}{\boldsymbol{\lambda}}\newcommand{\mvLambda}{\boldsymbol{\Lambda}}\newcommand{\mvmu}{\boldsymbol{\mu}}
\newcommand{\mvnu}{\boldsymbol{\nu}}
\newcommand{\mvomega}{\boldsymbol{\omega}}\newcommand{\mvOmega}{\boldsymbol{\Omega}}\newcommand{\mvpi}{\boldsymbol{\pi}}
\newcommand{\mvPi}{\boldsymbol{\Pi}}\newcommand{\mvrho}{\boldsymbol{\rho}}\newcommand{\mvsigma}{\boldsymbol{\sigma}}
\newcommand{\mvSigma}{\boldsymbol{\Sigma}}\newcommand{\mvtau}{\boldsymbol{\tau}}\newcommand{\mvupsilon}{\boldsymbol{\upsilon}}
\newcommand{\mvUpsilon}{\boldsymbol{\Upsilon}}\newcommand{\mvchi}{\boldsymbol{\chi}}\newcommand{\mvzeta}{\boldsymbol{\zeta}}
\newcommand{\mvvartheta}{\boldsymbol{\vartheta}}\newcommand{\mvxi}{\boldsymbol{\xi}}\newcommand{\mvXi}{\boldsymbol{\Xi}}
\newcommand{\mvomic}{\boldsymbol{\o}}\newcommand{\mvvarpi}{\boldsymbol{\varpi}}\newcommand{\mvvarrho}{\boldsymbol{\varrho}}
\newcommand{\mvvarsigma}{\boldsymbol{\varsigma}}\newcommand{\mvvarphi}{\boldsymbol{\varphi}}\newcommand{\mvpsi}{\boldsymbol{\psi}}\newcommand{\mvPsi}{\boldsymbol{\Psi}}   
%% Math Bold Symbols commands  
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Math Frakur fonts except for i which is \mfi
\newcommand{\f}[1]{\mathfrak{#1}}
\newcommand{\fA}{\mathfrak{A}}\newcommand{\fB}{\mathfrak{B}}\newcommand{\fC}{\mathfrak{C}}
\newcommand{\fD}{\mathfrak{D}}\newcommand{\fE}{\mathfrak{E}}\newcommand{\fF}{\mathfrak{F}}
\newcommand{\fG}{\mathfrak{G}}\newcommand{\fH}{\mathfrak{H}}\newcommand{\fI}{\mathfrak{I}}
\newcommand{\fJ}{\mathfrak{J}}\newcommand{\fK}{\mathfrak{K}}\newcommand{\fL}{\mathfrak{L}}
\newcommand{\fM}{\mathfrak{M}}\newcommand{\fN}{\mathfrak{N}}\newcommand{\fO}{\mathfrak{O}}
\newcommand{\fP}{\mathfrak{P}}\newcommand{\fQ}{\mathfrak{Q}}\newcommand{\fR}{\mathfrak{R}}
\newcommand{\fS}{\mathfrak{S}}\newcommand{\fT}{\mathfrak{T}}\newcommand{\fU}{\mathfrak{U}}
\newcommand{\fV}{\mathfrak{V}}\newcommand{\fW}{\mathfrak{W}}\newcommand{\fX}{\mathfrak{X}}
\newcommand{\fY}{\mathfrak{Y}}\newcommand{\fZ}{\mathfrak{Z}}\newcommand{\fa}{\mathfrak{a}}
\newcommand{\fb}{\mathfrak{b}}\newcommand{\fc}{\mathfrak{c}}\newcommand{\fd}{\mathfrak{d}}
\newcommand{\fe}{\mathfrak{e}}\newcommand{\ff}{\mathfrak{f}}\newcommand{\fg}{\mathfrak{g}}
\newcommand{\fh}{\mathfrak{h}}\newcommand{\mfi}{\mathfrak{i}}\newcommand{\fj}{\mathfrak{j}}
\newcommand{\fk}{\mathfrak{k}}\newcommand{\fl}{\mathfrak{l}}\newcommand{\fm}{\mathfrak{m}}
\newcommand{\fn}{\mathfrak{n}}\newcommand{\fo}{\mathfrak{o}}\newcommand{\fp}{\mathfrak{p}}
\newcommand{\fq}{\mathfrak{q}}\newcommand{\fr}{\mathfrak{r}}\newcommand{\fs}{\mathfrak{s}}
\newcommand{\ft}{\mathfrak{t}}\newcommand{\fu}{\mathfrak{u}}\newcommand{\fv}{\mathfrak{v}}
\newcommand{\fw}{\mathfrak{w}}\newcommand{\fx}{\mathfrak{x}}\newcommand{\fy}{\mathfrak{y}}
\newcommand{\fz}{\mathfrak{z}}  
%% End Math Frak                             
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Double capital letters
\newcommand{\bb}[1]{\mathbb{#1}}
\newcommand{\bA}{\mathbb{A}}\newcommand{\bB}{\mathbb{B}}\newcommand{\bC}{\mathbb{C}}
\newcommand{\bD}{\mathbb{D}}\newcommand{\bE}{\mathbb{E}}\newcommand{\bF}{\mathbb{F}}
\newcommand{\bG}{\mathbb{G}}\newcommand{\bH}{\mathbb{H}}\newcommand{\bI}{\mathbb{I}}
\newcommand{\bJ}{\mathbb{J}}\newcommand{\bK}{\mathbb{K}}\newcommand{\bL}{\mathbb{L}}
\newcommand{\bM}{\mathbb{M}}\newcommand{\bN}{\mathbb{N}}\newcommand{\bO}{\mathbb{O}}
\newcommand{\bP}{\mathbb{P}}\newcommand{\bQ}{\mathbb{Q}}\newcommand{\bR}{\mathbb{R}}
\newcommand{\bS}{\mathbb{S}}\newcommand{\bT}{\mathbb{T}}\newcommand{\bU}{\mathbb{U}}
\newcommand{\bV}{\mathbb{V}}\newcommand{\bW}{\mathbb{W}}\newcommand{\bX}{\mathbb{X}}
\newcommand{\bY}{\mathbb{Y}}\newcommand{\bZ}{\mathbb{Z}}        
% End Double capital letters      
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 
%% Blackboard bold
\newcommand{\dA}{\mathds{A}}\newcommand{\dB}{\mathds{B}}
\newcommand{\dC}{\mathds{C}}\newcommand{\dD}{\mathds{D}}
\newcommand{\dE}{\mathds{E}}\newcommand{\dF}{\mathds{F}}
\newcommand{\dG}{\mathds{G}}\newcommand{\dH}{\mathds{H}}
\newcommand{\dI}{\mathds{I}}\newcommand{\dJ}{\mathds{J}}
\newcommand{\dK}{\mathds{K}}\newcommand{\dL}{\mathds{L}}
\newcommand{\dM}{\mathds{M}}\newcommand{\dN}{\mathds{N}}
\newcommand{\dO}{\mathds{O}}\newcommand{\dP}{\mathds{P}}
\newcommand{\dQ}{\mathds{Q}}\newcommand{\dR}{\mathds{R}}
\newcommand{\dS}{\mathds{S}}\newcommand{\dT}{\mathds{T}}
\newcommand{\dU}{\mathds{U}}\newcommand{\dV}{\mathds{V}}
\newcommand{\dW}{\mathds{W}}\newcommand{\dX}{\mathds{X}}
\newcommand{\dY}{\mathds{Y}}\newcommand{\dZ}{\mathds{Z}} 
%% Blackboard bold
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Roman capital letters
\newcommand{\rA}{\mathrm{A}}\newcommand{\rB}{\mathrm{B}}\newcommand{\rC}{\mathrm{C}}\newcommand{\rD}{\mathrm{D}}
\newcommand{\rE}{\mathrm{E}}\newcommand{\rF}{\mathrm{F}}\newcommand{\rG}{\mathrm{G}}\newcommand{\rH}{\mathrm{H}}
\newcommand{\rI}{\mathrm{I}}\newcommand{\rJ}{\mathrm{J}}\newcommand{\rK}{\mathrm{K}}\newcommand{\rL}{\mathrm{L}}
\newcommand{\rM}{\mathrm{M}}\newcommand{\rN}{\mathrm{N}}\newcommand{\rO}{\mathrm{O}}\newcommand{\rP}{\mathrm{P}}
\newcommand{\rQ}{\mathrm{Q}}\newcommand{\rR}{\mathrm{R}}\newcommand{\rS}{\mathrm{S}}\newcommand{\rT}{\mathrm{T}}
\newcommand{\rU}{\mathrm{U}}\newcommand{\rV}{\mathrm{V}}\newcommand{\rW}{\mathrm{W}}\newcommand{\rX}{\mathrm{X}}
\newcommand{\rY}{\mathrm{Y}}\newcommand{\rZ}{\mathrm{Z}}
% End Roman capital letters   
\newcommand{\sA}{\mathscr{A}}
\newcommand{\sB}{\mathscr{B}}
\newcommand{\sC}{\mathscr{C}}
\newcommand{\sD}{\mathscr{D}}
\newcommand{\sE}{\mathscr{E}}
\newcommand{\sF}{\mathscr{F}}
\newcommand{\sG}{\mathscr{G}}
\newcommand{\sH}{\mathscr{H}}
\newcommand{\sI}{\mathscr{I}}
\newcommand{\sJ}{\mathscr{J}}
\newcommand{\sK}{\mathscr{K}}
\newcommand{\sL}{\mathscr{L}}
\newcommand{\sM}{\mathscr{M}}
\newcommand{\sN}{\mathscr{N}}
\newcommand{\sO}{\mathscr{O}}
\newcommand{\sP}{\mathscr{P}}
\newcommand{\sQ}{\mathscr{Q}}
\newcommand{\sR}{\mathscr{R}}
\newcommand{\sS}{\mathscr{S}}
\newcommand{\sT}{\mathscr{T}}
\newcommand{\sU}{\mathscr{U}}
\newcommand{\sV}{\mathscr{V}}
\newcommand{\sW}{\mathscr{W}}
\newcommand{\sX}{\mathscr{X}}
\newcommand{\sY}{\mathscr{Y}}
\newcommand{\sZ}{\mathscr{Z}}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Local Math Operator 
\DeclareMathOperator{\E}{\mathds{E}}
\DeclareMathOperator{\pr}{\mathds{P}}
\DeclareMathOperator{\sgn}{sgn}
\DeclareMathOperator{\var}{Var}
\DeclareMathOperator{\cov}{Cov}
\DeclareMathOperator{\argmax}{argmax}
\DeclareMathOperator{\argmin}{argmin}
\DeclareMathOperator{\hess}{Hess}
\DeclareMathOperator{\tr}{tr} 
\DeclareMathOperator{\sech}{sech}
\newcommand{\bluecheck}{}%
\definecolor{darkpastelgreen}{rgb}{0.01, 0.75, 0.24}
\DeclareRobustCommand{\greencheck}{%
  \tikz\fill[scale=0.4, color=darkpastelgreen]
  (0,.35) -- (.25,0) -- (1,.7) -- (.25,.15) -- cycle;%
}
%% End Local Math Operator
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

\newcommand{\XC}[1]{{{\textcolor{red}{#1}}}}

\title{GeONet: a neural operator for learning the Wasserstein geodesic}

% The standard author block has changed for UAI 2024 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1]{\href{mailto:<andrewgracyk@gmail.com>?Subject=Your UAI 2024 paper}{Andrew~Gracyk}}
\author[2]{\href{mailto:<xiaohuic@usc.edu>?Subject=Your UAI 2024 paper}{Xiaohui~Chen}}
%\author[1]{\href{mailto:<jj@example.edu>?Subject=Your UAI 2024 paper}{Jane~J.~von~O'L\'opez}{}}
% Add affiliations after the authors
\affil[1]{%
     Department of Statistics\\
     University of Illinois at Urbana-Champaign
}
\affil[2]{%
    Department of Mathematics\\
    University of Southern California
}

\begin{document}

\maketitle

\begin{abstract}
   Optimal transport (OT) offers a versatile framework to compare complex data distributions in a geometrically meaningful way. Traditional methods for computing the Wasserstein distance and geodesic between probability measures require mesh-specific domain discretization and suffer from the curse-of-dimensionality. We present \emph{GeONet}, a mesh-invariant deep neural operator network that learns the non-linear mapping from the input pair of initial and terminal distributions to the Wasserstein geodesic connecting the two endpoint distributions. In the offline training stage, GeONet learns the saddle point optimality conditions for the dynamic formulation of the OT problem in the primal and dual spaces that are characterized by a coupled PDE system. The subsequent inference stage is instantaneous and can be deployed for real-time predictions in the online learning setting. We demonstrate that GeONet achieves comparable testing accuracy to the standard OT solvers on simulation examples and the MNIST dataset with considerably reduced inference-stage computational cost by orders of magnitude.
\end{abstract}
  

\section{Introduction}

Recent years have seen tremendous progress in statistical and computational optimal transport (OT) as a lens to explore machine learning problems. One prominent example is to use the Wasserstein distance to compare data distributions in a geometrically meaningful way, which has found various applications, such as in generative models~\citep{arjovsky17a}, domain adaptation~\citep{7586038} and computational geometry~\citep{Solomon_2015}. Computing the optimal transport map (if it exists) can be expressed in a fluid dynamics formulation with the minimum kinetic energy~\citep{Benamou2000ACF}. Such a dynamical formulation defines geodesics in the Wasserstein space of probability measures, thus providing richer information for interpolating between data distributions that can be used to design efficient sampling methods from high-dimensional distributions~\citep{finlay2020train}. Moreover, learning the continuous-time dynamical Wasserstein geodesic is a practically important task in many science and engineering domains, including developmental trajectory reconstruction in cell reprogramming~\citep{Schiebinger2019}, 3D warping for shape analysis in computational geometry~\citep{7053911}, optimal control such as swarm robotics and control systems~\citep{ChenGeorgiouPavon2021,8619816,9483194}, matching supply and demand networks~\citep{10.1007/s10851-022-01121-y}, computer vision such as color transfer~\citep{10204903}, and language translation~\citep{xu-etal-2021-vocabulary}.


Traditional methods for numerically computing the Wasserstein distance and geodesic require domain discretization that is often mesh-dependent (i.e., on regular grids or triangulated domains). Classical solvers such as Hungarian method~\citep{Kuhn1955}, the auction algorithm~\citep{BertsekaCastanon1989s}, and transportation simplex~\citep{LuenbergerYe2015}, suffer from the curse-of-dimensionality and scale poorly for even moderately mesh-sized problems~\citep{KlattTamelingMunk2020,Genevay2016,Benamou2000ACF}. Entropic regularized OT~\citep{NIPS2013_Cuturi} and the Sinkhorn algorithm~\citep{Sinkhorn_1964} have been shown to efficiently approximate the OT solutions at low computational cost, handling high-dimensional distributions~\citep{BenamouCarlierCuturiNennaPerye2015}; however, high accuracy is computationally obstructed with a small regularization parameter~\citep{Altschuler_2017,pmlr-v80-dvurechensky18a}. Recently, machine learning methods to compute the Wasserstein geodesic for a \emph{given} input pair of probability measures have been considered in~\citep{LiuMaChenZhaZhou2021,LiuGongLiu_2023,pooladian2023multisample,tong2023improving}, as well as \emph{amortized} methods~\cite{LacombeDigneCourtyBonneel_2023,amos2023meta} for generating static OT maps.

\begin{figure*}
  \centering
  \includegraphics[scale=0.64]{GeONet_first_figure_2023_1.pdf}
  \vspace{-1mm}
  \caption{A geodesic at different spatial resolutions. Low-resolution inputs can be adapted into high-resolution geodesics (i.e., super-resolution) with our output mesh-invariant GeONet method.}
  \label{fig:GeONet_example}
  \vspace{-3mm}
\end{figure*}

A major challenge of using the OT-based techniques is that one needs to recompute the Wasserstein distance and geodesic for new input pair of probability measures. Thus, issues of scalability on large-scale datasets and suitability in the online learning setting are serious concerns for modern machine learning, computer graphics, and natural language processing tasks~\citep{Genevay2016,Solomon_2015,Kusner_2015ICML}. This motivates us to tackle the problem of learning the Wasserstein geodesic from an \emph{operator learning} perspective.

There is a recent line of work on learning neural operators for solving general differential equations or discovering equations from data, including DeepONet~\citep{LuJinPangZhangKarniadakis2021_DeepONet}, Fourier Neural Operators~\citep{Li_FourierNeuralOperator}, and physics-informed neural networks/operators (PINNs/PINOs)~\citep{RAISSI2019686,https://doi.org/10.48550/arxiv.2111.03794}. Those methods are mesh-independent, data-driven, and designed to accommodate specific physical laws governed by certain partial differential equations (PDEs).

{\bf Our contributions.} In this paper, we propose a deep neural operator learning framework \emph{GeONet} for the Wasserstein geodesic. Our method is based on learning the optimality conditions in the dynamic formulation of the OT problem, which is characterized by a coupled PDE system in the primal and dual spaces. Our main idea is to recast the learning problem of the Wasserstein geodesic from training data into an operator learning problem for the solution of the PDEs corresponding to the primal and dual OT dynamics. Our method can learn the highly non-linear Wasserstein geodesic operator from a wide collection of training distributions. GeONet is mesh-invariant, thus it is also suitable for zero-shot super-resolution applications on images, i.e., it is trained on lower resolution and predicts at higher resolution without seeing any higher resolution data~\citep{ZSSR}. See Figure~\ref{fig:GeONet_example} for an example of a higher-resolution Wasserstein geodesic connecting two lower-resolution Gaussian mixture distributions.

Surprisingly, the training of our GeONet does not require the true geodesic data for connecting the two endpoint distributions. Instead, it only requires the training data as boundary pairs of initial and terminal distributions. The reason that GeONet needs much less input data is because its training process is implicitly informed by the OT dynamics such that the continuity equation in the primal space and Hamilton-Jacobi equation in the dual space must be simultaneously satisfied to ensure zero duality gap. Since the geodesic data are typically difficult to obtain without resorting to some traditional numerical solvers, the \emph{amortized inference} nature of GeONet, where inference on related training pairs can be reused~\citep{Gershman2014AmortizedII}, has substantial computational advantage over standard computational OT methods and machine learning methods for computing the geodesic designed for single input pair of distributions~\citep{COTFNT,LiuMaChenZhaZhou2021}. 

\begin{table*}
  \caption{We compare our method GeONet with other methodologies, including traditional neural operators, physics-based neural networks (PINNs) for learning dynamics, and traditional optimal transport solvers. }
  
  \label{tab:comparison}
  \centering
  \begin{tabular}[H]{
    p{4.6cm}  P{3.1cm} P{1.0cm}  P{1.8cm} P{1.2cm} }%{
  %  l
  %  S[table-format = 3]
  %  S[table-format = 2]
  %  S[table-format = 1.3]
  %  S[table-format = -2.2]
  %  S[table-format = 1.3]
  %  S[table-format = 1.3]
  %  S[table-format = 2.2]
  %  }
    \toprule  
    %\cmidrule(lr){1-5}   
    \textbf{Method characteristic} & {Neural operator w/o physics-informed learning} & {PINNs} & {Traditional OT solvers} & {GeONet (Ours)} \\
    \hline
    \text{operator learning} & {\greencheck}  &   &  & {\greencheck}  \\
    \text{satisfies the associated PDEs}   & {\greencheck} & {\greencheck} &
    {}  &  {\greencheck} \\
    \text{does not require known geodesic data} &  & {\greencheck} &  {\greencheck} &  {\greencheck} \\
    \text{output mesh independence} & {\greencheck}  & {\greencheck} &  {} & {\greencheck}  \\
    \bottomrule
\end{tabular}
\end{table*}

Once GeONet training is complete, the inference stage for predicting the geodesic connecting new initial and terminal data distributions requires only a forward pass of the network, and thus it can be performed in real-time. In contrast, standard OT methods re-compute the Wasserstein distance and geodesic for each new input distribution pair. This is an appealing feature of amortized inference to use a pre-trained GeONet for fast geodesic computation or fine-tuning on a large number of future data distributions. A detailed comparison between our proposed method GeONet with other existing neural operators and networks for learning dynamics from data can be found in Table~\ref{tab:comparison}.


\section{Background}
\label{sec:background}

\subsection{Optimal transport problem: static and dynamic formulations}

The optimal mass transportation problem, first considered by the French engineer Gaspard Monge, is to find an optimal map $T^*$ for transporting a source distribution $\mu_0$ to a target distribution $\mu_1$ that minimizes some cost function $c : \bR^d \times \bR^d \to \bR$:
\begin{equation}
    \label{eqn:monge_problem}
    \min_{T : \bR^d \to \bR^d} \left\{ \int_{\bR^d} c(x, T(x)) \; \rd \mu_0(x) :  T_\sharp \mu_0 = \mu_1 \right\},
\end{equation}
where $T_\sharp \mu$ denotes the pushforward measure defined by $(T_\sharp \mu) (B) = \mu(T^{-1}(B))$ for measurable subset $B \subset \bR^d$. In this paper, we focus on the quadratic cost $c(x,y) = \|x-y\|_2^2$. The Monge problem~\eqref{eqn:monge_problem} induces a metric, known as the \emph{Wasserstein distance}, on the space $\cP_2(\bR^d)$ of probability measures on $\bR^d$ with finite second moments. In particular, the 2-Wasserstein distance can be expressed in the relaxed Kantorovich form:
\begin{equation}
    \label{eqn:kantorovich_problem}
    W_2^2(\mu_0, \mu_1) := \min_{\gamma \in \Gamma(\mu_0, \mu_1)} \left\{ \int_{\bR^d \times \bR^d} \|x - y\|_2^2 \; \rd \gamma(x,y) \right\},
\end{equation}
where minimization over $\gamma$ runs over all possible couplings $\Gamma(\mu_0, \mu_1)$ with marginal distributions $\mu_0$ and $\mu_1$. Problem~\eqref{eqn:kantorovich_problem} has the dual form (cf.~\cite{Villani2003_topics-in-ot}):
\begin{equation}
    \label{eqn:kantorovich_dual_problem}
    \begin{gathered}
        W_2^2(\mu_0, \mu_1)  = \sup_{\varphi \in L^1(\mu_0), \; \psi \in L^1(\mu_1)} \Big\{ \int_{\bR^d} \varphi \; \rd \mu_0 \\
        \ \ \ \ \ \ + \int_{\bR^d} \psi \; \rd \mu_1 : \varphi(x) + \psi(y) \leq \|x-y\|_2^2 \Big\}.
    \end{gathered}
\end{equation}
Problems~\eqref{eqn:monge_problem} and~\eqref{eqn:kantorovich_problem} are both referred to as the \emph{static OT} problems, which have a close connection to fluid dynamics. Specifically, the Benamou-Brenier dynamic formulation~\citep{Benamou2000ACF} expresses the Wasserstein distance as a minimal kinetic energy flow problem:
\begin{align}
\label{eqn:benamou-brenier_formula}
\begin{gathered}
{1\over2} W_2^2(\mu_0, \mu_1) = \min_{(\mu, \vv)} \int_0^1 \int_{\mathbb{R}^d} {1\over2} || \vv(x, t) ||_2^2 \ \mu(x, t) \ \rd x \ \rd t \\
  \mbox{subject to}  \ \  \partial_t \mu + \div(\mu \vv) = 0, \mu(\cdot, 0) = \mu_0, \mu(\cdot, 1) = \mu_1,
 \end{gathered}
\end{align}
where $\mu_t := \mu(\cdot, t)$ is the probability density flow at time $t$ satisfying the continuity equation (CE) constraint $\partial_t \mu + \div(\mu \vv) = 0$ that ensures the conservation of unit mass along the flow $\{\mu_t\}_{t \in [0,1]}$. To solve~\eqref{eqn:benamou-brenier_formula}, we apply the Lagrange multiplier method to find the saddle point in the primal and dual variables. In particular, for any flow $\mu_t$ starting from $\mu_0$ and terminating at $\mu_1$, the Lagrangian function for~\eqref{eqn:benamou-brenier_formula} can be written as
\begin{equation}
    \label{eqn:lagrangian_benamou-brenier}
    \begin{gathered}
    \cL(\mu, \vv, u) = \int_0^1 \int_{\bR^d} \left[ {1\over2} \|\vv\|_2^2 \mu + \left( \partial_t \mu + \div(\mu \vv) \right) u \right] \; \rd x \; \rd t,
    \end{gathered}
\end{equation}
where $u := u(x, t)$ is the dual variable for CE. Using integration-by-parts under suitable decay conditions for $\|x\|_2 \to \infty$, we find that the optimal dual variable $u^*$ for the dynamic OT problem satisfies the Hamilton-Jacobi (HJ) equation
\begin{equation}
    \label{eqn:HJ}
    \partial_t u + {1\over2} \|\nabla u\|_2^2 = 0,
\end{equation}
and the optimal velocity vector field is given by $\vv^*(x, t) = \nabla u^*(x, t)$. Hence, we obtained that the  Karush–Kuhn–Tucker (KKT) optimality conditions for~\eqref{eqn:benamou-brenier_formula} are solution $(\mu^*, u^*)$ to the following system of PDEs:
\begin{equation}
    \label{eqn:benamou-brenier_kkt}
    \left\{
    \begin{gathered}
      \partial_t \mu + \div(\mu \nabla u) = 0, \ \ \partial_t u + {1\over2} \|\nabla u\|_2^2 = 0, \\
     \mu(\cdot,0) = \mu_0, \ \ \mu(\cdot,1) = \mu_1.
     \end{gathered} \right.
\end{equation}

In addition, if $\psi^*$ and $\varphi^*$ are the optimal Kantorovich potentials for solving the static dual OT problem~\eqref{eqn:kantorovich_dual_problem}, then the solution to the HJ equation~\eqref{eqn:HJ} can be viewed as an interpolation $u(x,t)$ of the Kantorovich potentials between the initial and terminal distributions in the sense that $u^*(x, 1) = \psi^*(x)$ and $u^*(x, 0) = -\varphi^*(x)$ (both up to some additive constants). A detailed derivation of the primal-dual optimality conditions for the dynamical OT formulation is provided in Appendix~\ref{app:sec:primal-dual_optimality}.


\subsection{Learning neural operators}

Physics-informed neural networks (PINNs)~\citep{RAISSI2019686} aim to learn the solution of a PDE from data for a {\it given} input function $a$:
\begin{equation}
    \label{eqn:pinn}
    \partial_t u + \mathcal{D}_a[u] = 0
\end{equation}
subject to some boundary data $u_0$ and $u_T$, where ${\cal D}_a$ denotes a differential operator in space that may depend on the input function $a \in {\cal A}$. Different from the classical neural network learning paradigm that is purely data-driven, a PINN has less input data (i.e., some randomly sampled data points from the solution $u$ and the boundary conditions) since the solution operator $\Gamma^\dagger : {\cal A} \to {\cal U}$ is learned by obeying the induced physical laws governed by~\eqref{eqn:pinn}, and not from observations. Even though the PINN is mesh-independent, it only learns the solution for a {\it single} instance of the input function $a$ in the PDE~\eqref{eqn:pinn}. In order to learn the behavior of the inverse problem $\Gamma^\dagger : {\cal A} \to {\cal U}$ for an entire family of $\cal A$, we consider the operator learning perspective.

A neural operator generalizes a neural network that learns the mapping $\Gamma^\dagger : {\cal A} \to {\cal U}$ between infinite-dimensional function spaces $\cal A$ and $\cal U$~\citep{Kovachki_neuraloperator,LiKovachkiAzizzadenesheliLiuBhattacharyaStuartAnandkumar}. A notable example of operating learning is that $\cal A$ and $\cal U$ contain functions defined over a space-time domain $\Omega \times [0, T]$ with $\Omega \subset \bR^d$, and the mapping of interest $\Gamma^\dagger$ is implicitly defined through a differential operator.

The idea of using neural networks to approximate any nonlinear continuous operator stems from the universal approximation theorem for operators~\citep{392253,LuJinPangZhangKarniadakis2021_DeepONet}. In particular, we construct a parametric map by a neural network $\Gamma_\theta := \Gamma(\cdot; \theta) : \mathcal{A} \rightarrow \mathcal{U}$ for a finite-dimensional parameter $\theta \in \Theta$ to approximate the true solution operator $\Gamma^{\dag}$. In this paper, we adopt the \emph{DeepONet} architecture~\citep{LuJinPangZhangKarniadakis2021_DeepONet}, which is suitable for their ability to learn mappings from pairings of initial input data to model $\Gamma^\dagger$. In the next subsection, we briefly discuss some basics of DeepONet architecture for modeling $\Gamma^\dagger$ and its enhanced version. Then, the neural operator learning problem is to find an optimal $\theta^* \in \Theta$ as a minimizer of the classical risk minimization problem
\begin{align}
\label{eqn:PINN_formulation_1}
\begin{gathered}
    \min_{\theta \in \Theta}  \E_{(a, u_0, u_T) \sim \nu} \Big[ \big\| (\partial_{t} + \mathcal{D} ) \Gamma_\theta(a) \big\|_{L^2(\Omega \times (0,T)) }^2   \\  
  \qquad + \lambda_0 \big\|\Gamma_\theta(a)(\cdot,0) - u_0 \big\|_{L^2(\Omega)}^2 \\
  \qquad + \lambda_T \big\|\Gamma_\theta(a)(\cdot,T) - u_T \big\|_{L^2(\Omega)}^2  \Big],
\end{gathered}
\end{align}
where the input data $(a, u_0, u_T)$ are sampled from some joint distribution $\nu$. In~\eqref{eqn:PINN_formulation_1}, we minimize the PDE residual loss corresponding to $\partial_{t} u + \mathcal{D}_a[u] = 0$ while constraining the network by imposing boundary conditions. The loss function has weights $\lambda_0, \lambda_T > 0$. Given a finite set of samples $\{(a^{(i)}, u_0^{(i)}, u_T^{(i)})\}_{i=1}^n$, and data points randomly sampled in the space-time domain $\Omega \times (0, T)$, we may minimize the empirical loss analog of~\eqref{eqn:PINN_formulation_1} by replacing $\| \cdot \|_{L^2(\Omega \times (0,T))}$ with the discrete $L^2$ norm over domain $\Omega \times (0,T)$. Computation of the exact differential operators $\partial_t$ and ${\cal D}_a$ can be conveniently exploited via automatic differentiation in standard deep learning packages.

\subsection{Deep operator networks}
\label{sec:DeepONets}

The DeepONet architecture~\citep{LuJinPangZhangKarniadakis2021_DeepONet} is based on the universal approximation theorem for operators~\citep{392253}, which says a general nonlinear continuous operator  $\Gamma^{\dagger}$ may be approximated as follows:
\vspace{-2mm}
\begin{equation}
\label{eqn:DeepONet}
\Gamma^{\dagger}(u) (x,t) \approx \sum_{k=1}^{p} \mathcal{B}_{k} \big( u(x_1), \hdots, u(x_m); \theta \big) \cdot \mathcal{T}_{k} (x, t; \xi), 
\end{equation}
where $\mathcal{B}_{k}, \mathcal{T}_{k}$ are scalar elements of output of neural networks $\mathcal{B}, \mathcal{T}$, and $p$ is the number of such elements. For instance, we may take $\mathcal{B}$ and $\mathcal{T}$ as artificial neural networks parameterized by $\theta, \xi$ respectively. Networks $\mathcal{B}, \mathcal{T}$ are referred to as the {\it branch} and {\it trunk} networks, respectively. 

The unstacked DeepONet in~\eqref{eqn:DeepONet} is restricted to one input function $u$. In our problem, since we have two initial and terminal conditions, we consider an enhanced version of DeepONet~\citep{https://doi.org/10.48550/arxiv.2202.08942}, where the operator $\Gamma^{\dagger}$ is approximated using two branch networks to encode for input $u_0$ and $u_1$,
\begin{align}
\label{eqn:DeepONet_enhanced}
\begin{gathered}
    \Gamma^{\dagger}( u_0, u_1) (x,t) \approx \sum_{k=1}^{p}  \mathcal{B}^0_{k} \big( u_0(x_1), \hdots, u_0(x_m); \theta^0 \big) \\ 
 \times \mathcal{B}^1_{k} \big( u_1(x_1), \hdots, u_1(x_m); \theta^1 \big) \times \mathcal{T}_{k} (x, t; \xi).
\end{gathered}
\end{align}
In~\eqref{eqn:DeepONet_enhanced}, the operator $\Gamma^\dagger$ is applied at the functions $u_0$ and $u_1$, and then evaluated at distinct locations $x_1, \hdots, x_m$ for the branch input.


\section{Our method}
\label{sec:our_method}

\begin{figure*}[t]
  \centering
  %\includegraphics[scale=0.45]{GeONet Diagram 5.pdf}
  \includegraphics[scale=0.53]{GeONet_Diagram.pdf}
  \caption{Architecture of GeONet. The solution to CE yields the geodesic. GeONet branches and trunks output vectors of dimension $p$, in which we perform multiplication among neural network elements to produce the solutions to CE and HJ.}
  \label{fig:GeONet_diagram}
  \vspace{-3mm}
\end{figure*}

We present {\it GeONet}, a geodesic operator network for learning the 2-Wasserstein geodesic $\{\mu_t\}_{t \in [0,1]}$ connecting $\mu_0$ to $\mu_1$. Let $\Omega \subset \bR^d$ be the spatial domain where the probability measures are supported. For absolutely continuous probability measures $\mu_0, \mu_1 \in {\cal P}_2(\Omega)$, it is well-known that the constant-speed geodesic $\{\mu_t\}_{t \in [0,1]}$ between $\mu_0$ and $\mu_1$ is an absolutely continuous curve in the metric space $({\cal P}_2(\Omega), W_2)$, which we denote as $\text{AC}({\cal P}_2(\Omega))$. Moreover, the geodesic $\mu_t$ solves the kinetic energy minimization problem in~\eqref{eqn:benamou-brenier_formula}~\citep{sabtanbrogio2015_OT}. Some basic facts on the metric geometry structure of the Wasserstein geodesic and its relation to the fluid dynamic formulation are reviewed and discussed in Appendix~\ref{app:sec:wasserstein_facts}. In this work, our goal is to learn the non-linear operator
\begin{align}
    \Gamma^\dagger : {\cal P}_{2}(\Omega) \times {\cal P}_{2}(\Omega) \to & \ \text{AC}({\cal P}_2(\Omega)), \\
    (\mu_0, \mu_1) \mapsto & \ \{\mu_t\}_{t \in [0,1]},
\end{align}
based on a training dataset $\{(\mu_0^{(1)}, \mu_1^{(1)}), \hdots, (\mu_0^{(n)}, \mu_1^{(n)})\}$. The core idea of GeONet is to learn the KKT optimality condition~\eqref{eqn:benamou-brenier_kkt} for the Benamou-Brenier problem. Since~\eqref{eqn:benamou-brenier_kkt} is derived to ensure the zero duality gap between the primal and dual dynamic OT problems, solving the Wasserstein geodesic requires us to introduce two sets of neural networks that train the coupled PDEs simultaneously. Specifically, we model the operator learning problem as an enhanced version of the unstacked DeepONet architecture~\citep{LuJinPangZhangKarniadakis2021_DeepONet,https://doi.org/10.48550/arxiv.2202.08942} by jointly training three primal networks in~\eqref{eqn:Continuity_sol} and three dual networks in~\eqref{eqn:HJ_sol} as follows:
%\vspace{-3mm}
\begin{equation}
\label{eqn:Continuity_sol}
\begin{gathered}
    \mathcal{C}(\mu_0, \mu_1) (x,t; \phi)  = \sum_{k=1}^p \mathcal{B}_{k}^{0,\text{cty}} (\mu_0; \theta^{0, \text{cty}}) \\
    \times  \mathcal{B}_{k}^{1, \text{cty}} (\mu_1; \theta^{1,\text{cty}}) \times  \mathcal{T}_{k}^{\text{cty}} (x,t; \xi^{\text{cty}})
\end{gathered}
\end{equation}
and
\begin{equation}
\label{eqn:HJ_sol}
\begin{gathered}
    \mathcal{H}(\mu_0, \mu_1) (x,t; \psi)  = \sum_{k=1}^p \mathcal{B}_{k}^{0,\text{HJ}} (\mu_0; \theta^{0,\text{HJ}}) \\ 
    \times \mathcal{B}_{k}^{1,\text{HJ}} (\mu_1; \theta^{1,\text{HJ}}) \times \mathcal{T}_{k}^{\text{HJ}} (x,t; \xi^{\text{HJ}}),
\end{gathered}
\end{equation}
where $\mathcal{B}^{j,\text{cty}} (\mu_j(x_1), \dots, \mu_j(x_m); \theta^{j,\text{cty}}) : \mathbb{R}^{m} \rightarrow \mathbb{R}^p$ and $\mathcal{B}^{j,\text{HJ}} (\mu_j(x_1), \dots, \mu_j(x_m); \theta^{j,\text{HJ}}) : \mathbb{R}^{m}  \rightarrow \mathbb{R}^p$ are \emph{branch} neural networks taking $m$-discretized input of initial and terminal density values at $j = 0$ and $j = 1$ respectively, and $\mathcal{T}^{\text{cty}} (x,t; \xi^{\text{cty}}) : \mathbb{R}^d \times [0,1] \rightarrow \mathbb{R}^p$ and $\mathcal{T}^{\text{HJ}} (x,t; \xi^{\text{HJ}}) : \mathbb{R}^d \times [0,1] \rightarrow \mathbb{R}^p$ are \emph{trunk} neural networks taking spatial and temporal inputs. Here $\Theta$ and $\Xi$ are finite-dimensional parameter spaces, and $p$ is the output dimension of the branch and truck networks. Denote parameter concatenations $\phi := (\theta^{0,\text{cty}}, \theta^{1,\text{cty}}, \xi^{\text{cty}})$ and $\psi := (\theta^{0,\text{HJ}}, \theta^{1,\text{HJ}}, \xi^{\text{HJ}} )$. Then the primal operator network $\mathcal{C}_{\phi}(x,t,\mu_0, \mu_1) := \mathcal{C}(\mu_0, \mu_1)(x,t; \phi)$ for $\phi \in \Theta \times \Theta \times \Xi$ acts as an approximate solution to the CE, hence the true geodesic $\mu_t(x) = \Gamma^{\dag}(x, t, \mu_0(x), \mu_1(x))$, while the dual operator network $\mathcal{H}_{\psi}(x,t,\mu_0,\mu_1)$ for $\psi \in \Theta \times \Theta \times \Xi$ corresponds to that of the associated HJ equation. The overall architecture of GeONet is shown in Figure~\ref{fig:GeONet_diagram}.

In our GeONet implementation, we adopt a modified multi-layer perceptron (MLP) architecture, which has been shown to have great ability in improving performance for physics-informed DeepONets~\citep{doi:10.1126/sciadv.abi8605}. We shall elaborate on this architecture in Appendix~\ref{Modified_mlp} and describe our empirical findings with this modified MLP for GeONet in section~\ref{subsec:gaussian_mixtures}.

%{\bf Fourier feature architecture.} An additional augmented architecture is that of the Fourier feature, useful for input data exhibiting fine features, i.e., a lack of spatial differentiability. Spatial-temporal data is transformed using a Fourier mapping. We illustrate in this architecture in~\ref{Fourier_feature}.


\begin{algorithm*}[t]
\caption{End-to-end training of GeONet}\label{alg:cap}
\textbf{Input:} data pairs $(\mu_0^{(1)} \mu_1^{(1)}), \hdots, (\mu_0^{(n)}, \mu_1^{(n)})$; batch size $N$;  initialization of the neural network parameters $\phi, \psi \in \Theta \times \Theta \times \Xi$; weight parameters $\alpha_1, \alpha_2, \beta_0, \beta_1$; domain $\Omega$ and branch domain (mesh) $\tilde{\Omega}$.; denote $i \in \{1,\hdots,N\}$.
%\textbf{Initialize:} vectors $X$ and $T$.
\begin{algorithmic}[1]
\While{$\mathcal{L}_{\text{total}}$ has not converged}
\State Independently draw $N$ sample points from $(x_{\Omega}^i, t^i) \in U(\Omega) \times U(0,1)$, $N$ points from $x_{\tilde{\Omega}}^i \in U(\tilde{\Omega})$, and $N$ density pairs from $\{(\mu_0^{(\ell)},\mu_1^{(\ell)})\}_{\ell=1}^n$, possibly repeating.
\State Compute $\mathcal{R}_{\text{cty},i} =  \partial_{t} \mathcal{C}_{\phi,i} + \div(\mathcal{C}_{\phi,i} \nabla \mathcal{H}_{\psi,i} )   $ at $(x_{\Omega}^i,t^i)$.
\Comment{\texttt{continuity residual}}
\State Compute $\mathcal{R}_{\text{HJ},i} =    \partial_{t} \mathcal{H}_{\psi,i} + {1\over2}\| \nabla \mathcal{H}_{\psi,i} \|_2^2  $ at $(x_{\Omega}^i,t^i)$. \Comment{\texttt{HJ residual}}
\State Compute $B_{0,i} =    \mathcal{C}_{\phi,0,i}  - \mu_0^{(i)}(x_{\tilde{\Omega}}^i), \ \ B_{1,i} = \mathcal{C}_{\phi,1,i} - \mu_1^{(i)}(x_{\tilde{\Omega}}^i)    $. \Comment{\texttt{boundary residual}}
\State Compute 
\[
\begin{gathered}
\mathcal{L}_{\text{cty}} =  \frac{\alpha_1}{N} \sum_{i=1}^N \mathcal{R}_{\text{cty},i}^2, \ \ \ \  \mathcal{L}_{\text{HJ}} = \frac{\alpha_2}{N} \sum_{i=1}^N  \mathcal{R}_{\text{HJ},i}^2 , \\
\mathcal{L}_{\text{BC}} = \frac{1}{N}  \sum_{i=1}^N (  \beta_0  B_{0,i}^2 + \beta_1  B_{1,i}^2 ),
\end{gathered}
\]
\State Compute $\mathcal{L}_{\text{total}}(\phi, \psi) = \mathcal{L}_{\text{cty}} + \mathcal{L}_{\text{HJ}} + \mathcal{L}_{\text{BC}}$.
\State Minimize $\mathcal{L}_{\text{total}}(\phi, \psi)$ to update $\phi$ and $\psi$. \Comment{\texttt{minimize the loss function}}
\EndWhile
\end{algorithmic}
\end{algorithm*}

To train the GeONet defined in~\eqref{eqn:Continuity_sol} and~\eqref{eqn:HJ_sol}, we minimize the empirical loss function corresponding to the system of primal-dual PDEs and boundary residuals in~\eqref{eqn:benamou-brenier_kkt} over the parameter space $\Theta \times \Theta \times \Xi$:
\begin{equation}
\label{eqn:GeONet_loss}
\phi^*, \psi^* \ \ \ = \ \ \ \argmin_{\phi, \psi \in \Theta \times \Theta \times \Xi} \ \ \    \mathcal{L}_{\text{cty}} + \mathcal{L}_{\text{HJ}} + \mathcal{L}_{\text{BC}},  
\end{equation}
where 
$\mathcal{L}_{\text{cty}}$ is the loss component in which the CE is satisfied in~\eqref{eqn:GeONet_loss_CE} and $\mathcal{L}_{\text{HJ}}$ is the HJ loss component in~\eqref{eqn:GeONet_loss_HJ}, while boundary conditions are incorporated in the $\mathcal{L}_{\text{BC}}$ term in~\eqref{eqn:GeONet_loss_BC}. Automatic differentiation of our GeONet involves differentiating the coupled DeepONet architecture (cf. Figure~\ref{fig:GeONet_diagram}) to compute the physics-informed loss terms.

Our loss function involves weight parameters $\alpha_1, \alpha_2, \beta_0, \beta_1$ to impose the physics-informed loss strength. Our coefficient tuning in the loss function is motivated and follows the general strategy outlined in~\citep{doi:10.1126/sciadv.abi8605}, where coefficients are tuned by examining errors and altered in an iterative procedure in which error is minimized. Boundary conditions are enforced to a greater extent, as precision with these affects precision in the physics loss.

We now illustrate our training procedure. The physics training is done via a \emph{collocation} procedure, following~\citep{RAISSI2019686}. We randomly sample $N$ pairs $(x,t)$ uniformly within $\Omega \times [0,1]$, where the CE and HJ expectation terms~\eqref{eqn:GeONet_loss_CE} and~\eqref{eqn:GeONet_loss_HJ} in the loss function are approximated via a discrete empirical average. For the boundary terms~\eqref{eqn:GeONet_loss_BC}, we evaluate $x$ among fixed locations with $\Omega$, typically a hypercube mesh, since these are where known boundary data is given, in which the neural operator is subsequently formulated and evaluated.
\begin{strip}
\begin{align}
\label{eqn:GeONet_loss_CE}
 \mathcal{L}_{\text{cty}} & =   \alpha_1 \E_{(\mu_0,\mu_1) \sim ( \mathcal{P}_2(\Omega), \mathcal{P}_2(\Omega))} \Big[ ||  \frac{ \partial}{\partial t} \mathcal{C}_{\phi} (x, t)  + \div ( \mathcal{C}_{\phi} (x, t) \nabla \mathcal{H}_{\psi } (x, t) )  ||_{L^2(\Omega \times (0,1))}^2 \Big],  \\
\label{eqn:GeONet_loss_HJ}
\mathcal{L}_{\text{HJ}} & =  \alpha_2 \E_{(\mu_0,\mu_1) \sim ( \mathcal{P}_2(\Omega), \mathcal{P}_2(\Omega))} \Big[ ||  \frac{\partial}{\partial t} \mathcal{H}_{\psi}(x, t)   + \frac{1}{2} || \nabla \mathcal{H}_{\psi}(x, t) ||_2^2  ||_{L^2(\Omega \times (0,1))}^2 \Big], \\ 
\label{eqn:GeONet_loss_BC}
\mathcal{L}_{\text{BC}} & =  \beta_0 \E_{(\mu_0) \sim ( \mathcal{P}_2(\Omega))} \Big[ || \mathcal{C}_{\phi}(x, 0)  - \mu_{0} ||_{L^2(\Omega)}^2 \Big] + \beta_1 \E_{(\mu_1) \sim (\mathcal{P}_2(\Omega))} \Big[ || \mathcal{C}_{\phi}(x, 1)  - \mu_{1}  ||_{L^2(\Omega)}^2 \Big].
\end{align}
\end{strip}


{\bf Entropic regularization.} Our GeONet is compatible with entropic regularization, which is related to the Schr\"odinger bridge problem and stochastic control~\citep{ChenGeorgiouPavon2016}. In Appendix~\ref{app:sec:Entropic_regularization}, we propose the \emph{entropic-regularized GeONet} (ER-GeONet), which learns a similar system of KKT conditions for the optimization as in~\eqref{eqn:benamou-brenier_kkt}. In the zero-noise limit as the entropic regularization parameter $\varepsilon \downarrow 0$, the solution of the optimal entropic interpolating flow converges to solution of the Benamou-Brenier problem~\eqref{eqn:benamou-brenier_formula} in the sense of the method of vanishing viscosity~\citep{Mikami2004,evans2010}. On one hand, adding a small entropy term (Laplacian) ensures the unique viscosity solution for the regularized HJ equation is smooth and benefits training. On the other hand, similarly as in the static OT problem, adding Laplacian approximates the OT flow (i.e., the Wasserstein geodesic is not solved exactly).






\section{Numeric experiments}
\label{sec:experiments}

In this section, we perform simulation studies and a real-data example to demonstrate GeONet. Our code is publicly available at: \url{https://github.com/agracyk2/GeONet}.

\textbf{Error metric.} We use the $L^1$ error $\int_{\Omega} | \mathcal{C} - \mu | \rd x$ as our error metric to assess the performance, where $\mu := \mu(x, t)$ is a reference geodesic as proxy of the true geodesic without entropic regularization. The $L^1$ error integral is estimated by evaluating a discrete Riemann sum along a mesh and the reference is computed using the POT Python library~\citep{Solomon_2015,flamary2021pot}. Since $\int_{\Omega} | \mu| \rd x = 1$ for all time points, the $L^1$ error is relative, thus a meaningful metric essentially corresponding to the percentage error between the neural operator geodesic and the reference. We also consider the $L^2$ and Wasserstein error metric for predicted Wasserstein geodesics (see Appendix \ref{app:error_metrics}). 

\begin{figure*}[!t]
  \centering
  %\includegraphics[scale=0.55]{Gaussians 8.pdf}
  \includegraphics[scale=0.53]{GeONet_univariate_samples_2.pdf}
  \vspace{0mm}
  \caption{Four geodesics predicted by GeONet with reference geodesics computed by POT on test univariate Gaussian mixture distribution pairs with $k_0 = k_1 = 6$. The reference serves as a close approximation to the true geodesic. The vertical axis is space and the horizontal axis is time.}
  \vspace{0mm}
  \label{fig:GeONet_gaussian_mixture}
\end{figure*}

\begin{figure*}[t]
  \centering
  %\includegraphics[scale=0.55]{Gaussians 8.pdf}
  \includegraphics[scale=0.625]{GeONet_bivariate_cameraready2.pdf}
  \caption{Geodesics predicted by GeONet on bivariate Gaussians over a square domain. The top of each pair is the reference solution computed by POT, and the bottom is GeONet. }
  \label{fig:GeONet_gaussian_mixture_bivariate}
\end{figure*}


\subsection{Input as continuous density: Gaussian mixture distributions}
\label{subsec:gaussian_mixtures}


Since finite mixture distributions are powerful universal approximators for continuous probability density functions~\citep{doi:10.1080/25742558.2020.1750861}, we first deploy GeONet on Gaussian mixture distributions over domains of varying dimensions. We learn the Wasserstein geodesic mapping between two distributions of the form $\mu_j(x) = \sum_{i=1}^{k_j} \pi_i \mathcal{N}(x | u_i , \Sigma_i)$  \text{subject to} $ \sum_{i=1}^{k_j} \pi_i = 1$,
where $j \in \{0,1\}$ corresponds to initial and terminal distributions $\mu_0, \mu_1$, and $k_j$ denotes the number of components in the mixture. Here $u_i$ and $\Sigma_i$ are the mean vectors and covariance matrices of individual Gaussian components respectively. Due to the space limit, we defer simulation setups, model training details, and error metrics to Appendices~\ref{Hyperparameter_settings},~\ref{Training and performance} and~\ref{app:error_metrics}, respectively.


We examine errors in regard to an identity geodesic (i.e., $\mu_0 = \mu_1$), a random test pairing, and an out-of-distribution (OOD) pairing. The mesh-invariant nature of the output of GeONet allows zero-shot super-resolution for adapting low-resolution data into high-resolution geodesics, which includes initial data at $t=0,1$. Traditional OT solvers and non-operator learning based methods have no ability to do this, as they are confined to the original mesh. Thus, we also include a random test pairing on higher resolution than training data. The result is reported in Table~\ref{tab:GeONet_gaussian_mixture}.

\textbf{Univariate Gaussians.} We choose spatial domain $x \in \Omega = [0,10]$ discretized into a $100$-point mesh. We generate $20,000$ training pairs $(\mu_0, \mu_1)$ of Gaussians, taking $k_j = 6$ for the number of Gaussians in each mixture. We take means $\mu_i \in [2,8]$ and variances $\Sigma_i \in [0.5,0.6]$ uniformly. Empirically, we found a large batch size more suitable for training than a low one, so we take a batch size of $2,000$, meaning these many uniform collocation points are taken for both the PDE residuals and boundary points for each training iteration. We choose physical loss coefficient $\alpha_1 = 0.5, \alpha_2 = 0.25$, with boundary coefficients $\beta_0 = \beta_1 = 1$. We found these coefficients a good balance to enforce the physical constraint without sacrificing boundary restrictions after iterating these coefficients among $[0.05,20]$ and examining the error. Additional training details are given in Appendix ~\ref{Hyperparameter_settings}.


\textbf{Bivariate Gaussians.} In our experiment, domain $\Omega = [0,5] \times [0,5] \subseteq \mathbb{R}^2$ was chosen, which was discretized into a $24 \times 24$ grid for GeONet input, meaning the branch networks took vector input of $576$ in length for each in a non-convolutional architecture, but a convolutional architecture is also suitable in higher-dimensional cases as we see in Figure ~\ref{fig:GeONet_3d_gaussians}. We generate $5,000$ training pairs $(\mu_0, \mu_1)$. Recall that GeONet is mesh-invariant, so the $24 \times 24$ grids can be adapted to any higher resolution, which is used in Figure~\ref{fig:GeONet_gaussian_mixture_bivariate}. We use a combination of low and high variance Gaussians in the mixture, 6 of which had variance in $[0.35,0.4]$ and 6 in $[0.75, 0.9]$, giving a total of 12 Gaussians in each mixture in each pair. Covariances were in $[-0.1,0.1]$.
Additional training details are given in Appendix ~\ref{Hyperparameter_settings}.

\begin{figure*}[h]
\vspace{0mm}
  \centering
  \includegraphics[scale=0.72]{GeONet_MNIST_samples_2.pdf}
  \caption{Beginning from the top left and going clockwise, we display the initial conditions in the encoded space, the geodesics in the encoded space, and the decoded geodesics as $28 \times 28$ images.}
  \label{fig:GeONet_MNIST}
\end{figure*}

\textbf{Training.} To compute the DeepONet derivatives, we take the inner product in the enhanced DeepOnet as in equations~\eqref{eqn:Continuity_sol},~\eqref{eqn:HJ_sol}, and subsequently use automatic differentiation after the inner products are taken. Alternatively, we experimented by computing a Hessian for the second-order derivatives, but this is costly in terms of memory, meaning a large batch size cannot be used without a monumental memory cost, and so this method of differentiation is not viable. 

We found that given sufficient data the GeONet with larger output dimensions slightly outperforms it with lower dimensions output. In the univariate Gaussian experiment, we take $p=800$, which outperformed $p=200$ by reducing training loss from approximately $2.5\times10^{-4}$ to $1.5\times10^{-4}$ and reducing test error by about $1\%$. In the bivariate experiment, changing $p=400$ to $p=800$ reduced training loss from approximately $2.1\times10^{-5}$ to $1.8\times10^{-5}$. 

Architecture generally made some difference to training loss, but not significant, making a width of around $100$-$200$ suitable for branches and trunks. For example, increasing branch width in the univariate experiment from $100$ to $150$ lowered training loss by approximately $4\times10^{-5}$. Increasing branch width to $200$ and trunk width to $150$ from $150$ and $100$ respectively had minimal effect, lowering training loss by about $1\times10^{-5}$. We found the modified MLP architecture preferable, lowering final training loss from approximately $3\times10^{-4}$ with standard architecture for univariate Gaussians.




\begin{table*}[t]
  \caption{$L^1$ error of GeONet on 50 test data of univariate and bivariate Gaussian mixtures. We compute errors on cases of the identity geodesic, a random pairing in which $\mu_0 \neq \mu_1$, high-resolution random pairings refined to $200$ and $75 \times 75$ resolutions in the 1D and 2D cases respectively, and out-of-distribution examples. We report the means and standard deviations as a percentage, making all values multiplied by $10^{-2}$ by those of the table.\vspace{0mm}}
  % In the second part, we train and test upon $k_0 = k_1 = 5$, $\pi_i = 0.2$ for all $i$, with the same loss coefficients. 
  \label{tab:GeONet_gaussian_mixture}
  \centering
  \begin{tabular}[b]{
    l
    S[table-format = 3]
    S[table-format = 2]
    S[table-format = 1.3]
    S[table-format = -2.2]
    S[table-format = 1.3]
    S[table-format = 1.3]
    S[table-format = 2.2]
    }
    \toprule
    \multicolumn{1}{c}{} & 
    \multicolumn{5}{c}{GeONet $L^1$ error for Gaussian mixtures}\\
    \cmidrule(lr){2-6}        
    \textbf{Experiment \ \ } & {$\bm{ t=0 }$} & {$\bm{  t=0.25  }$} & {$\bm{ t=0.5  }$} & {$\bm{ t=0.75  }$} & {$\bm{ t =1 }$}  \\
    \midrule
    \text{1D identity} & 
    {$2.67 \pm 0.750$} & 
    {$2.85 \pm 0.912$} &
    {$3.04 \pm 1.02$} & 
    {$2.86 \pm 0.898$} &
    {$2.63 \pm 0.696$} \\
    \text{1D random}  &
    {$4.92 \pm 2.00$}  & 
    {$5.43 \pm 3.02$}  & 
    {$5.76 \pm 3.56$} & 
    {$5.26 \pm 3.25$} & 
    {$4.65 \pm 1.50$} \\
    \text{1D high-res.}  &
    {$4.76 \pm 1.53$}  & 
    {$5.49 \pm 3.00$}  & 
    {$6.01 \pm 3.53$} & 
    {$5.59 \pm 2.99$} & 
    {$4.77 \pm 1.49$} \\
    \text{1D OOD} & 
    {$14.1 \pm 4.34$} &
    {$18.8 \pm 5.96$} &  
    {$22.2 \pm 7.32$} &
    {$19.2 \pm 6.14$} &
    {$13.8 \pm 4.68$} \\
    \midrule
    \text{2D identity} & 
    {$6.50 \pm 1.15$} & 
    {$7.68 \pm 0.915$} &
    {$7.69 \pm 0.924$} & 
    {$7.70 \pm 0.889$} &
    {$6.42 \pm 1.11$} \\
    \text{2D random}  &
    {$6.59 \pm 1.01$} & 
    {$7.10 \pm 0.869$} &
    {$7.13 \pm 0.892$} & 
    {$7.04 \pm 0.780$} &
    {$6.33 \pm 0.835$} \\
    \text{2D high-res.}  &
    {$6.66 \pm 0.766$} & 
    {$7.71 \pm 1.26$} &
    {$7.88 \pm 1.21$} & 
    {$7.59 \pm 0.979$} &
    {$6.29 \pm 0.723$} \\
    \text{2D OOD} & 
    {$10.2 \pm 1.18$} &
    {$9.82 \pm 1.12$} &  
    {$9.98 \pm 1.23$} &
    {$9.67 \pm 1.03$} &
    {$9.92 \pm 0.944$} \\
    
    %\hline
    %\text{Identity $k_0 = k_1 = 5$} & 
    %{$0.23 \pm 0.23$} & 
    %{$1.7 \pm 0.67$} &
    %{$2.6 \pm 1.1$} & 
    %{$1.8 \pm 0.74$} &
    %{$0.22 \pm 0.21$} \\
    %\text{Generic $k_0 = k_1 = 5$} & 
    %{$0.22 \pm 0.15$} & 
    %{$1.7 \pm 0.92$} &
    %{$2.7 \pm 1.5$} & 
    %{$1.7 \pm 0.76$} &
    %{$0.20 \pm 0.13$} \\
    \bottomrule
\end{tabular}
\vspace{0mm}
\end{table*}



\begin{figure*}[h!]
  \centering
  \includegraphics[scale=0.7]{GeONet_runtime_comparison.pdf}
  \caption{We compare GeONet to the classical POT library on 1D and 2D Gaussians in terms of mean and standard deviations of runtime on an unmodified scale as well as one that is log-log using discretization length in one dimension as the x-axis, taken over 30 pairs. We use 20-time steps for 1D and 5 for 2D. Finer meshes are omitted for 2D for computational reasonableness. }
  \label{fig:GeONet_runtime}
\end{figure*}


\subsection{Input as point clouds: Gaussian mixture distributions}
\label{empirical_Gaussians}

\begin{figure}[h]
  \centering
  %\includegraphics[scale=0.55]{Gaussians 8.pdf}
  % \includegraphics[scale=0.66]{GeONet_MNIST_samples_3.pdf}
  \includegraphics[scale=0.55]{GeONet_3gauss_to_3gauss_cropped.pdf}
  \vspace{0mm}
  \caption{We compare to GeONet to the alternative methodology in a discrete setting, using POT as ground truth. GeONet is the only method among the comparison which captures the geodesic behavior among the translocation of points. \vspace{-2mm}}
  \label{fig:GeONet_pointclouds}
\end{figure}


GeONet can be applied to continuous densities made discrete. In scenarios with access to point clouds of data, we may use GeONet with discrete data made into empirical distributions. We test GeONet on an example of a Gaussian setup. We fix an initial and terminal distribution and sample discrete particles in $\Omega \subseteq \mathbb{R}^2$, as encompassed in ~\citep{LiuGongLiu_2023}. The sampled particles are represented by empirical densities, in which we compare upon the transition of densities in the non-particle setting using POT as a baseline~\citep{flamary2021pot}. The result is reported in Table~\ref{tab:GeONet_gaussian_mixture_discrete} and an estimated geodesic example is shown in Figure~\ref{fig:GeONet_pointclouds}. We observe that conditional flow matching (CFM)~\citep{tong2023improving} and rectified flow (RF)~\citep{LiuGongLiu_2023} have 3-4 times comparably larger estimation errors than GeONet, except for the initial time $t = 0$, because this initial data is given and learned directly for RF and CFM. GeONet is the only framework among the comparison which captures the geodesic behavior to a considerable degree; however, we remark GeONet tends to smooth, or regularize, the solutions. Second, RF and CFM have the same fixed resolution as the input probability distribution pairing, while GeONet can estimate the density flows on higher resolution than the input pairing (cf. the third row in Figure~\ref{fig:GeONet_pointclouds}).

%\setcitestyle{numbers}

\begin{table*}[h]
  \caption{$L^1$ error between GeONet, the conditional flow matching (CFM) library's optimal transport solver~\cite{tong2023improving}, and rectified flow (RF)~\cite{LiuGongLiu_2023}, using POT again as a baseline for comparison. All values are multiplied by $10^{-2}$ to those of the table.\vspace{0mm}}
  % In the second part, we train and test upon $k_0 = k_1 = 5$, $\pi_i = 0.2$ for all $i$, with the same loss coefficients. 
  \label{tab:GeONet_gaussian_mixture_discrete}
  \centering
  \begin{tabular}[b]{
    l
    S[table-format = 3]
    S[table-format = 2]
    S[table-format = 1.3]
    S[table-format = -2.2]
    S[table-format = 1.3]
    S[table-format = 1.3]
    S[table-format = 2.2]
    }
    \toprule
    \multicolumn{1}{c}{} & 
    \multicolumn{5}{c}{$L^1$ comparison error on 2D Gaussian mixture point clouds}\\
    \cmidrule(lr){2-6}        
    \textbf{Experiment \ \ } & {$\bm{ t=0 }$} & {$\bm{  t=0.25  }$} & {$\bm{ t=0.5  }$} & {$\bm{ t=0.75  }$} & {$\bm{ t =1 }$}  \\
    \midrule
    \text{GeONet}  &
    {$22.9 \pm 1.08$}  & 
    {$28.8 \pm 1.01$}  & 
    {$30.0 \pm 1.10$} & 
    {$29.6 \pm 0.877$} & 
    {$22.6 \pm 1.02$} \\
    \text{CFM}  &
    {$0.0 \pm 0.0$} & 
    {$94.1 \pm 3.68$} &
    {$98.9 \pm 2.41$} & 
    {$91.8 \pm 4.15$} &
    {$75.9 \pm 3.77$} \\
    \text{RF}  &
    {$0.0 \pm 0.0$} & 
    {$103 \pm 2.48$} &
    {$112 \pm 3.61$} & 
    {$112 \pm 5.03$} &
    {$91.3 \pm 3.79$} \\
    
    %\hline
    %\text{Identity $k_0 = k_1 = 5$} & 
    %{$0.23 \pm 0.23$} & 
    %{$1.7 \pm 0.67$} &
    %{$2.6 \pm 1.1$} & 
    %{$1.8 \pm 0.74$} &
    %{$0.22 \pm 0.21$} \\
    %\text{Generic $k_0 = k_1 = 5$} & 
    %{$0.22 \pm 0.15$} & 
    %{$1.7 \pm 0.92$} &
    %{$2.7 \pm 1.5$} & 
    %{$1.7 \pm 0.76$} &
    %{$0.20 \pm 0.13$} \\
    \bottomrule
\end{tabular}
%\vspace{-6mm}
\end{table*}


\setcitestyle{names}








\subsection{A real data application}
\label{eqn:MNIST}







Our next experiment was on the MNIST dataset of $28 \times 28$ images of single-digit numbers. It is difficult for GeONet to capture the geodesics between digits: MNIST resembles jump-discontinuous data, and is relatively piecewise constant otherwise, which is troublesome for physics-informed learning. To remedy our problems with MNIST, we use a pre-trained autoencoder to encode the MNIST digits into a low-dimensional representation $v \in \mathbb{R}^{32}$ with an encoder $\Phi$ and a decoder $\Phi^{-1} : v \rightarrow \mathbb{R}^{28} \times \mathbb{R}^{28}$ mapping the encoded representation into newly-formed digits resembling that which was fed into the encoder. The encoded data is made nonnegative via shifting upwards by a constant (we choose 10), and normalized over the domain to satisfy the density condition. This prepares the encoded data for GeONet input. We employ GeONet upon the encoded representations, learning the geodesic between highly irregular encoded data. The data can be decoded by unnormalizing and shifting downwards by the arbitrary constant. For normalization constants at $t\neq 0,1$, we use interpolation between the constants at $t=0,1$.

Table~\ref{tab:GeONet_CIFAR-10} reports the $L^1$ errors for geodesic estimated in the encoded space and recovered images in the ambient space. As expected, the ambient-space error is much larger than the encoded-space error, meaning that the geodesics in the encoded space and ambient image space do not coincide. Figure~\ref{fig:GeONet_MNIST} shows the learned geodesics in the encoded space and decoded images on the geodesics. 






\subsection{Runtime comparison}





Our method is highlighted by the fact that it is almost instantaneous: it is highly suitable when many geodesics are needed quickly, or over fine meshes. Traditional optimal transport solvers are greatly encumbered when evaluated over a fine grid, but the mesh-invariant output nature of GeONet bypasses this. In Figure~\ref{fig:GeONet_runtime}, we illustrate GeONet versus POT, a traditional OT library. GeONet greatly outperforms POT for fine grids in terms of runtime, especially if POT is used to compute an accurate solution. Even when POT is used with equivalent accuracy, GeONet still outperforms, most illustrated in the log-log plot. The log-log plot also demonstrates that our method speeds computation up to orders of magnitude. We restrict the accuracy of POT by employing a stopping threshold of $0.5$ for 1D and $10.0$ for 2D. We found these choices were comparable to GeONet, remarking a threshold of $10.0$ in the 2D case is sufficiently large so that even larger thresholds have limited effect on error.




\subsection{Out-of-distribution generalization}


We discuss GeONet on out-of-distribution data in the test setting upon Gaussian mixture data. Our error results are provided in Table~\ref{tab:GeONet_gaussian_mixture}. For univariate Gaussians, we choose means in $[1,9]$, which was expanded from the domain $[2,8]$ in training. This increased relative error by about $10 \%$. Variances were in $[0.3,0.4]$. A 100-point mesh is used for evaluations with POT regularization parameter $\epsilon = 6\times10^{-4}$. For 2D Gaussians, we test on 16 mixture components (training has 12). Means were in $[0.6,4.4]\times[0.6,4.4]$, which was expanded from $[0.8,4.2]\times[0.8,4.2]$ in training. There were 8 components in the mixture with variance in $[0.25,0.3]$ and the other 8 in $[0.65,0.8]$, which have lower variances than those in training. Covariances are within $[-0.15,0.15]$ for off-diagonal components in each covariance matrix. Evaluations were over a $24 \times 24$ mesh, the same used as neural operator input.





\subsection{Limitations}



There are several limitations we would like to discuss. First, GeONet's branch network input exponentially increases in spatial dimension, necessitating extensive input data even in moderately high-dimensional scenarios. One strategy to mitigate this is through leveraging low-dimensional data representations as in the MNIST experiment. GeONet is near instantaneous for any dimension, but its dimension-based restrictions to perform are mostly hindered by the ability to handle neural network input in the branches. Second, GeONet mandates predetermined evaluation points for branch input, a requisite grounded in the pairing of initial conditions. It is of interest to extend GeONet to include training input data pairs on different resolutions. Third, given the regularity of the OT problem~\citep{hutter2021minimax,Caffarelli_1996}, developing a generalization error bound for assessing the predictive risk of GeONet is an important future work. Finally, the dynamical OT problem is closely connected to the mean-field planning with an extra interaction term~\citep{FU2023112346}. Extending the current operator learning perspective to such problems would be interesting.


%\subsection{Limitations}
%We discuss limitations in Appendix~\ref{Training and performance}.




\begin{acknowledgements}
    Andrew Gracyk was supported by the NSF under grant No. 1922758. Xiaohui Chen was partially supported by NSF CAREER grant DMS-2347760, NSF grant DMS-2413404, and a gift from the Simons Foundation.
\end{acknowledgements}







\clearpage
\bibliographystyle{plainnat}
%\bibliographystyle{iclr2024_conference}
\bibliography{neural_operator}


\onecolumn

\title{GeONet: a neural operator for learning the Wasserstein geodesic (Supplementary Material)}
\maketitle 

\appendix


\section{Training algorithm}



\section{Derivation of primal-dual optimality conditions for dynamical OT problem}
\label{app:sec:primal-dual_optimality}

The primal-dual analysis is a standard technique in the optimization literature such as in analyzing certain semidefinite programs~\citep{9366690}. Recall the Benamou-Brenier fluid dynamics formulation of the static optimal transport problem
\begin{align}
\label{app:eqn:min_kinetic_energy}
& \min_{(\mu, \vv)}  \int_0^1 \int_{\mathbb{R}^d} {1\over2} || \vv(x, t) ||_2^2 \ \mu(x, t) \ \rd x \ \rd t \\
\label{app:eqn:continuity_equation} 
 & \mbox{subject to}  \ \  \partial_t \mu + \div(\mu \vv) = 0,\\
 \label{app:eqn:boundary_condition}
 &   \mu(\cdot, 0) = \mu_0, \ \  \mu(\cdot, 1) = \mu_1.
\end{align}
Here, equation~\eqref{app:eqn:continuity_equation} is referred to as the \emph{CE} (CE), preserving the unit mass of the density flow $\mu_t = \mu(\cdot, t)$. We write the Lagrangian function for any flow $(\mu_t)_{t \in [0, 1]}$ initializing from $\mu_0$ and terminating at $\mu_1$ as
\begin{equation}
    \label{app:eqn:lagrangian_benamou-brenier}
    \begin{gathered}
    L(\mu, \vv, u) = \int_0^1 \int_{\bR^d} \left[ {1\over2} \|\vv\|_2^2 \mu + \left( \partial_t \mu + \div(\mu \vv) \right) u \right] \; \rd x \; \rd t,
    \end{gathered}
\end{equation}
where $u := u(x, t)$ is the dual variable for (CE). To find the optimal solution $\mu^*$ for the minimum kinetic energy~\eqref{app:eqn:min_kinetic_energy}, we study the saddle point optimization problem
\begin{equation}
    \label{app:eqn:lagrangian_benamou-brenier_saddle_point}
    \min_{(\mu, \vv) \in \text{(CE)}} \max_u L(\mu, \vv, u),
\end{equation}
where the minimization over $(\mu, \vv)$ runs over all flows satisfying (CE) such that $\mu(\cdot, 0) = \mu_0$ and $\mu(\cdot, 1) = \mu_1$. Note that if $\mu \notin \text{(CE)}$, then by scaling with arbitrarily large constant, we see that
\begin{equation}
    \max_u \int_0^1 \int_{\bR^d} \left( \partial_t \mu + \div(\mu \vv) \right) u \; \rd x \; \rd t = + \infty.
\end{equation}
Thus,
\begin{align}
    \min_{(\mu, \vv) \in \text{(CE)}} \int_0^1 \int_{\mathbb{R}^d} {1\over2} || \vv ||_2^2 \mu \ dx \ dt = & \min_{(\mu, \vv)} \max_u L(\mu, \vv, u) \\ \geq & \max_u \min_{(\mu, \vv)} L(\mu, \vv, u),
\end{align}
where the minimization over $(\mu, \vv)$ is unconstrained.  Using integration-by-parts and suitable decay for vanishing boundary as $\|x\|_2 \to \infty$, we have
\begin{align*}
    L(\mu, \vv, u) = & \int_0^1 \int_{\bR^d} \left[ {1\over2} \|\vv\|_2^2 \mu - \mu \partial_t u - \langle \vv, \nabla u \rangle \mu \right] \; \rd x \; \rd t \\
    & \qquad + \int_{\bR^d} \left[ \mu(\cdot,1) u(\cdot,1) - \mu(\cdot,0) u(\cdot,0) \right] \; \rd x.
\end{align*}
Now, we fix $\mu$ and $u$, and minimize $L(\mu, \vv, u)$ over $\vv$. The optimal velocity vector is $\vv^* = \nabla u$, and we have
\begin{equation}
    \max_u \min_{\mu} L(\mu, \vv^*, u) = \int_0^1 \int_{\bR^d} \left[ -\left( {1\over2} \|\nabla u\|_2^2 + \partial_t u \right) \mu \right] \; \rd x \; \rd t  + \int_{\bR^d} \left[ u(\cdot,1) \mu_1 -  u(\cdot,0) \mu_0 \right] \; \rd x,
\end{equation}
for any flow $\mu_t$ satisfying the boundary conditions $\mu(\cdot, 0) = \mu_0$ and $\mu(\cdot, 1) = \mu_1$. If ${1\over2} \|\nabla u\|_2^2 + \partial_t u \neq 0$, then by the same scaling argument above, we have
\begin{equation}
    \min_{\mu} \int_0^1 \int_{\bR^d} \left[ -\left( {1\over2} \|\nabla u\|_2^2 + \partial_t u \right) \mu \right] \; \rd x \; \rd t = -\infty
\end{equation}
because $\mu$ is unconstrained (except for the boundary conditions). Then we deduce that
\begin{equation}
\label{app:eqn:duality}
    \min_{(\mu, \vv) \in \text{(CE)}} \int_0^1 \int_{\mathbb{R}^d} {1\over2} || \vv ||_2^2 \mu \geq \max_{u \in \text{(HJ)}} \left\{ \int_{\bR^d} u(\cdot,1) \mu_1 - \int_{\bR^d} u(\cdot,0) \mu_0 \right\},
\end{equation}
where $u \in \text{(HJ)}$ means that $u$ solves the \emph{HJ equation} (HJ)
\begin{equation}
    \label{app:eqn:HJ}
    \partial_t u + {1\over2} \|\nabla u\|_2^2 = 0.
\end{equation}
From~\eqref{app:eqn:duality}, we see that the duality gap is non-negative, and it is equal to zero if and only if $(\mu^*, u^*)$ solves the following system of PDEs
\begin{equation}
    \label{app:eqn:benamou-brenier_kkt}
    \left\{
    \begin{gathered}
      \partial_t \mu + \div(\mu \nabla u) = 0, \ \ \partial_t u + {1\over2} \|\nabla u\|_2^2 = 0, \\
     \mu(\cdot,0) = \mu_0, \ \ \mu(\cdot,1) = \mu_1.
     \end{gathered} \right.
\end{equation}
PDEs in~\eqref{app:eqn:benamou-brenier_kkt} are referred to as the Karush–Kuhn–Tucker (KKT) conditions for the Wasserstein geodesic problem.

\section{Metric geometry structure of the Wasserstein space and geodesic}
\label{app:sec:wasserstein_facts}

In this section, we review some basic facts on the metric geometry properties of the Wasserstein space and geodesic. We first discuss the general metric space $(X, d)$, and then specialize to the Wasserstein (metric) space $({\cal P}_p(\bR^d), W_p)$ for $p \geq 1$. Furthermore, we connect to the fluid dynamic formulation of optimal transport. Most of the materials are based on the reference books~\citep{BurageBuragoIvanov2001_MetricGeometry,AmbrosioGigliSavare2008_GradientFlows,sabtanbrogio2015_OT}.

\subsection{General metric space}

\begin{defn}[Absolutely continuous curve]
\label{def:ac_curve}
Let $(X, d)$ be a metric space. A curve $\omega : [0, 1] \to X$ is \emph{absolutely continuous} if there is a function $g \in L^1([0,1])$ such that for all $t_0 < t_1$, we have
\begin{equation}
\label{eqn:ac_curve}
d(\omega(t_0), \omega(t_1)) \leq \int_{t_0}^{t_1} g(\tau) \, \rd \tau.
\end{equation}
Such curves are denoted by $\text{AC}(X)$.
\end{defn}

\begin{defn}[Metric derivative]
\label{def:metric_derivative}
If $\omega : [0,1] \to X$ is a curve in a metric space $(X, d)$, the \emph{metric derivative} of $\omega$ at time $t$ is defined as
\begin{equation}
    \label{eqn:metric_derivative}
    |\omega'|(t) := \lim_{h \to 0} {d(\omega(t+h), \omega(t)) \over |h|},
\end{equation}
if the limit exists.
\end{defn}

The following theorem generalizes the classical Rademacher theorem from a Euclidean space into any metric space in terms of the metric derivative.

\begin{thm}[Rademacher]
\label{thm:rademacher}
If $\omega : [0,1] \to X$ is Lipschitz continuous, then the metric derivative $|\omega'|(t)$ exists for almost every $t \in [0,1]$. In addition, for any $0 \leq t < s \leq 1$, we have
\begin{equation}
    d(\omega(t), \omega(s)) \leq \int_{t}^{s} |\omega'|(\tau) \, \rd \tau.
\end{equation}
\end{thm}

Theorem~\ref{thm:rademacher} tells us that absolutely continuous curve $\omega$ has a metric derivative well-defined almost everywhere, and the ``length" of the curve $\omega$ is bounded by the integral of the metric derivative. Thus, a natural definition of the length of a curve in a general metric space is to take the best approximation over all possible meshes.

\begin{defn}[Curve length]
\label{def:curve_length}
For a curve $\omega : [0,1] \to X$, we define its \emph{length} as
\begin{equation}
    \label{eqn:curve_length}
    \text{Length}(\omega) := \sup \left\{ \sum_{k=0}^{n-1} d(\omega(t_k), \omega(t_{k+1})) : n \geq 1, 0 = t_0 < t_1 < \hdots < t_n = 1 \right\}.
\end{equation}
\end{defn}

Note that if $\omega \in \text{AC}(X)$, then
\begin{equation}
    d(\omega(t_k), \omega(t_{k+1})) \leq \int_{t_k}^{t_{k+1}} g(\tau) \, \rd \tau 
\end{equation}
so that
\begin{equation}
    \text{Length}(\omega) \leq \int_0^1 g(\tau) \, \rd \tau < \infty,
\end{equation}
i.e., the curve $\omega$ is of bounded variation.

\begin{lem}
If $\omega \in \text{AC}(X)$, then
\begin{equation}
    \text{Length}(\omega) = \int_0^1 |\omega'|(\tau) \, \rd \tau.
\end{equation}
\end{lem}

\begin{defn}[Length space and geodesic space]
Let $\omega : [0,1] \to X$ be a curve in $(X, d)$.
\begin{enumerate}
    \item The space $(X, d)$ is a \emph{length space} if
    \begin{equation}
        d(x, y) = \inf \left\{ \text{Length}(\omega) : \omega(0) = x, \omega(1) = y, \omega \in \text{AC}(X) \right\}.
    \end{equation}
    
    \item The space $(X, d)$ is a \emph{geodesic space} if
    \begin{equation}
        d(x, y) = \min \left\{ \text{Length}(\omega) : \omega(0) = x, \omega(1) = y, \omega \in \text{AC}(X) \right\}.
    \end{equation}
\end{enumerate}
\end{defn}

\begin{defn}[Geodesic]
Let $(X, d)$ be a length space. 
\begin{enumerate}
    \item A curve $\omega : [0,1] \to X$ is said to be a \emph{constant-speed geodesic} between $\omega(0)$ and $\omega(1)$ if
    \begin{equation}
        d(\omega(t), \omega(s)) = |t-s| \cdot d(\omega(0), \omega(1)),
    \end{equation}
    for any $t, s \in [0, 1]$.
    
    \item If $(X, d)$ is further a geodesic space, a curve $\omega : [0,1] \to X$ is said to be a \emph{geodesic} between $x_0 \in X$ and $x_1 \in X$ if it minimizes the length among all possible curves such that $\omega(0) = x_0$ and $\omega(1) = x_1$.
\end{enumerate}
\end{defn}
Note that in a geodesic space $(X, d)$, a constant-speed geodesic is indeed a geodesic. In addition, we have the following equivalent characterization of the geodesic in a geodesic space.

\begin{lem}
Let $(X, d)$ be a geodesic space, $p > 1$, and $\omega : [0,1] \to X$ a curve connecting $x_0$ and $x_1$. Then the following statements are equivalent.
\begin{enumerate}
    \item $\omega$ is a constant-speed geodesic.
    
    \item $\omega \in \text{AC}(X)$ such that for almost every $t \in [0, 1]$, we have
    \begin{equation}
        |\omega'|(t) = d(\omega(0), \omega(1)).
    \end{equation}
    
    \item $\omega$ solves
    \begin{equation}
        \min \left\{ \int_0^1 |\tilde\omega'|^p \, \rd t : \tilde\omega(0) = x_0, \tilde\omega(1) = x_1 \right\}.
    \end{equation}
\end{enumerate}
\end{lem}


\subsection{Wasserstein space}

Since the Wasserstein space $({\cal P}_p(\bR^d), W_p)$ for $p \geq 1$ is a metric space, the following definition specializes Definition~\ref{def:metric_derivative} to the Wasserstein metric derivative.

\begin{defn}[Wasserstein metric derivative]
\label{def:wasserstein_metric_derivative}
Let $\{\mu_t\}_{t \in [0,1]}$ be an absolutely continuous curve in the Wasserstein (metric) space $({\cal P}_p(\bR^d), W_p)$. Then the \emph{metric derivative} at time $t$ of the curve $t \mapsto \mu_t$ with respect to $W_p$ is defined as
\begin{equation}
    |\mu'|_p(t) : = \lim_{h \to 0} {W_p(\mu_{t+h}, \mu_{t}) \over |h|}.
\end{equation}
For $p = 2$, we write $|\mu'|(t) := |\mu'|_2(t)$.
\end{defn}

In the rest of this section, we consider probability measures $\mu_t$ that are absolutely continuous with respect to the Lebesgue measure on $\bR^d$ and we use $\mu_t$ to denote the probability measure, as well as its density, when the context is clear. 

\begin{thm}
\label{thm:metric_derivative_cty_eqn}
Let $p > 1$ and assume $\Omega \in \bR^d$ is compact.

\underline{\bf Part 1.} If $\{\mu_t\}_{t \in [0,1]}$ is an absolutely continuous curve in $W_p(\Omega)$, then for almost every $t \in [0,1]$, there is a velocity vector field $\vv_t \in L^p(\mu_t)$ such that
\begin{enumerate}
    \item $\mu_t$ is a weak solution of the CE $\partial_t \mu_t + \div(\mu_t \vv_t) = 0$ in the sense of distributions (cf. the definition in~\eqref{eqn:cty_eqn_weak_solution} below);
    
    \item for almost every $t \in [0, 1]$, we have
    \begin{equation}
        \|\vv_t\|_{L^p(\mu_t)} \leq |\mu'|_p(t),
    \end{equation}
    where $\|\vv_t\|_{L^p(\mu_t)}^p = \int_{\Omega} \|\vv_t\|_2^p \, \rd \mu_t$.
\end{enumerate}

\underline{\bf Part 2.} Conversely, if $\{\mu_t\}_{t \in [0,1]}$ are probability measures in ${\cal P}_p(\Omega)$, and for each $t \in [0, 1]$ we suppose $\vv_t \in L^p(\mu_t)$ and $\int_0^1 \|\vv_t\|_{L^p(\mu)} \, \rd t < \infty$ such that $(\mu_t, \vv_t)$ solves the CE, then we have
\begin{enumerate}
    \item $\{\mu_t\}_{t \in [0,1]}$ is an absolutely continuous curve in $({\cal P}_p(\bR^d), W_p)$;
    
    \item for almost every $t \in [0, 1]$,
    \begin{equation}
        |\mu'|_p(t) \leq \|\vv_t\|_{L^p(\mu_t)}.
    \end{equation}
\end{enumerate}
\end{thm}

As an immediate corollary, we have the following dynamical representation of the Wasserstein metric derivative.

\begin{cor}
\label{cor:metric_derivative_cty_eqn}
If $\{\mu_t\}_{t \in [0,1]}$ is an absolutely continuous curve in $({\cal P}_p(\bR^d), W_p)$, then the velocity vector field $\vv_t$ given in Part 1 of Theorem~\ref{thm:metric_derivative_cty_eqn} must satisfy
\begin{equation}
    \|\vv_t\|_{L^p(\mu_t)} = |\mu'|_p(t).
\end{equation}
\end{cor}

Corollary~\ref{cor:metric_derivative_cty_eqn} suggests that $\vv_t$ can be viewed as the \emph{tangent vector field} of the curve $\{\mu_t\}_{t \in [0,1]}$ at time point $t$. Moreover, Corollary~\ref{cor:metric_derivative_cty_eqn} suggests the following (Euclidean) gradient flow for tracking particles in $\bR^d$: let $y(t) := y_x(t)$ be the trajectory starting from $x \in \bR^d$ (i.e., $y(0) = x$) that evolves according the ordinary differential equation (ODE)
\begin{equation}
\label{eqn:particle_ode}
    {\rd \over \rd t} y(t) = \vv_t( y(t) ).
\end{equation}
The dynamical system~\eqref{eqn:particle_ode} defines a flow $Y_t : \Omega \to \Omega$ of vector field $\vv_t$ on $\Omega$ via
\begin{equation}
    \label{eqn:flow_representation}
    Y_t(x) = y(t).
\end{equation}
Then, it is straightforward to check that the pushforward measure flow $\mu_t := (Y_t)_\sharp \mu_0$ and the chosen velocity vector field $\vv_t$ in the ODE~\eqref{eqn:particle_ode} is a weak solution of the CE $\partial_t \mu_t + \div(\mu_t \vv_t) = 0$ in the sense that
\begin{equation}
\label{eqn:cty_eqn_weak_solution}
    {\rd \over \rd t} \int_\Omega \phi \, \rd t = \int_\Omega \langle \nabla \phi, \vv_t \rangle \, \rd \mu_t,
\end{equation}
for any ${\cal C}^1$ function $\phi : \Omega \to \bR$ with compact support.


\begin{thm}[Constant-speed Wasserstein geodesic]
\label{thm:wasserstein_geodesic}
Let $\Omega \in \bR^d$ be a convex subset and $\mu, \nu \in {\cal P}_p(\Omega)$ for some $p > 1$. Let $\gamma$ be an optimal transport plan under the cost function $\|x-y\|_p^p$. Define
\begin{align*}
    & \pi_t : \Omega \times \Omega \to \Omega, \\
        & \pi_t(x, y) = (1-t) x + t y,
\end{align*}
as the linear interpolation between $x$ and $y$ in $\Omega$. Then, the curve $\mu_t = (\pi_t)_\sharp \gamma$ is a constant-speed geodesic in $({\cal P}_p(\bR^d), W_p)$ connecting $\mu_0 = \mu$ and $\mu_1 = \nu$.
\end{thm}

If $\mu$ has a density with respect to the Lebesgue measure on $\bR^d$, then there is an optimal transport map $T$ from $\mu$ to $\nu$~\citep{Brenier1991}. According to Theorem~\ref{thm:wasserstein_geodesic}, we obtain \emph{McCann's interpolation}~\citep{MCCANN1997153} in the Wasserstein space as
\begin{equation}
    \mu_t = [(1-t) \text{id} + t T]_\sharp \mu,
\end{equation}
which is a constant-speed geodesic in $({\cal P}_p(\bR^d), W_p)$. $\text{id}$ is the identity function in $\bR^d$.

To sum up, we collect the following facts about the geodesic structure and dynamical formulation of the OT problem. Let $p > 1$, and $\Omega \subset \bR^d$ be a convex subset (either compact or have no mass escaping at infinity).

\begin{enumerate}
    \item The metric space $({\cal P}_p(\Omega), W_p)$ is a geodesic space.
    
    \item For $\mu, \nu \in {\cal P}_p(\Omega)$, a constant-speed geodesic $\{\mu_t\}_{t \in [0, 1]}$ in $({\cal P}_p(\Omega), W_p)$ between $\mu$ and $\nu$ (i.e., $\mu_0 = \mu$ and $\mu_1 = \nu$) must satisfy $\mu_t \in \text{AC}({\cal P}_p(\Omega))$ and
    \begin{equation}
        |\mu'|(t) = W_p(\mu(0), \mu(1)) = W_p(\mu, \nu)
    \end{equation}
    for almost every $t \in [0, 1]$.
    
    \item The above $\mu_t$ solves 
    \begin{equation}
        \min \left\{ \int_0^1 |\tilde\mu'|^p(t) \, \rd t : \tilde\mu(0) = \mu, \tilde\mu(1) = \nu, \tilde\mu \in \text{AC}({\cal P}_p(\Omega)) \right\}.
    \end{equation}
    
    \item The above $\mu_t$ solves the Benamou-Brenier problem
    \begin{equation}
        W_p^p(\mu, \nu) = \min \left\{ \int_0^1 \|\vv_t\|_{L^p(\tilde\mu_t)}^p \, \rd t : \tilde\mu(0) = \mu, \tilde\mu(1) = \nu, \partial_t \tilde\mu_t + \div(\tilde\mu_t \vv_t) = 0 \right\},
    \end{equation}
    and $\mu_t = \mu(\cdot, t)$ defines a constant-speed geodesic in $({\cal P}_p(\Omega), W_p)$.
\end{enumerate}

\section{Entropic regularization}
\label{app:sec:Entropic_regularization}

 Our GeONet is compatible with entropic regularization, which is closely related to the Schr\"odinger bridge problem and stochastic control~\citep{ChenGeorgiouPavon2016}. Specifically, the entropic-regularized GeONet (ER-GeONet) solves the following fluid dynamic problem:
 
 \vspace{-5mm}
\begin{align}
\label{eqn:benamou-brenier_entropic_regularization}
\begin{gathered}
\min_{(\mu, \vv)} \int_0^1 \int_{\mathbb{R}^d} {1\over2} || \vv(x, t) ||_2^2 \ \mu(x, t) \ \rd x \ \rd t \\
  \qquad\qquad\qquad\qquad \mbox{subject to}  \ \  \partial_t \mu + \div(\mu \vv) + \varepsilon \Delta \mu = 0, \ \ \mu(\cdot, 0) = \mu_0, \ \ \mu(\cdot, 1) = \mu_1.
 \end{gathered}
\end{align}

Applying the same variational analysis as in the unregularized case $\varepsilon = 0$ (cf. Appendix~\ref{app:sec:primal-dual_optimality}), we obtain the KKT conditions for the optimization~\eqref{eqn:benamou-brenier_entropic_regularization} as the solution to the following system of PDEs:

\vspace{-5mm}
\begin{align}
    \label{eqn:benamou-brenier_entropic_regularization_kkt_1}
      \partial_t \mu + \div(\mu \nabla u) = & -\varepsilon \Delta \mu, \\
      \label{eqn:benamou-brenier_entropic_regularization_kkt_2}
      \partial_t u + {1\over2} \|\nabla u\|_2^2 = & \ \ \varepsilon \Delta u,
\end{align}
\vspace{-4mm}

with the boundary conditions $\mu(\cdot,0) = \mu_0, \mu(\cdot,1) = \mu_1$ for $\varepsilon > 0$. Note that~\eqref{eqn:benamou-brenier_entropic_regularization_kkt_2} is a parabolic PDE, which has a unique smooth solution $u^\varepsilon$. The term $\varepsilon \Delta u$ effectively regularizes the (dual) HJ equation in~\eqref{eqn:benamou-brenier_kkt}. In the zero-noise limit as $\varepsilon \downarrow 0$, the solution of the optimal entropic interpolating flow~\eqref{eqn:benamou-brenier_entropic_regularization} converges to solution of the Benamou-Brenier problem~\eqref{eqn:benamou-brenier_formula} in the sense of the method of vanishing viscosity~\citep{Mikami2004,evans2010}.




\section{Gradient enhancement}
\label{Method_augmentations}

 In practice, we may fortify the base method by adding extra residual terms of the differentiated PDEs to our loss function of GeONet. Such gradient enhancement technique has been used to strengthen  PINNs~\citep{YU2022114823}, which improves the efficiency as fewer data points are needed to be sampled from $U(\Omega) \times U(0,1)$, and prediction accuracy as well.


The motivation behind gradient enhancement stems from minimizing the residual of a differentiated PDE. We turn our attention to PDEs of the form
\begin{equation}
\label{eqn:gradient_enhancement_1}
\left\{
\begin{gathered}
\mathcal{F} \Big(x, t, \partial_{x_1}u, \hdots, \partial_{x_d}u, \partial_{x_1 x_1} u , \hdots, \partial_{x_d x_d} u, \hdots, \partial_t u, \lambda \Big) = 0 \ \ \ \ \text{on} \ \ \ \ \Omega \times [0,1], \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ 
u(\cdot, 0) = u_0, \ \ \ u(\cdot, 1) = u_1 \ \ \ \ \ \ \ \ \ \ \ \ \  \text{on} \ \ \ \ \Omega, \ \ 
\end{gathered}
\right.
\end{equation}
for domain $\Omega \subseteq \mathbb{R}^d$, parameter vector $\lambda$, and boundary conditions $u_0, u_1$. One may differentiate the PDE function $\mathcal{F}$ with respect to any spatial component to achieve
\begin{equation}
\label{gradient_enhancement_derivative}
\frac{ \partial }{ \partial x_{\ell}} \mathcal{F} \Big(x, t, \partial_{x_1}u, \hdots, \partial_{x_d}u, \partial_{x_1 x_1} u , \hdots, \partial_{x_d x_d} u, \hdots, \partial_t u, \lambda \Big) = 0 . 
\end{equation}
The differentiated PDE is additionally equal to $0$, similar to what we see in a PINN setup. If we substitute a neural network into the differentiated PDE of~\eqref{gradient_enhancement_derivative}, what remains is a new residual, just as we saw with the neural network substituted into the original PDE. Minimizing this new residual in the related loss function characterizes the gradient enhancement method. 

We utilize the same loss function in~\eqref{eqn:GeONet_loss}, but we add the additional terms
\begin{align}
\label{eqn:gradient_enhancement_2}
\mathcal{L}_{\text{GE,cty}} \ \ \ & = \ \ \  \sum_{\ell=1}^d  \gamma_{\ell}  \E [ | |  \ \ \frac{\partial}{\partial x_{\ell}} (\frac{\partial}{\partial t} \mathcal{C}_{\phi} + \div ( \mathcal{C}_{\phi} \nabla \mathcal{H}_{\psi}) )  \ \ | |_{L^2(\Omega \times (0,1))}^2 ],  \\
\mathcal{L}_{\text{GE, HJ}} \ \ \ & = \ \ \  \sum_{\ell=1}^d \omega_{\ell}  \E [ | | \ \ \frac{\partial}{\partial x_{\ell}} ( \frac{\partial}{\partial t} \mathcal{H}_{\psi} + \frac{1}{2} | | \nabla \mathcal{H}_{\psi}| |_2^2 )  \ \  | |_{L^2( \Omega \times (0,1))}^2 ], 
\end{align}
where the variables and neural networks that also appeared in~\eqref{eqn:GeONet_loss} are the same. Here $\gamma_{\ell}$ and $\omega_{\ell}$ are positive weights. The summation is taken in order to account for the gradient enhancement of each spatial component of $x \in \Omega$.




\begin{comment}
\section{DeepONets}
\label{app:sec:DeepONets}

A challenge resides in solving the risk minimization problem over numerous instances of data. This challenge may be conciliated by instituting a DeepONet that learns a general nonlinear operator, where one (or a pair of) neural network(s) encode(s) the input and another encodes the collocation samples. This architecture originates as an equivalence to the universal approximation theorem for operators.

\textbf{General DeepONet.} A general operator $G^{\dagger}$ may be approximated by an unstacked DeepONet~\citep{392253,LuJinPangZhangKarniadakis2021_DeepONet}

\vspace{-2mm}
\begin{equation}
\label{eqn:DeepONet}
G^{\dagger}( u_0) (x,t) \approx \sum_{k=1}^{p} \mathcal{B}_{k} \big( u_0(x_1), \hdots, u_0(x_m), \theta \big) \cdot \mathcal{T}_{k} (x, t, \xi) , 
\end{equation}

where $\mathcal{B}_{k}, \mathcal{T}_{k}$ are scalar elements of output of neural networks $\mathcal{B}, \mathcal{T}$, and $p$ is a constant denoting the number of such elements. We take $\mathcal{B}$ and $\mathcal{T}$ to be artificial neural networks parameterized by $\theta, \xi$ respectively. $\mathcal{B}, \mathcal{T}$ are known as the branch and trunk networks respectively. $u_0$ is the initial function in which the operator is applied, evaluated at distinct locations $x_1, \hdots, x_m$ for branch input. $(x,t)$ is any arbitrary point in space and time in which $G^{\dagger}(u_0)$ may be evaluated.

\textbf{Enhanced DeepONet.} The above framework is restricted to one initial input function $u_0$. We turn our attention to the enhanced DeepONet, a DeepONet styled to act upon dual initial conditions~\cite{https://doi.org/10.48550/arxiv.2202.08942}. Our true operator $\Gamma^{\dagger}$ may be approximated using a second neural network encoder for input $u_1$,

\vspace{-2mm}
\begin{equation}
\label{eqn:DeepONet_enhanced}
\Gamma^{\dagger}( u_0, u_1) (x,t) \approx \sum_{k=1}^{p} \mathcal{B}^0_{k} \big( u_0(x_1), \hdots, u_0(x_m), \theta^0 \big) \cdot
\mathcal{B}^1_{k} \big( u_1(x_1), \hdots, u_1(x_m), \theta^1 \big) \cdot
\mathcal{T}_{k} (x, t, \xi) .
\end{equation}




\textbf{Physics-informed DeepONet.} The enhanced DeepONet may be substituted into any physics-informed framework, such as that of equation~\eqref{eqn:PINN_formulation_1}, taking place of the PDE solution value in the empirical loss to be minimized~\cite{DBLP:journals/corr/abs-2103-10974}.  Generalization of the trained DeepONet permits any solution to the PDEs to be evaluated instantaneously given the appropriate input function(s).
\end{comment}


\section{Specialized architectures}

\subsection{Modified multi-layer perceptron}
\label{Modified_mlp}

Here we outline the forward pass of the modified multi-layer perceptron used throughout the experiments as presented in ~\cite{doi:10.1126/sciadv.abi8605} Let $\sigma$ denote an activation function (at least twice differentiable to allow automatic differentiation of the networks to satisfy the PDEs), $X$ as neural network design input, $W^i$ the weights of the neural network at index $i$, and $b^i$ the bias at layer $i$. Here, $X$ can refer to either branch or trunk inputs, as this architecture is used for both.

The forward pass is given by
\begin{align}
& U = \sigma(W^1 X + b^1), \ \ \ V = \sigma(W^2 X + b^2) \\ 
& H^{1} = \sigma( W^{h,1} X + b^{h,1} ) \\ 
& Z^{k} = \sigma( W^{z,k} H^k + b^{z,k} ) \\ 
& H^{k} = (1 - Z^{k-1} ) \odot U + Z^{k-1} \odot U \\ &
\mathcal{N}_{\theta} = W^{\ell} H^{\ell} + b^{\ell} ,
\end{align}
where $k \in \{1,\hdots,\ell\}$, $\odot$ is an element-wise product, and $\mathcal{N}_{\theta}$ is the neural network final output, either a branch or a trunk.


\subsection{Fourier feature architecture}
\label{Fourier_feature}

We outline the Fourier feature architecture of \cite{doi:10.1126/sciadv.abi8605}. We embed trunk input $y=(x,t)$ in a higher-dimensional space by taking transformations of the form
\begin{equation}
U  = ( \cos(2 \pi B_x y), \sin(2 \pi B_x y))^T 
\end{equation}
and passing them into trunk input. Alternatively, we consider the more elaborated architecture of ~\cite{Wang_2021}, which requires passing in $x,t$ into distinct Fourier embeddings of the form of $U$, and using separate layers for each. An element-wise product is taken before the last layer. We used this for our experiments of ~\ref{empirical_Gaussians}, but generally found the Fourier feature architecture of passing in $y=(x,t)$ to formulate $U$ as effective as well.




\newpage 

\section{Hyperparameter settings and training details}
\label{Hyperparameter_settings}

We discuss training characteristics of GeONet based on the primary experiments. An unmodified Adam optimizer was chosen for all branch, trunk neural networks with a learning rate starting from $5\mathrm{e}{-4}$. All layers share the same width. We use $\tanh$ activation for all neural networks.  Coefficients $\alpha_1,\alpha_2,\beta_0,\beta_1$ were computed after examining errors. Coefficients were selected in the range $[0.05,20]$. Neural network depths refer to $\ell$ in each modified MLP. Training is done on a NVIDIA T4 GPU. 


\begin{table*}[h]
  \caption{Architecture and training details in our Gaussian mixture experiments of Section~\ref{sec:experiments} and Appendix~\ref{Training and performance}.}
\begin{center}
\begin{tabular}[H]{
l
S[table-format = 3]
S[table-format = 2]
S[table-format = 1.3]
S[table-format = -2.2]
S[table-format = 1.3]
S[table-format = 1.3]
S[table-format = 2.2]
}
\toprule
\multicolumn{1}{c}{Hyperparameter} & 
\multicolumn{1}{c}{1D Gaussians} &
\multicolumn{1}{c}{2D Gaussians} \\
\toprule
\ \ No. of initial conditions $(\mu_0, \mu_1)$ & {20,000} &  {5,000}       \\
\ \ $m$ (branch input dimension) & {100} &  {576}       \\
\ \ Branch width & {150} &  {200}       \\
\ \ Branch depth  & {7} & {7}     \\
\ \ Trunk width  & {100} & {150}       \\
\ \ Trunk depth & {7} & {7}    \\
\ \ $p$ (dimension of  outputs) & {800} & {800}\\
\ \ Batch size & {2,000} & {2,000}  \\
\ \ Final training time & {$\sim 2$ hrs} & {$\sim 2$ hrs} \\
\ \ Final training loss & {$\sim 1.5\mathrm{e}{-4}$} & {$\sim 1.8\mathrm{e}{-5}$}  \\
\ \ $\alpha_1,\alpha_2, \beta_0, \beta_1$ & {$0.5,0.25,1,1$} & {$0.5, 0.25, 1,1$}  \\
\bottomrule
\end{tabular}
\end{center}
\end{table*}


\begin{table*}[h]
  \caption{Architecture and training details in our empirical Gaussians and encoded MNIST experiments of Section~\ref{sec:experiments} and Appendix~\ref{Training and performance}.}
\begin{center}
\begin{tabular}[H]{
l
S[table-format = 3]
S[table-format = 2]
S[table-format = 1.3]
S[table-format = -2.2]
S[table-format = 1.3]
S[table-format = 1.3]
S[table-format = 2.2]
}
\toprule
\multicolumn{1}{c}{Hyperparameter} & 
\multicolumn{1}{c}{Empirical Gaussians} & \multicolumn{1}{c}{Encoded MNIST} \\
\toprule
\ \ No. of initial conditions $(\mu_0, \mu_1)$  &  {1,000}  & {30,000}      \\
\ \ $m$ (branch input dimension) &  {625}  & {32}      \\
\ \ Branch width &  {100}  & {150}     \\
\ \ Branch depth   & {7} & {7}    \\
\ \ Trunk width   & {100}  & {100}      \\
\ \ Trunk depth  & {5} & {7}    \\
\ \ $p$ (dimension of  outputs)  & {200} & {200} \\
\ \ Batch size  & {1,000} & {1,000} \\
\ \ Final training time  & {$\sim 2$ hrs} & {$\sim 4$ hrs} \\
\ \ Final training loss  & {$\sim 7.0\mathrm{e}{-4}$} & {$\sim 2.0\mathrm{e}{-2}$} \\
\ \ $\alpha_1,\alpha_2, \beta_0, \beta_1$  & {$0.5, 0.25, 1,1$} & {$1,1,1,1$} \\
\bottomrule
\end{tabular}
\end{center}
\end{table*}



\newpage 
\section{Training and performance}
\label{Training and performance}

\subsection{Univariate and bivariate Gaussian mixture experiments}



%\textbf{Univariate Gaussians.} We choose spatial domain $x \in \Omega = [0,10]$ discretized into a $100$ point mesh. We generate $20,000$ training pairs $(\mu_0, \mu_1)$ of Gaussians, taking $k_j = 6$ for the number of Gaussians in each mixture. We take means $\mu_i \in [2,8]$ and variances $\Sigma_i \in [0.5,0.6]$ uniformly. Empirically, we found a large batch size more suitable for training than a low one, and so we take a batch size of $2,000$, meaning this many uniform collocation points are taken for both the PDE residuals and boundary points for each training iteration. We choose physical loss coefficient $\alpha_1 = 0.5, \alpha_2 = 0.25$, with boundary coefficients $\beta_0 = \beta_1 = 1$. We found these coefficients a good balance to enforce the physical constraint without sacrificing boundary restrictions after iterating these coefficients among $[0.05,20]$ and examining error. Additional training details are given in Appendix ~\ref{Hyperparameter_settings}.


%\textbf{Bivariate Gaussians.} In our experiment, domain $\Omega = [0,5] \times [0,5] \subseteq \mathbb{R}^2$ was chosen, which was discretized into a $24 \times 24$ grid for GeONet input, meaning the branch networks took vector input of $576$ in length for each. We generate $5,000$ training pairs $(\mu_0, \mu_1)$. Recall that GeONet is mesh-invariant, so the $24 \times 24$ grids can be adapted to any higher resolution, which is used in figure~\ref{fig:GeONet_gaussian_mixture_bivariate}. We use a combination of low and high variance Gaussians in the mixture, 6 of which had variance in $[0.35,0.4]$ and 6 in $[0.75, 0.9]$, giving a total of 12 Gaussians in each mixture in each pair. Covariance was in $[-0.1,0.1]$. Additional training details are given in Appendix ~\ref{Hyperparameter_settings}.


\begin{comment}
\begin{figure}[h]
  \centering
  %\includegraphics[scale=0.55]{Gaussians 8.pdf}
  \includegraphics[scale=0.56]{GeONet_bivariate_samples_8.pdf}
  \caption{Geodesics predicted by GeONet on bivariate Gaussians over a square domain. The top of each pair is the reference solution computed by POT, and the bottom is GeONet.  \vspace{-3mm}}
  \label{fig:GeONet_gaussian_mixture_bivariate}
\end{figure}
\end{comment}

%\textbf{Training.} To compute the DeepONet derivatives, we take the inner product in the enhanced DeepOnet as in equations~\eqref{eqn:Continuity_sol},~\eqref{eqn:HJ_sol}, and subsequently use automatic differentiation after the inner products are taken. Alternatively, one may compute a Hessian for the second-order derivatives, but this is costly in terms of memory, meaning a large batch size cannot be used without a monumental memory cost. We found the neural networks do not train properly without a large batch size, and so this method of differentiation is not viable. We found the DeepONet output dimension taken to be quite large  slightly outperforms a lower-dimension output given sufficient data and no overfitting. In the univariate Gaussian experiment, we take $p=800$, which outperformed $p=200$ by reducing training loss from approximately $2.5\mathrm{e}{-4}$ to $1.5\mathrm{e}{-4}$ and reducing test error by about $1\%$. In the bivariate experiment, changing $p=400$ to $p=800$ reduced training loss from approximately $2.1\mathrm{e}{-5}$ to $1.8\mathrm{e}{-5}$. Architecture generally made some difference to training loss, but not significant, making a width of around $100$-$200$ suitable for branches and trunks. For example, increasing branch width in the univariate experiment from $100$ to $150$ lowered training loss by approximately $4\mathrm{e}{-5}$. Increasing branch width to $200$ and trunk width to $150$ from $150$ and $100$ respectively had minimal effect, lowering training loss by about $1\mathrm{e}{-5}$. We found the modified MLP architecture preferable, lowering final training loss from approximately $3\mathrm{e}{-4}$ with standard architecture for univariate Gaussians.




\textbf{Performance.} Our baseline results were collected by deploying GeONet on the identity geodesic in Table~\ref{tab:GeONet_gaussian_mixture}. The baseline identity geodesic provides a benchmark for comparing and interpreting the errors across different setups. The univariate cases were evaluated upon a $100$ point mesh, and the bivariate upon a $40 \times 40$ mesh, except in the zero-shot super-resolution case, in which the grid is refined and previously specified. From Table~\ref{tab:GeONet_gaussian_mixture}, we can draw the following observations. The loss boundary conditions~\eqref{eqn:GeONet_loss_BC} allow greater precision for $t=0,1$, which suggests that a lack of data-enforced conditions along the inner region of the time continuum would cause greater error. Errors for predicting the univariate Gaussian trivial identity geodesic in the intermediate $t = 0.25, 0.5, 0.75$ are uniformly smaller than other in-distribution setups since the former is an easier task. In the bivariate experiment, we found that error quickly rises as variance decreases, which is equivalent to a task of learning more complicated geodesics. We did not find lower variance drastically affects performance in the univariate experiment, suggesting GeONet and potentially physics-informed DeepONets in general are less effective as the dimension increases. We did not find the number of Gaussians in the mixtures drastically affected results, but naturally more complicated geodesics induce greater error, which is to be expected. We found bivariate errors are similar to the random case as in the identity case, suggesting there is some notion of base neural operator error, which may not exist with simpler data.

\subsection{Gaussian empirical densities}


\textbf{Training.} 3000 point cloud particles were sampled from mixtures composed of 3 Gaussians for source $\mu_0$ and target $\mu_1$. 2D histograms were constructed to turn particle data into empirical densities, with bins ranging from $-7$ to $7$. Domain $\Omega = [0,5] \times [0,5]$ was discretized into a $25 \times 25$ point domain and assigned for the histograms' locations used as GeONet spatial input. A batch size of 1,000 was chosen. We take $p=200$, $\alpha_1 = 0.5, \alpha_2 = 0.25, \beta_0 = \beta_1 = 1$, which can be altered to impose strength of the boundary and physics terms accordingly. We employ the Fourier feature network architecture of~\cite{Wang_2021} for trunk networks. We take matrix $B_v$ with elements sampled in $\mathcal{N}(0,\sigma_v^2)$, subsequently taking $(\cos(2 \pi B_v v), \sin(2 \pi B_v v))^T$ as input for a fully-connected network, where $v$ is either space or time input. Our architecture for this experiment is fully outlined in ~\ref{Fourier_feature}. Empirically, we found low variance necessary, and we chose $\sigma = 0.5$ for both $v = x,t$ for both continuity and trunk branches.

\textbf{Performance.} In this experiment, GeONet correctly captures the translocation of mass and overall geodesic behavior. The other methods are more suited for point clouds but yield high errors in learning the geodesic. GeONet tends to slightly regularize the solution by smoothing them, in which GeONet has trouble learning precision that comes with particle samples.



\subsection{MNIST experiment}


\begin{table*}[!b]
\caption{$L^1$ error of GeONet on 50 test pairings of encoded MNIST. All values are multiplied by $10^{-2}$. Error was calculated upon the geodesic in both the shifted and ambient/original space.}
  \label{tab:GeONet_CIFAR-10}
  \centering
  \begin{tabular}[h]{
    l
    S[table-format = 3]
    S[table-format = 2]
    S[table-format = 1.3]
    S[table-format = -2.2]
    S[table-format = 1.3]
    S[table-format = 1.3]
    S[table-format = 2.2]
    }
    \toprule
    \multicolumn{1}{c}{} & 
    \multicolumn{5}{c}{GeONet $L^1$ error on encoded MNIST data}\\
    \cmidrule(lr){2-6}        
    \textbf{Test setting} & {$\bm{ t=0 }$} & {$\bm{  t=0.25  }$} & {$\bm{ t=0.5  }$} & {$\bm{ t=0.75  }$} & {$\bm{ t =1 }$}  \\
    \midrule
    \text{Encoded, identity} &
    {$0.923 \pm 0.213$} &
    {$0.830 \pm 0.166$} & 
    {$0.825 \pm 0.165$} & 
    {$0.834 \pm 0.173$} & 
    {$0.931 \pm 0.215$} \\
    \text{Encoded, random} & 
    {$1.62 \pm 0.333$} & 
    {$2.14 \pm 1.22$} & 
    {$2.78 \pm 1.62$} & 
    {$2.11 \pm 1.17$} & 
    {$1.54 \pm 0.282$} \\
    \midrule
    \text{Ambient, identity} & 
    {$26.7 \pm 11.2$} & 
    {$34.0 \pm 6.88$} & 
    {$35.3 \pm 8.32$} & 
    {$36.4 \pm 9.77$} & 
    {$34.0 \pm 13.2$} \\
    \text{Ambient, random} & 
    {$32.1 \pm 16.6$} & 
    {$58.2 \pm 15.0$} & 
    {$68.1 \pm 18.8$} & 
    {$56.4 \pm 14.3$} & 
    {$24.7 \pm 10.7$} \\
    \bottomrule
\end{tabular}
\vspace{-3mm}
\end{table*}


\textbf{Training.} As described in section ~\ref{sec:experiments}, to learn the geodesic, we ensure all values within the encoded representation are nonnegative, meaning we can shift all encoded representations by some arbitrary constant. We choose $10$ for this. This constant can be deducted in later stages to ensure the valid representation is met. We normalize the data so that the density conditions are satisfied before GeONet input. A domain of $[0,5]$ was divided into an equispaced mesh of $32$ points for the encoded representation. This domain is rather arbitrary and is chosen simply for DeepONet input purposes, which can be modified as seen fit. $30,000$ encoded pairs were chosen to train GeONet and the pre-trained autoencoder, the entirety of MNIST. We used a batch size of $1,000$. Additional details are found in Appendix ~\ref{Hyperparameter_settings}.

\begin{figure}[h]
  \centering
  %\includegraphics[scale=0.55]{Gaussians 8.pdf}
  \vspace{-2mm}
  \includegraphics[scale=0.75]{GeONet_training_loss_6.pdf}
  \caption{We examine iterations of the Adam optimizer in the total and late training on a log scale. We examine late training in order to observe oscillatory behavior between the continuity and HJ loss to see if they adversarially compete in late training. We do not observe this pattern, and the continuity loss greatly surpasses the HJ loss in value. These graphs were created using the encoded MNIST experiment. \vspace{-3mm}}
  \label{fig:training_loss_late_training}
\end{figure}


\textbf{Performance.} GeONet performs well in this experiment. Scaling the physics-informed term by a constant of less than one did not prove necessary in this experiment to ensure all loss terms are met to a sufficient degree. As before, boundary terms are uniformly smaller, likely since these terms are known and included in the loss function to be minimized. The same error metric is used as in the synthetic experiments but with normalization, making the $L^1$ error relative. We remark OOD generalization is omitted because the distribution of the encoded data is not known. We also remark the decoded images, being the geodesic returned to its original state, do not directly translate to a geodesic performed upon an original pair of images. NaN values are omitted in the error computations, which are possible in the POT solutions due to the irregularity of the initial conditions. 

\textbf{Regularization.} Classical geodesic algorithms require a small regularization parameter in order to be computed. This affects the synthetic experiments trivially, but we found this regularization induces greater in the MNIST experiment. This is to be considered when evaluating the errors, and true error is likely smaller between GeONet and the reference geodesics computed with POT than what is displayed. This regularization acts as a form of "smoothing" of the solutions.





\section{GeONet error for additional error metrics}
\label{app:error_metrics}


\begin{table}[H]
  \caption{We list mean and standard deviations of error of GeONet on 50 random $\mu_0 \neq \mu_1$ samples for alternative error metrics, being $L^2$ error and the Wasserstein-1 distance. We remark we use sliced Wasserstein distance for the 2D case, as this metric is computationally feasible for higher dimensional cases. We perform this for random Gaussian mixture pairings. All values are multiplied by $10^{-2}$ by those of the table. \vspace{3mm}}
  \label{tab:GeONet_gaussian_mixture_L2+W}
  \centering
  \begin{tabular}[b]{
    l
    S[table-format = 3]
    S[table-format = 2]
    S[table-format = 1.3]
    S[table-format = -2.2]
    S[table-format = 1.3]
    S[table-format = 1.3]
    S[table-format = 2.2]
    }
    \toprule
    \multicolumn{1}{c}{} & 
    \multicolumn{3}{c}{GeONet alternative metric error for random Gaussian mixtures}\\
    \cmidrule(lr){2-4}        
    \textbf{Experiment \ \ } & {$\bm{ t=0 }$} & {$\bm{  t=0.25  }$} & {$\bm{ t=0.5  }$}   \\
    \midrule
    \text{1D, $L^2$}  &
    {$5.19 \pm 1.74$}  & 
    {$6.91 \pm 4.81$}  & 
    {$7.28 \pm 5.39$} \\
    \text{1D, Wasserstein}  &
    {$0.352 \pm 0.116$}  & 
    {$0.364 \pm 0.178$}  & 
    {$0.403 \pm 0.228$} \\
    \midrule
    \text{2D, $L^2$}  &
    {$6.93 \pm 0.883$} & 
    {$7.72 \pm 1.23$} &
    {$8.11 \pm 1.30$} \\
    \text{2D, Wasserstein}  &
    {$0.245 \pm 0.0329$} & 
    {$0.264 \pm 0.0316$} &
    {$0.275 \pm 0.0447$}  \\

    \bottomrule
\end{tabular}

\vspace{3mm}

\begin{tabular}{
    p{2.5cm}  P{2.5cm} P{2.5cm} }
    \toprule
  %  \multicolumn{1}{c}{} 
  %  \multicolumn{5}{c}{GeONet alternative metric error for random Gaussian mixtures}\\
   % \cmidrule(lr){2-6}        
    \textbf{Experiment \ \ } &  {$\bm{ t=0.75  }$} & {$\bm{ t =1 }$}  \\
    \midrule
    \text{1D, $L^2$}  & 
    {$6.49 \pm 4.36$} & 
    {$4.81 \pm 1.58$} \\
    \text{1D, Wasserstein} & 
    {$0.386 \pm 0.166$} & 
    {$0.347 \pm 0.101$} \\
    \midrule
    \text{2D, $L^2$} & 
    {$7.79 \pm 1.14$} &
    {$6.87 \pm 1.05$} \\
    \text{2D, Wasserstein} & 
    {$0.267 \pm 0.0338$} &
    {$0.246 \pm 0.0356$} \\

    \bottomrule
\end{tabular}
%\vspace{-6mm}
\end{table}



\newpage
\section{3D Gaussians figure}

\begin{figure}[h]
  \centering
  \includegraphics[scale=0.75]{GeONet_3d_gaussians.pdf}
  \vspace{0mm}
  \caption{We illustrate GeONet on 3D Gaussians.}
  \label{fig:GeONet_3d_gaussians}
\end{figure}






\newpage
\section{Sample HJ graphs}



\begin{figure}[h]
  \centering
  %\includegraphics[scale=0.55]{Gaussians 8.pdf}
  \includegraphics[scale=0.55]{geonet_HJ_graphs.pdf}
  \caption{We present sample HJ equations for (a) three univariate Gaussian mixtures and (b) three bivariate Gaussian mixtures from the primary experiments performed in Section~\ref{sec:experiments}. The univariate HJ samples at certain times are the vertical cross-sections of the graphs, and the bivariate samples are given at certain times. \vspace{-3mm}}
  \label{fig:GeONet_HJs}
\end{figure}










\end{document}