Deep Mean Field Theory: Layerwise Variance and Width Variation as Methods to Control Gradient Explosion

Greg Yang, Sam S. Schoenholz

Feb 15, 2018 (modified: Feb 15, 2018) ICLR 2018 Conference Blind Submission readers: everyone Show Bibtex
  • Abstract: A recent line of work has studied the statistical properties of neural networks to great success from a {\it mean field theory} perspective, making and verifying very precise predictions of neural network behavior and test time performance. In this paper, we build upon these works to explore two methods for taming the behaviors of random residual networks (with only fully connected layers and no batchnorm). The first method is {\it width variation (WV)}, i.e. varying the widths of layers as a function of depth. We show that width decay reduces gradient explosion without affecting the mean forward dynamics of the random network. The second method is {\it variance variation (VV)}, i.e. changing the initialization variances of weights and biases over depth. We show VV, used appropriately, can reduce gradient explosion of tanh and ReLU resnets from $\exp(\Theta(\sqrt L))$ and $\exp(\Theta(L))$ respectively to constant $\Theta(1)$. A complete phase-diagram is derived for how variance decay affects different dynamics, such as those of gradient and activation norms. In particular, we show the existence of many phase transitions where these dynamics switch between exponential, polynomial, logarithmic, and even constant behaviors. Using the obtained mean field theory, we are able to track surprisingly well how VV at initialization time affects training and test time performance on MNIST after a set number of epochs: the level sets of test/train set accuracies coincide with the level sets of the expectations of certain gradient norms or of metric expressivity (as defined in \cite{yang_meanfield_2017}), a measure of expansion in a random neural network. Based on insights from past works in deep mean field theory and information geometry, we also provide a new perspective on the gradient explosion/vanishing problems: they lead to ill-conditioning of the Fisher information matrix, causing optimization troubles.
  • TL;DR: By setting the width or the initialization variance of each layer differently, we can actually subdue gradient explosion problems in residual networks (with fully connected layers and no batchnorm). A mathematical theory is developed that not only tells you how to do it, but also surprisingly is able to predict, after you apply such tricks, how fast your network trains to achieve a certain test set performance. This is some black magic stuff, and it's called "Deep Mean Field Theory."
  • Keywords: mean field, dynamics, residual network, variance variation, width variation, initialization