- Keywords: recurrent neural networks, conditional computation, representation learning
- TL;DR: We show that conditionally computing individual dimensions of an RNN's hidden state depending on input data at each time step from scratch with no assumptions leads to higher accuracy with far fewer computations than state-of-the-art approach.
- Abstract: Recurrent Neural Networks (RNN) are the state-of-the-art approach to sequential learning. However, standard RNNs use the same amount of computation at each timestep, regardless of the input data. As a result, even for high-dimensional hidden states, all dimensions are updated at each timestep regardless of the recurrent memory cell. Reducing this rigid assumption could allow for models with large hidden states to perform inference more quickly. Intuitively, not all hidden state dimensions need to be recomputed from scratch at each timestep. Thus, recent methods have begun studying this problem by imposing mainly a priori-determined patterns for updating the state. In contrast, we now design a fully-learned approach, SA-RNN, that augments any RNN by predicting discrete update patterns at the fine granularity of independent hidden state dimensions through the parameterization of a distribution of update-likelihoods driven entirely by the input data. We achieve this without imposing assumptions on the structure of the update pattern. Better yet, our method adapts the update patterns online, allowing different dimensions to be updated conditional to the input. To learn which to update, the model solves a multi-objective optimization problem, maximizing accuracy while minimizing the number of updates based on a unified control. Using publicly-available datasets we demonstrate that our method consistently achieves higher accuracy with fewer updates compared to state-of-the-art alternatives. Additionally, our method can be directly applied to a wide variety of models containing RNN architectures.