Keywords: adaptive computation, thinking tokens, pretraining, architecture, transformers
TL;DR: We propose a novel architecture that enables unsupervised parallel adaptive computation by forking residual streams, trained only using LM loss.
Abstract: Current approaches for scaling inference-time compute in transformers rely on training them to emit explicit chain-of-thought tokens before producing an answer. While these methods are powerful, they are limited because they cannot be applied during pretraining and are limited to only serially-generated, natural-language verbalization to scale inference-time compute. In this work, we propose **Thoughtbubbles**, a transformer variant that natively performs parallel adaptive computation in latent space by learning to fork or delete residual streams. Thus, tokens that require a large amount of computation can form a "bubble" of cloned residuals in the middle of the network for additional computation. Crucially, this behavior is learned during pretraining with only language modeling loss. **Thoughtbubbles** outperforms both standard decoder LMs as well as non-adaptive parallel computation approaches on OpenWebText and peS2o perplexity and in zero-shot evaluations such as HellaSwag and LAMBADA after pretraining across 150M to 772M parameter scales. The implicit nature of our method enables adaptive computation to be learned starting at pretraining time, paving the way to unify train-time and test-time scaling behaviors.
Primary Area: unsupervised, self-supervised, semi-supervised, and supervised representation learning
Submission Number: 10324
Loading