Keywords: representation learning, self-supervised learning, data-augmentation, learning dynamics, sample efficient SSL, compute efficient SSL
TL;DR: We characterize the implicit biases imposed by learning dynamics along with architecture and loss function for self-supervised representation learning, leading to practical recommendations for improving its sample and compute efficiency.
Abstract: Recent progress in self-supervised (SSL) visual representation learning has led to the development of several different proposed frameworks that rely on augmentations of images but use different loss functions.
However, there are few theoretically grounded principles to guide practice, so practical implementation of each SSL framework requires several heuristics to achieve competitive performance.
In this work, we build on recent analytical results to design practical recommendations for competitive and efficient SSL that are grounded in theory.
Specifically, recent theory tells us that existing SSL frameworks are actually minimizing the same idealized loss, which is to learn features that best match the data similarity kernel defined by the augmentations used.
We show how this idealized loss can be reformulated to a functionally equivalent loss that is more efficient to compute.
We study the implicit bias of using gradient descent to minimize our reformulated loss function, and find that using a stronger orthogonalization constraint with a reduced projector dimensionality should yield good representations.
Furthermore, the theory tells us that approximating the reformulated loss should be improved by increasing the number of augmentations, and as such using multiple augmentations should lead to improved convergence.
We empirically verify our findings on CIFAR, STL and Imagenet datasets, wherein we demonstrate an improved linear readout performance when training a ResNet-backbone using our theoretically grounded recommendations.
Remarkably, we also demonstrate that by leveraging these insights, we can reduce the pretraining dataset size by up to 2$\times$ while maintaining downstream accuracy simply by using more data augmentations.
Taken together, our work provides theoretically grounded recommendations that can be used to improve SSL convergence and efficiency.
Supplementary Material: zip
Primary Area: Machine vision
Submission Number: 19645