Mini-Sequence Transformers: Optimizing Intermediate Memory for Long Sequences Training

Published: 25 Sept 2024, Last Modified: 06 Nov 2024NeurIPS 2024 posterEveryoneRevisionsBibTeXCC0 1.0
Keywords: Long-Context, Foundation Models, Systems for ML, LLM Training, GPUs, Memory-efficient Training
TL;DR: We propose Mini Sequence to reduce intermediate memory overhead for long sequence training, with 12X longer than the standard implementation of LLaMA3-8b training on a single A100 device.
Abstract: We introduce Mini-Sequence Transformer (MsT), a simple and effective methodology for highly efficient and accurate LLM training with extremely long sequences. MsT partitions input sequences and iteratively processes mini-sequences to reduce intermediate memory usage. Integrated with activation recomputation, it enables significant memory savings in both forward and backward passes. In experiments with the Llama3-8B model, with MsT, we measure no degradation in throughput or convergence even with 12x longer sequences than standard implementations. MsT is fully general, implementation-agnostic, and requires minimal code changes to integrate with existing LLM training frameworks. Integrated with the huggingface library, MsT successfully extends the maximum context length of Qwen, Mistral, and Gemma-2 by 12-24x.
Primary Area: Infrastructure (libraries, improved implementation and scalability, distributed solutions)
Submission Number: 1055
Loading