Keywords: Optimal Transport, Wasserstein distance, Sliced Wasserstein distance, Regression
TL;DR: We propose a fast method for estimating Wasserstein distances across multiple pairs of distributions by formulating a regression problem, using variants of sliced Wasserstein distances as predictors.
Abstract: We address the problem of efficiently computing Wasserstein distances for multiple pairs of distributions drawn from a meta-distribution. To this end, we propose a fast estimation method based on regressing Wasserstein distance on sliced Wasserstein (SW) distances. Specifically, we leverage both standard SW distances, which provide lower bounds, and lifted SW distances, which provide upper bounds, as predictors of the true Wasserstein distance. To ensure parsimony, we introduce two linear models: an unconstrained model with a closed-form least-squares solution, and a constrained model that uses only half as many parameters. We show that accurate models can be learned from a small number of distribution pairs. Once estimated, the model can predict the Wasserstein distance for any pair of distributions via a linear combination of SW distances, making it highly efficient. Empirically, we validate our approach on diverse tasks, including Gaussian mixtures, point-cloud classification, and Wasserstein-space visualizations for 3D point clouds. Across various datasets such as MNIST point clouds, ShapeNetV2, MERFISH Cell Niches, and scRNA-seq, our method consistently provides a better approximation of Wasserstein than the state-of-the-art method, Wasserstein Wormhole, and classical methods, particularly in low-data regimes. To illustrate its robustness, we also experiment the method with intra- and inter-class settings. Finally, we demonstrate that \emph{RG} can accelerate Wasserstein Wormhole training, yielding \emph{RG-Wormhole}.
Primary Area: other topics in machine learning (i.e., none of the above)
Submission Number: 14821
Loading