TL;DR: We propose a pure low-precision training framework for XMC models using BFLOAT16 and FP8, achieving significant GPU memory savings while competing baselines on most public datasets under low-bitwidth constraints.
Abstract: Large output spaces, also referred to as Extreme multilabel classification (XMC), is a setting that arises, e.g., in large-scale tagging and product-to-product recommendation, and is characterized by the number of labels ranging from hundreds of thousands to millions. This means that the linear classification head, usually only a tiny fraction of the overall model, turns into the main driver for compute and memory demand. Current state-of-the-art XMC methods predominantly rely on FP16-FP32 mixed-precision training, which we show can be unstable, and inefficient in terms of memory usage and computational overhead. Meanwhile, existing low-precision methods typically retain higher precision for the classification layer. In this work, we propose ELMO, a pure low-precision training framework for XMC models using BFloat16 and Float8 data types. By leveraging Kahan summation and stochastic rounding, we demonstrate that XMC models can be effectively trained entirely in Float8, without relying on single-precision master weights or tensor scaling. Low-precision training, combined with our proposed memory optimizations---gradient fusion and chunking---enables significant reductions in GPU memory usage. For example, we train a 3-million-label XMC model with only 6.6 GiB of GPU memory, compared to the 39.7GiB required by the optimized SOTA method, Renee without compromising accuracy.
Lay Summary: Modern recommendation systems, like those used to tag content or suggest related products, often need to choose from millions of possible tags/products. This entails storing and crunching through billions of numbers in the last layer of multi-layered deep learning systems, leading to a major computational and memory bottleneck. Current systems represent each of the billion parameters as a mix of 32 and 16-bit numbers to speed up calculations, but these are still memory-intensive, sometimes unstable, and still remain slow. We introduce a new method that uses a much coarser representation lower precision (16 and 8-bits) for representing these parameters leading to substantial memory savings and speeding up computation as it requires moving less data around.
To maintain both stability and accuracy at lower precision, we applied techniques that prevent small numbers from being lost when added to much larger ones (a method known as Kahan summation) and introduced randomness to reduce systematic bias in rounding errors (known as stochastic rounding). We also combine operations to reduce intermediate read/write steps and break large label computations into smaller, manageable parts (known as chunking). These optimizations help cut down memory use even further. These help cut down memory use even further. With our approach, we reduced memory use for a 3-million-label model from 39.7 GB to just 6.6 GB. This also made it possible to train models with up to 8.6 million labels on a regular consumer GPU RTX 4060 Ti, something previously thought impractical.
Link To Code: https://github.com/xmc-aalto/elmo
Primary Area: General Machine Learning->Scalable Algorithms
Keywords: Extreme Multi-Label Classification, Low-Precision Training, Peak Memory Optimization, FLOAT8 training
Submission Number: 15923
Loading