% \documentclass{uai2023} % for initial submission
\documentclass[accepted]{uai2023} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2023} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2023} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{abbrvnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

\usepackage{multirow}
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography
\usepackage{hyperref}
\usepackage{graphicx}
\usepackage{amssymb,amsmath,amsthm}
\usepackage{epstopdf}
\usepackage{algorithm,algcompatible}
\usepackage{mathtools}
\usepackage{wrapfig}
\usepackage{soul}
\usepackage{tikz}
\usepackage{xspace}
\usepackage{subcaption}
%\usepackage{enumitem}

\renewcommand{\thefootnote}{\fnsymbol{footnote}}

\usepackage[inline]{enumitem}

\definecolor{Blue}{rgb}{0.2,0.2,0.6}
\hypersetup{
  colorlinks=true,
  citecolor=Blue,
  linkcolor=Blue,
  urlcolor=Blue,
}% hyperlinks

\usetikzlibrary{fit,positioning,arrows,automata}

\mathtoolsset{showonlyrefs}

 
\newtheorem{theorem}{Theorem}
\newtheorem{definition}{Definition}
\newtheorem{proposition}{Proposition}
\newtheorem{lemma}{Lemma}
\newtheorem{corollary}{Corollary}
\newtheorem{remark}{Remark}
\newtheorem{example}{Example}
\newtheorem{defn}{Definition}
\newtheorem{assum}{Assumption}

\def\bx{{\mathbf{x}}}
\def\bz{{\mathbf{z}}}
\def\bG{{\mathbf{G}}}
\def\bA{{\mathbf{A}}}
\def\bB{{\bf B}}
\def\bC{{\bf C}}
\def\bH{{\mathbf{H}}}
\def\bR{{\mathbf{R}}}
\def\bP{{\mathbf{P}}}
\def\bS{{\mathbf{S}}}
\def\bX{{\mathbf{X}}}
\def\bJ{{\mathbf{J}}}
\def\bQ{{\mathbf{Q}}}
\def\bK{{\mathbf{K}}}
\def\bU{{\mathbf{U}}}
\def\bV{{\mathbf{V}}}
\def\bF{{\mathbf{F}}}

\newcommand{\mC}{{\mathbb C}}
\newcommand{\mD}{{\mathbb D}}
\newcommand{\mV}{{\mathcal{V}}}
\newcommand{\mE}{{\mathbb E}}
\newcommand{\mP}{{\mathbb P}}
\newcommand{\mR}{{\mathbb R}}
\newcommand{\mN}{{\mathbb N}}
\newcommand{\mS}{{\mathbb S}}
\newcommand{\cB}{{\mathcal B}}
\newcommand{\cC}{{\mathcal C}}
\newcommand{\cD}{{\mathcal D}}
\newcommand{\cE}{{\mathcal E}}
\newcommand{\cF}{{\mathcal F}}
\newcommand{\cG}{{\mathcal G}}
\newcommand{\cH}{{\mathcal H}}
\newcommand{\cI}{{\mathcal I}}
\newcommand{\cJ}{{\mathcal J}}
\newcommand{\cL}{{\mathcal L}}
\newcommand{\cN}{{\mathcal N}}
\newcommand{\cO}{{\mathcal O}}
\newcommand{\cP}{{\mathcal P}}
\newcommand{\cR}{{\mathcal R}}
\newcommand{\cS}{{\mathcal S}}
\newcommand{\cT}{{\mathcal T}}
\newcommand{\cU}{{\mathcal U}}
\newcommand{\cV}{{\mathcal V}}
\newcommand{\cW}{{\mathcal W}}
\newcommand{\cX}{{\mathcal X}}
\newcommand{\cY}{{\mathcal Y}}
\newcommand{\cZ}{{\mathcal Z}}
\newcommand{\one}{{\mathbf 1}}
\newcommand{\tT}{{\text{T}}}

\def\OT{{\textup{OT}}}
\def\KL{{\textup{KL}}}
\def\TV{{\textup{TV}}}
\def\LSI{{\textup{LSI}}}
\def\TI{{\textup{TI}}}
\def\<{{\langle}}
\def\>{{\rangle}}
\def\Rn{{\mathbb{R}^n}}
\def\R{{\mathbb{R}}}
\def\d{{\text{d}}}
% \def\l{{\left}}
% \def\r{{\right}}
\newcommand{\Cov}{{\text{Cov}}}
\newcommand{\Var}{{\text{Var}}}


% \newcommand{\fan}[1]{{\color{violet}{#1}}}
\newcommand{\fan}[1]{{\color{black}{#1}}}
\newcommand{\dam}[1]{{\color{orange}{#1}}}
\newcommand{\ones}{\mathbf{1}} 

\DeclarePairedDelimiter\ceil{\lceil}{\rceil}
\DeclarePairedDelimiter\floor{\lfloor}{\rfloor}

\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}

\everypar{\looseness=-1}

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

% \usepackage[hang]{footmisc}
\setlength{\footnotemargin}{5pt}

\title{Generating Synthetic Datasets by \\ Interpolating along Generalized Geodesics}

% The standard author block has changed for UAI 2023 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\stepcounter{footnote} % Increment counter of footnote to get \dagger instead of \ast
\author[1]{\href{mailto:<jiaojiaofan@gatech.edu>?Subject=Your UAI 2023 paper}{Jiaojiao Fan\thanks{Work done partly during an internship at Microsoft Research.}}{}} 
\author[2]{David Alvarez-Melis}
\affil[1]{%
    Georgia Tech\\
    Atlanta, Georgia, USA
}
\affil[2]{%
    Microsoft Research \& Harvard University\\
    Cambridge, Massachusetts, USA
}

  
\begin{document}
\maketitle

\begin{abstract}
  Data for pretraining machine learning models often consists of collections of heterogeneous datasets. Although training on their union is reasonable in agnostic settings, it might be suboptimal when the target domain ---where the model will ultimately be used--- is known in advance. In that case, one would ideally pretrain only on the dataset(s) most similar to the target one. Instead of limiting this choice to those datasets already present in the pretraining collection, here we explore extending this search to all datasets that can be synthesized as `combinations' of them. We define such combinations as multi-dataset interpolations, formalized through the notion of generalized geodesics from optimal transport (OT) theory. We compute these geodesics using a recent notion of distance between labeled datasets, and derive alternative interpolation schemes based on it: using either barycentric projections or optimal transport maps, the latter computed using recent neural OT methods. These methods are scalable, efficient, and ---notably--- can be used to interpolate even between datasets with distinct and unrelated label sets. Through various experiments in transfer learning in computer vision, we demonstrate this is a promising new approach for targeted on-demand dataset synthesis. 
\end{abstract}



\section{Introduction}


Recent progress in machine learning has been characterized by the rapid adoption of large pretrained models as a fundamental building block \citep{brown2020language}. These models are typically pretrained on large amounts of general-purpose data and then adapted (e.g., \textit{fine-tuned}) to a specific task of interest. Such pretraining datasets usually draw from multiple heterogeneous data sources, e.g., arising from different domains or sources. Traditionally, all available datasets are used in their entirety during pretraining, for example by pooling them together into a single dataset (when they all share the same label sets) or by training in all of them sequentially one by one. These strategies, however, come with important disadvantages. Training on the union of multiple datasets might be prohibitive or too time-consuming, and it might even be detrimental --- indeed, there is a growing line of research showing that removing pretraining data sometimes improves transfer performance \citep{jain2022data}. On the other hand, sequential learning (i.e., consuming datasets one by one) is infamously prone to \textit{catastrophic forgetting} \citep{mccloskey1989catastrophic, kirkpatrick2017overcoming}: the information from earlier datasets gradually vanishing as the model is trained on new datasets. The pitfalls of both of these approaches suggest training instead on a \textit{subset} of the available pretraining datasets, but how to choose that subset is unclear. However, when the target dataset on which the model is to be used is known in advance, the answer is much easier: intuitively, one would train only of those relevant to the target one: e.g., those most similar to it. Indeed, recent work has shown that selecting pretraining datasets based on the distance to the target is a successful strategy \citep{alvarez2020geometric, gao2021information}. However, such methods are limited to selecting (only) among individual datasets already present in the collection.

In this work, we propose a novel approach to \textit{generate} synthetic pretraining datasets as combinations of existing ones. In particular, this method searches among all possible continuous combinations of the available datasets and thus is not limited to selecting specifically one of them. When given access to the target dataset of interest, we seek among all such combinations the one closest (in terms of a metric between datasets) to the target. By characterizing datasets as sampled from an underlying probability distribution, this problem can be understood as a generalization (from Euclidean to probability space) of the problem of finding among the convex hull of a set of reference points, that which is closest to a query point. While this problem has a simple closed-form solution in Euclidean space (via an orthogonal projection), solving it in probability space is ---as we shall see here--- significantly more challenging.

We tackle this problem from the perspective of interpolation. Formally, we model the combination of datasets as an interpolation between their distributions, formalized through the notion of geodesics in probability space endowed with the Wasserstein metric \citep{ambrosio2008gradient, villani2008optimal}. In particular, we rely on \textit{generalized geodesics} \citep{craig2016exponential, ambrosio2008gradient}: constant-speed curves connecting a pair (or more) distributions parametrized with respect to a `base' distribution, whose role is played by the target dataset in our setting. Computing such geodesics requires access to either an optimal transport coupling or a map between the base distribution and every other reference distribution. The former can be computed very efficiently with off-the-shelf OT solvers, but are limited to generating only as many samples as the problem is originally solved on. In contrast, OT maps allow for on-demand out-of-sample mapping and can be estimated using recent advances in neural OT methods \citep{fan2020scalable, korotin2022neural, makkuva2020optimal}. However, most existing OT methods assume unlabeled (feature-only) distributions, but our goal here is to interpolate between classification (i.e., labeled) datasets. Therefore, we leverage a recent generalization of OT for labeled datasets to compute couplings \citep{alvarez2020geometric} and adapt and generalize neural OT methods to the labeled setting to estimate OT maps.

In summary, the contributions of this paper are:
\begin{enumerate*}[label=(\roman*)]
  \item a novel approach to generate new synthetic classification datasets from existing ones by using geodesic interpolations, applicable even if they have disjoint label sets,
  \item two efficient methods to solve OT between labeled datasets,
%   compute generalize geodesics,
  which might be of independent interest,
  \item empirical validation of the method in various transfer learning settings.
\end{enumerate*}



\section{Related work}

% Context:
% \begin{itemize}
%     \item Other approaches to sharing knowledge across domains for a new domain typically rely on model ---rather than data--- combination: mixture-of-experts, model weighting, and interpolation (CITE). Much less work on non-trivial data combination, especially interpolation. 
% \end{itemize}


\paragraph{Mixup and related In-Domain Interpolation} Generating training data through convex combinations was popularized by \textit{mixup} \citep{zhang2018mixup}: a simple data augmentation technique that interpolates features and labels between pairs of points. This and other works based on it \citep{zhang2021how, chuang2021fair,yao2021meta} use mixup to improve in-domain model robustness~\citep{zhu2023interpolation} and generalization by increasing the in-distribution diversity of the training data. Although sharing some intuitive principles with mixup, our method interpolates entire datasets ---rather than individual datapoints--- with the goal of improving across-distribution diversity and therefore out-of-domain generalization.% In particular, interpolation in this generalized scenario 

\paragraph{Dataset synthesis in machine learning} Generating data beyond what is provided as a training dataset is a crucial component of machine learning in practice. Basic transformations such as rotations, cropping, and pixel transformations can be found in most state-of-the-art computer vision models. More recently, Generative Adversarial Nets (GAN) have been used to generate synthetic data in various contexts~\citep{bowles2018GAN, yoon2019pate-gan},
a technique that has proven particularly successful in the medical imaging domain \citep{sandfort2019data}. Since GANs are trained to replicate the dataset on which they are trained, these approaches are typically confined to generating in-distribution diversity and typically operate on features only.



\paragraph{Discrete OT, Neural OT, Gradient Flows}
Barycentric projection~\citep{ambrosio2008gradient,perrot2016mapping} is a simple and effective method to approximate an optimal transport map with discrete regularized OT.
On the other hand, there has been remarkable recent progress in methods to estimate OT maps in Euclidean space using neural networks~\citep{makkuva2020optimal,fan2021scalable,rout2022generative}, which have been successfully used for image generation~\citep{rout2022generative}, style transfer~\citep{korotin2022neural}, among other applications. However, the estimation of an optimal map between (labeled) datasets has so far received much less attention. Some conditional Monge map solvers~\citep{bunne2022supervised} 
% \fan{[asa is not conditional, better replace this asa. by fan]}
utilize the label information in a semi-supervised manner, where they assume the label-to-label correspondence between two distributions is known. Our method differs from theirs in that we do not require a pre-specified label-to-label mapping, but instead estimate it from data. Geodesics and interpolation in general metric spaces have been studied extensively in the optimal transport and metric geometry literature \citep{mccann1997convexity, agueh2011barycenters, ambrosio2008gradient, santambrogio2015optimal, villani2008optimal, craig2016exponential}, albeit mostly in a theoretical setting. Gradient flows \citep{santambrogio2015optimal}, increasingly popular in machine learning to model existing processes
\citep{bunne2022proximal, mokrov2021large-scale, fan2022variational, hua2023dynamic} or solving optimization problems over datasets \citep{alvarez2021dataset}, provide an alternative approach for interpolation between distributions but are computationally expensive.

% On the computational side, estimating such maps using neural networks is a flourishing area of research \citep{makkuva2020optimal, fan2020scalable, korotin2022neural}. 

% Jiaojiao, Korotin neural OT works for solving maps\\
% conditioned / label guided maps, Bunne et al. \\
% interpolation for labeled datasets: mixup, etc

%  When the label correspondence aligns with the feature similarity, we are the same as Korotin's or normal Monge map, such as MNIST-USPS example. When the label correspondence conflicts with the feature similarity, we are far away from theirs, but still similar to the normal Monge map, such as 0-1 example in Teams. 
% When are we different with normal Monge map??
% We can have stochastic map, but if there is no enforcing the diversity in pushforward map, it would have conditional collapse.


\section{Background}
% \vspace{-0.2cm}
\subsection{Distribution interpolation with OT}
% \vspace{-0.2cm}
Consider $\cP(\cX)$ the space of probability distributions with finite second moments over some Euclidean space $\cX$. Given $\mu,\nu\in \cP(\cX)$, the Monge formulation optimal transport problem seeks a map $T:\cX\rightarrow\cX$ that transforms $\mu$ into $\nu$ at a minimal cost. Formally, the objective of this problem is
% \begin{align}
%  \min_{T: T\sharp \mu = \nu} \int_{\mR^d} \|x-T(x)\|_2^2 \d\nu(x),   
% \end{align}
$\min_{T: T\sharp \mu = \nu} \int_{\mR^d} \|x-T(x)\|_2^2 \d\mu(x),$
where the minimization is over all the maps that pushforward distribution $\mu$ into distribution $\nu$. While a solution to this problem might not exist, a relaxation due to Kantorovich is guaranteed to have one. This modified version yields the Wasserstein-2 distance:
$ W_2^2(\mu,\nu) = \min_{\pi \in \Pi(\mu, \nu)} \int_{\mR^d} \|x-x'\|_2^2 \d\pi(x, x'),$
% \begin{align}
%   W_2^2(\mu,\nu) = \min_{\pi \in \Pi(\mu, \nu)} \int_{\mR^d} \|x-x'\|_2^2 \d\pi(x, x'),   
% \end{align}
where now the constraint set $\Pi(\mu,\nu)= \{ \pi \in \mathcal{P}(\mathcal{X}^2) \mid P_{0\sharp}\pi = \mu, P_{1\sharp\pi}=\nu \}$ contains all couplings with marginals $\mu$ and $\nu$. The optimal such coupling is known as the OT plan. A celebrated result by \citet{brenier1991polar} states that whenever $P$ has density with respect to the Lebesgue measure, the optimal $T^*$ exists and is unique. In that case, the Kantorovich and Monge formulations coincide and their solutions are linked by $\pi^* = (\text{Id}, T^*)_\sharp \mu$ where $\rm Id$ is the identity map. The Wasserstein-2 distance enjoys many desirable geometrical properties compared to other distances for distributions \citep{ambrosio2008gradient}. One such property is the characterization of geodesics in probability space \citep{agueh2011barycenters,santambrogio2015optimal}. When $\cP(\cX)$ is equipped with metric $W_2$, the unique minimal geodesic between any two distributions $\mu_0$ and $\mu_1$ is fully determined by $\pi$, the optimal transport plan between them, through the relation:
\begin{align}\label{eq:displacement}
  \rho_t^D: = ((1-t)x + t y )\sharp \pi(x,y) , \quad t \in [0,1],
\end{align}
known as \emph{displacement interpolation}. If the Monge map from $\mu_0$ to $\mu_1$ exists, the geodesic can also be written as
\begin{align}\label{eq:mccan}
  \rho^M_t: = ((1-t) {\rm Id} + t T^* )\sharp \mu_0 , \quad t \in [0,1],
\end{align}
and is known as \emph{McCann's interpolation}~\citep{mccann1997convexity}. It is easy to see that $\rho^M_0=\mu_0$ and $\rho^M_1=\mu_1$.

Such interpolations are only defined between two distributions.
%  \emph{Wasserstein geodesic}~\citep{Vil03,benamou2000computational} is the minimizing, constant-speed geodesic curve in Wasserstein space, and it coincides with McCann's interpolation~\citep[\S5.4]{santambrogio2015optimal}.
When there are $m \ge 2$ marginal distributions $\{\mu_1, \ldots, \mu_m\}$, the \emph{Wasserstein barycenter}
% $\rho^B_a: = \argmin_\rho \sum_{i=1}^m a_i W_2^2(\rho, \mu_i) , \quad a \in \Delta_{m-1} \subset \mR^m$
\begin{align}
    % \vspace{-0.3cm}
  \rho^B_a: = \argmin_\rho \sum_{i=1}^m a_i W_2^2(\rho, \mu_i) , \quad a \in \Delta_{m-1} \subset \mR_{\ge 0}^m
\end{align}
generalizes McCann's interpolation~\citep{agueh2011barycenters}. Intuitively, the interpolation parameters $a = [a_1,\dots, a_m]$ determine the `mixture proportions' of each dataset in the combination, akin to the weights in a convex combination of points in Euclidean space. In particular, when $a$ is a one-hot vector with $a_i=1$, then $\rho^B_a = \mu_i$, i.e., the barycenter is simply the $i$-th distribution. Barycenters have attracted significant attention in machine learning recently~\citep{srivastava2018scalable,korotin2021continuous}, but they remain challenging to compute in high dimension~\citep{fan2020scalable,korotin2022wasserstein}.

Another limitation of these interpolation notions is the non-convexity of $W_2^2$ along them. In Euclidean space, given three points $x_1,x_2,y \in \mR^d$, the function $t \mapsto \|x_t-y\|_2^2$, where $x_t$ is the interpolation $x_t = (1-t) x_1 + t x_2 $, is convex.
% Despite of their analogy to the point interpolations in Euclidean space, where $\|() \|$, 
In contrast, in Wasserstein space, neither the function $t \mapsto W_2^2(\rho^M_t , \nu)$ or $a \mapsto W_2^2(\rho^B_a , \nu)$ are guaranteed to be convex~\citep[\S4.4]{santambrogio2017euclidean}. This complicates their theoretical analysis, such as in gradient flows.
To circumvent this issue, \citet{ambrosio2008gradient} introduced the \emph{generalized geodesic} of $\{\mu_1,\ldots,\mu_m\}$ with base $\nu\in\mathcal{P}(\mathcal{X})$:
% $\rho^G_a := \left(\sum_{i=1}^m a_i T^*_i \right)\sharp \nu  , \quad a \in \Delta_{m-1},$
\begin{align}
  \rho^G_a := \left(\sum_{i=1}^m a_i T^*_i \right)\sharp \nu  , \quad a \in \Delta_{m-1},
\end{align}
where $T^*_i $ is the optimal map from $\nu$ to $\mu_i$.
% Next, the lemma below will show that 

% \fan{[
% % Maybe we should just put transport metric here, 
% remove barycenter stuff to appendix, remove lemma1 as well.]}
\begin{lemma}\label{lem:convex_w2}
  The functional $\mu \mapsto W_2^2( \mu, \nu)$ is convex along the generalized geodesics, and
  $
    W_2^2(\rho^G_a, \nu ) \le  \sum_{i=1}^m a_i W_2^2(\mu_i, \nu) .
  $
  % And the function $a \mapsto W_2^2(\rho_a^G , \nu)$ is convex. 
\end{lemma}


Thus, unlike the barycenter, the generalized geodesic does yield a notion of convexity satisfied by the Wasserstein distance and is easier to compute. 
The proof of Lemma \ref{lem:convex_w2} is postponed to \S A.
% \ref{sec:proof}.
For these reasons, we adopt this notion of interpolation for our approach. It remains to discuss how to use it on (labeled) datasets.

% Compared to barycenter, \eqref{eq:gen_geodesic} is also much easier to solve. 
% This together with the convexity inspires us to study the analog of it in dataset space.
% in the sense that $\rho^B_{[t,1-t]} = \rho^M_{t} $
% Interpolation in normal Euclidean space: introduce barycenter, mcccan interpolation, generalized geodesic
% \vspace{-0.1cm}
\subsection{Dataset distance}
% \vspace{-0.2cm}
Consider a dataset $\cD_P = \{ z^{(i)}\}_{i=1}^N = \{ x^{(i)}, y^{(i)}\}_{i=1}^N \overset{i.i.d.}{\sim} P(x,y)$. The Optimal Transport Dataset Distance (OTDD)~\citep{alvarez2020geometric} measures its distance to another dataset $\cD_Q$ as:
\begin{align}\label{eq:otdd}
  %   \vspace{-0.25cm}
   & d^2_{\OT} (\cD_P,\cD_Q ) = \nonumber                                                                    \\
   & \min_{\pi \in \Pi (P,Q)} \int \left( \|x-x'\|_2^2 + W_2^2(\alpha_y, \alpha_{y'}) \right) \d\pi (z,z' ),
  %   \vspace{-0.25cm}
\end{align}
which defines a proper metric between datasets. Here, $\alpha_y, \alpha_{y'}$ are class-conditional measures corresponding to $P(x|y)$ and $Q(x|y')$. This distance is strongly correlated with transfer learning performance, i.e., the accuracy achieved when training a model on
$\cD_P$ and then fine-tuning and evaluating on $\cD_P$. Therefore, it can be used to select pretraining datasets for a given target domain. Henceforth we abuse the notation $P$ to represent both a dataset and its underlying distribution for simplicity. To avoid confusion, we use $\nu$ and $\mu$ to represent distributions in the feature space (typically $\mathbb{R}^d$) and use $P,Q$ to represent distributions in the product space of features and labels.


% \fan{
% % The underlying distribution of datasets $\cD_\nu$ are $\nu$ or we abuse the notation of them?
% We can consider different number of labels. TODO: Need to explain notations here}
% In product space, they're different, and we hope to make analog


% \section{Method and algorithm}
\section{Dataset interpolation along generalized geodesic}
% \vspace{-0.2cm}
% We introduce our generalized geodesic scheme to interpolate among multiple datasets $P_i$. It relies on solving all optimal maps $\cT_i$ from an external dataset $Q$ to each dataset $P_i$ first (\S\ref{sec:map}), and then do the convex combination over all the pushforward data to get interpolation dataset $P_a$ (\S\ref{sec:comb}).
% For the downstream transfer learning task, where we assume $Q$ is the test dataset, and all $P_i$ are the training datasets, we propose a method to locate the closest dataset $P_a^*$ to the dataset $Q$.
Our method consists of two steps: estimating optimal transport maps between the target dataset and all training datasets (\S\ref{sec:map}), and using them to generate a convex combination of these datasets by interpolating along generalized geodesics (\S\ref{sec:comb}). 
\fan{The OT map estimation is in feature space or original space depending on the dataset's dimension.} 
For some downstream applications, we will additionally project the target dataset into the `convex hull' of the training datasets (\S\ref{sec:proj}).


% \fan{TODO: highlight that map can also be in feature space}

\subsection{Estimating optimal maps between labeled datasets}\label{sec:map}
% \vspace{-0.25cm}
% Compared to solving Monge map in Euclidean space, the challenge here is in two folds: the labels can have different
The OTDD is a special case of Wasserstein distance, so it is natural to consider the alternative Monge (map-based) formulation to \eqref{eq:otdd}.
% We call this map OTDD map.
We propose two methods to approximate the OTDD map, one using the entropy-regularized OT and another one based on neural OT.
% \fan{[change to barycentric projection]}


\paragraph{OTDD barycentric projection.}
Barycentric projections~\citep{ambrosio2008gradient,pooladian2021entropic} can be efficiently computed for entropic regularized OT using the Sinkhorn algorithm~\citep{sinkhorn1967diagonal}.
% \fan{[change to uniform distribution]}
% Assume that we have i.i.d. samples $X_\nu = (x_\nu^{(1)}, \ldots, x_\nu^{(N_\nu)}) \in \mR^{N_\nu \times d}, X_\mu = (x_\mu^{(1)}, \ldots, x_\mu^{(N_\mu)}) \in \mR^{N_\mu \times d} $ from two distributions $\nu$ and $\mu$ separately.
Assume that we have empirical distributions $\nu = \sum_{i=1}^{N_\nu} \frac{1}{N_\nu} \delta_{x_\nu^{(i)}} $ and $ \mu = \sum_{i=1}^{N_\mu} \frac{1}{N_\mu} \delta_{x_\mu^{(i)}} $, where $\delta_x$ is the Dirac function at $x$.
% from two distributions $\nu$ and $\mu$ separately.
Denote all the samples compactly as matrices: $X_\nu = \left(x_\nu^{(1)}, \ldots, x_\nu^{(N_\nu)} \right) \in \mR^{N_\nu \times d}, X_\mu = \left(x_\mu^{(1)}, \ldots, x_\mu^{(N_\mu)} \right) \in \mR^{N_\mu \times d} $.
After solving the optimal coupling
$\pi^*: =
  % \min_{\pi \in \Pi (\nu, \mu)} \int  \|x-x'\|_2^2  \d\pi (x,x' )
  \min_{\pi \in \Pi (\nu, \mu)}
  \sum_{i,j} \|x_\nu^{(i)}-x_\mu^{(j)}\|^2 \pi(i,j)
  % \int  \|x-x'\|_2^2  \d\pi (x,x' )
$,
the barycentric projection can be expressed as
$T_B (X_\nu) = N_\nu  \pi^* X_\mu.$
% \begin{align}
%     T_B (X_\nu) = N_\nu  \pi^* X_\mu.
% \end{align}
We extend this method to two datasets $Z_Q = \{X_Q, Y_Q \}, Z_P = \{X_P, Y_P \}$, where we have additional one-hot label data \fan{$Y_Q = (y_Q^{(1)}, \ldots, y_Q^{(N_Q)} ) \in \{0,1\}^{N_Q \times C_Q},
Y_P = (y_P^{(1)}, \ldots, y_{P}^{(N_P)}) \in
\{0,1\}^{N_P \times C_P }
$.} $C_Q$ and $C_P$ are the number of classes in dataset $Q$ and $P$.
% where the data $z=(x,y)$ includes the features $x$ and the labels $y$, 
We solve the optimal coupling $\pi^* \in \mR_{\ge 0}^{N_P \times N_Q}$ for OTDD~\eqref{eq:otdd} following the regularized scheme in ~\citet{alvarez2020geometric}.
% , and represent labels as one-hot vectors $y \in \mR^C $.
%To be compatible with the downstream task, we pre-process the label $y \in \mR^C $ to be one-hot vector, where $C$ is the number of classes.
The barycentric projection can then be written as:
\begin{align}\label{eq:bary_map}
  \cT_B (Z_Q) = [N_Q  \pi^* X_P , N_Q  \pi^* Y_P].
\end{align}
The visualization of barycentric projected data appears in Figure \ref{fig:bary_projection}.
\begin{figure}[t]
  \centering
  \begin{subfigure}{0.51\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/bary_projection.png}
  \end{subfigure}
  \caption{Visualization of OTDD barycentric projection on binary PCAM  dataset. \fan{
  We first solve the optimal coupling $\pi^* \in [0,1]^{N_Q \times N_P}$ for the problem \eqref{eq:otdd} using entropy regularization. Next, we map the $i$-th datapoint in the source dataset to a pair consisting of a weighted image and a weighted soft label. The weight vector, extracted from the $i$-th row of the coupling $\pi^*$, is then normalized to sum to 1. As a result, the mapped image (or soft label) is obtained as a convex combination of all the images (or one-hot labels) in the target dataset.
  % We first solve the optimal coupling $\pi^* \in [0,1]^{N_Q \times N_P}$ to the problem \eqref{eq:otdd} with a entropy regularization.
  % Then the $i$-th datapoint in the source dataset will be mapped to a pair of a weighted image and a weighted soft label. 
  % The weight vector $\in [0,1]^{N_P}$ is extracted from the $i$-th row in the coupling $\pi^*$, and normalized to sum to 1. The mapped image (or soft label) is thus a convex combination of all images (or one-hot labels) in the target dataset.
  }
  }
  \label{fig:bary_projection}
\end{figure}
However, this approach has two important limitations: it can not naturally map out-of-sample data and it does not scale well to large datasets (due to the quadratic dependency on sample size). \fan{To relieve the scaling issue, we will use batchified version of OTDD barycentric projection in this paper (see complexity discussion in \S\ref{sec:conclude}).}

\paragraph{OTDD neural map.} Inspired by recent approaches to estimate Monge maps using neural networks \citep{rout2022generative,fan2021scalable}, we design a similar framework for the OTDD setting. \citet{fan2021scalable} approach the Monge OT problem with general cost functions by solving its max-min dual problem $$\sup_f\inf_T  \int \left[ c( x ,T(x))-f(T(x)  )\right] \d\nu(x)  + \int f(x') \d\mu(x').$$
% $\sup_f\inf_T  \int \left[ c( x ,T(x))-f(T(x)  )\right] \d\nu(x)
%   + \int f(x') \d\mu(x')$.
We extend this method to the distributions involving labels by introducing an additional classifier in the map. Given two datasets $P,Q$, we parameterize the map $\cT_N: \mR^d \times [0,1]^{C_Q} \rightarrow  \mR^d \times [0,1]^{C_P}$ as
\begin{align}\label{eq:map}
  \cT_N(z) = \cT_N(x,y)  =[\bar x ; \bar y ]= [G(z) ; \ell(G(z))],
\end{align}
where $G:
  \mR^d \times [0,1]^{C_Q} \rightarrow \mR^d $ is the pushforward feature map, and the $\ell:  \mR^d \rightarrow [0,1]^{C_P} $ is a frozen classifier that is pre-trained on the dataset $P$.
%   Assume that the mapping 
Notice that, with the cost $c(z, \cT(z)) =  \| x- G( z)\|_2^2 + W_2^2(\alpha_{y}, \alpha_{\bar y} ) $, the Monge formulation of OTDD \eqref{eq:otdd} reads $\inf_{T\sharp Q = P }  \int \| x- G( z)\|_2^2 + W_2^2(\alpha_{y}, \alpha_{\bar y} ) \d Q(z).$
%   \[   \inf_{T\sharp Q = P }  \int \| x- G( z)\|_2^2 + W_2^2(\alpha_{y}, \alpha_{\bar y} ) \d Q(z). \]
We therefore propose to solve the max-min dual problem
\begin{align}\label{eq:max-min}
  %   \vspace{-0.3cm}
  \sup_f \inf_G  \int \left[ \| x- G( z)\|_2^2 + W_2^2(\alpha_{y}, \alpha_{\bar y} )\right] \d Q(z) \nonumber \\
  - \int f(\bar x , \bar y ) \d Q(z)
  + \int f(x' ,y') \d P(z').
  %   \vspace{-0.2cm}
\end{align}

% About $W_2$ distance, there are two confusing parts: 

% 1) Which classifier $\ell$ should we use? Choice 1: 
% % This should be a part of pushforward map $T(\cdot)$, so it should be generated by 
% a classifier trained on the fly. Choice 2: a pre-trained classifier on the target dataset. Choice 3: don't you also need large batch size?
% I think choice 2 is better.
% % if we can assume that given each feature in the target domain is associated with only one label.
% % One should also notice that in this case, $\ell$ only depends on $G(z)$. 
% But in this way, we cannot detach the features.

% 2) Given a pushforward label $\bar y$, which feature distribution $\alpha_{\bar y}$ should we use? Choice 1: feature distribution from target distribution (computed in advance), Choice 2: feature distribution from pushforward distribution (computed on the fly).  I think we can use both because they're equivalent in~\eqref{eq:otdd_monge} given the constraint $T \sharp \mu =\nu$.  
% Given this cost, \eqref{eq:otdd_monge} corresponds to the $W_2^2$ distance.

% The iterpretation is \eqref{eq:otdd_monge} tends to learn a map that is based on the normal Monge map, and an additional conditional distribution in target domain. I think we can still say something from Kantorovich duality.

% With embedding of $x,x'$, $d_x$ is still symmetric, triangular inequality, but $x \ne x'$ may still gives $d_x(emb(x),emb(x'))=0$.

% Option 1: $W_2^2(\alpha_{y}, \alpha_{\ell(G(z),y)}) $ or $W_2^2(\alpha_{y}, \alpha_{\ell(G(z))}) $ where $\ell(\cdot)$ is a trainable classifier.

% Option 2: $W_2^2(\alpha_{y}, \alpha_{\ell(G(z),y)}) $

\begin{figure}[t]
  \centering
  \begin{subfigure}{0.5\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/neural_map.png}
  \end{subfigure}
  \caption{Training paradigm for learning the OTDD neural map betweem two datasets (distributions), parametrized via a pushforward feature map $G$ and a labeling function $\ell$, using 
  % a discriminator (or dual potential)
  a dual potential
  $f$.
  }
  \label{fig:neural_map}
\end{figure}

Implementation details are provided in \S B.
% \ref{sec:neural_map}. 
Compared to previous conditional Monge map solvers~\citep{bunne2022supervised,asadulaev2022neural}, the two methods proposed here: (i) do not assume class overlap across datasets, allowing for maps between datasets with different label sets; (ii) are invariant to class permutation and re-labeling; (iii) do not force one-to-one class alignments, e.g., samples can be mapped across similar classes.
%We will revisit these three traits later with concrete examples.
% For example, if the source and target datasets are MNIST and EMNIST respectively, it is likely to map $0$ digit images to $C$ or $D$ character images.

\subsection{Convex combination in dataset space}\label{sec:comb}
% \vspace{-0.25cm}
% \fan{[TODO: add diagrams for otdd barycentric projection, neural map, and convex combination]}
\begin{figure}[ht!]
%   \vspace{-0.5cm}
  \centering
  \begin{subfigure}{0.51\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/convex_combination.png}
  \end{subfigure}
  \caption{
    % Our pipeline of generating a training dataset that is closest to the test dataset.
    In few-shot settings, we use pseudo-labels for the test dataset, generated e.g.~via kNN using the few-shot examples. 
    If more labeled data from the test dataset is available, we use it instead of the pseudo-labels. \fan{The projection dataset has the same number of samples as the test dataset.}
  }
  \label{fig:convex_comb}
\end{figure}

Computing generalized geodesics requires constructing convex combinations of datapoints from different datasets. Given a weight vector $a \in \Delta_{m-1}$, features can be naturally combined as $x_a = \sum_{i=1}^m a_i x_i $. But combining labels is not as simple because: (i) we allow for datasets with a different number of labels, so adding them directly is not possible; (ii) we do not assume different datasets have the same label sets, e.g. MNIST (digits) vs CIFAR10 (objects). Our solution is to represent all labels in the same dimensional space by padding them with zeros in all entries corresponding to other datasets. As an example, consider three datasets with $2,3$, and $4$ classes respectively. Given a label vector $y_1\in \R^3$ for the first one, we embed it into $\R^9$ as
$\tilde{y}_1 = [y_1; \mathbf{0}_3; \mathbf{0}_4]^\top.$
% $\tilde{y}_1 = [y_1; \mathbf{0}_3; \mathbf{0}_4]^\top$. 
Defining $\tilde{y}_2, \tilde{y}_3$ analogously, we compute their combination  as
% \begin{align}
%  y_a = a_1\tilde{y}_1 + a_2\tilde{y}_2 + a_3\tilde{y}_3.
% \end{align}
$y_a = a_1\tilde{y}_1 + a_2\tilde{y}_2 + a_3\tilde{y}_3$.
This representation is lossless and preserves the distinction of labels across datasets.
The visualization of our convex combination is in Figure \ref{fig:convex_comb}.
% A label vector for the first ones

% Therefore, we extend all soft labels $y_i$ to have the same length by padding zeros to "non-meaningful" output slots. One example could self-explain this idea. Assume $(x_i,y_i) \sim P_i,~i=1,2,3 ,$ and they contain 7, 3, 11 classes respectively, i.e. $y_1 \in \mR^7, y_2\in \mR^3, y_3\in \mR^{11}$. Then the combination of labels is
% \begin{align}
%   y_a = a_1 
%   \begin{bmatrix}
%      y_1 \\ \mathbf{0}_{3} \\\mathbf{0}_{11}
%   \end{bmatrix} 
%   + a_2
%   \begin{bmatrix}
%     \mathbf{0}_7 \\ y_2 \\ \mathbf{0}_{11}
%   \end{bmatrix} 
% + a_3
%   \begin{bmatrix}
%     \mathbf{0}_7 \\\mathbf{0}_3 \\ y_3
%   \end{bmatrix},
% \end{align}
% where $\mathbf{0}_b$ is a $b$-dimensional zero vector.
% This combination assumes that the labels from different datasets are all distinguished, and is capable of storing the full information of original labels.

\begin{figure*}[ht!]
  \centering
  \begin{subfigure}{0.25\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/projection.png}
    \caption{}
  \end{subfigure}
  \begin{subfigure}{0.65\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/mixup_vs_projection_caption.png}
    \caption{}
    % \vspace{-0.6cm}
  \end{subfigure}
  % \vspace{-0.3cm}
  \caption{\textbf{Visualization and comparison of dataset interpolation methods.} (a) The reference dataset $Q$ (with color-coded classes) is projected onto the generalized geodesic of the training datasets $P_i$, resulting in $\widehat{P}_a$. (b) 2D visualizations of (left-to-right): dataset $Q$, the `optimal' interpolated dataset $ \hat{P}_a := \left(\sum_{i=1}^m \hat a_i \cT^*_i \right)\sharp Q$ using the true OTDD maps $\cT_i^*$ , and a naively interpolated dataset $\left(\sum_{i=1}^m \hat a_i \cT_i \right)\sharp Q $ using randomly generated maps $\cT_i$.
  %(a) Approximated projection $\widehat{P}_a$ of 3D dataset $Q$ onto the generalized geodesic of datasets $\{P_i\}$.
    % uses mixup.
    % The datasets $Q$ and $P_i$ are the same as in Figure \ref{fig:proj}. 
  }
  \label{fig:proj_2d}
\end{figure*}
\subsection{Projection onto generalized geodesic of datasets}\label{sec:proj}
% \vspace{-0.25cm}


We now put together the components in Sec \ref{sec:map} and \ref{sec:comb} to construct generalized geodesics between datasets in two steps. First, we compute OTDD maps $\cT_i^*$ between $Q$ and all other datasets $P_i, i=1,\ldots, m$ using the discrete or neural OT approaches. Then, for any interpolation vector  $a \in \Delta_{m-1}$ we identify a dataset along the generalized geodesic via
\begin{align}
  P_a := \left(\sum_{i=1}^m a_i \cT^*_i \right)\sharp Q.
\end{align}
% $P_a := \left(\sum_{i=1}^m a_i \cT^*_i \right)\sharp Q$. 
By using the convex combination method in \S\ref{sec:comb} for labeled data, we can efficiently sample from $P_a$. 

Next, we find the dataset $P^*_a$ that minimizes the distance between $P_a$ and $Q$, i.e. the projection of $Q$ onto the generalized geodesic. We first approach this problem from a Euclidean viewpoint.
Suppose there are several distributions $\{\mu_i\}_{i=1}^m$ and an additional distribution $\nu$ on Euclidean space $\mR^d$. Lemma \ref{lem:convex_w2} guarantees
% the convexity along generalized geodesic $\rho^G_a$ promises 
there exists a unique parameter $a^*$ that minimizes $W_2^2(\rho_{a}^G, \nu)$. However, finding $a^*$ is far from trivial because there is no closed-form formula of the map $a \mapsto W_2^2(\rho_{a}^G, \nu)$ and it can be expensive to calculate $W_2^2(\rho_{a}^G, \nu)$ for all possible $a$. To solve this problem, we resort to another transport distance: the (2,$\nu$)-transport metric.

% \fan{[TODO: 1) add a plot for convexity of transport metric vs non-convexity of w distance. 2) elaborate why convexity is so important.]}
\begin{definition}[\citet{craig2016exponential}]
  Given distributions $\mu_i, \mu_j$, the (2,$\nu$)-transport metric between them is given by $$W_{2,\nu}(\mu_i,\mu_j) := \left( \int \|T_i^*(x) - T_j^*(x) \|_2^2 \d \nu (x) \right)^{1/2},$$ where $T_i^*$ is the optimal map from $\nu$ to $\mu_i$.
\end{definition}
When $\nu$ has a density with respect to Lebesgue measure $W_{2,\nu}$ is a valid metric~\citep[Prop. 1.15]{craig2016exponential}. Moreover, we can derive a closed-form formula for the map $a \mapsto W_{2, \nu }^2(\rho_{a}^G, \nu)$.
% where we postpone the proof to appendix.
% The following Proposition provides an surrogate to locate $a^*$.
% distribution from xx to test dataset. 
\begin{proposition}\label{prop:eq}
  % For transportation metric!! we have equation.
  $W_{2, \nu }^2(\rho^G_a, \nu )
    =  \sum_{i=1}^m a_i W_{2,\nu }^2(\mu_i, \nu ) - \frac{1}{2} \sum_{i \neq j} a_i a_j W_{2,\nu }^2(\mu_i, \mu_j ).
    % =  \sum_{i=1}^m a_i W_{2, \nu }^2(\mu_i, \nu ) - \frac{1}{2} \sum_{i \neq j} \alpha_i \alpha_j W_{2, \nu }^2(\mu_i, \mu_j ) .
  $
  % When $a^* = $
  % \begin{align}
  % W_{2, \nu }^2(\rho^G_a, \nu ) =  \sum_{i=1}^m 
  % \end{align}
\end{proposition}
This equation
% is a generalization of the Prop. 1.15 in \citet{craig2016exponential} and 
implies that given distributions $\{\mu_i\}, \nu$ in Euclidean space, we can trivially solve the optimal $a^*$ that minimizes $W_{2, \nu }^2(\rho^G_a, \nu ) $ by a quadratic programming solver\footnote{We use the implementation \url{https://github.com/stephane-caron/qpsolvers}}. The proof (\S A
% \ref{sec:proof}
) relies on Brenier's theorem.
% This motivates us to locate a "best suited" task 
Inspired by this, we also define a transport metric for datasets:
\begin{definition}\label{def:Q_ds_distance}
  % Denote $M \in \mR^{C_{P_i} \times C_{P_j}}$ as the label-to-label matrix where $M(i,j) := W_2^2(\alpha_{y_i}, \alpha_{y_j}) .$ The squared (2,$Q$)-dataset distance is given by $\cW^2_{2,Q}(P_i, P_j) := \int \left( \|\bar x_i - \bar x_j \|_2^2 + \bar y_i^\top M \bar y_j \right) \d Q $, where $\cT_i^*(z) = [\bar x_i; \bar y_i ] $ is the OTDD map from dataset $Q$ to $P_i$, $i=1,\ldots, m$.
  % Denote $M \in \mR^{C_{P_i} \times C_{P_j}}$ as the label-to-label matrix where $M(i,j) := W_2^2(\alpha_{y_i}, \alpha_{y_j}) .$ 
  The squared (2,$Q$)-dataset distance is  $$\cW^2_{2,Q}(P_i, P_j) := \int \left( \| x_i -  x_j \|_2^2 +
    W_2^2(\alpha_{y_i}, \alpha_{y_j})
    % \bar  y_i^\top M \bar y_j
    \right) \d Q(z), $$ where $ [ x_i;  y_i ] =\cT_i^*(z)$ and $\cT_i^*$ is the OTDD map from $Q$ to $P_i$.%., $i=1,\ldots, m$.
\end{definition}
Denote $\cP_{2,Q} (\cX \times \cP(\cX) ) $ as the set of all probability measures $P$ that satisfy $  d_\OT (P,Q) < \infty $ and the OTDD map from $Q$ to $P $ exists. The following result shows that (2,$Q$)-dataset distance is a proper distance. The proof is again deferred to \S A.
% \ref{sec:proof}.
\begin{proposition}\label{prop:metric}
  % Denote $\cP_{2,Q} = \{P: d_\OT (P,Q) < \infty \}$
  $\cW_{2,Q}$ is a valid metric on $\cP_{2,Q}(\cX \times \cP(\cX))$.
  % \vspace{-0.1cm}
\end{proposition}
Unfortunately, in this case $\cW^2_{2,Q}(P_a, Q)$ does not have an analytic form like before because Brenier's theorem may not hold for a general transport cost problem.
However, we still borrow this idea and define an approximated projection $\widehat{P}_a$ as the minimizer of function
\begin{align}\label{eq:ds_eq}
  % \vspace{-0.4cm}
   & \cW^2(P_a, Q): =  \nonumber                                                                       \\
   & \sum_{i=1}^m a_i \cW^2_{2,Q}(P_i, Q) - \frac{1}{2} \sum_{i \neq j} a_i a_j \cW^2_{2,Q}(P_i, P_j),
  % \vspace{-0.3cm}
\end{align}
which is an analog of Proposition \ref{prop:eq}. 
\fan{Since ${P}_a$ is defined by its interpolation weight $a$, solving $\widehat{P}_a$ is equivalent to finding a weight 
\begin{align}\label{eq:hat_a}
    \hat a = \argmin_{a \in \Delta_{m-1}} \cW^2(P_a, Q),
\end{align}
% $\hat a$ that minimizes \eqref{eq:ds_eq}, 
which is a simple quadratic programming problem.}
Unlike the Wasserstein distance, $\cW^2_{2,Q}(\cdot, \cdot)$ is easier to compute because it does not involve optimization, so it is relatively cheap to find the minimizer of $\cW^2(P_a, Q)$.
Experimentally, we observe that
% there is still have high correlation between
$W_{2, Q }^2(P_a, Q )$ is predictive of model transferability across tasks. Figure \ref{fig:proj_2d}(a) illustrates this projection on toy 3D datasets, color-coded by class.

% and the generalization accuracy in practice. \fan{[TODO: define the projection of datasets concretely, put the equation here.]}
% Therefore, we hope to
% borrow the idea of generalized geodesic in feature space and extend it to the product space of the feature and the label. 
% According to \eqref{eq:gen_geodesic}, the procedures of generating synthetic datasets on the generalized geodesic can be split into two steps: 1) solving the optimal map $T_i^*$ from $\nu$ to $\mu_i$. 2) get the convex combination of data from the pushforward measures $\{ T_i^* \sharp \nu \} $. However, the difficulties are two folds: 1) no good map solver, gradient flow is too time-consuming. 2) how to combine label?

% \begin{wrapfigure}[14]{r}{5cm}
%     % \begin{figure}[h]
%     \centering
%     \vspace{-1cm}
%     \begin{subfigure}{0.35\textwidth}
%         % \centering
%         \includegraphics[width=1\linewidth]{images/projection.png}
%         \vspace{-0.5cm}
%     \end{subfigure}
%     \caption{Approximated projection $\widehat{P}_a$ of 3D dataset $Q$ onto the generalized geodesic of datasets $\{P_i\}$.\fan{[change name]}}
%     \label{fig:proj}
% \end{wrapfigure}


\section{Experiments}
% \vspace{-0.2cm}



\subsection{Learning OTDD maps}
In this section, we visualize the quality of the learnt OTDD maps on both  synthetic and realistic datasets.
% \vspace{-0.3cm}
\paragraph{Synthetic datasets}
% \vspace{-0.25cm}
% \begin{wrapfigure}[7]{r}{8cm}
%     % \begin{figure}[h]
%     \centering
%     \vspace{-2cm}
%     \begin{subfigure}{0.6\textwidth}
%         % \centering
%         \includegraphics[width=1\linewidth]{images/mixup_vs_projection.png}
%         \vspace{-0.5cm}
%     \end{subfigure}
%     \caption{Left to right: the 2D projection of the datasets $Q$, $\left(\sum_{i=1}^m a^*_i \cT^*_i \right)\sharp Q$, $\left(\sum_{i=1}^m a^*_i \cT_i \right)\sharp Q $. The datasets $Q$ and $P_i$ are the same as in Figure \ref{fig:proj}.}
%     \label{fig:optimality}
% \end{wrapfigure}
% [done] 2D Gaussian mixture

% [done] Some quality result of digits? 


% One figure Compare the target value with discrete OT? 

% One figure OTDD between pushforward and target?

% The OTDD codebase calculates the 
% \begin{align}
%     \inf_\pi  \frac{1}{2} \int  \lambda_x \|x - x' \|^2 +  \lambda_y W_2^2(\alpha_y, \alpha_{y'}) d\pi(z,z')
% \end{align}
% Now I take $\lambda_x = 0.01, \lambda_y = 0.1$, and $x$ are extended to three channels, and normalized to $[-1,1]$.


% Give an example of generate more samples on target domain?? (no advantage..)
Figure \ref{fig:proj_2d} (b) illustrates the role of the optimal map in estimating the projection of a dataset into the generalized geodesic hull of three others. Using maps $\cT_i^*$ estimated via barycentric projection \eqref{eq:bary_map} results in better preservation of the four-mode class structure, whereas using non-optimal maps $\cT_i$ based on random couplings (as the usual \textit{mixup} does) destroys the class structure.

\paragraph{*NIST datasets}
\begin{figure}[h!]
  \centering
  \begin{subfigure}{0.48\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/EMNIST.png}
  \end{subfigure}
  \caption{Datasets generated by pushing forward $Q$ (the EMNIST dataset) towards Fashion-MNIST, MNIST, USPS, KMNIST, using OTDD maps $\mathcal{T}_i$, obtained using the neural OT method described in Section~\ref{sec:map}.}
  \label{fig:EMNIST}
\end{figure}
In Figure \ref{fig:EMNIST}, we provide qualitative results of OTDD map from EMNIST (letter)~\citep{cohen2017emnist} dataset to all other *NIST dataset and USPS dataset. At this point, we can confirm three traits of OTDD map, which are mentioned at the end of \S\ref{sec:map}.
% \begin{wrapfigure}[11]{r}{8.6cm}
%   % \begin{figure}[h]
%   \centering
% %   \vspace{-0.4cm}
%   \begin{subfigure}{0.5\textwidth}
%     \includegraphics[width=1\linewidth]{images/permute.png}
%     % \vspace{-0.5cm}
%   \end{subfigure}
%   \caption{The numbers above images are the labels. In the first labelling method, all 0 MNIST digits are assigned as class "0", and they are labelled as class "7" in the bottom labelling.}
%   \label{fig:labelling}
% \end{wrapfigure}

\begin{figure*}[ht!]
  \centering
  \begin{subfigure}{1\textwidth}
    \centering
    \includegraphics[width=1\linewidth]{images/ternary.png}
  \end{subfigure}
  % \vspace{-0.2cm}
  \caption{Relationship between the function $\cW^2(P_a, Q)$ and the accuracy of the fine-tuned model.
    % The test dataset is MNIST, EMNIST, USPS, FMNIST, KMNIST from left to right.
    The model trained on the projection dataset $\hat P_a$, i.e. the minimizer of $\cW^2(P_a, Q)$, tends to have a better generalization accuracy.
    The training datasets are marked on the vertexes of each ternary plot. Each ternary plot is an average of 5 runs with distinct random seeds.
    % The correlation of $\cW^2(P_a, Q)$ and the accuracy for each dataset is $-0.7849, -0.8267, -0.2709, 0.2939, -0.4090$ from left to right respectively. % We calculated this by only using the average of data, so for each dataset, the W2 would be a 36 dimension vector, so does accuracy. Then we stack them to be (2,36) matrix and use torch.correff to get the correlation.
  }
  \label{fig:ternary}
\end{figure*}

1) We don't assume a known source label to target label correspondence. So we can map between two irrelevant datasets such as EMNIST and FashionMNIST.
% As a result, two datasets can have different number of labels; 
2) The map is invariant to the permutation of label assignment. For example, we show two different labelling in Figure 1 in appendix
% \ref{fig:labelling}, 
and the final OTDD map will be the same.
3) It doesn't enforce the label-to-label mapping but would follow the feature similarity. From Figure \ref{fig:EMNIST}, we notice many cross-class mapping behaviors. For example, when the target domain is USPS~\citep{hull1994database} dataset, the lower-case letter "l" is always mapped to digit 1, and the capital letter "L" is mapped to other digits such as 6 or 0 because the map follows the feature similarity.


