Fast Finite Width Neural Tangent KernelDownload PDF

Published: 28 Jan 2022, Last Modified: 22 Oct 2023ICLR 2022 SubmittedReaders: Everyone
Keywords: Neural Tangent Kernel, NTK, Finite Width, Fast, Algorithm, JAX, Jacobian, Software
Abstract: The Neural Tangent Kernel (NTK), defined as the outer product of the neural network (NN) Jacobians, $\Theta_\theta(x_1, x_2) = \left[\partial f(\theta, x_1)\big/\partial \theta\right] \left[\partial f(\theta, x_2)\big/\partial \theta\right]^T$, has emerged as a central object of study in deep learning. In the infinite width limit, the NTK can sometimes be computed analytically and is useful for understanding training and generalization of NN architectures. At finite widths, the NTK is also used to better initialize NNs, compare the conditioning across models, perform architecture search, and do meta-learning. Unfortunately, the finite-width NTK is notoriously expensive to compute, which severely limits its practical utility. We perform the first in-depth analysis of the compute and memory requirements for NTK computation in finite width networks. Leveraging the structure of neural networks, we further propose two novel algorithms that change the exponent of the compute and memory requirements of the finite width NTK, dramatically improving efficiency. We open-source (https://github.com/iclr2022anon/fast_finite_width_ntk) our two algorithms as general-purpose JAX function transformations that apply to any differentiable computation (convolutions, attention, recurrence, etc.) and introduce no new hyper-parameters.
One-sentence Summary: We develop and open-source a new algorithm for fast computation of the finite width Neural Tangent Kernel, the outer product of Jacobians of a neural network.
Supplementary Material: zip
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 5 code implementations](https://www.catalyzex.com/paper/arxiv:2206.08720/code)
29 Replies

Loading