Sample Submission
01 Dec 2021 | Distributionally Robust Neural Networks Regularization Robustness Empirical Risk MinimizationDistributionally Robust Neural Networks for Group Shifts: On The Importance of Regularization for Worst-case Generalization
Content:
- Problem definition
- Dataset
- ERM vs DRO
- Possible solution in the literature (DRO)
- ERM and DRO with Heavy regularization
- ERM and DRO with group adjustments
- The proposed Optimization Algorithm
- Compare to existing algorithms (importance weighting)
- Takeaway.
Problem definition:
While state-of-the-art deep learning models such as (ResNet, Bert, ..etc) are widely used today because of their high performance; on average; these models may not perform the same on a specific subset(s) or group(s) of the data.
In some medical applications, for example, a single mistake can lead to catastrophic events. It is not acceptable to have a model that fails to detect the inputs from a specific group of patients (black people, females, babies, etc), even if these groups of people are rare within the population. It is not acceptable at all to have a model that has 97% average accuracy and fails, for example, on most of the babies data even if the babies represent 1% of the patients. The group of (babies, black people, etc) are considered as minority groups in the dataset and we need to make sure that our model is performing as well as possible on all groups.
The cause of this undesirable tendency is the way we train our neural networks. The standard training aims to minimize the average loss and doesn’t look at how the model performs on every single group of examples. This is the standard training procedure for Empirical Risk Minimization (ERM) in equation (\ref{eq1}). It is possible with ERM to find some groups with very low losses and other groups with very high losses. Looking at the average accuracy/loss alone is not enough to judge if the model performs well on all groups.
\[\begin{equation} \hat{\theta}_{ERM} := \arg\min_{\theta\in\Theta} \mathbb{E}_{(x,y)\sim \hat{P}} [l(\theta ; (x, y))]\label{eq1}\tag{1} \end{equation}\]- (x,y) is the input features and labels, drawn from some distribution $\hat{P}$.
- $\Theta$ is the model family.
- L is the loss function.
Average Accuracy, and the Worst-Group Accuracy:
Suppose that we have a dataset of images of two types of birds, waterbirds which usually appear near water in nature, and land birds which usually appear on the land in nature. All the waterbirds images have a water background and all the land birds images have a land background.
Table(1): An example of data with (waterbirds, water background), and (land birds, land backgrounds) images.
Here the model can learn to classify the background instead of the foreground object (the bird) and get the same accuracy as classifying the actual birds. In this case, there is a “correlation” between the background and the bird class. Let’s now add a few images that break this correlation; (waterbirds, land background), and (land birds, water backgrounds) images.
Table(2): Adding few (waterbirds, land background), and (land birds, water backgrounds) images
Since the (waterbirds, water background), and (land birds, land background) images dominate the training dataset and the test dataset, we can still get good average test accuracy by just looking at the background.
Although we are performing well on average, the accuracy for the two minority groups of images (waterbirds, land background), and (landbirds, water background) could be very low since the model doesn’t see many images from these groups during training. We will call the group with the worst accuracy (performance) the worst group. In the figure below, while the average accuracy is 97% we see that the (waterbirds, land) is the worst group with just 20% accuracy.
Figure(1): The model can perform well on the average test accuracy but gives poor accuracy on some groups of images.
Now it is clear that the model can give a good average test accuracy but a low worst group test accuracy. A direct solution is to optimize for the worst-case training loss and this is what an optimization method called Distributionally Robust Optimization (DRO) does. The paper studies the efficiency of DRO in improving the worst-group test error.
This paper studies this problem under the case of overparameterized neural networks when the network can perfectly fit all the training data (all the groups), and also give good average test performance while giving a very poor test performance on some of the groups/minority groups. The paper studies how to make the worst-group test performance as good as possible while maintaining a good average test performance at the same time.
Before diving into the paper let’s define some of the terms used:
Group: A group is sub-distribution in a mixture distribution. Empirically, this corresponds to a subset of the sampled dataset. The paper assumes that we have the prior knowledge about data to define these groups.
Overparameterized Neural Network: The models of interest in the paper are those models that have enough capacity (parameters) to achieve near-optimal average training and test accuracy. We refer to these models as overparameterized neural networks; as an example this paper focuses on two networks: ResNet50, and BERT. Most of the DRO works have been focused on the classic (under parameterized) model setting. In contrast, this paper studies group DRO in the overparameterized regime with vanishing training loss and poor worst-case generalization.
Datasets:
The paper utilizes three datasets, two of them are images classification datasets, and the third is a Natural Language Inference dataset. All the datasets share some features:
a. Each dataset consists of some groups.
b. Each dataset has a group(s) that contain a large number of images/text. For these groups, there is always a correlation between the class of each example and some of the features in the example (the birdtype-background correlation).
c. Each dataset has a group(s) of a small number of examples. These groups contain examples that break the correlation in (b). Having these few examples make the correlation in (b) a spurious correlation.
For all the datasets, most of the examples follow this class-features correlation, and a small chunk of the dataset breaks this correlation.\
- Object recognition with correlated backgrounds (Waterbirds dataset), (Williams et al., 2018): A dataset of images of two kinds of birds, waterbirds, and lanbirds with one of two backgrounds, water or land. The groups (waterbirds, water) and (landbirds, land) represent the majority of the images. We have a few images from the groups (waterbirds, land) and (land birds, water).
- Object recognition with correlated demographics (CelebA dataset), (Liu et al., 2015): A dataset of face images. Most of the male images have dark hair and most of the female images have blond hair. We have a few images from the groups (male, blond hair), and (female, dark hair).
- Natural Language Inference (MultiNLI dataset), (Liu et al., 2015):
Here we have two sentences, a hypothesis, and a premise. We want to predict if the hypothesis is entailed by, neutral with, or contradicts the given premise. Sometimes the presence of negation words (like no, not, and never) indicates contradiction, but this is not always true!.
A samples from the datasets are shown in figure (2).Figure(2):Waterbirds, CelebA, and MultiNLI datasets
The paper uses two models to study the worst-case performance on the three datasets:
- ResNet50 for the waterbirds, and CelebA datasets.
- BERT For MultiNLI dataset.
Empirical Risk Minimization (ERM) vs Distributionally Robust Optimization (DRO):
The paper finds that overparameterized neural networks can perfectly fit all the training data (all the groups) for all the three mentioned datasets (Waterbirds, CelebA, and MultiNLI datasets), and also give good average test performance, but a very poor test worst-case performance (see table 3).
This happens because, during training, we optimize with ERM for the average loss only and we don’t consider the worst group.
The direct solution is to optimize for the worst-case loss itself. In Distributionally Robust Optimization (DRO) we try to minimize the worst (highest) loss over the groups. In other words, we compute the loss for every group of examples and minimize the highest one (equation \ref{eq2}).
\(\begin{equation}
\hat{\theta}_{DRO} := \arg\min_{\theta\in\Theta} \{\max_{g\in\mathbb{G}} \mathbb{E}_{(x,y)\sim \hat{P_g}} [L(\theta ; (x, y))]\}\label{eq2}\tag{2}
\end{equation}\)
This means that we need to predefine the groups in our dataset. The question that arises is how to define these groups. The answer is that we need some domain knowledge of the dataset in hand. Here we consider the case where we have prior knowledge about the groups in the dataset.
Both ERM and DRO Have Poor Worst-group Test Accuracy In The Overparameterized Regime:
When tested on the above datasets, ERM models perform poorly on the worst-case group at test time.
Even with DRO, models perform similarly to ERM models, attaining near-perfect training accuracies and high average test accuracies, but poor worst-group test accuracies (see table 3). This indicates that the models can generalize well on average (good average test accuracy), but they do not generalize well on the worst-case group (good worst-case group performance on the training set and bad worst-case group performance on the test set).
Table(3): Both ERM and DRO models perform poorly on the worst-case group in the absence of regularization
We can see in table (3) that with DRO we can achieve the average train-test generalization, but we still suffer from worst-group generalization.
DRO Improves Worst-group Accuracy Only Under the Appropriate Regularisation:
The gap between worst-case group training performance, and the worst-case group test performance is called the generalization gap. This generalization gap suggests an overfitting problem often encountered in ML models; thus a natural solution is to apply more regularization to the model in addition to the standard regularisation for ResNet and BERT. This paper experimented applying 2 different regularisation methods: $l_{2}$ penalty, and early stopping.
And as expected, the regularisation changed the results (See Table 3).
Under heavy regularisation DRO was able to give a high worst group accuracy both in train and test sets, whereas ERM failed to do so and instead attained low worst group accuracy and high average accuracy in both train and test sets.
It is also important to choose the best regularization technique for each dataset/model. The paper mentioned that $l_{2}$ regularization works well with ResNet50 (Waterbirds, and CelebA datasets), but fails with BERT (MultiNLI dataset). See table 3 below for detailed results and parameters used for regularization for each dataset.
Table(4):Average and worst-group accuracies for each training method. Both ERM and DRO models
perform poorly on the worst-case group in the absence of regularization (top). With strong regularization (middle, bottom), DRO achieves high worst-group performance, significantly improving
from ERM. Cells are colored by accuracy, from low (red) to medium (white) to high (blue) accuracy.
Accounting For Generalization Through Group Adjustments Improves DRO:
The paper shows that using DRO with strong regularisation can narrow the worst-group generalization gap. However, even with heavy regularization, some groups may still have relatively large gaps. To push the groups test performance more toward the training performance, the paper proposes to prioritize between groups during training with DRO method.
If we can guess which groups will have a higher gap, we can prioritize obtaining lower training loss for these groups in training.
Usually the groups with few examples suffer from this generalization gap. So the paper tries to make these small groups have a low worst-case loss also on both training and test data. A simple approach is to use the group size as a variable in the loss function. A straightforward way of using the group size is to add the inverse of the group size as a new term in the loss function as in equation (\ref{eq3}).
\(\begin{equation}
\hat{\theta}_{adj} := \arg\min_{\theta\in\Theta} \{\max_{g\in\mathbb{G}}\{ \mathbb{E}_{(x,y)\sim \hat{P_g}} [L(\theta ; (x, y))] + \frac{C}{\sqrt{n_g}}\}\}\label{eq3}\tag{3}
\end{equation}\)
This is referred to in the paper as group adjustments. By using it we encourage the model to focus on fitting the smaller group. The next table shows how group adjustments can narrow the worst-group test accuracy gap.
Table(5): Average and worst-group test accuracies with and without group adjustments. Group adjustments
improve worst-group accuracy, though average accuracy drops for Waterbirds
Algorithm Implementation:
The algorithm implements the minimax optimization problem in equation (\ref{eq2}) by maintaining a distribution q over groups, with high masses on high-loss groups, and updating on each example proportionally to the mass on its group.
To explain the algorithm we will start with equation (\ref{eq4}) which shows the famous gradient descent algorithm that we use to learn the network weights. $\eta_{\theta}$ is the learning rate that is fixed for all the groups.
\(\begin{equation}
\theta^{t} \leftarrow \theta^{t-1} - \eta_{\theta} \nabla l(\theta ; (x, y))\label{eq4}\tag{4}
\end{equation}\)
The proposed algorithm is slightly different from equation () above. The proposed algorithm learns to give a higher learning rate to the groups with high losses during training. Accordingly, each group will have its own learning rate. These learning rates change during training based on the model performance(loss) over each group. The higher the loss on each group, the higher its associated learning rate.
Now let’s replace the learning rate $\eta_{\theta}$ with a new per-group learning rate $\eta_{\theta}*q^t_{g}$ to get equation(\ref{eq5}).
\(\begin{equation}
\theta^{t} \leftarrow \theta^{t-1} - \eta_{\theta} q^t_{g} \nabla l(\theta^{t-1};(x, y))\label{eq5}\tag{5}
\end{equation}\)
Here $q^t_{g}$ represents the learnable part of the new learning rate, and it is updated during each training iteration. A simple practical implementation proposed by the paper is as follows:
- We sample an example/batch from only one group during each iteration.
- Then we compute the loss on this example/batch and compute the value of $q^t_{g}$ for that group using exponentiated gradient ascent as follow:
\(\begin{equation} q’_{g} \leftarrow q’_{g} exp(\eta_{q} l(\theta^{t-1} ; (x, y))\label{eq6}\tag{6} \end{equation}\)
- Finally, we use $q^t_{g}$ to update the model using equation(\ref{eq5}).
The algorithm provides stability and obtaining convergence guarantees in contrast to the existing group DRO algorithm (Oren et al., 2019) by updating q using gradients instead of picking the group with the worst average loss at each iteration.
Comparison With Existing Method (Importance Weighting) :
In addition to the ERM baseline mentioned above, a commonly used method in ML for robust models is importance weighting. It simply optimizes for weighted loss as shown in equation (\ref{eq8}).
\(\begin{equation} \hat{\theta}_{w} := \arg\min_{\theta\in\Theta} \mathbb{E}_{(x,y,g)\sim \hat{P}} [w_{g}*l(\theta ; (x, y))]\label{eq8}\tag{8} \end{equation}\)
The weights are chosen heuristically, one of the choices for the weights is the inverse of the frequency of the group, giving the minority group higher weight in the training loss. This paper shows experimentally that group DRO outperforms importance weighting (see table (6)). And also theoretically by proving the inability of importance weighting to find reasonable weights in non-convex problems, unlike group DRO.Table(6): Comparison of ERM, upweighting (UW), and group DRO models on the three datasets.
The Takeaway:
The paper concludes the following:
- Looking at the average accuracy/loss only is not enough to judge the model performance (we need to consider the group performance as well).
- The paper focuses on overparameterized neural networks. ERM fails to improve the worst-group test performance as it optimizes for the average loss.
- Group DRO models perform similarly to ERM models, attaining near-perfect training accuracies and high average test accuracies, but poor worst-group test accuracies.
- Adding a strong regularisation, group DRO can achieve a good worst-group generalization gap and translate good worst-group training loss to good worst-group test loss.
- To guarantee convergence for the group DRO method, a new stochastic optimization algorithm is proposed.
- It is proved empirically and theoretically that Group DRO is better than the existing method of importance weighting.
References:
- S.Sagawa, P. W.Koh, T.Hashimoto, and P.Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In International Conference on Learning Representations (ICLR), 2020a.
- Y. Oren, S. Sagawa, T. Hashimoto, and P. Liang. Distributionally robust language modeling. In Empirical Methods in Natural Language Processing (EMNLP), 2019.
- A. Williams, N. Nangia, and S. Bowman. A broad-coverage challenge corpus for sentence understanding through inference. In Association for Computational Linguistics (ACL), pp. 1112–1122, 2018.
- Z. Liu, P. Luo, X. Wang, and X. Tang. Deep learning face attributes in the wild. In Proceedings of the IEEE International Conference on Computer Vision, pp. 3730–3738, 2015.
- C. Wah, S. Branson, P. Welinder, P. Perona, and S. Belongie. The Caltech-UCSD Birds-200-2011 dataset. Technical report, California Institute of Technology, 2011.