% \begin{figure}[h!]
%   \centering
%   \begin{subfigure}{0.48\textwidth}
%     \centering
%     \includegraphics[width=1\linewidth]{images/EMNIST.png}
%   \end{subfigure}
%   \caption{The dataset $Q$ is EMNIST (letters). We show all the datasets pushforwarded towards Fashion-MNIST, MNIST, USPS, KMNIST by OTDD map. The OTDD map is solved by neural OT method.}
%   \label{fig:EMNIST}
% \end{figure}

%    \begin{figure}[h]
%   \centering
% %   \vspace{-0.4cm}
%   \begin{subfigure}{0.35\textwidth}
%     \includegraphics[width=1\linewidth]{images/permute.png}
%     % \vspace{-0.5cm}
%   \end{subfigure}
%   \caption{The numbers above images are the labels. In the first labelling method, all 0 MNIST digits are assigned as class "0", and they are labelled as class "7" in the bottom labelling.}
%   \label{fig:labelling}
% \end{figure}

% $\cT^*$ is solved by barycentric projection \eqref{eq:bary_map} and $\cT$ is sub-optimal map. Both of them push the samples from $Q$ towards $P_i$.  It is easy to construct a sub-optimal map. We simply draw samples from $P_i$ and randomly assign them as the output of $\cT_i$. The last two plots in Figure \ref{fig:proj_2d} (b) show that a sub-optimal map could discard the information from dataset $Q$, and end up with a random "mixup" of samples from $P_i$. \fan{[explain mixup is not vanilla]} As a result, the optimality of maps $\cT_i^*$ is crucial for the projection dataset to inherit the characteristic of external dataset $Q$.


