
# ExPLAIND: Unifying Model, Data, and Training Attribution to Study Model Behavior

This repository provides the implementations of the experiments presented in ExPLAIND: Unifying Model, Data, and Training Attribution to Study Model Behavior. To jump directly to the experiments, go to `experiments/`.


## Requirements

### Manual installation

We ran all our experiments in python version `3.12.7`. You can use `conda` to create a fresh environment first and install the necessary packages using `pip`.

```
conda create -n explaind python=3.12
conda activate explaind

pip install torch torchvision numpy tqdm tensorboard pandas
```

If you also want to recreate the plots shown in the paper, you additionally need the following packages:

```setup
pip install plotly umap_learn
```

Alternatively, you can also install from the requirements file:

```setup
pip install -r requirements.txt
```

## Training models with history

If you want to apply ExPLAIND to your own model, you need to retrain it tracking relevant parts of the training process by using the wrappers provided in this repository. Note that depending on model size full history tracking can become very expensive. We're currently working on a solution for allowing cheaper partial tracking. For example, the training process of the modulo addition model includes the following additions:

```python
model = SingleLayerTransformerClassifier().to(device)
# wrap into path wrapper
model = ModelPath(model, device=device, checkpoint_path="model_checkpoint.pt")
loss_fct = RegularizedCrossEntropyLoss(alpha=alpha, p=reg_pow, device=device)
optimizer = AdamWOptimizerPath(model, checkpoint_path="optimizer_checkpoint.pt")
data_path = DataPath(train_loader, checkpoint_path=checkpoint_path + "data_checkpoint.pt", overwrite=True, full_batch=False)

for epoch in range(epochs):
    for batch in data_path.dataloader:
        x, y = data_path.get_batch(batch)
        optimizer.zero_grad()
        output = model.forward(x)
        l, reg = loss_fct(output, y, params=model.parameters(), output_reg=True)
        l.backward()
        optimizer.step()

        # log checkpoint values we need for epk prediction
        model.log_checkpoint() 
        optimizer.log_checkpoint()

# save the checkpoints to disk at the locations defined above
# loading these later, we can compute the EPK reformulation of the model
optimizer.save_checkpoints()
model.save_checkpoints()
data_path.save_checkpoints()
```

For actual, executable training scripts, you can have a look at `experiments/train_models/modulo_model.py` and `experiments/train_models/cifar2_model.py`.

## Getting EPK predictions and kernel accumulations

Once you have the history of the training run you want to explain, you can load them into the EPK module and compute the prediction of the reformulated model as follows:

```python
epk = ExactPathKernelModel(
    model=model,  # wrapper from before
    optimizer=optimizer,  # wrapper from before
    loss_fn=RegularizedCrossEntropyLoss(alpha=0.0),
    data_path=data_path,  # wrapper from before
    integral_eps=0.01,  # 1/eps = 1/0.01 = 100 integral steps
    evaluate_predictions=True,
    keep_param_wise_kernel=True,
    param_wise_kernel_keep_out_dims=True,
)

# make batch size small enough so you don't run OOM
val_loader = torch.utils.data.DataLoader(val_loader.dataset, batch_size=100, shuffle=False)

preds = []
for i, (X, y) in enumerate(val_loader):
    torch.cuda.empty_cache()
    X = X.to(device)
    y = y.to(device)
    pred = epk.predict(X, y_test=y, keep_kernel_matrices=True)
    preds.append((i, pred, y))
```

Note that there are different settings for which (accumulated) slices of the kernel to store during the prediction. Depending on your choices there, runtimes can vary greatly because of GPU I/O and extra matrix computations involved. For the complete respective valiadtion scripts, consider giving `experiments/validate_epk/` a look.

## Experiments, ablations, and plots

Besides further instructions on how to reproduce the experiments in our paper, the `experiments/` folder contains all the scripts to run additional experiments, ablations, and generate plots. Any checkpoints, plots or other artifacts will be stored in `results/` by default.


