TL;DR: We propose UltraTWD, a novel framework for accurate tree-Wasserstein distance computation using optimized ultrametric trees, significantly improving performance in text-based applications.
Abstract: The Wasserstein distance is a widely used metric for measuring differences between distributions, but its super-cubic time complexity introduces substantial computational burdens. To mitigate this, the tree-Wasserstein distance (TWD) offers a linear-time approximation by leveraging a tree structure; however, existing TWD methods often compromise accuracy due to suboptimal tree structures and edge weights. To address it, we introduce UltraTWD, a novel unsupervised framework that simultaneously optimizes both ultrametric tree structures and edge weights to more faithfully approximate the cost matrix. Specifically, we develop algorithms based on minimum spanning trees, iterative projection, and gradient descent to efficiently learn high-quality ultrametric trees. Empirical results across document retrieval, ranking, and classification tasks demonstrate that UltraTWD achieves superior approximation accuracy and competitive downstream performance. Code is available at: https://github.com/NeXAIS/UltraTWD.
Lay Summary: Understanding how similar two distributions are — like two documents — is a fundamental problem in machine learning. A powerful tool for this is the “Wasserstein distance,” which measures how much effort it takes to transform one distribution into another. However, calculating this exactly is often slow and impractical for large datasets.
Our method, called UltraTWD, offers a much faster way to approximate the Wasserstein distance on a tree. It does this by learning a special type of tree structure, called an “ultrametric tree,” that simplifies the calculation while preserving neighbor relationships between data points. Unlike earlier methods, UltraTWD can simultaneously learn both the tree structure and edge weights in a scalable and unsupervised way.
This makes our approach both accurate and computationally efficient, enabling its application to large-scale problems involving Wasserstein distances. By accelerating a core machine learning tool, UltraTWD paves the way for more practical and scalable AI systems.
Link To Code: https://github.com/NeXAIS/UltraTWD
Primary Area: General Machine Learning
Keywords: 1-Wasserstein Distance, Tree-Wasserstein Distance, Ultrametric Tree, Optimization
Submission Number: 7472
Loading