Approximating 1-Wasserstein Distance with Trees

Published: 18 Sept 2022, Last Modified: 28 Feb 2023Accepted by TMLREveryoneRevisionsBibTeX
Abstract: The Wasserstein distance, which measures the discrepancy between distributions, shows efficacy in various types of natural language processing and computer vision applications. One of the challenges in estimating the Wasserstein distance is that it is computationally expensive and does not scale well for many distribution-comparison tasks. In this study, we aim to approximate the 1-Wasserstein distance by the tree-Wasserstein distance (TWD), where the TWD is a 1-Wasserstein distance with tree-based embedding that can be computed in linear time with respect to the number of nodes on a tree. More specifically, we propose a simple yet efficient L1-regularized approach for learning the weights of edges in a tree. To this end, we first demonstrate that the 1-Wasserstein approximation problem can be formulated as a distance approximation problem using the shortest path distance on a tree. We then show that the shortest path distance can be represented by a linear model and formulated as a Lasso-based regression problem. Owing to the convex formulation, we can efficiently obtain a globally optimal solution. We also propose a tree-sliced variant of these methods. Through experiments, we demonstrate that the TWD can accurately approximate the original 1-Wasserstein distance by using the weight estimation technique. Our code can be found in the GitHub repository.
Submission Length: Regular submission (no more than 12 pages of main content)
Changes Since Last Submission: Dear Action Editor, Thank you again for handling our paper. We are pleased about the decision. Based on the comments, we have updated the paper. > discuss the time complexity of all sub-routines. Stating that FISTA has an "efficient solver" is not clear. We added the explanation of FISTA and added the computational complexity of FISTA as follows. For weight estimation, we employed the fast iterative shrinkage thresholding algorithm (FISTA) \citep{beck2009fast}, which is the ISTA with Nesterov's accelerated gradient descent \citep{nesterov1983method}. FISTA converges to the optimal value as $O(1/k^2)$, where $k$ is the number of the iteration. For example, the dominant computation of FISTA on the Lasso problem $||\mathbf{y} - \mathbf{A}\mathbf{w}||_2^2 + \lambda ||\mathbf{w}||_1$ is the matrix-vector multiplication of $\mathbf{A}$ and $\mathbf{A}^\top$. In our case, the size of $\mathbf{A}$ is $|\Omega| \times N$, which includes $|\Omega|N$ multiplications and summations for each iteration. In practice, the computational cost of matrix-vector multiplication is small. > discuss the space complexity of the method and its competitors (e.g. Sinkhorn). We added the space complexity of the Sinkhorn algorithm. Moreover, we added the space complexity of the proposed method if we use the sparse matrix for storing the B matrix. For the Sinkhorn algorithm, it needs to store the $N_\text{leaf} \times N_\text{leaf}$ dimensional cost matrix. Tree methods need to store the $N \times N_\text{leaf}$ matrix, where $N$ ($N > N_\text{leaf}$) is the number of entire nodes. Thus, if we store the matrices in a dense format, the required memory size of the tree method is slightly larger than that of the Sinkhorn algorithm. For a sparse matrix, with ClusterTree, the sizes of $\mathbf{B}$ for the Twitter, BBCsport, and Amazon datasets are 94.4 KB, 226.2KB, and 681.1KB, respectively. > report error bars for metrics in Table 1 and 2 We added the standard deviation. During the revision, we found some values in Table 2 in the original manuscript were incorrect (Due to the mistake of copying the experiment results). So, we corrected some numbers in Table 2 (Table 4 in the camera-ready version). For example, QuadTree for Twitter data has been updated to 0.701 from 0.691, and the WMD (Sinkhorn) of the Amazon dataset is updated to 0.903 from 0.905. The conclusion does not change with the modification. However, since this is related to the modification of numbers in a table, we would appreciate it if the Action Editor checks the results. > increase font sizes of axis labels in plots (e.g. Fig3) We enlarged the font size of the figures. > draw tables without vertical bars and boxes (one can use the booktabspackage, see eg, [1]) Thank you for the information about booktabs. We updated all the tables with the booktabs package. In addition to the above modifications, we made the following edits. - We went through the paper again and updated the English. - We prepared the source code in the GitHub repository.
Assigned Action Editor: ~antonio_vergari2
License: Creative Commons Attribution 4.0 International (CC BY 4.0)
Submission Number: 207