% We show the difference of final pushforward dataset caused by not using an optimal map. 
% an optimal map $\cT^*$ given by barycentric map \eqref{eq:bary_map} and a random map $\cT$. With the same external dataset $Q$ but different dataset maps, the  

% In Figure \ref{fig:EMNIST} of the appendix, we in addition provide qualitative results of OTDD map from EMNIST (letter)~\citep{cohen2017emnist} dataset to all other *NIST dataset and USPS dataset. At this point, we can confirm three traits of OTDD map, which are mentioned at the end of \S\ref{sec:map}. 
% \begin{wrapfigure}[11]{r}{7cm}
%     % \begin{figure}[h]
%     \centering
%     \vspace{-0.4cm}
%     \begin{subfigure}{0.5\textwidth}
%         % \centering
%         \includegraphics[width=1\linewidth]{images/permute.png}
%         \vspace{-0.5cm}
%     \end{subfigure}
%     \caption{The numbers above images are the labels. In the first labelling method, all 0 MNIST digits are assigned as class "0", and they are labelled as class "7" in the bottom labelling.}
%     \label{fig:labelling}
% \end{wrapfigure}
% 1) We don't assume a known source label to target label correspondence. So we can map between two irrelevent datasets such as EMNIST and FashinMNIST.
% % As a result, two datasets can have different number of labels; 
% 2) The map is invariant to the permutation of label assignment. For example, we show two different labelling in Figure \ref{fig:labelling}, and the final OTDD map will be the same. 
% 3) It doesn't enforce the label to label mapping but would follow the feature similarity. From Figure \ref{fig:EMNIST} in the appendix, we notice many cross-class mapping behaviors. For example, when the target domain is USPS~\citep{hull1994database} dataset, the lower-case letter "l" is always mapped to digit 1, and the capital letter "L" is mapped to other digits such as 6 or 0 because the map follows the feature similarity.

