MINI-SEQUENCE TRANSFORMER: Optimizing Intermediate Memory for Long Sequences Training

Published: 18 Jun 2024, Last Modified: 16 Jul 2024LCFM 2024EveryoneRevisionsBibTeXCC BY 4.0
Keywords: long context training, memory optimization
TL;DR: We propose Mini Sequence to reduce intermediate memory overhead for long sequence training, with 12X longer than the standard implementation of LLaMA3-8b training on a single A100 device.
Abstract: We introduce MINI-SEQUENCE TRANSFORMER (MsT), a simple and effective methodology for highly efficient and accurate LLM training with extremely long sequences. MsT partitions input sequences and iteratively processes mini-sequences to reduce intermediate memory usage. Integrated with activation recomputation, it enables significant memory savings in both forward and backward passes. In experiments with the Llama3-8B model, with MsT, we measure no degradation in throughput or convergence even with 12x longer sequences than standard implementations due to our careful memory optimizations. MsT is fully general, implementation-agnostic, and requires minimal code changes to integrate with existing LLM training frameworks.
Submission Number: 6
Loading