# Representation-Level Counterfactual Calibration for Debiased Zero-Shot Recognition

This repository contains the implementation of counterfactual calibration methods for improving CLIP model performance against spurious correlations in zero-shot image recognition tasks.

![Teaser](picture/main-method.png)

## Setup

### Option 1: Using Conda (Recommended)

We provide the [`environment.yml`](environment.yml) file that can be used to create a Conda environment:
   ```bash
   conda env create -f environment.yml
   conda activate counterfactual_embedding
   ```

### Option 2: Using pip

1. Create a virtual environment:
   ```bash
   python -m venv counterfactual_embedding
   # On Windows
   venv\Scripts\activate
   # On Linux/Mac
   source venv/bin/activate
   ```

2. Install dependencies:
   ```bash
   pip install -r requirements.txt
   ```

## Dependencies

The main dependencies include:
- Python 3.10
- PyTorch 2.0.1
- torchvision 0.15.2
- CUDA Toolkit 11.7 (for GPU acceleration)
- NumPy, Pandas, PIL, h5py
- matplotlib, scikit-learn, tqdm

## Project Structure

```
counterfactual-embedding/
├── CLIP_utils/              # CLIP model utilities
├── counterfactualCLIP/      # Main implementation of counterfactual methods
├── data/                    # Dataset loaders for various datasets
│   ├── waterbirds.py        # Waterbirds dataset loader
│   ├── COCO_GB_V1.py        # COCO_GB_V1 dataset loader
│   ├── urbancars.py         # UrbanCars dataset loader
│   ├── NICO.py              # NICO dataset loader
├── others/                  # Auxiliary code
│   └── prompt/              # Visualization tools
├── demo_visual.ipynb        # Demo visualization notebook
├── environment.yml          # Conda environment specification
├── requirements.txt         # Pip requirements
└── README.md                # This documentation file
```

## Dataset Preparation

### Waterbirds
The Waterbirds dataset consists of images of birds superimposed on either land or water backgrounds. To prepare the dataset:

