On the geometry of generalization and memorization in deep neural networksDownload PDF

Published: 12 Jan 2021, Last Modified: 05 May 2023ICLR 2021 PosterReaders: Everyone
Keywords: deep learning theory, representation learning, statistical physics methods, double descent
Abstract: Understanding how large neural networks avoid memorizing training data is key to explaining their high generalization performance. To examine the structure of when and where memorization occurs in a deep network, we use a recently developed replica-based mean field theoretic geometric analysis method. We find that all layers preferentially learn from examples which share features, and link this behavior to generalization performance. Memorization predominately occurs in the deeper layers, due to decreasing object manifolds’ radius and dimension, whereas early layers are minimally affected. This predicts that generalization can be restored by reverting the final few layer weights to earlier epochs before significant memorization occurred, which is confirmed by the experiments. Additionally, by studying generalization under different model sizes, we reveal the connection between the double descent phenomenon and the underlying model geometry. Finally, analytical analysis shows that networks avoid memorization early in training because close to initialization, the gradient contribution from permuted examples are small. These findings provide quantitative evidence for the structure of memorization across layers of a deep neural network, the drivers for such structure, and its connection to manifold geometric properties.
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics
One-sentence Summary: We analyze the representational geometry of deep neural networks during generalization and memorization, and find the structure across layers of when and where memorization occurs, as well as drivers for this emerging structure.
Data: [CIFAR-100](https://paperswithcode.com/dataset/cifar-100), [ImageNet](https://paperswithcode.com/dataset/imagenet)
11 Replies

Loading