Low-rank Linearization of Large Language Models

Published: 21 Jun 2024, Last Modified: 26 Jul 2024ES-FoMo-II 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: linear attention, large language models, knowledge distillation
TL;DR: We propose a method to get state-of-the-art subquadratic LLMs by converting LLMs into linear versions with linear attentions trained to match softmax attention
Abstract: Recent subquadratic architectures show exciting progress, now rivaling Transformers in various quality metrics. However, scaling these models up to modern large language model (LLM) sizes can be prohibitively costly. We thus study how to efficiently *linearize* existing LLMs---swapping their attentions with fast analogs, before only adapting the fast analog weights---to quickly create high-quality linear-time and constant memory LLMs. Our approach, Low-rank Linear Conversion via Attention Transfer (lolcat), is a simple two-step method that (1) *learns* expressive yet efficient attention analogs by training linear attentions to match the outputs of LLM softmax attentions, before (2) replacing the attentions and adjusting for this swap with only *low-rank adaptation* (LoRA). By linearizing Llama 3 8B and Mistral-7B, lolcat produces strong linear attention LLMs that outperform state-of-the-art non-Transformer 7B LLMs, achieving 1.8-4.7 higher points on popular LM Evaluation Harness (LM Eval) tasks, while only training 0.4% of their model parameters with 0.003\% of their training tokens. lolcat-linearized LLMs further achieve 2.5 - 4$\times$ the throughput of original FlashAttention-2 LLMs, while only increasing memory 1.1$\times$ when scaling generation length 512$\times$ from 512 to 131K tokens. Finally, lolcat significantly improves linearization quality and training efficiency, leading to 5.0 higher LM Eval points than concurrent methods with only 0.2\% of their training tokens.
Submission Number: 83
Loading