Weighted Grouped Query Attention in Transformers

ACL ARR 2024 June Submission3182 Authors

15 Jun 2024 (modified: 04 Jul 2024)ACL ARR 2024 June SubmissionEveryoneRevisionsBibTeXCC BY 4.0
Abstract: The attention mechanism forms the foundational blocks for transformer language models. Recent approaches show that scaling the model achieves human-level performance. However, with increasing demands for scaling and constraints on hardware memory, the inference costs of these models remain high. To reduce the inference time, Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) were introduced. In this paper, we propose a variation of Grouped-Query Attention, termed Weighted Grouped-Query Attention (WGQA). We introduced new learnable parameters for each key and value head in the T5 decoder attention blocks, enabling the model to take a weighted average during finetuning. Our model achieves an average of 0.53% improvement over GQA, and the performance converges to traditional Multi-head attention (MHA) with no additional overhead during inference. We evaluated the introduction of these parameters and subsequent finetuning informs the model about the grouping mechanism during training, thereby enhancing performance. Additionally, we demonstrate the scaling laws in our analysis by comparing the results between T5-small and T5-base architecture.
Paper Type: Short
Research Area: Machine Learning for NLP
Research Area Keywords: Natural Language Processing, Transformers, Memory Efficiency, Grouped Query Attention, Large Language Models, Computational Linguistics
Contribution Types: NLP engineering experiment, Approaches low compute settings-efficiency, Publicly available software and/or pre-trained models
Languages Studied: English
Submission Number: 3182