Addax: Utilizing Zeroth-Order Gradients to Improve Memory Efficiency and Performance of SGD for Fine-Tuning Language Models

Published: 10 Oct 2024, Last Modified: 28 Nov 2024FITML 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Large Language Models, memory efficient Fine-tuning, Zeroth order optimization
Abstract: Fine-tuning language models (LMs) with the standard Adam optimizer often demands excessive memory, limiting accessibility. The ``in-place'' version of Stochastic Gradient Descent (IP-SGD) and Memory-Efficient Zeroth-order Optimizer (MeZO) have been proposed as solutions to improve memory efficiency. However, IP-SGD still requires a substantial amount of memory, and MeZO suffers from slow convergence and degraded final performance due to its zeroth-order nature. This paper introduces {\em Addax}, a novel method that improves both memory efficiency and algorithm performance of IP-SGD by integrating it with MeZO. Specifically, Addax computes the zeroth- or first-order gradient of the data points in the minibatch based on their memory consumption and combines zeroth- and first-order gradient estimates to obtain the updated direction in each step. By computing the zeroth-order gradient of data points that require more memory and the first-order gradient of the ones that require less memory, Addax overcomes the slow convergence of MeZO and the excessive memory requirement of IP-SGD. Additionally, the zeroth-order gradient acts as a regularizer for the first-order gradient, further enhancing the model's final performance. Theoretically, we establish the convergence of Addax under mild assumptions, demonstrating faster convergence and less restrictive hyper-parameter choices than MeZO. Our extensive experiments with diverse LMs and tasks show that Addax consistently outperforms MeZO in terms of accuracy and convergence speed while having a comparable memory footprint. In particular, our experiments using one A100 GPU on the OPT-13B model reveal that, on average, Addax outperforms MeZO in terms of accuracy/F1 score by $14\%$ and runs $15\times$ faster while having a comparable memory footprint to MeZO. In our experiments on the larger OPT-30B model, on average, Addax outperforms MeZO in terms of accuracy/F1 score by $>16\%$ and runs $30\times$ faster on a single H100 GPU. Moreover, Addax surpasses the performance of standard fine-tuning approaches, such as IP-SGD and Adam, in most tasks in terms of accuracy/F1 score with significantly less memory requirement.
Submission Number: 32
Loading