# CMID
Code for Conditional Mutual Information-Debiasing (CMID) method to improve OOD generalization and subgroup robustness. (The base code comes from the [`group_DRO`](https://github.com/kohpangwei/group_DRO)  implementation.)

### Install dependencies
   The code uses `python 3.6.8`. Dependencies can be installed by using: 

   ```
   pip install -r requirements.txt
   ```

   Change the `root_dir` variable in `data/data.py`. Datasets will be stored in the location specified by `root_dir`. (Check [this link](https://github.com/kohpangwei/group_DRO) for more details.)

### Subgroup Robustness Experiments
Experiments on Waterbirds, CelebA, MultiNLI, and CivilComments datasets.
#### Download datasets. 

   - Waterbirds:
     The code expects the following files/folders in the `[root_dir]/cub` directory:
     - `data/waterbird_complete95_forest2water2/`


     A tarball of this dataset can be downloaded from [this link](https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz). 
  
   - CelebA:
     The code expects the following files/folders in the `[root_dir]/celebA` directory:
     - `data/list_eval_partition.csv`
     - `data/list_attr_celeba.csv`
     - `data/img_align_celeba/`


     These dataset files  can be downloaded from [this Kaggle link](https://www.kaggle.com/jessicali9530/celeba-dataset).
  
   - MultiNLI:
     The code expects the following files/folders in the `[root_dir]/multinli` directory:
     - `data/metadata_random.csv` 
     - `glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli`
     - `glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mm`
     - `glue_data/MNLI/cached_train_bert-base-uncased_128_mnli`
     
     The metadata file is included in `dataset_metadata/multinli` in the folder. 
     The `glue_data/MNLI` files are generated by the [huggingface Transformers library](https://github.com/huggingface/transformers) and can be downloaded [here](https://nlp.stanford.edu/data/dro/multinli_bert_features.tar.gz).
  
   - CivilComments:
     The code expects the following files/folders in the `[root_dir]/civcom` directory
     - `all_data_with_grouped_identities.csv`
     - `all_data_with_identities.csv`
     
     A tarball of this dataset can be downloaded from [this link](https://drive.google.com/file/d/1ioV8bf5jpEhXW2UTN41z-0uxUqJk5UT6/view?usp=share_link).

   
#### Run code and infer results.
   The main files to run the experiment and infer results are `run_expt.py` and `parse_log_file.py`, respectively. The specific commands are listed below:
   
   - Waterbirds: 

     ```
     python run_expt.py --log_dir /CMID/log-wb -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.0005 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 100 --cmi_reg --log_every 20 --reg_st 20.0 --cmistinc --scale 4
     ```

     ```
     python parse_log_file.py --log_dir /CMID/log-wb --num_groups 4
     ```    
   
   - CelebA:

     ```
     python run_expt.py --log_dir /CMID/log-cel -s confounder -d CelebA -t Blond_Hair -c Male --lr 0.0003 --batch_size 128 --weight_decay 0.001 --model resnet50 --n_epochs 50 --cmi_reg --log_every 20 --reg_st 10.0 --cmistinc --scale 5
     ```

     ```
     python parse_log_file.py --log_dir /CMID/log-cel --num_groups 4
     ```
   
   - MultiNLI:

     ```
     python /run_expt.py --log_dir /CMID/log-mnli -s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --lr 5e-05 --batch_size 32 --weight_decay 0 --model bert --n_epochs 5 --cmi_reg --reg_st 75.0 --cmistinc --lr1 0.005
     ```

     ```
     python parse_log_file.py --log_dir /CMID/log-mnli --num_groups 6
     ```
   
   - CivilComments:

     ```
     python run_expt.py --log_dir /CMID/log-ccom -s confounder -d CivComMod -t toxicity -c identity_any --lr 0.00001 --batch_size 32 --weight_decay 0.001 --model bert-base-uncased --n_epochs 10 --cmi_reg --reg_st 25.0 --cmistinc --lr1 0.0001
     ```

     ```
     python parse_log_file.py --log_dir /CMID/log-ccom --num_groups 16
     ```

### OOD Generalization Experiment: Camelyon Dataset
#### Download dataset.
The code expects the following files/folders in the `./camelyon` directory.
   - `data/camelyon17_v1.0/metadata.csv`
   - `data/camelyon17_v1.0/patches/`
     
Including all the patch data. If these files do not exist, the code will download them here during run time.

#### Run code and infer results.  
We use a different file for Camelyon to use [Wilds](https://github.com/p-lambda/wilds) dataloading. To run it, go into the `./camelyon` directory and run the following sample command, which will output `camelyon.txt` in the same directory containing results.
   
```
python camelyon.py --cmi_reg --epochs 5 --epochs2 10 --lr 0.0001 --lr1 0.0001 --weight_decay 0.01 --reg_st 0.5 --batch_size 32 &> camelyon.txt
```
   
