# Exploring Unified Perspective for Fast Shapley Value Estimation

This repository is the official implementation of **Exploring Unified Perspective for Fast Shapley Value Estimation**. 



## Preparation

### Dataset

#### Census

- We directly use `*shap.datasets.adult()`.

#### News

- Download from [here](https://archive.ics.uci.edu/dataset/332/online+news+popularity) and place it in ./exp/census folder. Preprocessing code is in all the SimSHAP training file.

#### Bankruptcy

- Download from [here](https://archive.ics.uci.edu/dataset/572/taiwanese+bankruptcy+prediction). Run ` preprocessing.py` in ./exp/bankruptcy folder and then use Taiwan_data_ENG_95.csv for further experiments.

#### CIFAR-10

- Download from [here](https://www.cs.toronto.edu/~kriz/cifar.html)

### Requirements

To install requirements:

```bash
pip install -r requirements.txt
```

2 additional packages need to be installed following the README.md in their github page.
[fastshap](https://github.com/iancovert/fastshap) and [shapreg](https://github.com/iancovert/shapley-regression)

### Device 

We tested our code on a linux machine with an Nvidia RTX 3090 GPU card. We recommend using a GPU card with a memory > 8GB.

### Train Model

train the simshap model and fastshap model in the exp folder. 
```bash
cd exp/<dataset_name>
python <dataset_name> simshap.py --test_training_speed True/False
```
<dataset_name> = census, news, bankruptcy, cifar10. 

The '--test_training_speed' args is for train speed evaluation, default is False.
more examples in `run.sh`.
### Validation

- To test loss curve for tabular datasets, run
```bash

cd exp/<dataset_name>
python exp/<dataset_name>/evaluation_losscurve.py
```
The results are stored in exp/<dataset_name>/results_losscurve folder

- To test Insertion and Deletion metrics, run
```bash
cd exp/cifar10
python evaluation_insertion_deletion.py
```
The results are stored in exp/<dataset_name>/results_deletion_insertion folder.

- To generate explanations for CIFAR-10, run
```bash
cd exp/cifar10
python generation.py
```
- To conduct ablation study on tabular datasets, run
```bash
cd exp/<dataset_name>
python ablation study.py --lr <lr> --batch_size <batch_size> --epochs <epochs> --num_samples <num_samples>
```
for CIFAR-10, there is an additional argument '--data_percent <data_percent>', ranging from 0 to 1.