Supplementary Material: pdf
Track: Extended Abstract Track
Keywords: RNNs, AI interpretability, probabilistic representations, sequence learning
TL;DR: Linear RNNs can approximate Bayesian filters and learn representations of latent hidden states from sequences alone.
Abstract: Understanding the computational mechanisms by which neural networks perform probabilistic inference is a central problem in mechanistic interpretability and sequence modeling. Here, we use token sequences generated from Hidden Markov Models (HMMs) with known latent dynamics and analytically specified Bayes-optimal forward filters to provide a mechanistic understanding of probabilistic inference in recurrent neural networks. HMMs provide both a ground-truth generative process and an exact optimal solution, enabling precise comparison between model representations and optimal latent state inference. Surprisingly, purely linear recurrent networks with a softmax readout layer, either hand-engineered to approximate the optimal Bayesian filter or trained from data, achieve near-optimal prediction performance. Moreover, trained linear RNNs recover low-dimensional representations of latent state probabilities despite never being given direct access to those states. These findings suggest that linear recurrent architectures can serve as both effective and interpretable models for structured probabilistic sequence prediction.
Submission Number: 39
Loading