Fast Finite Width Neural Tangent KernelDownload PDF

Published: 29 Jan 2022, Last Modified: 22 Oct 2023AABI 2022 PosterReaders: Everyone
Keywords: Neural Tangent Kernel, NTK, Finite Width, Fast, Algorithm, JAX, Jacobian, Software
TL;DR: We develop and open-source a new algorithm for fast computation of the finite width Neural Tangent Kernel, the outer product of Jacobians of a neural network.
Abstract: The Neural Tangent Kernel (NTK), defined as the outer product of the neural network (NN) Jacobians, $\Theta_\theta(x_1, x_2) = \left[\partial f(\theta, x_1)\big/\partial \theta\right] \left[\partial f(\theta, x_2)\big/\partial \theta\right]^T$, has emerged as a central object of study in deep learning. In the infinite width limit, the NTK provides a precise posterior distribution of the NN outputs after training, and can be used for uncertainty estimation and modelling deep ensembles. However, the infinite width NTK rarely admits a closed-form solution, and when it does, it is usually prohibitively expensive to compute exactly, and a finite width NTK can be used as a Monte Carlo approximation. The finite width NTK also models approximate inference in finite Bayesian NNs (BNNs), and has widespread applications in deep learning theory, meta-learning, neural architecture search, and many other areas. Unfortunately, the finite width NTK is also notoriously expensive to compute, which severely limits its practical utility. We perform the first in-depth analysis of the compute and memory requirements for NTK computation in finite width networks. Leveraging the structure of neural networks, we further propose two novel algorithms that change the exponent of the compute and memory requirements of the finite width NTK, improving efficiency by orders of magnitude in a wide range of practical models on all major hardware platforms. We open-source (https://github.com/iclr2022anon/fast_finite_width_ntk) our two algorithms as general-purpose JAX function transformations that apply to any differentiable computation (convolutions, attention, recurrence, etc.) and introduce no new hyper-parameters.
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 5 code implementations](https://www.catalyzex.com/paper/arxiv:2206.08720/code)
1 Reply

Loading