TL;DR: We propose SAFE, an optimization-based pruning method that improves generalization of sparse models by inducing flatness.
Abstract: Sparsifying neural networks often suffers from seemingly inevitable performance degradation, and it remains challenging to restore the original performance despite much recent progress.
Motivated by recent studies in robust optimization, we aim to tackle this problem by finding subnetworks that are both sparse and flat at the same time.
Specifically, we formulate pruning as a sparsity-constrained optimization problem where flatness is encouraged as an objective.
We solve it explicitly via an augmented Lagrange dual approach and extend it further by proposing a generalized projection operation, resulting in novel pruning methods called SAFE and its extension, SAFE$^+$.
Extensive evaluations on standard image classification and language modeling tasks reveal that SAFE consistently yields sparse networks with improved generalization performance, which compares competitively to well-established baselines.
In addition, SAFE demonstrates resilience to noisy data, making it well-suited for real-world conditions.
Lay Summary: Pruning aims to reduce the memory and computational load of neural networks by zeroing out parameters, but often at the cost of accuracy. The core challenge lies in improving performance during pruning—or better yet, designing methods that can progress towards optimal performance. Such objectives are best studied in optimization, and recent advancements have shed light on new insights and strategies.
Notably, recent research has emphasized the benefits of flat minima—regions in the loss landscape with low curvature.
Intuitively, it can be understood as finding a solution where the changes in the loss or shifts in the parameters are more tolerable.
This has not only shown to improve robustness and generalization in many prior studies, but it can potentially be highly desirable when the parameters are modified through pruning.
This motivates us to propose SAFE, a constrained optimization algorithm that aims to find better sparse solutions through enforcing flatness, where we employ various established optimization techniques to enforce flatness while gradually imposing sparsity. We further extend this into SAFE+, which allows for flexible support of diverse pruning scores within its constrained optimization framework.
Our results show that SAFE and SAFE+ successfully induce flatter and sparser minima, improving over existing baselines in both image classification and post-training pruning of language models, as well as improving robustness towards label noise, common image noises, and adversarial attacks.
Link To Code: https://github.com/LOG-postech/safe-torch, https://github.com/LOG-postech/safe-jax
Primary Area: Deep Learning->Algorithms
Keywords: Pruning, Constrained optimization, Sharpness minimization
Submission Number: 3619
Loading