Keywords: linear attention, structured pruning, dimension reduction, deltanet, associative memory
TL;DR: We showcase rank-collapse in the associative memory implemented by linear attention models and introduce a pruning framework that allows reducing the hidden state's size by 50% with minimal performance degradation.
Abstract: Linear attention offers a computationally efficient yet expressive alternative to softmax attention, maintaining a recurrent state that functions as a linear associative memory. However, recent empirical results indicate that the associative memory of trained linear attention models often exhibits a low-rank structure, suggesting that these models underexploit their capacity in practice. To illuminate this phenomenon, we provide a theoretical analysis of the role of rank in linear attention, revealing that low effective rank can affect retrieval error by amplifying query noise, as well as poorly condition query gradients. In addition to these theoretical insights, we conjecture that the low-rank states can be substantially reduced post-training with only minimal performance degradation, yielding faster and more efficient models. To this end, we propose a hardware-aware approach that structurally prunes key and query matrices, effectively reducing the state size. We adapt existing pruning strategies to fit our framework and, building on our theoretical analysis, propose a structured pruning method based on a rank-revealing QR decomposition. Our empirical results demonstrate the effectiveness of this associative memory reduction framework. We highlight that it enables the removal of 50\% of the query and key channels at only a minor increase in perplexity.
Submission Number: 11
Loading