T-REX: Tied Recurrence Extraction
Keywords: optimization, looped transformers, recurrent neural networks, reasoning
TL;DR: T-REX optimizes transformers in an untied parameter space, then projects to exactly tied weights at inference, yielding recurrent models with better optimization and stronger out-of-distribution reasoning.
Abstract: Recurrent and looped transformers offer a promising inductive bias for reasoning by reusing the same computation across depth, but exact weight tying can make these models difficult to optimize. We introduce Tied Recurrence Extraction (T-REX), a simple training framework that bridges untied and tied transformers. T-REX trains layers with independent parameters and periodically projects them toward their shared parameter-wise mean, allowing optimization to benefit from untied flexibility while encouraging convergence to a recurrent computation. At inference time, the learned parameters are projected to their mean, yielding an exactly tied recurrent model. We show that parameter groups converge to their across-layer mean under learning-rate decay and illustrate how its temporary untying can help escape local minima that trap tied models. Empirically, T-REX improves out-of-distribution generalization on algorithmic reasoning tasks including soft distance propagation, maze solving, and in-context graph traversal, where it learns more stable iterative rules than both untied and tied-weight transformer baselines.
Email Sharing: We authorize the sharing of all author emails with Program Chairs.
Data Release: We authorize the release of our submission and author names to the public in the event of acceptance.
Submission Number: 55
Loading