# MMBS
This repository consists of an implementation of the MMBS attribution method, and code for benchmarking it against other attribution methods. It was used in the paper "Monte Carlo Multi-Feature Baseline Shapley (MMBS): An axiomatic attribution method for fine-grained explanations of image classification networks", which can be found on [https://openreview.net/forum?id=LLFIcr7zWh](https://openreview.net/forum?id=LLFIcr7zWh).

## Scripts in this repository
### Core components
The core components that are re-used between scripts are located in the "xai" folder.

### Example notebook
examples.ipynb is a Jupyter notebook with examples.

### Scripts for generating the figures and tables in the paper
The following scripts were used to generate the following figures and tables in the paper:
- *Figure 1:* plot\_attribution\_map\_differences.py
- *Figure 2:* calc\_fashion\_mnist\_convergence.py, plot_fashion_mnist_convergence.py
- *Figure 3:* calc\_fashion\_mnist\_sweep, plot\_fashion\_mnist_sweep.py
- *Figure 4:* calc\_method\_comparison.py, plot\_deletion\_curves.py
- *Figure 5:* plot\_baseline\_differences.py
- *Figures C.1, C2:* calc_imagenet_convergence.py, plot_imagenet_convergence.py
- *Figure C.3:* calc\_mmbs\_mbshap\_comparison.py, plot\_mmbs\_mbshap\_comparison.py
- *Table 1:* calc\_runtime.py, parse_runtime
- *Table 2:* calc\_method\_comparison.py, parse\_method\_comparison.py
- *Table 3:* calc\_baseline\_audcs.py, parse\_baseline\_audcs.py
- *Tables D.1, D.2:* calc\_AUDC\_sensitivity.py, parse\_AUDC\_sensitivity.py
The svn files that were generated for the figures were edited by hand in Inkscape to add labels and improve the layout.

## Running the code
### Setting up a conda environment
A conda environment can be set up by running the commands from the file *create_environment_mmbs.txt* in a terminal where anaconda is installed.

### Setting up the datasets and folder locations
To be able to run the scripts from the paper and the examples you need to create an extra script *folder_locations.py* that contains three functions: get\_fashion\_mnist_data_path(), get\_imagenet\_val_data_path(), and get\_experiments\_path():
```python
from pathlib import Path

def get_fashion_mnist_data_path():
    return Path(<<FILL IN A PATH>>)

def get_imagenet_val_data_path():
    return Path(<<FILL IN A PATH>>)

def get_experiments_path():
    return Path(<<FILL IN A PATH>>)
```

The get\_fashion\_mnist_data_path should point towards an empty folder where the fashion MNIST dataset can be downloaded. By running the script *fetch_fashion_mnist.py* fashion MNIST will be downloaded to that folder.

The get_imagenet_val_data_path should point towards the imagenet-val folder with the validation images from ImageNet1k, that can be downloaded from [Kaggle](https://www.kaggle.com/datasets/titericz/imagenet1k-val).

The get\_experiments\_path should point towards an empty folder where the results from the experiments can be saved.

