Keywords: function space distance, continual learning, transformers
TL;DR: We generalize a data-free estimator of function space distance to arbitrary architectures, including transformers, and demonstrate its successful application to continual learning.
Abstract: Measuring how neural network functions evolve during training, finetuning, or editing is critical for several applications. Such shifts can be formalized via a function space distance (FSD) — the expected squared difference in network outputs under a data distribution — but computing the true FSD requires dataset access that is often infeasible. The previously proposed Linearized Activation Function TRick (LAFTR) circumvents this challenge via specific approximations for linear networks with ReLU activations. We extend this to a more general LInearized Function TRick (LIFTR) to enable data-free FSD estimation for arbitrary architectures, with particular focus on transformers. Our approach decomposes FSD estimation into moment propagation using only pre-computed activation statistics of the data, resulting in a modular implementation that easily generalizes to arbitrary functions. On a modular arithmetic continual learning task, we show that a stochastic variant of LIFTR approaches oracle performance while outperforming parameter-space linearization baselines. LIFTR estimates correlate strongly with oracle FSD and produce better-aligned gradients than competing methods. We further demonstrate that LIFTR degrades more gracefully with network depth than global parameter-space linearization.
Submission Number: 57
Loading