Keywords: in-context learning, memory-augmented transformers, memformers, first-order methods, conjugate gradient descent, transformers
TL;DR: Memory-augmented transformers can implement linear first-order optimization methods (LFOMs)—including advanced algorithms like conjugate gradient descent—by leveraging memory registers to store past gradients.
Abstract: We show that memory-augmented Transformers (Memformers) can implement linear first-order optimization methods such as conjugate gradient descent, momentum methods, and more generally, methods that linearly combines past gradients. Building on prior work that demonstrates how Transformers can simulate preconditioned gradient descent, we provide theoretical and empirical evidence that Memformers can learn more advanced optimization algorithms. Specifically, we analyze how memory registers in Memformers store suitable intermediate attention values allowing them to implement algorithms such as conjugate gradient. Our results show that Memformers can efficiently learn these methods by training on random linear regression tasks, even learning methods that outperform conjugate gradient. This work extends our knowledge about the algorithmic capabilities of Transformers, showing how they can learn complex optimization methods.
Primary Area: optimization
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2025/AuthorGuide.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors’ identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 11444
Loading