Keywords: Numerical Linear Algebra, Transformers, Out-of-distribution, Generalization
TL;DR: We show that current transformer methods for linear algebra are simply learning statistics about the data
Abstract: Matrix operations, such as linear solves, eigendecompositions, and log determinants, are foundational building blocks for any number of downstream applications. Therefore, any broadly capable learning system should be able to effectively approximate these operations in its internal representation. Accordingly, there is great motivation to study transformers for linear algebra --- for if transformers cannot even semi-competently perform matrix operations, then we cannot expect them to form a basis for a generally intelligent system. We demonstrate that current techniques developing transformers for linear algebra have striking failure modes, prohibitive scaling, and particularly poor out-of-distribution generalization to other matrix distributions, and matrices of different sizes. Investigating further, we find that current transformer approaches operate as statistical interpolators, rather than discovering algorithms that will generalize to matrices from other distributions. Based on our understanding of these limitations, we develop a sequence of interventions that substantially improve scaling and performance, including matrix embeddings through a learnable projection, linear attention, looping, and a data pre-training distribution of structured matrices. We term the resulting method the \emph{RangeFormer}, which we show has significantly improved scaling and performance on challenging OOD matrices from the \emph{matrix market}. Moreover, with RangeFormer we show for the first time that transformers can be successfully applied to downstream tasks that involve iterative matrix operations, including Gaussian process learning, and improving the sampling distribution of randomized methods.
Primary Area: interpretability and explainable AI
Submission Number: 19637
Loading