FiRST: Finetuning Router-Selective Transformers for Input-Adaptive Latency Reduction

28 Sept 2024 (modified: 25 Nov 2024)ICLR 2025 Conference Withdrawn SubmissionEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Input-Adaptive Layer Selection; Resource-Constrained Environments; Latency Reduction; Finetuning
TL;DR: We propose FiRST, an algorithm that reduces inference latency in Large Language Models by adaptively selecting transformer layers retaining performance and compatibility with KV-caching
Abstract: Auto-regressive Large Language Models (LLMs) demonstrate remarkable performance across domanins such as vision and language processing. However, due to sequential processing through a stack of transformer layers, autoregressive decoding faces significant computation/latency challenges, particularly in resource-constrained environments like mobile and edge devices. Existing approaches in literature that aim to improve latency via skipping layers have two distinct flavors - 1) Early exit 2) Input-agnostic heuristics where tokens exit at pre-determined layers irrespective of input sequence. Both the above strategies have limitations - the former cannot be applied to handle KV Caching necessary for speed-ups in modern framework and the latter does not capture the variation in layer importance across tasks or more generally, across input sequences. To address both limitations, we propose \textsc{FiRST}, an algorithm that reduces inference latency by using layer-specific routers to select a subset of transformer layers adaptively for each input sequence - the prompt (during prefill stage) decides which layers will be skipped during decoding. \textsc{FiRST} preserves compatibility with KV caching enabling faster inference while being quality-aware. \textsc{FiRST} is model-agnostic and can be easily enabled on any pre-trained LLM. We further improve performance by incorporating LoRA adapters for fine-tuning on external datasets, enhancing task-specific accuracy while maintaining latency benefits. Our approach reveals that input adaptivity is critical - indeed, different task-specific middle layers play a crucial role in evolving hidden representations depending on task. Extensive experiments show that \textsc{FiRST} significantly reduces latency while retaining competitive performance (as compared to baselines), making our approach an efficient solution for LLM deployment in low-resource environments.
Primary Area: generative models
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2025/AuthorGuide.
Reciprocal Reviewing: I understand the reciprocal reviewing requirement as described on https://iclr.cc/Conferences/2025/CallForPapers. If none of the authors are registered as a reviewer, it may result in a desk rejection at the discretion of the program chairs. To request an exception, please complete this form at https://forms.gle/Huojr6VjkFxiQsUp6.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors’ identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 13480
Loading