Critical Initialization of Wide and Deep Neural Networks using Partial Jacobians: General Theory and Applications

Published: 21 Sept 2023, Last Modified: 14 Jan 2024NeurIPS 2023 spotlightEveryoneRevisionsBibTeX
Keywords: Criticality, Gaussian Process, Jacobian, LayerNorm, Residual connections, ResNet
TL;DR: (i) We introduce a new diagnostic for critical initialization in a wide class of deep neural networks. (ii) We show that a combination of Normalization layers and residual connections leads to everywhere-critical architectures.
Abstract: Deep neural networks are notorious for defying theoretical treatment. However, when the number of parameters in each layer tends to infinity, the network function is a Gaussian process (GP) and quantitatively predictive description is possible. Gaussian approximation allows one to formulate criteria for selecting hyperparameters, such as variances of weights and biases, as well as the learning rate. These criteria rely on the notion of criticality defined for deep neural networks. In this work we describe a new practical way to diagnose criticality. We introduce *partial Jacobians* of a network, defined as derivatives of preactivations in layer $l$ with respect to preactivations in layer $l_0\leq l$. We derive recurrence relations for the norms of partial Jacobians and utilize these relations to analyze criticality of deep fully connected neural networks with LayerNorm and/or residual connections. We derive and implement a simple and cheap numerical test that allows one to select optimal initialization for a broad class of deep neural networks; containing fully connected, convolutional and normalization layers. Using these tools we show quantitatively that proper stacking of the LayerNorm (applied to preactivations) and residual connections leads to an architecture that is critical for any initialization. Finally, we apply our methods to analyze ResNet and MLP-Mixer architectures; demonstrating the everywhere-critical regime.
Supplementary Material: zip
Submission Number: 5449