# Take One Gram of Neural Features, Get Enhanced Group Robustness

Our implementation builds upon the [WILDS framework](https://github.com/p-lambda/wilds}) by supplementing scripts to extract features (`examples/extract_style_features`),
partition a dataset in clusters and assign pseudo-group labels (`examples/discover_env_labels`).


## How to install 

1. Install required librairies

```bash
 pip install torch==1.10.0 coloredlogs torchsummary verboselogs incense yacs munkres torch_tb_profiler torchaudio==0.10.0 torchvision==0.11.1 tensorflow tensorflow-gpu keras==2.6.* transformers
```


2. Install `torch-scatter` and custom `wilds` package 
```bash
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu102.html && pip install -e .
```

### Downloading and training on the WILDS datasets
The first time you run these scripts, you might need to download the datasets. You can do so with the `--download` argument, for example:
```
python examples/run_expt.py --dataset waterbirds --algorithm groupDRO --root_dir data --download
```

## Running experiments

In the following, we take the example of experiments with `ERM`, `GroupDRO` and `GramStyle` on Waterbirds dataset. Available datasets include Waterbirsd and CelebA. 
COCO-on-Places-224 images will be released if accepted. The code used to generate the latter dataset is taken from <cite>Ahmed et al.'s [repository](https://github.com/Faruk-Ahmed/predictive_group_invariance/blob/main/coco/data/data_makers/coco_places.py)

### ERM
```bash
python3 examples/run_expt.py --dataset waterbirds --algorithm ERM --model resnet50 --root_dir <root_dir>  --log_dir <log_dir> --weight_decay 1e-3 --lr 1e-4 --progress_bar
```

### GroupDRO
```bash
python3 examples/run_expt.py --dataset waterbirds --algorithm GroupDRO --model resnet50 --root_dir <root_dir>  --log_dir <log_dir> --weight_decay 1. --lr 1e-5 --progress_bar
```

### GramStyle
1. Traim a VGG-19 model with ERM for 1 epoch
```bash
python3 examples/run_expt.py --dataset waterbirds --algorithm GroupDRO --model vgg19 --root_dir <root_dir>  --log_dir <log_dir_id_model> --weight_decay 1e-3 --lr 1e-4 --n_epochs 1 --progress_bar
```

2. Extract features Gram matrices
```bash
python3 examples/extract_style_features.py --dataset waterbirds --root_dir <root_dir> --model vgg19 --log_dir <log_dir_id_model> --output_name <output_name> --cluster_feature_type gram_matrix --normalization
```

3. Run clustering algorithm to discover labels
```bash
python3 examples/discover_env_labels.py --dataset waterbirds --root_dir <root_dir> --env_dir <log_dir_id_model/env_output_name> --clustering_algo kmeans --layers features.28
```

4. Train a ResNet-50 with GroupDRO optimization with predicted environments labels in `log_dir_id_model/env_output_name`:
```bash
python3 examples/run_expt.py --dataset waterbirds --algorithm GroupDRO --model resnet50 --root_dir <root_dir>  --log_dir <log_dir> --env_file_path <log_dir_id_model/env_output_name/env_labels_kmeans_features.28.csv> --weight_decay 1. --lr 1e-5 --progress_bar
```

