PLUMAGE: probablistic low-rank unbiased min variance gradient estimation framework for efficient large model training
Keywords: GradientEstimation, Low-rank, compression, optimization
TL;DR: A k-sparse unbiased min variance low-rank gradient estimator that fixes the projection-misalignment of low-rank moments. it reduce the loss faster than fixed top-k estimator used by Galore, while matching its memory/compute footprint
Abstract: Accelerator memory and network constraints are dominant bottlenecks when training large language models (LLMs) with billions of parameters. Low-rank gradient estimators have been successfully applied by methods such as Galore and Flora for LLM training on consumer hardware, by compressing gradients and optimizer tensors. However, the underlying gradient estimation methods are biased or subject to a high variance. Moreover, low-rank optimizer states, such as the first and second moments under the previous subspace, become misaligned whenever the projection is updated. This misalignment can lead to instabilities during training with low-rank gradients.
We propose Plumage: Probabilistic Low‑rank Unbiased Minimum‑vAriance Gradient Estimator.
Plumage can be applied as a drop-in replacement for low-rank LLM training without introducing new hyperparameters beyond the chosen rank $r$ and the update interval $\tau$. In addition, we resolve the misalignment of low-rank statistics.
We apply Plumage as a drop-in replacement for the fixed top-k components estimate used in Galore to observe how the gradient bias-variance tradeoff impacts the optimization of LLMs.
The resulting Plumage+Adam shrinks the full-rank optimization's gap in the pre-training evaluation loss by 33\% on average across models, with a 34\% improvement on the commonsense benchmark for the 1B models. In finetuning tasks, the average training loss gap across the GLUE benchmark is shrunk by 28\% --- without retuning the full-rank learning rate and within a similar computational and memory footprint as Galore.
Alternatively, in our 1B pretraining benchmark, Plumage+Adam surpassed the terminal loss of Galore within 30\% fewer steps.
Supplementary Material: zip
Primary Area: optimization
Submission Number: 2102
Loading