Abstract: Despite fast progress, efficiently training large language models (LLMs) in extremely long contexts remains challenging.
Existing methods fall back to training LLMs with short contexts (up to a few thousand tokens) and use inference time techniques when evaluating on very long contexts (above 1M tokens).
Training on very long contexts is limited by GPU memory availability and the prohibitively long training times it requires on state-of-the-art hardware.
Meanwhile, many real-life applications require training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for extraction, summarization, or fact reconciliation tasks.
We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long contexts computationally tractable. At the core of our adjoint sharding algorithm lies the adjoint method, which efficiently computes gradients that are provably equivalent to the gradients computed using standard backpropagation.
We also propose truncated adjoint sharding to accelerate the algorithm while maintaining performance.
We provide a distributed and a parallel-computing version of adjoint sharding to speed up training and to show that adjoint sharding is compatible with these standard memory-reduction techniques.
Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3$\times$ on a large language model with 1.27B parameters on 1M context length training. This reduction in memory usage allows increasing the maximum context length of training a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
Submission Length: Regular submission (no more than 12 pages of main content)
Assigned Action Editor: ~Hugo_Touvron1
Submission Number: 4874
Loading