Retentive Network

12 May 2024 (modified: 06 Nov 2024)Submitted to NeurIPS 2024EveryoneRevisionsBibTeXCC BY 4.0
Keywords: Retentive Network, Model Architecture
TL;DR: RetNet is a foundation architecture, simultaneously achieving training parallelism, low-cost inference, and good performance.
Abstract: In this work, we propose Retentive Network (RetNet) as a foundation architecture for large language models, simultaneously achieving training parallelism, low-cost inference, and good performance. We theoretically derive the connection between recurrence and attention. Then we propose the retention mechanism for sequence modeling, which supports three computation paradigms, i.e., parallel, recurrent, and chunkwise recurrent. Specifically, the parallel representation allows for training parallelism. The recurrent representation enables low-cost inference, which improves decoding throughput, latency, and GPU memory without sacrificing performance. The chunkwise recurrent representation facilitates efficient long-sequence modeling with linear complexity, where each chunk is encoded parallelly while recurrently summarizing the chunks. Experimental results on language modeling show that RetNet achieves favorable scaling results, parallel training, low-cost deployment, and efficient inference.
Primary Area: Deep learning architectures
Submission Number: 5128
Loading