# BFR
This repository provides the source code for Bayesian Feature Reweighting  (BFR), a two-stage framework that improves group robustness without any group annotation.
## Preparation

### Download datasets
- [Waterbirds](https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz)
- [CelebA](https://www.kaggle.com/datasets/jessicali9530/celeba-dataset) ([metadata](https://github.com/PolinaKirichenko/deep_feature_reweighting/blob/main/celeba_metadata.csv))
- ImageNet [(train](https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar) [,val)](https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar)
- [ImageNet-A](https://github.com/hendrycks/natural-adv-examples)
- CivilComments: The dataset will be automatically downloaded from`wilds`
- MultiNLI: Call the function `download` in `data/multinli.py` to download the dataset.

### Prepare `metadata.csv` for each dataset
- Waterbirds, CelebA, CivilComments, and MultiNLI provide `metadata.csv` files.
- For the ImageNet-9 and ImageNet-A datasets, run the following code
    ```python
    from data.in9_data import prepare_imagenet9_metadata, prepare_imageneta_metadata
    base_dir = "path/to/imagenet/folder"
    prepare_imagenet9_metadata(base_dir)
    data_root = "path/to/imagenet-a/folder"
    prepare_imageneta_metadata(data_root)
    ````
## Run BFR
We show the example for Waterbirds dataset.
### Probabilistic Pretraining
Select a dataset to train a base model.
```python
python train_basemodel.py --output_dir release/waterbirds
			  --pretrained_model
			  --num_epochs 100
			  --weight_decay 1e-4
			  --batch_size 32
			  --init_lr 1e-4
			  --eval_freq 1
			  --dataset waterbirds
			  --augment_data
			  --data_dir
			  --model resnet50
			  --optimizer sgd
```
### Bayesian Feature Reweighting
Make sure to align the `base_dir` with the `output_dir` from base model training. 
```python
python bfr_retraining.py --base_dir
			 --dataset waterbirds
			 --data_dir
			 --dro_metric acc
			 --inf exact
			 --inner_lr 5e-4
			 --max_outer_iter 100
			 --model resnet50
			 --multiplicative_updates
			 --normalize_weights
			 --optimizer adam
			 --outer_grad_clip 1.0
			 --outer_lr 1.0
			 --outer_lr_scheduler step
			 --temperature 0.1
			 --weight_decay 0.1
			 --workers 0
			 --no_wandb
			 --seed 1
			 --train_val
			 --batch_size 64