TL;DR: We characterize the theoretical benefits of test-time training for in-context learning and corroborate its benefits for tabular foundation models.
Abstract: Test-time training (TTT) methods explicitly update the weights of a model to adapt to the specific test instance, and they have found success in a variety of settings, including most recently language modeling and reasoning. To demystify this success, we investigate a gradient-based TTT algorithm for in-context learning, where we train a transformer model on the in-context demonstrations provided in the test prompt. Specifically, we provide a comprehensive theoretical characterization of linear transformers when the update rule is a single gradient step. Our theory (i) delineates the role of alignment between pretraining distribution and target task, (ii) demystifies how TTT can alleviate distribution shift, and (iii) quantifies the sample complexity of TTT including how it can significantly reduce the eventual sample size required for in-context learning. As our empirical contribution, we study the benefits of TTT for TabPFN, a tabular foundation model. In line with our theory, we demonstrate that TTT significantly reduces the required sample size for tabular classification (3 to 5 times fewer) unlocking substantial inference efficiency with a negligible training cost.
Lay Summary: Modern machine learning models, like large language models, can perform well across many tasks but might still struggle when encountering unfamiliar ones. A technique called "in-context learning" addresses this by providing relevant examples (demonstrations) of a task to the model, helping models better understand the new task. However, we are looking for ways to further leverage these examples by quickly updating the model's parameters using the provided examples themselves. To do this, our work explores a method called test-time training (TTT), which involves quickly adapting the model through one learning step based on the examples of a new task given at test-time. We develop a theoretical framework showing how, even with just one learning step, TTT can enhance the model’s ability to learn from the given demonstrations. Adapting the model to a new task using TTT significantly reduces the number of demonstrations needed, which makes inference data-efficient by enabling the model to achieve high performance from fewer examples. We confirm these in a real-world tabular foundation model, designed for table-like data. Using TTT, the tabular model maintains high performance while needing three to five times fewer examples.
Primary Area: Theory->Deep Learning
Keywords: in-context learning, transformers, test-time training, gradient descent
Submission Number: 14794
Loading