% For example, if the source and target datasets are MNIST and EMNIST respectively, it is likely to map $0$ digit images to $C$ or $D$ character images.
% Add a comparison between generalized geodesic interpolation and mixup in 2D, kind of ternary plot. We investigate the importance of the first step: optimal map.

% \subsubsection{Transfer learning}

% Show the feature space distance matrix?


% Train on EMNIST $\rightarrow$ FMNIST, we fine tune on FMNIST, USPS, MNIST, KMNIST in a few-shot manner and show the curve.
% \subsection{McCann's interpolation with OTDD map}

% Next, we will extend the interpolation between two datasets to multiple datasets.
% Application: Generate more samples on target domain

% [done] 2D Gaussian mixture

% digits: transfer learning task on EMNIST -> M-MNIST compare with barycentric mapping






% \vspace{-0.5cm}


\subsection{Transfer learning on *NIST datasets }\label{sec:nist}


Next, we use our framework to generate new pretraining datasets for 
% few-shot
transfer
learning. 
% Preliminary experiments show that if we have abundant training data for the test datasets, then
Preceding works illustrate that the transfer learning performance can be quite sensitive to the type of test datasets 
if there is abundant training data from the test task~\citep[Table 1]{zhai2019large}. Thus, we will focus on the few-shot setting, where we only have few labeled data from the test task.
We first show that the generalization ability of training models has a strong correlation with the distance $\cW^2_{2,Q}(P_a, Q) $. Then we compare our framework with several baseline methods.

\paragraph{Setup}
Given $m$ labeled pretraining datasets $\{P_i\}$, we consider a few-shot task in which only a limited amount of data from the target domain is labeled, e.g. 5 samples per class. The goal is to find a single dataset of size comparable to any individual $P_i$ that yields the best generalization to the target domain when pre-training a model on it and fine-tuning on the target few-shot data. Here, we seek this training dataset within those generated by generalized geodesics $\{P_a\}$, which can be understood as weighted interpolations of the training datasets $\{P_i\}$. Note this includes individual datasets as particular cases when $a$ is a one-hot vector. 


% along the generalized geodesic of data sets $\{P_i\}$, the projection of $Q$ can be 

% \fan{TODO: [use KNN for pushforward labels, streaming setup...]}

%\subsubsection{Connection to generalization ability}\label{sec:ternary}
% \vspace{-0.2cm}

% This result is using padding 1e-4 to all coupling in barycentric projection.
\begin{table*}[t]
  \caption{
  \textbf{Pretraining on synthetic data}. For each of the *NIST datasets, we treat it as the target domain and pretrain a neural net on a synthetic dataset generated as a combination of the remaining dataset with three interpolation methods. Here we show 5-shot transfer accuracy (mean $\pm$ s.d.~over 5 runs). The first baseline is to create a synthetic dataset as a training dataset by Mixup among datasets. For Mixup, we randomly sample data from each training dataset, and do the convex combination of them with weight $\hat a$ (see Eq. \eqref{eq:hat_a}). We use the same convex combination method in \S\ref{sec:comb}, thus this Mixup baseline is equivalent to our framework with suboptimal OTDD maps. The other two baselines (the bottom block) skip the transfer learning part, and directly train the model or solve 1-NN on the few-shot test dataset. 
  % \vspace{-0.5cm}
  }
  \begin{center}
    \begin{small}
      % \setlength\tabcolsep{2pt}
      \resizebox{\textwidth}{!}{
        \begin{tabular}{ccccccc}
          \toprule
          Methods                     & MNIST-M             & EMNIST              & MNIST               & FMNIST              & USPS                & KMNIST              \\
          \midrule
          OTDD barycentric projection & {\bf42.10$\pm$4.37} & {\bf67.06$\pm$2.55} & {\bf93.74$\pm$1.46} & {\bf70.12$\pm$3.02} & 86.01$\pm$1.50      & {\bf52.55$\pm$2.73} \\
          OTDD neural map             & 40.06$\pm$4.75      & 65.32$\pm$1.80      & 88.78$\pm$3.85      & 70.02$\pm$2.59      & 83.80$\pm$1.60      & 50.32$\pm$3.10      \\
          Mixup with weights $\hat a$                      & 33.85$\pm$2.22      & 60.95$\pm$1.38      & 88.68$\pm$1.57      & 66.74$\pm$3.79      & {\bf88.61$\pm$2.00} & 48.16$\pm$3.38      \\
          \midrule
          Train on few-shot dataset   & 19.10$\pm$3.57      & 53.60$\pm$1.18      & 72.80$\pm$3.10      & 60.50$\pm$3.07      & 80.73$\pm$2.07      & 41.67$\pm$2.11      \\
          1-NN  on few-shot dataset   & 20.95$\pm$1.39      & 39.70$\pm$0.57      & 64.50$\pm$3.32      & 60.92$\pm$2.42      & 73.64$\pm$2.35      & 40.18$\pm$3.09      \\
          \bottomrule
        \end{tabular}
      }
    \end{small}
  \end{center}
  %   \vskip -0.1in
  \label{tab:compare}
\end{table*}



% \subsubsection{Connection to generalization}
\paragraph{Connection to generalization}
The closed-form expression of $W_{2, \nu}^2 (\rho_a^G, \nu)$ (Prop.~\ref{prop:eq}) provides a distance between a base distribution $\nu$ and the distribution along generalized geodesic $\rho_a^G$ in Euclidean space.
% Since its expression is convex w.r.t. interpolation parameter $a$,
We study its analog \eqref{eq:ds_eq} for labeled datasets $Q$ and $\{P_i\}$ and visualize it in Figure \ref{fig:ternary} (first row).
% \begin{align}
% \cW^2(P_a, Q): = \sum_{i=1}^m a_i \cW^2_{2,Q}(P_i, Q) - \frac{1}{2} \sum_{i \neq j} a_i a_j \cW^2_{2,Q}(P_i, P_j)
% \end{align}
% where $\cW^2_{2,Q}(\cdot, \cdot)$ is the squared $(2,Q)$-dataset distance defined in Definition \ref{def:Q_ds_distance}. 
% Contrastive to the Wasserstein distance, $\cW^2_{2,Q}(\cdot, \cdot)$ is much easier to solve because it does not involve the optimization.
To investigate the generalization abilities of models trained on different datasets, we discretize the simplex $\Delta_2$ to obtain $36$ interpolation parameters $a$, and train a 5-layer LeNet  classifier on each $P_a$. Then we fine-tune all of these classifiers on the few-shot test dataset $Q$ with only 20 samples per each class. We control the same number of training iterations and fine-tuning iterations across all experiments.
The second row of Figure \ref{fig:ternary} shows fine-tuning accuracy. Comparing the first row and the second, we find the accuracy and $\cW^2(P_a, Q)$ are highly correlated. This implies that the model trained on the minimizer dataset of $\cW^2(P_a, Q)$ tends to have a better generalization ability. We fix the same colorbar range for all heatmaps across datasets to highlight the impact of training dataset choice. A more concrete visualization of the correlation between $\mathcal{W}^2(P_a, Q)$ and accuracy is shown in Figure 5 in appendix.
% \ref{fig:corr}.

For some test datasets, the choice of training dataset strongly affects the fine-tuning accuracy. For example, when $Q$ is EMNIST and the training dataset is FMNIST, the fine-tuning accuracy is only $\sim 60\%$, but this can be improved to $> 70\%$ by choosing an interpolated dataset closer to MNIST. This is reasonable because MNIST is more similar to EMNSIT than FMNIST or USPS. To some test datasets like FMNIST and KMNIST, this difference is not so obvious because all training datasets are all far away from the test dataset.
% , causing a range of $10\%$ accuracy gap.
% plot the same range.

% We show the range of accuracy can be $10 \%$ without choosing a good training dataset.

% \subsubsection{Comparison with baselines}
\paragraph{Comparison with baselines.} 



Next, we compare our method with several baseline methods on NIST datasets. In each set of experiments, we select one *NIST dataset as the target domain, and use the rest for pre-training. We consider a 5-shot task, so we \textbf{randomly} choose 5 samples per class to be the labeled data, and treat the remaining samples as unlabeled. Our method first trains a model on $\widehat{P}_a$, and fine-tunes the model on the 5-shot target data. To obtain $\widehat{P}_a$, we use barycentric projection or neural map to approximate the OTDD maps from the test to training datasets. Our results are shown in the first two rows in Table \ref{tab:compare}.
% The first baseline method is to create a synthetic dataset as a training dataset by Mixup among datasets. We randomly sample data from each training dataset, and do the convex combination of them with weight $\hat a$. We use the same convex combination method in \S\ref{sec:comb}, thus this baseline is equivalent to our framework with suboptimal OTDD maps. The other two baselines (the bottom block in Table \ref{tab:compare}) skip the transfer learning part, and directly train the model or solve 1-NN on the few-shot test dataset.
% Not use ternary plot anymore, just use all other datasets, and use only one set of simplex parameter to compare with others.
% From the first block, we can tell that training can bring additional knowledge from other domains, and improve the fine-tuning accuracy a lot even though training much less iterations on the test task.
% The transfer learning methods overall outperform the second block.
Overall, transfer learning can bring additional knowledge from other domains and improve the test accuracy by at most 21$\%$. Among the methods in the first block, training on datasets generated by OTDD barycentric projection outperforms others except USPS dataset, where the difference is only about 2.6$\%$.
% gives the best performance where mixup

% \begin{table}[H]
%   \caption{20shot lenet.}
%   \begin{center}
%     \begin{small}
%       \begin{tabular}{cccccc}
%         \toprule
%         Methods                     & MNIST               & USPS                & FMNIST              & KMNIST              & EMNIST              \\
%         \toprule
%         OTDD barycentric projection & {\bf93.81$\pm$0.18} & 86.60$\pm$0.52      & 71.81$\pm$1.70      & 56.51$\pm$1.94      & {\bf71.71$\pm$1.01} \\
%         OTDD neural map             & 90.32$\pm$1.30      & 85.92$\pm$0.34      & 71.77$\pm$1.85      & 54.59$\pm$0.98      & 70.34$\pm$0.93      \\
%         Mixup                       & 91.46$\pm$0.81      & 86.04$\pm$0.33      & {\bf71.87$\pm$1.48} & 56.05$\pm$1.05      & 69.43$\pm$0.49      \\
%         \midrule
%         1-NN  on 20-shot dataset    & 78.46$\pm$0.58      & {\bf88.44$\pm$0.86} & 67.80$\pm$1.05      & {\bf70.16$\pm$1.12} & 54.34$\pm$0.52      \\
%         Train on 20-shot dataset    & 85.13$\pm$1.12      & 83.79$\pm$0.65      & 69.08$\pm$3.24      & 52.87$\pm$2.13      & 64.21$\pm$0.78      \\
%         \bottomrule
%       \end{tabular}
%     \end{small}
%   \end{center}
%   \vskip -0.1in
%   % \label{tab:compare}
% \end{table}


% \begin{table}[H]
%   \caption{20shot spinalnet}
%   \begin{center}
%     \begin{small}
%       \begin{tabular}{cccccc}
%         \toprule
%         Methods                     & MNIST               & USPS                & FMNIST              & KMNIST              & EMNIST              \\
%         \toprule
%         OTDD barycentric projection & 93.46$\pm$1.64      & 91.14$\pm$0.74      & {\bf75.22$\pm$1.49} & 67.01$\pm$2.57      & 80.32$\pm$0.69      \\
%         OTDD neural map             & 90.87$\pm$4.21      & 91.74$\pm$0.90      & 74.42$\pm$1.23      & 62.26$\pm$2.62      & 80.48$\pm$1.92      \\
%         Mixup                       & {\bf93.83$\pm$0.85} & {\bf91.76$\pm$1.06} & 74.84$\pm$1.10      & 65.33$\pm$2.19      & 79.90$\pm$0.96      \\
%         \midrule
%         Train on few-shot dataset   & 93.21$\pm$0.75      & 90.36$\pm$2.45      & 74.07$\pm$0.87      & 63.19$\pm$0.95      & {\bf81.38$\pm$0.97} \\
%         1-NN  on few-shot dataset   & 78.46$\pm$0.58      & 88.44$\pm$0.86      & 67.80$\pm$1.05      & {\bf70.16$\pm$1.12} & 54.34$\pm$0.52      \\
%         \bottomrule
%       \end{tabular}
%     \end{small}
%   \end{center}
%   \vskip -0.1in
%   \label{tab:compare}
% \end{table}


% \begin{table}[H]
%   \caption{Fine-tuning accuracy results (in percent) for different 20-shot test datasets. The results are averaged over 5 runs with distinct random seeds.}
%   \begin{center}
%     \begin{small}
%       \begin{tabular}{cccccc}
%         \toprule
%         Methods                     & MNIST               & USPS                & FMNIST              & KMNIST              & EMNIST              \\
%         \toprule
%         OTDD barycentric projection & {\bf93.81$\pm$0.18} & 86.60$\pm$0.52      & 71.81$\pm$1.70      & 56.51$\pm$1.94      & {\bf71.71$\pm$1.01} \\
%         OTDD neural map             & 90.32$\pm$1.30      & 85.92$\pm$0.34      & 71.77$\pm$1.85      & 54.59$\pm$0.98      & 70.34$\pm$0.93      \\
%         Mixup                       & 91.46$\pm$0.81      & 86.04$\pm$0.33      & {\bf71.87$\pm$1.48} & 56.05$\pm$1.05      & 69.43$\pm$0.49      \\
%         \midrule
%         1-NN  on 20-shot dataset    & 78.46$\pm$0.58      & {\bf88.44$\pm$0.86} & 67.80$\pm$1.05      & {\bf70.16$\pm$1.12} & 54.34$\pm$0.52      \\
%         Train on 20-shot dataset    & 85.13$\pm$1.12      & 83.79$\pm$0.65      & 69.08$\pm$3.24      & 52.87$\pm$2.13      & 64.21$\pm$0.78      \\
%         \bottomrule
%       \end{tabular}
%     \end{small}
%   \end{center}
%   \vskip -0.1in
%   \label{tab:compare}
% \end{table}

% There maybe one case where OTDD map is better, the training dataset is on the boundary of limited, (enough to train a classifier, but relatively small.) Then our data comes in a stream, we have new test data.

% Camelyon


% It is promising direction to develop 
% So our framework may not be suitable when 
% This can make our framework expensive when there are multiple test datasets.
% For every new test dataset, we need to solve the OTDD map or projection to get the 


% \include{otdd_soft}
\subsection{Transfer learning on VTAB datasets}\label{sec:vtab}

Finally, we use our method for transfer learning with large-scale VTAB datasets~\citep{zhai2019large}. In particular, we take \href{https://www.robots.ox.ac.uk/~vgg/data/pets/}{Oxford-IIIT Pet dataset} as the target domain, and use \href{https://data.caltech.edu/records/mzrjq-6wc02}{Caltech101}, \href{https://www.robots.ox.ac.uk/~vgg/data/dtd/}{DTD}, and \href{https://www.robots.ox.ac.uk/~vgg/data/flowers/102/}{Flowers102} for pre-training. To encode a richer geometry in our interpolation, we embed the datasets using a masked auto encoder (MAE)~\citep{he2022masked} and learn the OTDD map in this ($\sim$200K dimensional) latent space.
% to measure feature difference $\|x-x'\|_2$ when calculating OTDD map. 
Since OTDD barycentric projection consistently works better than OTDD neural map (see Table \ref{tab:compare}), we only use barycentric projection 
% to solve OTDD map 
in this section.  We use ResNet-18 as the model architecture and pre-train the model on decoded MAE images (interpolated dataset) or original images (single dataset). 
\fan{Meanwhile, Mixup baseline is over pixel space and therefore
does not utilize embeddings at all.}

\begin{table}[ht]
\caption{\textbf{Transfer Learning on VTAB datasets}. The table shows relative improvement (w.r.t.~a no-transfer baseline) of test accuracy on \textsc{Oxford-IIIT Pet} (mean $\pm$ std over 5 runs) given only 1000 \fan{\text{randomly selected}} samples of this dataset to fine-tune. 
The first three rows show single-pretraining-dataset baselines, and the remaining rows show methods that pretrain on a synthetic interpolation of these three, generated using Mixup or our proposed OTDD Map, using uniform or $\hat a$ (see Eq.~\eqref{eq:hat_a}) dataset interpolation weights. 
The pooling baseline pretrains on a dataset including all the pre-training datasets.
To construct the sub-pooling pretraining dataset, for each training sample
 from the target dataset (\textsc{Pet}) we find its 10-nearest neighbors (in embedding space) from across all pretraining datasets, and label them as belonging to the class from the target domain.
}\label{tab:vtab}\centering
{\renewcommand{\arraystretch}{1.1}%
\begin{tabular}{cccc}
\toprule
Pre-Training & Map & Weights & Rel. Improv. ($\%$) \\ \midrule
\textsc{Caltech101}  & $-$ & $-$ & 59.68 $\pm$ 41.44 \\ 
\textsc{DTD} & $-$ & $-$ &-1.17 $\pm$ 9.52 \\ 
\textsc{Flowers102} & $-$ & $-$ & -2.45 $\pm$ 26.25 \\
Pooling & $-$ & $-$ & {28.96 $ \pm$ 18.29} \\
Sub-pooling & $-$ & $-$ & {3.00 $ \pm$ 19.10} \\
Interpolation & Mixup & uniform & 33.26 $\pm$ 21.30 \\ 
Interpolation & Mixup & $\hat a$ & 51.99  $\pm$ 34.10  \\ 
Interpolation & OTDD & uniform & 82.61 $\!\pm\!$ 25.93 \\ 
Interpolation & OTDD & $\hat a$ & \textbf{95.17$ \pm$ 20.57} \\
\bottomrule
\end{tabular}}
\end{table}

The pre-training interpolation dataset generated by our method has `optimal' mixture weights $a=(0.43,0.24, 0.33)$ for (\textsc{Caltech101}, \textsc{dtd}, \textsc{Flowers102}), suggesting a stronger similarity between the first of these and the target domain (\textsc{Pets}). This is consistent with the single-dataset transfer accuracies shown in Table \ref{tab:vtab}. However, their interpolation yields better transfer than any single dataset, particularly when using our full method (interpolating using OTDD map with optimal mixture weights). 

In Table \ref{tab:vtab}, we compute relative improvement per run, and then average these across runs; in other words, we compute the mean of ratios (MoR) rather than the ratio of means (RoM). Our reasoning for doing this was (i) controlling for the ‘hardness’ inherent to the randomly sampled subsets of \textsc{Pet} by relativizing before averaging and (ii) our observation that it is common practice to compute MoR when the denominator and numerator correspond to paired data (as is the case here), and the terms in the sum are sampled i.i.d. (again, satisfied in this case by the randomly sampled subsets of the target domain).

Table 2 shows a high deviation due to a particularly good result generated by the non-transfer learning baseline with seed 2, while other methods such as Caltech101 pretraining and Flowers102 pretraining had particularly bad results with the same seed.

% Even though our method does not outperform pre-training on Caltech101, our method can provide a reasonable baseline when there is no foreknowledge of each dataset's transfering ability.

% \begin{table}[H]
% \caption{Improvement of test accuracy (mean $\pm$ std over 5 runs in percent) of 1000-shot learning on Oxford-IIIT Pet compared to non-pretraining.
% % of 1000-shot learning on Oxford-IIIT Pet test dataset. TL is short for transfer learning.
% % Non-TL skips the pre-training step.
% % The model is initialized with ImageNet1K \href{https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html}{pretrained weights}.
% }\label{tab:vtab}
% \centering
% {\renewcommand{\arraystretch}{1.1}%
% % \begin{tabular}{|cc|c|}
% % \hline
% % \multicolumn{1}{|c|}{\multirow{7}{*}{TL}} & OTDD map ('optimal' weight) 
% %   & 95.17 $\pm$ 20.57 \\ \cline{2-3} 
% % \multicolumn{1}{|c|}{}                    & OTDD map (uniform weight)           & 82.61 $\pm$ 25.93 \\ \cline{2-3} 
% % \multicolumn{1}{|c|}{}                    &   Mixup ('optimal' weight)      & 51.99  $\pm$ 34.10 \\ \cline{2-3} 
% % \multicolumn{1}{|c|}{}                    & Mixup (uniform weight)         & 33.26 $\pm$ 21.30 \\ \cline{2-3} 
% % \multicolumn{1}{|c|}{}                    & Caltech101  & 59.68 $\pm$ 41.44 \\ \cline{2-3} 
% % \multicolumn{1}{|c|}{}                    &  DTD   &-1.17 $\pm$ 9.52 \\ \cline{2-3} 
% % \multicolumn{1}{|c|}{}                    & Flowers102    & -2.45 
% %  $\pm$ 26.25 \\ \hline
% % \multicolumn{2}{|c|}{Non-TL}                           & 0.0 \\ \hline
% % \end{tabular}


% \begin{tabular}{|c|c|}
% \hline
%  OTDD map ('optimal' weight) 
%   & 95.17 $\pm$ 20.57 \\ \hline
%    OTDD map (uniform weight)           & 82.61 $\pm$ 25.93 \\  \hline
%                        Mixup ('optimal' weight)      & 51.99  $\pm$ 34.10 \\ \hline 
%                      Mixup (uniform weight)         & 33.26 $\pm$ 21.30 \\ \hline 
%                      Caltech101  & 59.68 $\pm$ 41.44 \\ \hline 
%                       DTD   &-1.17 $\pm$ 9.52 \\ \hline 
%                      Flowers102    & -2.45 
%  $\pm$ 26.25 \\ \hline
% \end{tabular}
% }
% \end{table}



















\section{Conclusion and discussion}\label{sec:conclude}
The method we introduce in this work provides, as shown by our experimental results, a promising new approach to generate synthetic datasets as combinations of existing ones. Crucially, our method allows one to combine datasets even if their label sets are different, and is grounded on principled and well-understood concepts from optimal transport theory. Two key applications of this approach that we envision are: 
\begin{itemize}[leftmargin=*,noitemsep,topsep=0.5pt,parsep=0.5pt,partopsep=0.5pt]
  \item \textbf{Pretraining data enrichment}. Given a collection of classification datasets, generate additional interpolated datasets to increase diversity, with the aim of achieving better out-of-distribution generalization. This could be done even without knowledge of the specific target domain (as we do here) by selecting various datasets to play the role of the `reference' distribution.
  % for generalized geodesic.
  \item \textbf{On-demand optimized synthetic data generation}. Generate a synthetic dataset, by combining existing ones, that is `optimized' for transferring a model to a new (data-limited) target domain.
\end{itemize}



\paragraph{Complexity} 
% [\textbf{Our method}] 
The complexity of 
solving
OTDD \underline{barycentric projection} by Sinkhorn algorithm is $\cO(N^2 )$~\citep{dvurechensky2018computational}, where $N$ is the number of data in both datasets. This can be expensive for large-scale datasets. In practice, we solve the batched barycentric projection, i.e. take a batch from 
% source and target
both
datasets and solve the projection from source 
% batch
to target batch, and we normally fix batch size $B$ as $10^4$. This reduces the complexity from $\cO(N^2 )$ to $\cO(BN)$.
The complexity of solving \underline{OTDD neural map} is $\cO(B K H)$, where $K$ is number of iterations, and $H$ is the size of the network. We always choose $K = \cO(N)$ in the experiments.
The complexity of solving all the \underline{$(2,Q)$-dataset distances} in \eqref{eq:ds_eq} is $\cO(m^2N)$ since we need to solve the dataset distance between each pair of training datasets.
% The complexity of solving the quadratic programming is 
\underline{Putting these pieces together}, the complexity of approximating the interpolation parameter $\hat{a}$ for the minimizer of \eqref{eq:ds_eq} is  $\cO(N(B + m^2 ))$. 
% \fan{TODO: need to explain this mixup is not traditional mixup somewhere...}


\paragraph{Memory}  As the number of pre-training tasks ($m$) increases, our method, which generates an interpolated label by concatenating labels from all tasks, creates an increasingly sparse vector. Consequently, the memory demands of the classifier's output layer, which is proportional to $m$, could rise significantly.

\paragraph{Barycentric projection vs Neural map}
These two versions of our method offer complementary advantages. While estimating the OT map allows for easy out-of-sample mapping and continuous generation, the barycentric projection approach often yields better downstream performance (Table \ref{tab:compare}). We hypothesize this is due to the barycentric projection relying on (re-weighted) \textit{real} data, while the neural map \textit{generates} data which might be noisy or imperfect.


\paragraph{Pixel space vs feature space} We present results with OTDD mapping in both pixel space (\S \ref{sec:nist}) and feature space (\S \ref{sec:vtab}). For the VTAB datasets with regular-sized images (e.g. $256\times 256\times 3$), we found that the feature space is more appropriate for measuring data distance. For small-scale images like NIST, feature space may be overkill because most foundation models are trained on images with a larger size. In our preliminary experiments with NIST datasets, we attempted a feature space approach using an off-the-shelf ResNet-18 model. However, we encountered challenges in achieving convergence when training OTDD neural maps with PyTorch ResNet-18 features.

\paragraph{High variance issue}
Our method is not limited to the data scarcity regime, but indeed this is the most interesting one from the transfer learning perspective, which is why we assume limited labeled data (but potentially much more unlabeled data) from the target domain distribution. This is a typical few-shot learning scenario.
The quality of a learned OT map will likely depend on the number of samples used to fit it, and might suffer from high variance. To mitigate this in our setting, we opt for augmenting our 
 dataset by generating additional pseudo-labeled data via kNN (Fig. \ref{fig:convex_comb}). Recall that we do have access to more unlabeled data from the target domain, which is a common situation in practice.



\paragraph{Limitations}
% During the training of OTDD neural map, we notice that an overfitted classifier on the target domain can cause the instability during training \eqref{eq:max-min}. The loss can diverge earlier for a more overfitted classifier.
% We notice that it is easier to learn the map from diverse dataset to less diverse dataset, e.g. EMNIST $\rightarrow$ USPS.
Our method for generating a synthetic dataset relies on solving OTDD maps from the test dataset to each training dataset.
These OTDD maps are tailored to the considered test dataset and
can not be reused for a new test dataset. Another limitation is our framework is based on model training and fine-tuning pipeline. This can be resource-demanding for large-scale models, like GPT~\citep{brown2020language} or other similar models.
\fan{Finally, if at least one of the datasets is imbalanced, our OTDD map will struggle to match the class with similar marginal distributions.}




% \begin{contributions} % will be removed in pdf for initial submission 
%   % (without ‘accepted’ option in \documentclass)
%   % so you can already fill it to test with the
%   % ‘accepted’ class option
%   Briefly list author contributions.
%   This is a nice way of making clear who did what and to give proper credit.
%   This section is optional.

%   H.~Q.~Bovik conceived the idea and wrote the paper.
%   Coauthor One created the code.
%   Coauthor Two created the figures.
% \end{contributions}

\begin{acknowledgements} % will be removed in pdf for initial submission,
  % (without ‘accepted’ option in \documentclass)
  % so you can already fill it to test with the
  % ‘accepted’ class option
  % Briefly acknowledge people and organizations here.
We thank Yongxin Chen and Nicolò Fusi for their invaluable comments, ideas, and feedback.
We extend our gratitude to the anonymous reviewers for their useful feedback that significantly improved this work.
\end{acknowledgements}

% \newpage
% References
\bibliography{fan_236}

% \newpage
% \appendix
% \onecolumn

\end{document}
