Abstract: Despite fast progress, efficiently training large language models (LLMs) in extremely long contexts remains challenging.
Existing methods fall back to training LLMs with short contexts (up to a few thousand tokens) and use inference time techniques when evaluating on very long contexts (above 1M tokens).
Training on very long contexts is limited by GPU memory availability and the prohibitively long training times it requires on state-of-the-art hardware.
Meanwhile, many real-life applications require training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for extraction, summarization, or fact reconciliation tasks.
We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long contexts computationally tractable. At the core of our adjoint sharding algorithm lies the adjoint method, which efficiently computes gradients that are provably equivalent to the gradients computed using standard backpropagation.
We also propose truncated adjoint sharding to accelerate the algorithm while maintaining performance.
We provide a distributed and a parallel-computing version of adjoint sharding to speed up training and to show that adjoint sharding is compatible with these standard memory-reduction techniques.
Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3$\times$ on a large language model with 1.27B parameters on 1M context length training. This reduction in memory usage allows increasing the maximum context length of training a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
Submission Length: Regular submission (no more than 12 pages of main content)
Changes Since Last Submission: We thank the reviewers and action editor for their constructive feedback. Following the feedback, we made revisions to address three key concerns:
1. **Enhanced baseline comparison**: Added discussion of three memory reduction methods requested by reviewers (Mini-Sequence Transformers, Cut Your Losses, StreamBP) in the Related Works section, clarifying how adjoint sharding differs from and complements these approaches.
2. **Expanded architectural scope discussion**: Added detailed explanation in the Limitation section of why adjoint sharding applies to State Space Models but not Transformers, emphasizing that SSMs' Markovian recurrent structure enables VJP decomposition while Transformers' self-attention creates all-to-all dependencies that break the sequential structure the adjoint method currently requires.
3. **Truncation parameter guidance**: Added discussion on the choice of truncation parameter $\bar{T}$, noting that $\bar{T} = 1000$-$2000$ and $\bar{T}=\sqrt{T}$ provide good trade-offs for typical configurations and that performance degrades gracefully as $\bar{T}$ decreases.
4. **Improved positioning**: Revised conclusion to better position adjoint sharding as complementary to existing memory reduction techniques and clarify its applicability to recurrent architectures.
Assigned Action Editor: ~Hugo_Touvron1
Submission Number: 4874
Loading