Keywords: large language models, memory-efficient fine-tuning, block coordinate descent, importance sampling
Abstract: The substantial memory demands of pre-training and fine-tuning large language models (LLMs) require memory-efficient optimization algorithms. One promising approach is layer-wise optimization, which treats each transformer block as a single layer and optimizes it sequentially, while freezing the other layers to save optimizer states and activations. Although effective, these methods ignore the varying importance of the modules within each layer, leading to suboptimal performance. Moreover, layer-wise sampling provides only limited memory savings, as at least one full layer must remain active during optimization. To overcome these limitations, we propose **M**odule-wise **I**mportance **SA**mpling (**MISA**), a novel method that divides each layer into smaller modules and assigns importance scores to each module.
MISA uses a weighted random sampling mechanism to activate modules, provably reducing
gradient variance compared to layer-wise sampling.
Additionally, we establish an $\mathcal{O}(1/\sqrt{K})$ convergence rate under non-convex and stochastic conditions, where $K$ is the total number of training steps, and provide a detailed memory analysis showcasing MISA's superiority over existing baseline methods. Experiments on diverse learning tasks validate the effectiveness of MISA.
Supplementary Material: zip
Primary Area: Optimization (e.g., convex and non-convex, stochastic, robust)
Submission Number: 15661
Loading