# Wavelet Attribution Method: a unified approach for black-box feature attribution to any modality

This repository contains the code for the paper <i>One Wave to Explain Them All: A Unifying Perspective on Black-Box Explainability</i>. 

## Motivation 

### Method overview


The growing use of deep neural networks has highlighted the need for explainable AI (XAI) techniques due to the black-box nature of these models. XAI methods, especially gradient-based feature attribution methods like saliency maps, have become popular for interpreting model behavior, particularly in image classification. However, traditional pixel-based methods collapse the structural properties of images and are limited to image data. This work introduces the <b> Wavelet Attribution Method (WAM) </b>, a universal feature attribution technique that operates in the wavelet domain, preserving the inter-scale dependencies of input data. By extending existing methods like SmoothGrad and Integrated Gradients, WAM enables saliency mapping for various modalities, including images, audio, and 3D shapes. The wavelet domain approach provides more robust and interpretable explanations, outperforming existing methods in vision, audio processing, and 3D classification tasks.

![Flowchart of the Wavelet Attribution Method](figures/git-flowchart.png)


#### Illustration of the method for 2D and 3D signals 

This figure shows how the WCAM can decompose the different scales that are important for the prediction of a given signal. For images, scales are indexed in pixels, thus the finest scales are 1-2 pixel wide, the intermediate scales are 2-4 pixels wide and the coarsest scale correspond to the approximation details. We can see that the important content is located at the location of the target class.

![Scale decomposition of 3D shapes and images](figures/git-diagram.png)

#### Noise example on 1D signal 

We add 0dB white noise on he audio of the target class ('Crow') to form the input to the classifier. Interpretation audio reconstructed with important wavelet coefficients virtually eliminate noise, and also clearly emphasize parts of the target class audio (indicated with green box).

![Noise experiment](figures/git-audio.png)


## Repository usage

You can use this repository as a standalone to visualize explanations on images. To generate examples on sounds and on shapes, you will need to download the model weights and the datasets, supplied in the companion data `material`. This filder can be downloaded at the following URL: [https://doi.org/10.5281/zenodo.13873810](https://doi.org/10.5281/zenodo.13873810).

Once downloaded, set the root path in the notebooks and scripts as the path to the unzipped folder. For more details on the content of the folder `material`, please see the attached Readme.

We recommend that you create a virtual environement using the following commands:

```bash
cd wavelet-attribution
conda env create --file wavelet_attrib.yml
conda activate wavelet_attrib
```

### Demo of the feature attribution methods

Use the notebooks `demo_3D_voxels.ipynb`, `demo_images.ipynb` or `demo_sounds.ipynb` to generate and visualize explanations generated with our WAM. 

### Replication of the quantitative results

To replicate the Tables from the paper, please use the scripts provided in the `eval` folder. The script `evaluate_1d.py` enables you to replicate the results for the Audio part and requires the dataset ESC50. The script `evaluate_2d.py` requires the validation dataset of ImageNet, or its restriction to the 1,000 samples that we used for computing the metrics. You'll have to download the ImageNet validation dataset or you can use our 1,000 samples, provided in the `benchmark` subfolder of the `material` folder. 


To execute the script, use the following command:

```bash
python evaluate_{dimension}.py --source_dir /path/to/data --destination_dir /path/to/results --method wcam --metric insertion --device cuda
```

Similarly, you can reproduce the randomization sanity check by running the script `randomization.py` and reproduce the experiment on the link between the robustness of a model and its reliance on coarse scales (i.e. low frequency ranges) using the script `robustness.py`.
