Keywords: Sparse Transformer, KV Cache Compression, Neural Architecture Search
TL;DR: We present Neural Attention Search (NAtS), an end-to-end learnable sparse transformer
Abstract: We present Neural Attention Search (NAtS), an end-to-end learnable sparse transformer that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. To this end, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens. (ii) Local Tokens survive until the next global token appears. (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache and the inference costs for the transformer-based models while maintaining the models' performance.
Supplementary Material: zip
Primary Area: Deep learning (e.g., architectures, generative models, optimization for deep networks, foundation models, LLMs)
Submission Number: 5940
Loading