# Tutorial: Estimating Shapley Values for CIFAR diffusion models

This tutorial provides a step-by-step guide to perform experiments using our framework, focusing on the CIFAR dataset. It includes commands for training models, calculating global behaviors, and estimating data attributions using Shapley values. Before proceeding, ensure you have installed the required packages and configured directory paths as outlined in the [README](README.md).

## Setup Environment Variable

Set the base directory for your experiments to ensure consistent paths across commands::

```bash
export BASEDIR=/data/diffusion-attr
```

## Train the Full Model

Train a diffusion model on the complete CIFAR dataset. This serves as the baseline for later comparisons.

Command:

```bash
python main.py --dataset --method retrain
```

## Train Models on Subsets of Data Contributors and Calculate Global Behaviors

* Sparsified unlearning

Use sparsified fine-tuning to efficiently calculate model behavior for subsets of data contributors::

```bash
python unconditional_generation/unlearn.py \
    --dataset cifar100 \
    --method gd \
    --removal_dist shapley \
    --removal_seed 24 \
    --db $BASEDIR/results/cifar100/global_behavior.jsonl \
    --n_samples 10240 \
    --generate_samples \
    --batch_size 128 \
    --exp_name gd_vs_train \
    --use_ema \
    --trained_steps 1000
```

* Naive retraining from scratch (Reference)

Train the model from scratch on a subset of data contributors for comparison:

```bash
python unconditional_generation/main.py \
    --dataset cifar100 \
    --method retrain \
    --removal_dist shapley \
    --removal_seed 0
```

## Calculating global model behaviors

Measure global model properties like inception score or diversity score. The following commands generate `n_samples` samples and calculate global model behaviors for the specified methods:

* Retraining from scratch (Reference)

```bash
python unconditional_generation/calculate_global_scores.py \
    --dataset cifar100 \
    --method retrain \
    --removal_dist shapley \
    --removal_seed 24 \
    --db $BASEDIR/results/cifar100/global_behavior.jsonl \
    --n_samples 10240 \
    --generate_samples \
    --batch_size 128 \
    --exp_name gd_vs_train \
    --use_ema \
    --trained_steps 1000
```

## Quntifying data attributions and calculating LDS scores

* retraining/sFT/FT

Estimate Shapley values and compute LDS (Linear Datamodeling Score) for attribution methods:

```bash
python lds.py \
    --train_db $BASEDIR/results/cifar100/shapley/retrain_gd_global_behavior.jsonl \
    --dataset cifar100 \
    --test_exp_name retrain_vs_train \
    --train_exp_name retrain_vs_train \
    --removal_dist shapley \
    --method retrain  \
    --model_behavior_key is \
    --num_test_subset 100 \
    --max_train_size 500 \
    --by_class \
    --null_db $BASEDIR/results/cifar100/shapley/null_global_behavior.jsonl \above--full_db $BASEDIR/results/cifar100/shapley/full_global_behavior.jsonl
```

* Baseline (TRAK)

Use TRAK as a baseline to our method:

```bash
# D-TRAK computation
python src/attributions/methods/d_trak_grad.py \
    --dataset cifar100 \
    --model_behavior mean-squared-l2-norm \
    --t_strategy uniform \
    --k_partition 100 \
    --projector_dim 4096 \
    --method retrain
```

```bash
python baseline_lds.py \
    --dataset cifar100 \
    --test_exp_name retrain_vs_train \
    --projector_dim 32768 \
    --sample_dir $BASEDIR/cifar100/local_scores/ema_generated_samples \
    --num_test_subset 100 \
    --by_class \
    --sample_size 50 \
    --model_behavior_key is \
    --method trak \
    --gradient_type trak \
    --k_partition 100 \
    --projector_dim 4096
```

* Baseline (CLIP similarity)

Calculate CLIP similarity as another baseline:

```bash
python baseline_lds.py \
    --dataset cifar100 \
    --test_exp_name retrain_vs_train \
    --method clip_score \
    --model_behavior_key is \
    --num_test_subset 100 \
    --by_class \
    --sample_dir $BASEDIR/results/cifar100/retrain/ema_generated_samples/full/steps=20000 \
    --training_dir datasets/cifar100/train \
    --datamodel_alpha 0.25
```
