# Group-robust Sample Reweighting

## Dataset

Run the following command to download and preprocess the datasets used by this repo. 
```bash
python3 prepare_dataset.py --dataset <DATASET> --root_dir <DIR>
```

## Running the algorithm
We show the example for waterbirds dataset. 
### Stage 1: Select a held-out set, then train a base model on the remaining data.

```bash
python3 train_basemodel.py --output_dir=results/waterbirds10 --pretrained_model \
   --num_epochs=100 --weight_decay=1e-3 --batch_size=32 --init_lr=1e-3 \
   --eval_freq=1 --dataset waterbirds \
   --augment_data --held_out_ratio 0.1
```

### Stage 2: Last-layer retraining with sample reweighting. 
Make sure to align the base_dir with the output_dir from base model training. 
```bash
python gsr_retraining.py --base_dir results/waterbirds10 --dataset waterbirds --dro_metric acc --inf exact --inner_lr 0.0001 --max_outer_iter 100  --model resnet50 --multiplicative_updates --normalize_weights --optimizer lbfgs --outer_grad_clip 1.0 --outer_lr 1.0 --outer_lr_scheduler step --temperature 0.1 --weight_decay 0.1 --workers 0 --no_wandb --seed 1
```