Differentiable Filtering for Learning Hidden Markov Models

Published: 17 Jun 2026, Last Modified: 07 May 2026PMLREveryoneCC BY 4.0
Abstract: Hidden Markov Models (HMMs) are fundamental for modeling sequential data, yet learning their parameters from observations remains challenging. Classical methods like the Baum-Welch algo- rithm are computationally intensive and prone to local optima, while modern spectral algorithms offer provable guarantees but may produce probability outputs outside valid ranges. This work introduces Belief Net, a differentiable filtering framework that learns HMM parameters by formu- lating the forward filter as a structured neural network and optimizing it with stochastic gradient descent. This architecture recursively updates the belief state, which represents the posterior proba- bility distribution over hidden states based on the observation history. Unlike black-box transformer models, Belief Net’s learnable weights are explicitly the logits of the initial distribution, transi- tion matrix, and emission matrix, ensuring full interpretability. The model processes observation sequences using a decoder-only (causal) architecture and is trained end-to-end with standard au- toregressive next-observation prediction loss. On synthetic HMM data, Belief Net achieves faster convergence than Baum-Welch while successfully recovering parameters in both undercomplete and overcomplete settings, whereas spectral methods prove ineffective in the latter. Comparisons with transformer-based models are also presented on real-world language data.
Loading