# SATE Code
Code of "SATE: A TWO-STAGE APPROACH FOR PERFORMANCE PREDICTION IN SUBPOPULATION SHIFT SCENARIOS"

Some of our code is based on open source code (SubpopBench), more details can be found here:
[Change is Hard: A Closer Look at Subpopulation Shift](https://arxiv.org/abs/2302.12254) (Yang et al., ICML 2023). 
Their license and README are located in LICENSE.orig and README.orig. There are no URLs or names in this code related to the authors of SATE.

## 1 Quick Start
### 1.1 set up the environment
``` bash
pip install -r requirements.txt
```
### 1.2 download the datasets
by running "code/scripts/download.sh", which includes
``` bash
python download.py {dataset} --data_path ../dataset  --download
```
"{dataset}" could be  Waterbirds, CelebA or MultiNLI

    for CheXpertNoFinding

    1. Download the [downsampled CheXpert dataset](http://download.cs.stanford.edu/deep/CheXpert-v1.0-small.zip) and extract it.

    2. Register for an account and download the CheXpert demographics data [here](https://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf). Place the `CHEXPERT DEMO.xlsx` in your CheXpert directory. 

    3. Move or create a symbolic link to the `CheXpert-v1.0-small` folder named `chexpert` in your data directory.

    4. Run `python -m code.scripts.download chexpert --data_path <data_path>`.
we will release the way to download SNLI later

Successfully downloaded datasets will appear in "code/dataset" folder.

### 1.3 train test split
run "python gen.py" to generate metadata file which contains data directory, label, attribute and train-test-split.
### 1.4 train the models
For image tasks, run "stage1.sh" (for ERM and GroupDRO) and "stage2.sh" (for DFR, based on trained ERM model).

For language tasks, run "lan_stage1.sh" and "lan_stage2.sh"

you can choose datasets, algorithms, model architectures and other training parameters in the script files.

the trained models will be saved in "output/output_attrNo" or "output/output_attrYes", please categorize them into "output/(bert, resnet, vit)" by their architecture before next step.


### 1.5 performance prediction
Select.py is the main file that runs the whole performance prediction experiments. You can choose to run it on the above mentioned datasets, algorithms and model-archs (make sure you've trained them beforehead).

See "celeba.sh" for an example of running this experiments.

"waterbirds.sh" is an example of running "real-world shift" experiments, where you can choose to do perturbations(fog,blur,bright,contrast) on the test datasets.

## 2 File Discriptions
1. JS.py: functions to calculate J-S divergence of P(y), P(a) and P(y | a), quantitive metrics for subpopulation shift.
2. baselines.py and baseline_util.py: baseline algorithms
3. train.py: main file to train the models
4. Select.py: main file to run the whole performance prediction, contains SATE algorithm
5. params.py: the subgroup distribution of manaully designed test datasets. Each number represents how many samples we will extract from the corresponding subgroup for testing.
6. gen.py: the main file to do train test split in performance prediction experiments, contains the subgroup distribution of training sets.

## 3 Data Augmentations
In "dataset/datasets.py", we defined "pertur_transform" method for class "Waterbirds", "CelebA" and "BaseImageDataset"(the base class of CheXpertNoFinding).

Within that method, you can find out how we did data augmentations to generate $S'$ for subgroup accuracy estimation.

