Unsupervised Ground Metric Learning with Tree Wasserstein Distance

Published: 17 Jun 2024, Last Modified: 12 Jul 2024ICML 2024 Workshop GRaMEveryoneRevisionsBibTeXCC BY 4.0
Track: Extended abstract
Keywords: optimal transport, tree-embeddings, unsupervised learning, metric learning, distance-based learning
TL;DR: Unsupervised learning of ground metrics using tree Wasserstein distance provides a geometrically motivated, low-rank approximation of the full Wasserstein distance matrix that is computationally efficient
Abstract: Optimal transport (OT) is a powerful geometric machine learning tool for comparing distances between samples. Accurate OT distances rely on the underlying distance between dataset features, or ground metric. Ground metrics are commonly decided with heuristics or learned with supervised algorithms. However, since many interesting datasets are unlabelled, unsupervised ground metric learning approaches have recently been introduced. One promising option employs Wasserstein singular vectors (WSV), which emerge when computing OT distances between features and samples simultaneously. WSV is effective, but computationally expensive ($\mathcal{O}(n^5)$ complexity). Here, we propose to augment this method by embedding samples and features on trees, on which we compute the tree Wasserstein distance (TWD). We demonstrate theoretically and in practice that the algorithm converges to a better approximation of the full WSV approach than entropy regularisation, with faster (cubic) computational efficiency. In addition, we show that the initial tree structure can be chosen flexibly, since tree geometry does not constrain the solution up to the number of edge weights. These results poise unsupervised ground metric learning with TWD as a low-rank approximation of WSV with the potential for widespread low-compute application.
Submission Number: 34
Loading