1. Download the dataset from the [official source](https://github.com/kohpangwei/group_DRO?tab=readme-ov-file)
2. Extract the contents to a directory
3. Update the data path in your scripts accordingly

### COCO_GB
The COCO_GB dataset is a variant of COCO with group annotations. To prepare the dataset:

1. Download the dataset from the [COCO website](https://cocodataset.org/)
2. Download the group annotations from [COCO_GB website](https://github.com/datamllab/Mitigating_Gender_Bias_In_Captioning_System)
3. Place the annotations in the COCO directory structure
4. Update the data path in your configuration files

### UrbanCars
The UrbanCars dataset contains car images in urban or country settings. Prepare the dataset by:

1. Download from the provided [official source](https://github.com/facebookresearch/Whac-A-Mole)
2. Place in the appropriate directory structure following the pattern `obj-{type}_bg-{type}_co_occur_obj-{type}`

### NICO
The NICO (Non-I.I.D. Image Dataset with Contexts) dataset. To prepare:

1. Download from the [official website](https://nico.thumedialab.com/)
2. Extract to the appropriate directory structure

### Places
The Places dataset is used for external variants to provide diverse scene backgrounds. To prepare:

1. Download the Places365-Standard dataset from the [official website](http://places2.csail.mit.edu/download.html)
2. We recommend using the high-resolution version (Places365-Standard high-res images)
3. Extract the dataset to your project directory
4. Update the data path in your configuration files

This dataset is particularly important for the external variant method, which requires scene images separate from the foreground objects.

## Usage

### Preprocessing: Extract CLIP Information

First, we need to extract attentions and MLPs information from the CLIP model, which is necessary for estimating counterfactual embeddings in later steps:

#### Extract Attention and MLPs Information

```bash
# ViT-B-32 model
python extract_clip_info.py --model ViT-B-32 --pretrained laion2b_s34b_b79k --dataset waterbirds --batch_size 32 --cuda_id 0

# ViT-B-16 model
python extract_clip_info.py --model ViT-B-16 --pretrained laion2b_s34b_b88k --dataset waterbirds --batch_size 10 --cuda_id 0

# ViT-L-14 model
python extract_clip_info.py --model ViT-L-14 --pretrained laion2b_s32b_b82k --dataset waterbirds --batch_size 4 --cuda_id 0

# ViT-H-14 model
python extract_clip_info.py --model ViT-H-14 --pretrained laion2b_s32b_b79k --dataset waterbirds --batch_size 2 --cuda_id 0
```

#### Extract Prompt Information

```bash
# ViT-B-32 model
python extract_clip_info.py --model ViT-B-32 --pretrained laion2b_s34b_b79k --dataset waterbirds --just_text openai

# ViT-B-16 model
python extract_clip_info.py --model ViT-B-16 --pretrained laion2b_s34b_b88k --dataset waterbirds --just_text openai

# ViT-L-14 model
python extract_clip_info.py --model ViT-L-14 --pretrained laion2b_s32b_b82k --dataset waterbirds --just_text openai

# ViT-H-14 model
python extract_clip_info.py --model ViT-H-14 --pretrained laion2b_s32b_b79k --dataset waterbirds --just_text openai
```

#### For NICO Dataset

```bash
# Extract attention and MLPs for NICO dataset
python extract_clip_info_nico.py --model ViT-B-32 --pretrained laion2b_s34b_b79k --dataset nico --batch_size 32 --cuda_id 0
python extract_clip_info_nico.py --model ViT-B-16 --pretrained laion2b_s34b_b88k --dataset nico --batch_size 10 --cuda_id 0
python extract_clip_info_nico.py --model ViT-L-14 --pretrained laion2b_s32b_b82k --dataset nico --batch_size 4 --cuda_id 0
python extract_clip_info_nico.py --model ViT-H-14 --pretrained laion2b_s32b_b79k --dataset nico --batch_size 2 --cuda_id 0

# Extract prompt information for NICO dataset
python extract_clip_info_nico.py --model ViT-B-32 --pretrained laion2b_s34b_b79k --dataset nico --just_text openai
```

Available parameters:
- `--model`: CLIP model architecture (ViT-B-32, ViT-B-16, ViT-L-14, ViT-H-14)
- `--pretrained`: Pre-trained model weights to use (laion2b_s34b_b79k, laion2b_s34b_b88k, laion2b_s32b_b82k, laion2b_s32b_b79k)
- `--dataset`: Dataset to process (waterbirds, urbancars, nico, coco_gb)
- `--batch_size`: Batch size for processing (adjust based on GPU memory)
- `--cuda_id`: GPU ID to use
- `--just_text`: Extract only text embeddings using the specified prompt template (e.g., openai)

This preprocessing step will generate files that store the attention maps and MLP features, which will be used by the counterfactual embedding methods.

### Run

To obtain the running results of three methods under different architectures across various datasets, please run one of the following instructions:

##### External variant

```bash
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-32 --scene_type outer_cz --dataset waterbirds --alpha 0.4 --lam_hat 1.1 --lam 0.8 --batch_size 200 --select_scene_num 50
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-16 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.4 --lam 0.6 --batch_size 100 --select_scene_num 25
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-L-14 --scene_type outer_cz --dataset waterbirds --alpha 0.8 --lam_hat 1.3 --lam 0.7 --batch_size 100 --select_scene_num 2
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-H-14 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.7 --lam 0.7 --batch_size 100 --select_scene_num 2
```

##### Internal variant

```bash
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-32 --scene_type inner_cz --dataset waterbirds --alpha 0.4 --lam_hat 1.1 --lam 0.8 --batch_size 200 --select_scene_num 40
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-16 --scene_type inner_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.5 --lam 0.5 --batch_size 200 --select_scene_num 190
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-L-14 --scene_type inner_cz --dataset waterbirds --alpha 0.8 --lam_hat 0.8 --lam 0.7 --batch_size 200 --select_scene_num 190
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-H-14 --scene_type inner_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.7 --lam 0.5 --batch_size 100 --select_scene_num 95
```

##### Virtual variant

```bash
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-32 --scene_type virtual_cz --dataset waterbirds --alpha 0.7 --lam_hat 0.7 --lam 0.8 --batch_size 300 --select_scene_num 270
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-16 --scene_type virtual_cz --dataset waterbirds --alpha 0.8 --lam_hat 1.5 --lam 0.8 --batch_size 300 --select_scene_num 270
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-L-14 --scene_type virtual_cz --dataset waterbirds --alpha 0.65 --lam_hat 1.5 --lam 0.6 --batch_size 200 --select_scene_num 190
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-H-14 --scene_type virtual_cz --dataset waterbirds --alpha 0.7 --lam_hat 1.0 --lam 0.6 --batch_size 100 --select_scene_num 90
```

### Parameters

- `--model`: CLIP model to use (ViT-B-32, ViT-B-16, ViT-L-14, ViT-H-14)
- `--dataset`: Dataset to process (waterbirds, urbancars, nico)
- `--scene_type`: Scene type (outer_cz, inner_cz, virtual_cz)
- `--alpha`: Weight for mixed background and target object
- `--lam`: Weight for final prediction by counterfactual calibration
- `--lam_hat`: Weight for TDE scores
- `--batch_size`: Batch size for processing
- `--select_scene_num`: Number of scenes to select

### Ablation

To verify the TDE module, sampler and intervention module, please run:

$\lambda$

```bash
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-32 --scene_type outer_cz --dataset waterbirds --alpha 0.4 --lam_hat 1.1 --lam 0 --batch_size 200 --select_scene_num 50
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-16 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.4 --lam 0 --batch_size 100 --select_scene_num 25
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-L-14 --scene_type outer_cz --dataset waterbirds --alpha 0.8 --lam_hat 1.3 --lam 0 --batch_size 100 --select_scene_num 2
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-H-14 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.7 --lam 0 --batch_size 100 --select_scene_num 2
```

$\hat{\lambda}$

```bash
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-32 --scene_type outer_cz --dataset waterbirds --alpha 0.4 --lam_hat 0 --lam 0.8 --batch_size 200 --select_scene_num 50
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-16 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0 --lam 0.6 --batch_size 100 --select_scene_num 25
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-L-14 --scene_type outer_cz --dataset waterbirds --alpha 0.8 --lam_hat 0 --lam 0.7 --batch_size 100 --select_scene_num 2
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-H-14 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0 --lam 0.7 --batch_size 100 --select_scene_num 2
```

$select\_scene\_num$

```bash
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-32 --scene_type outer_cz --dataset waterbirds --alpha 0.4 --lam_hat 1.1 --lam 0.8 --batch_size 200 --select_scene_num 50
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-B-16 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.4 --lam 0.6 --batch_size 100 --select_scene_num 50
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-L-14 --scene_type outer_cz --dataset waterbirds --alpha 0.8 --lam_hat 1.3 --lam 0.7 --batch_size 100 --select_scene_num 50
python counterfactualCLIP/run_counterfactualCLIP.py --model ViT-H-14 --scene_type outer_cz --dataset waterbirds --alpha 0.2 --lam_hat 0.7 --lam 0.7 --batch_size 100 --select_scene_num 50
```

### Visualization

To visualize attention maps using original CLIP (Grad-CAM) and C(x) (TDE) methods:

```bash
python others/prompt/compute_prs_image_attention.py --model ViT-L-14 --attens_type Grad-CAM
python others/prompt/compute_prs_image_attention.py --model ViT-L-14 --attens_type TDE
```
![Teaser](picture/visual_more.png)

## Jupyter Notebooks

We also provide Jupyter notebook demos to visualize:
- Attention mechanisms (Grad-CAM, C(x))
