Deep Mean Field Theory: Variance and Width Variation by Layer as Methods to Control Gradient Explosion
- Abstract: A recent line of work has studied the statistical properties of neural networks to great success from a {\it mean field theory} perspective, making and verifying very precise predictions of neural network behavior and test time performance. In this paper, we build upon these previous works to explore two methods for taming the behaviors of random residual networks (with only fully connected layers and no batchnorm). The first method is {\it width variation (WV)}, i.e. varying the widths of layers as a function of depth. We show that width decay reduces gradient explosion without affecting the mean forward dynamics of the random network. The second method is {\it variance variation (VV)}, i.e. changing the initialization variances of weights and biases over depth. We show VV, used appropriately, can reduce gradient explosion of tanh and ReLU resnets from $\exp(\Theta(\sqrt L))$ and $\exp(\Theta(L))$ respectively to constant $\Theta(1)$. A complete phase-diagram is derived for how variance decay affects different dynamics, such as those of gradient and activation norms. In particular, we show the existence of many phase transitions where these dynamics switch between exponential, polynomial, logarithmic, and even constant behaviors. Using the obtained mean field theory, we are able to track surprisingly well how VV at initialization time affects training and test time performance on MNIST after a set number of epochs: the level set of test/train set accuracies coincide with the level sets of certain gradient norms.
- TL;DR: By setting the width or the initialization variance of each layer differently, we can actually subdue gradient explosion problems in residual networks (with fully connected layers and no batchnorm). A mathematical theory is developed that not only tells you how to do it, but also surprisingly is able to predict, after you apply such tricks, how fast your network trains to achieve a certain test set performance. This is some black magic stuff, and it's called "Deep Mean Field Theory."
- Keywords: mean field, dynamics, residual network, variance variation, width variation, initialization