Efficient Autoregressive Inference for Transformer Probabilistic Models

ICLR 2026 Conference Submission13982 Authors

18 Sept 2025 (modified: 08 Oct 2025)ICLR 2026 Conference SubmissionEveryoneRevisionsBibTeXCC BY 4.0
Keywords: probabilistic machine learning, neural processes, probabilistic meta-learning, amortized inference
TL;DR: We accelerate autoregressive inference of transformer probabilistic models such as prior-fitted networks and transformer neural processes.
Abstract: Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications -- from signal interpolation to multi-column tabular predictions -- require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to $20\times$ faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.
Supplementary Material: zip
Primary Area: probabilistic methods (Bayesian methods, variational inference, sampling, UQ, etc.)
Submission Number: 13982
Loading