

# Optimizing importance weighting in the presence of sub-population shifts


This folder contains the code that was used for the paper 'ROptimizing importance weighting in the presence of sub-population shifts'. The repository is licensed under the terms of the BSD-3 license.

In order to replicate results from the paper, the following steps need to be taken. 

1. It is recommended to start a new virtual environment before reproducing code
2. Install the required libraries, specified in the requirements.txt file, via  pip install -r requirements.txt 

This folder does not contain all the datasets used due to size issues. One can acquire these as follows:

* The Waterbirds dataset is made up of the places dataset ([link](http://places.csail.mit.edu)) and Caltech-UCSD Birds-200-2011 (CUB) dataset ([link](https://www.vision.caltech.edu/datasets/cub_200_2011/)). The celebA dataset can be found [here](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), and the multiNLI dataset [here](https://gluebenchmark.com/tasks). 

Below is a description of the key files in the folder. 
* **create_embeddings.py**: contains functions to create last-layer embeddings from finetuned models
* **create_pred_model.py**: contains functions to get predictions of finetuned models for JTT method
* **data.py**: class(es) used to create data for finetuning and subsequently load the embeddings
* **finetune.py**: used in order to finetune resnet50 and BERT models
* **generate_results_standard.py/generate_results_opt.py**: used to generate results for standard/optimized weights on created embeddings and given hyper-parameters
* **helpers.py/metrics.py**: mix of functions for small tasks and calculating results
* **model.py**: class(es) for all the models used
* **param_search_last_layer.py**: used in order to perform a hyper-parameter sweep for standard weights
* **training.py**: used to train resnet50 and BERT models
* **utils_glue.py**: used for multiNLI dataset
* **weight_searcher.py**: class that performs search for optimal weights
* **weights.py**: class that defines weights for models




