# Understanding Overfitting in Reweighting Algorithms for Worst-group Performance
ICLR 2022 Submission

## Table of Contents
- [Quick Start](#quick-start)
- [Introduction](#introduction)
- [Training](#training)

## Quick Start
Before training, please download the datasets first. You can download Waterbirds at this [link](https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz), and CelebA [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

If you want to exactly reproduce our experimental results, please create a virtual environment with Anaconda using:
```shell
conda env create --file env.yml
```


To train a model on Waterbirds, run the following command:
```shell
python waterbirds_train.py --data_root /path/to/dataset --alg [ALG] --wd [WD] --test_train --seed [SEED]
```
where `[ALG]` is `erm`, `iw` or `gdro`, `[WD]` is the weight decay level and `[SEED]` is the random seed.

To train a model on CelebA, run the following command:
```shell
python celeba_train.py --data_root /path/to/dataset --alg [ALG] --wd [WD] --test_train --seed [SEED]
```

## Introduction

In this work, we theoretically prove the pessimistic result that all reweighting algorithms overfit, and if regularization is applied, it must be large enough to prevent the model from achieving nearly perfect training performance in order to avoid overfitting.

This repository contains codes for experiments to empirically validate our theoretical results. Particularly, we conduct the experiments on two datasets: Waterbirds and CelebA.

## Training

On each of the dataset, we use a ResNet 18 as the model and optimize it with momentum SGD. Our codes provide command-line options for learning rate (`--lr`), weight decay level (`--wd`) and multi-level learning rate decay scheduler (`--scheduler`), so it is very simple to train a model under different settings for optimization.

For instance, to train a model on CelebA with Group DRO, learning rate `0.001`, weight decay `0.01` for 300 epochs with the learning rate decayed at Epochs 200 and 250, simply run:
```shell
python celeba_train.py --data_root /path/to/dataset --alg gdro --lr 0.001 --wd 0.01 --epochs 300 --scheduler 200,250 --test_train --seed [SEED]
```

In our experiments, we use the following fixed set of random seeds: `2002, 2022, 2042, 2062, 2082`.