# D-TRAK

Data Attribution on Diffusion Models.

## Install library
```bash
pip install accelerate==0.20.3
pip install datasets==2.12.0
pip install diffusers==0.16.1
pip install fast-jl==0.1.3
pip install torch==2.0.1
pip install torchvision==0.15.2
pip install traker==0.1.3
pip install transformers==4.30.1
```

## Quickstrat
Check out [quickstart.ipynb](quickstart.ipynb) to see how to conduct attribution on diffusion models directly!

## Preparation
Download the datasets and change the path in the code accordingly. 

CIFAR can be directly loaded from huggingface.

[CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)

[ArtBench](https://github.com/liaopeiyuan/artbench)

## Getting started

We majorly provide the commands to run experiments on CIFAR-2. It is easy to transfer to other datasets.

```shell
cd CIFAR2
```

Run [00_EDA.ipynb](CIFAR2/00_EDA.ipynb) first to create dataset splits and training set subsets.

### Train a diffusion model and generate images
```shell
bash run_train.sh 0 18888 5000-0.5
bash run_gen.sh 0 0 5000-0.5
```

### Construct the LDS benchmark

```shell
bash run_lds_val_sub.sh 0 18888 5000-0.5 0 63

bash run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_val.pkl 0 63
bash run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_val.pkl 0 63
bash run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_val.pkl 0 63

bash run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_gen.pkl 0 63
bash run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_gen.pkl 0 63
bash run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_gen.pkl 0 63
```

### Compute gradients
We shard the training set into 5 parts, each has 1000 examples.

Use the following commands to compute the gradients to be used for TRAK. 

```shell
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 loss uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 loss uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 loss uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 loss uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
```

Use the following commands to compute the gradients to be used for D-TRAK. 

```shell
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
```

### Compute the TRAK/D-TRAK attributions and evaluate the LDS scores
Run notebooks in [methods/04_if](CIFAR2/methods/04_if).

The implementations of other baselines can also be found in [methods](CIFAR2/methods).

### Conduct counterfactual evaluation 

Run this [notebook](CIFAR2/methods/04_if/get_indices_gen.ipynb) first to get the indices of those training examples to be removed.

### Retrain models after removing the top-influenctial training examples

```shell
bash scripts/run_counter.sh 0 18888 5000-0.5 0 59
```

### Generate images using the retrained models

Run [02_counter.ipynb](02_counter.ipynb)

### Measure l2 distance

Run [03_counter_eval_l2.ipynb](03_counter_eval_l2.ipynb)

### Measure CLIP cosine similarity

Run [03_counter_eval_clip.ipynb](03_counter_eval_clip.ipynb)

