Keywords: signal propagation, transformers, pre-training, statistical physics
TL;DR: Our theoretical framework provides a full theory of signal propagation in deep transformers that unifies two phenomena previously discussed separately, namely entropy and rank collapse.
Abstract: Finding the right initialisation for neural networks is crucial to ensure smooth
training and good performance. In transformers, the wrong initialisation can
lead to one of two failure modes of self-attention layers: rank collapse, where
all tokens collapse into similar representations, and entropy collapse, where
highly concentrated attention scores lead to training instability. While previous work has studied different scaling regimes for transformers, an asymptotically exact, down-to-the constant prescription for how to initialise transformers has so
far been lacking. Here, we provide an analytical theory of signal propagation
through deep transformers with self-attention, layer
normalisation, skip connections and MLP. Our theory yields a simple algorithm to compute trainability diagrams
that identify the correct choice of initialisation hyper-parameters for a given
architecture. We overcome the key challenge, an exact treatment of the self-attention layer, by establishing a formal parallel with the Random Energy Model from statistical
physics.
We also analyse gradients in the backward path and determine the regime where gradients vanish at initialisation. We demonstrate the versatility of our framework through three case studies. Our theoretical framework gives a unified perspective on the
two failure modes of self-attention and gives quantitative predictions on the
scale of both weights and residual connections that guarantee smooth training.
Submission Number: 6
Loading