Eureka-Moments in Transformers: Multi-Step Tasks Reveal Softmax Induced Optimization Problems

ICLR 2024 Workshop ME-FoMo Submission53 Authors

Published: 04 Mar 2024, Last Modified: 04 May 2024ME-FoMo 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: sudden convergence, transformer, transformers, grokking, in-context learning, Softmax, attention, temperature, multi-step task, multi-step decision task, two-step taks, gradient, phase transition, abrupt learning, rapid improvements, eureka-moment, eureka moment
TL;DR: We study saturated loss curves followed by rapid improvements of the loss of transformers in multi-step tasks. Similar leaps can be found for in-context learning. We find that the Attention leads to small gradients. Softmax normalization solves this.
Abstract: In this work, we study rapid improvements of the training loss in transformers when being confronted with multi-step decision tasks. We found that transformers struggle to learn the intermediate task and both, training and validation loss saturate for hundreds of epochs. When transformers finally learn the intermediate task, they do this rapidly and unexpectedly. We call these abrupt improvements Eureka-moments, since the transformer appears to suddenly learn a previously incomprehensible concept. We designed synthetic tasks to study the problem in detail, but the leaps in performance can be observed also for language modeling and in-context learning (ICL). We suspect that these abrupt transitions are caused by the multi-step nature of these tasks. Indeed, we find connections and show that ways to improve on multi-step tasks can be used to improve the training of language modeling and ICL. Using the synthetic data we trace the problem back to the Softmax function in the self-attention block of transformers and show ways to alleviate the problem. These fixes reduce the required number of training steps, lead to higher likelihood to learn the intermediate task and to higher final accuracy.
Submission Number: 53
Loading