Navigating Extremes: Dynamic Sparsity in Large Output Spaces

Published: 25 Sept 2024, Last Modified: 06 Nov 2024NeurIPS 2024 posterEveryoneRevisionsBibTeXCC BY-NC 4.0
Keywords: Dynamic sparse training, extreme classification, memory efficient training, large output spaces, scalable machine learning
TL;DR: Investigates Dynamic Sparse Training for large output spaces. Leveraging semi-structured sparsity, intermediate layers, and auxiliary loss, it enables end-to-end training with millions of labels on commodity hardware with near-dense performance.
Abstract: In recent years, Dynamic Sparse Training (DST) has emerged as an alternative to post-training pruning for generating efficient models. In principle, DST allows for a much more memory efficient training process, as it maintains sparsity throughout the entire training run. However, current DST implementations fail to capitalize on this. Because sparse matrix multiplication is much less efficient than dense matrix multiplication on GPUs, most implementations simulate sparsity by masking weights. In this paper, we leverage recent advances in semi-structured sparse training to apply DST in the domain of classification with large output spaces, where memory-efficiency is paramount. With a label space of possibly millions of candidates, the classification layer alone will consume several gigabytes of memory. Switching from a dense to a fixed fan-in sparse layer updated with sparse evolutionary training (SET); however, severely hampers training convergence, especially at the largest label spaces. We find that the gradients fed back from the classifier into the text encoder make it much more difficult to learn good input representations, despite using a dense encoder. By employing an intermediate layer or adding an auxiliary training objective, we recover most of the generalisation performance of the dense model. Overall, we demonstrate the applicability of DST in a challenging domain, characterized by a highly skewed label distribution, that lies outside of DST's typical benchmark datasets, and enable end-to-end training with millions of labels on commodity hardware.
Primary Area: Optimization for deep networks
Submission Number: 19358
Loading