# Code for Correct-N-Contrast: a Contrastive Approach for Improving Robustness to Spurious Correlations
This folder contains code for our ICLR 2022 submission.  

## Requirements

To install requirements, we recommend setting up a virtual environment with conda:

```setup
conda env create -f environment.yml  
conda activate cnc
``` 

We also include a `requirements.txt` file for installing dependencies with `pip install -r requirements.txt`.

We also include a `requirements.txt` file for installing dependencies with `pip install -r requirements.txt`.  

List of (installable) dependencies:  
* python 3.7.9  
* matplotlib 3.3.2
* numpy 1.19.2  
* pandas 1.1.3  
* pillow 8.0.1  
* pytorch=1.7.0  
* scikit-learn 0.23.2  
* scipy 1.5.2  
* transformers 4.4.2 
* torchvision 0.8.1  
* tqdm 4.54.0  
* umap-learn 0.4.6

## Datasets and code 

**Colored MNIST**: Running the training command below should automatically download and setup the Colored MNIST dataset.  

**Waterbirds**: Download the dataset from [here](https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz). Unzipping this should result in a folder `waterbird_complete95_forest2water2`, which should be moved to `./datasets/data/Waterbirds/`.  

**CelebA**: Download dataset files from this [Kaggle link](https://www.kaggle.com/jessicali9530/celeba-dataset). Then move files to `./datasets/data/CelebA/` such that we have the following structure:
```
# In `./datasets/data/CelebA/`:
|-- list_attr_celeba.csv
|-- list_eval_partition.csv
|-- img_align_celeba/
    |-- image1.png
    |-- ...
    |-- imageN.png
```  

**CivilComments-WILDS**: Loading this dataset requires the `transformers` package. One can download the source csv from [here](https://worksheets.codalab.org/bundles/0x8cd3de0634154aeaad2ee6eb96723c6e). Then, move `all_data_with_identities.csv` to `./datasets/data/CivilComments/all_data_with_identities.csv`.


## Training and Evaluation  

For all datasets except Colored MNIST, running the below commands requires loading an initial trained ERM model, which we provide as part of our submission through download links below in the **Pre-trained Models** section. The training for these models is as described in Appendix D.2.2. We recommend downloading as training the initial ERM model can take a fair amount of time, e.g. ~1.5 hours for Waterbirds on a machine with 8 CPUs and 1 NVIDIA V100 GPU, and ~3 hours for CelebA on a machine with 32 CPUs and 4 NVIDIA V100 GPUs. Aside from the specific hyperparameters in Appendix D.2.2., training these ERM models also just involves standard mini-batch SGD with the dataset labels.


### Colored MNIST  

Train:  
```train
python train_supervised_contrast.py --dataset colored_mnist --train_encoder --arch cnn --data_cmap hsv --test_shift random -tc 0 1 -tc 2 3 -tc 4 5 -tc 6 7 -tc 8 9 --p_correlation 0.995 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0 --max_epoch_s 5 --bs_trn_s 32 --num_anchor 32 --num_positive 32 --num_negative 32 --num_negative_easy 32 --batch_factor 32 --optim sgd --lr 1e-3 --momentum 0.9 --weight_decay 1e-4 --weight_decay_c 1e-4 --target_sample_ratio 1 --temperature 0.05 --max_epoch 3 --no_projection_head --contrastive_weight 0.75 --bs_trn 32 --bs_val 32 --num_workers 0 --no_projection_head --log_loss_interval 10 --checkpoint_interval 10000 --log_visual_interval 40000 --verbose --replicate 42 --seed 42
```

Evaluate:  
```evaluate
python train_supervised_contrast.py --dataset colored_mnist --arch cnn --evaluate --load_encoder [model_file_name.pth.tar] --data_cmap hsv --test_shift random -tc 0 1 -tc 2 3 -tc 4 5 -tc 6 7 -tc 8 9 --p_correlation 0.995 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0
```

### Waterbirds

Train:  
```train
python train_supervised_contrast.py --dataset waterbirds --arch resnet50_pt --train_encoder --pretrained_spurious_path "./model/waterbirds/waterbirds_erm_regularized.pt" --num_anchor 17 --num_positive 17 --num_negative 17 --num_negative_easy 17 --batch_factor 32 --optim sgd --lr 1e-4 --momentum 0.9 --weight_decay 1e-3 --weight_decay_c 1e-3 --target_sample_ratio 1 --temperature 0.1 --max_epoch 5 --no_projection_head --contrastive_weight 0.75 --log_visual_interval 10000 --checkpoint_interval 10000 --verbose --log_loss_interval 10 --replicate 42 --seed 42
```

Evaluate:  
```evaluate
python train_supervised_contrast.py --dataset waterbirds --arch resnet50_pt --evaluate --load_encoder waterbirds_cnc_pretrained.pth.tar 
```

### CelebA

Train:  
```train
python train_supervised_contrast.py --dataset celebA --arch resnet50_pt --train_encoder --pretrained_spurious_path "./model/celebA/celeba_erm_regularized.pt" --num_anchor 64 --num_positive 64 --num_negative 64 --num_negative_easy 64 --batch_factor 32 --optim sgd --lr 1e-5 --momentum 0.9 --weight_decay 1e-1 --weight_decay_c 1e-1 --target_sample_ratio 0.1 --temperature 0.05 --max_epoch 15 --no_projection_head --contrastive_weight 0.75 --log_visual_interval 10000 --checkpoint_interval 10000 --verbose --log_loss_interval 10 --replicate 42 --seed 42
```

Evaluate:  
```evaluate
python train_supervised_contrast.py --dataset celebA --arch resnet50_pt --evaluate --load_encoder celebA_cnc_pretrained.pth.tar 
```

### CivilComments-WILDS

Train:  
```train
python -W ignore train_supervised_contrast.py --dataset civilcomments --arch bert-base-uncased_pt --train_encoder --pretrained_spurious_path ./model/civilcomments/civilcomments_erm_early.pth.tar --num_anchor 16 --num_positive 16 --num_negative 16 --num_negative_easy 16 --batch_factor 128 --bs_trn 16 --clip_grad_norm --optim AdamW --lr 1e-4 --weight_decay 1e-2 --target_sample_ratio 0.1 --temperature 0.1 --max_epoch 10 --no_projection_head --contrastive_weight 0.75 --log_loss_interval 10 --checkpoint_interval 10000 --verbose --log_visual_interval 400000 --verbose --replicate 42 --seed 42
```

Evaluate:  
```eval
python -W ignore train_supervised_contrast.py --dataset civilcomments --arch bert-base-uncased_pt --evaluate --load_encoder civilcomments_cnc_pretrained.pth.tar 
```

## Pre-trained Models

Both pretrained initial ERM models and the trained Correct-N-Contrast models are available to download [here](https://drive.google.com/drive/folders/1SqhrdhLbBCNmTqCEf9Xm5ib4iTnt5ncx?usp=sharing).

ERM models were trained as described in Appendix D.2.2. Correct-N-Contrast models were trained as described in Appendix D.2.4.  

Once downloaded, models should be moved to the following file paths:  

**Waterbirds**  
- ERM model: `./model/waterbirds/./model/waterbirds/waterbirds_erm_regularized.pt`  
- CNC model: `./model/waterbirds/config-tn=waterbird_complete95-cn=['forest2water2']/waterbirds_cnc_pretrained.pth.tar`  

**CelebA**  
- ERM model: `./model/celebA/celeba_erm_regularized.pt`
- CNC model: `./model/celebA/config/celebA_cnc_pretrained.pth.tar`

**CivilComments-WILDs**  
- ERM model: `./model/civilcomments/civilcomments_erm_early.pth.tar`
- CNC model: `./model/civilcomments/config/civilcomments_cnc_pretrained.pth.tar`
