Just read twice: closing the recall gap for recurrent language models

Published: 21 Jun 2024, Last Modified: 26 Jul 2024ES-FoMo-II 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: efficient architectures, in context learning, linear attention
TL;DR: Causal recurrent language models are brittle at in-context learning since they only store limited amount of information in memory during inference. We explore how going beyond causal language modeling can help.
Abstract: Recurrent large language models that compete with Transformers in language modeling perplexity are emerging at a rapid rate (e.g., Mamba, RWKV,). Excitingly, these architectures use a constant amount of memory during inference. However, due to the limited memory, recurrent LMs cannot recall all the information in long contexts leading to brittle in-context learning (ICL) quality. It is not clear how to improve recall quality while preserving efficiency. We observe the order in which information is shown to the LM impacts the selection difficulty. To formalize this, we show that the hardness of information recall reduces to the hardness of set disjointness (SD), a decades-old and canonical problem in communication complexity theory that requires a streaming algorithm (e.g., recurrent model) to decide whether inputted sets are disjoint. We use this connection to empirically and theoretically show that the recurrent memory requirement to solve SD changes with set order. Our analysis suggests, to mitigate the reliance on data order, we can put information in the right order in-context or process prompts non-causally. Towards that end, we first propose: (1) JRT-Prompt, where information is repeated multiple times in the prompt, showing the model all data orders. This gives $11.1 \pm 1.2$ points of improvement, averaged across the ICL tasks, with $11.9\times$ higher throughput than FlashAttention-2 (length $32\mathrm{k}$, batch size $16$). We then propose (2) JRT-RNN, which uses non-causal cross-linear-attention to process prompts and provides $13.7$ (at $360\mathrm{M}$ params., $30\mathrm{B}$ tokens) and $6.9$ (at $1.3\mathrm{B}$ params., $50\mathrm{B}$ tokens) point average improvements over the decoder baseline with $19.2\times$ higher throughput than FA2. Code is available at: https://github.com/HazyResearch/prefix-linear-attention
Submission Number: 85
Loading