DS-RNNs: Conditional Computation in Recurrent Models via Input-Dependent Sparse Gating

27 Apr 2026 (modified: 10 May 2026)Under review for TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: Recurrent neural networks (RNNs) and state-space models (SSMs) typically execute the same dense computation for every input, coupling inference cost to parameter count and exacerbating interference in multi-task or heterogeneous regimes. We introduce Dynamic-Sparse RNNs (DS-RNNs), a framework for conditional computation via learnable, input-dependent sparse gating. In DS-RNNs, a small router network predicts sparse masks over input and hidden-state channels, effectively routing each input to a specialized sparse subnetwork. Unlike prior adaptive methods such as DeltaRNNs or static sparse training like RigL, our approach is fully learned and budget-controlled, and utilizes structured masking that translates into practical FLOP savings on commodity hardware by reducing sparse matrix multiplications to dense operations on active submatrices. Empirically, DS-RNNs maintain most of the dense model performance at 90\% sparsity (sometimes even exceeding it) across diverse architectures, including LSTMs, GRUs, LTCs, and S4 models. They match or outperform RigL and DeltaRNNs on most tasks, while remaining stable in regimes where these baselines fail. We further show that DS-RNNs naturally induce subnetwork specialization without explicit supervision: masks for different classes (single-task) or tasks (multi-task) exhibit low overlap and are predictive of class identity. Our theoretical analysis provides intuition for this behavior, linking it to a formal bound on gradient interference. As a result, DS-RNNs improve robustness: DS-LSTM scales better with task count in multi-task regression and consistently reduces forgetting across benchmarks when combined with standard continual learning methods (e.g., GEM, replay).
Submission Type: Long submission (more than 12 pages of main content)
Assigned Action Editor: ~Ofir_Lindenbaum1
Submission Number: 8642
Loading