Keywords: representation learning, neural networks, deep learning
TL;DR: KLAS is a stitch selection algorithm that improves accuracy-efficiency curves by leveraging KL divergence to identify stitching points between pretrained models, without any additional training cost.
Abstract: Given the wide range of deployment targets, flexible model selection is essential for optimizing performance within a given compute budget.
Recent work demonstrates that stitching pretrained models within a model family enables cost-effective interpolation of the accuracy-efficiency tradeoff space.
Stitching transforms intermediate activations from one pretrained model into another, producing a new interpolated stitched network.
Such networks provide a pool of deployment options along the accuracy-efficiency spectrum.
However, existing stitching approaches often yield suboptimal tradeoffs and lack generalizability, as they primarily rely on heuristics to select stitch configurations.
We argue that constructing improved accuracy-efficiency tradeoffs requires explicitly capturing and leveraging the _similarity_ between pretrained models being stitched.
To this end, we introduce KLAS, a novel stitch selection framework that automates and generalizes stitch selection across model families by leveraging KL divergence between intermediate representations.
KLAS identifies the most promising stitches from the $\mathcal{O}(n^k)$ possibilities for $k$ pretrained models of depth $n$.
Through comprehensive experiments, we demonstrate that KLAS produces improved accuracy-efficiency curve of stitched models at the same cost as baselines.
KLAS achieves up to $1.21\%$ higher ImageNet-1K top-1 accuracy at the same computational cost, or maintains accuracy with a $1.33\times$ reduction in FLOPs.
Primary Area: unsupervised, self-supervised, semi-supervised, and supervised representation learning
Submission Number: 20851
Loading