# Bort: Towards Explainable Neural Networks with Bounded Orthogonal Constraint

This repository is the official implementation of **Bort: Towards Explainable Neural Networks with Bounded Orthogonal Constraint**. 

![Alt text](./framework.png)

## Preparation

### Dataset

#### MNIST

- Download from [here](http://yann.lecun.com/exdb/mnist/)

#### CIFAR-10

- Download from [here](https://www.cs.toronto.edu/~kriz/cifar.html)

#### ImageNet

- Download from [here](https://www.image-net.org/)

Organize ImageNet as follows:

```
- dataset
    |- train
    |   |- class1
    |   |   |- image1
    |   |   |- ...
    |   |- ...
    |- test
        |- class1
        |   |- image1
        |   |- ...
        |- ...
```

For visualization, please download the sub-dataset from [here](https://ufile.io/pbmbztbj), and put the files in the [data folder](data).

### Pre-trained models

We can download pre-trained models for visualization from [here](https://ufile.io/pbmbztbj), and put the files in the [log folder](log).

### Requirements

To install requirements:

```setup
pip install -r requirements.txt
```

### Device 

We tested our code on a linux machine with an Nvidia RTX 3090 GPU card. We recommend using a GPU card with a memory > 8GB. Nevertheless, using our pre-trained models provided above for evaluation only takes ~3 GB memory.

## Training Models

### Baseline approach

To train simple (ACNN-Small) network on CIFAR-10, please run the command as follows:
```python
python examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.1 --wd 0.01 --act_type guided
```

### Bort approach

To train simple (ACNN-Small) network on CIFAR-10, please run the command as follows:
```python
python examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim bort --wc 0.1 --wd 0.01 --act_type guided
```

### Ablation study

To re-produce the ablation study results on CIFAR-10, you can run commands below:
```python
python examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim bogd --wc 1 --wd 0.01 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.1 --wd 0.01 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.01 --wd 0.01 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.001 --wd 0.01 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.0001 --wd 0.01 --act_type guided

python examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.01 --wd 1 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.01 --wd 0.1 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.01 --wd 0.01 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.01 --wd 0.001 --act_type guided
train.pypython examples/train.py --dataset cifar10 --bs 256 --epochs 40 --lr 0.01 --model simple --optim sgd --wc 0.01 --wd 0.0001 --act_type guided
```

### More options

There are more options to train various models:
| Args | Options |
| - | - |
| --dataset | mnist / cifar10 / imagenet |
| --model | simple / allconv12 / lenet |
| --optim | sgd / bort / adamw / abort |


> **NOTE**  
    - To train your networks on ImageNet, we recommand running [timm](https://github.com/rwightman/pytorch-image-models) by inserting our Bort's [code](src/bort/optimizers) in [this folder](https://github.com/rwightman/pytorch-image-models/tree/master/timm/optim).  
    - Since training models on ImageNet is time consuming, we have provided a download link of pre-trained models for convenience.

## Reconstruction

<table align="center">
    <tr align="center">
    <td><img src="results/reconstruction/mnist/4-ori.png" width="160" height="160" alt=""></td>
    <td><img src="results/reconstruction/mnist/4-sgd.png" width="160" height="160" alt=""></td>
    <td><img src="results/reconstruction/mnist/4-bort.png" width="160" height="160" alt=""></td>
    </tr>
    <tr align="center">
    <td>Original image</td>
    <td>SGD</td>
    <td>Bort</td>
    </tr>
</table>

- **Step 1**: Please select [hyper-parameters](_1_reconstruction.py) first:
    - `suffix` (Line 58~59): sgd / bort
    - `dataset` (Line 61~63): mnist / cifar10 / imagenet
- **Step 2**: Run the command below:
    ```
    python _1_reconstruction.py
    ```
- **Step 3**: Check the reconstruction images at the output folder.

## Saliency map

<table align="center">
    <tr align="center">
    <td><img src="results/saliency_map/mnist/979-ori.png" width="160" height="160" alt=""></td>
    <td><img src="results/saliency_map/mnist/979-sgd-cam.png" width="160" height="160" alt=""></td>
    <td><img src="results/saliency_map/mnist/979-sgd-tracetopk.png" width="160" height="160" alt=""></td>
    <td><img src="results/saliency_map/mnist/979-bort-tracetopk.png" width="160" height="160" alt=""></td>
    </tr>
    <tr align="center">
    <td>Original image</td>
    <td>CAM</td>
    <td>SGD</td>
    <td>Bort</td>
    </tr>
</table>

- **Step 1**: Please select [hyper-parameters](_2_saliency_map.py) first:
    - `xai_name` (Line 35): cam / tracetopk
    - `layer_name` (Line 36): For tracetopk, `act3` for MNIST/CIFAR-10 and `act8` for ImageNet; For CAM, `act3` for MNIST/CIFAR-10 and `act12` for ImageNet
    - `dataset` (Line 53~55): mnist / cifar10 / imagenet
    - `optim` (Line 57~58): sgd / bort
- **Step 2**: Run the command below:
    ```
    python _2_saliency_map.py
    ```
- **Step 3**: Check the reconstruction images at the output folder.

## Deletion / Insertion

- **Step 1**: Please select [hyper-parameters](_3_run_benchmark.py) first:
    - `xai_name` (Line 32): cam / tracetopk
    - `layer_name` (Line 33): For tracetopk, `act3` for MNIST/CIFAR-10; For CAM, `act3` for MNIST/CIFAR-10
    - `mpath` (Line 45~48)
- **Step 2**: Run the command below:
    ```
    python _3_run_benchmark.py
    ```

## Decomposition and adversarial sample synthesis

### Decomposition results

<table align="center">
    <tr align="center">
    <td><img src="results/decompose/0-ori.png" width="160" height="160" alt=""></td>
    <td><img src="results/decompose/0-topk64-bort.png" width="160" height="160" alt=""></td>
    </tr>
    <tr align="center">
    <td>Original image</td>
    <td>Top 64 reconstruction</td>
    </tr>
    <tr align="center">
    <td><img src="results/decompose/all_decom-bort.png" width="320" height="320" alt=""></td>
    <td><img src="results/decompose/kmeans-bort.png" width="160" height="320" alt=""></td>
    </tr>
    <tr align="center">
    <td>Top 64 activations</td>
    <td>K-Means (8 clusters)</td>
    </tr>
</table>

### Adversarial generation results

<table align="center">
    <tr align="center">
    <td><img src="results/decompose/src-x: 0-srcimg.png" width="160" height="160" alt=""></td>
    <td><img src="results/decompose/src-x: 7-trgimg.png" width="160" height="160" alt=""></td>
    <td><img src="results/decompose/src-x: 7-trasimg.png" width="160" height="160" alt=""></td>
    </tr>
    <tr align="center">
    <td><img src="results/decompose/src-x: 0-srcimg-bar.png" width="160" height="80" alt=""></td>
    <td><img src="results/decompose/src-x: 7-trgimg-bar.png" width="160" height="80" alt=""></td>
    <td><img src="results/decompose/src-x: 7-trasimg-bar.png" width="160" height="80" alt=""></td>
    </tr>
    <tr align="center">
    <td>Source</td>
    <td>Target</td>
    <td>Adversarial sample</td>
    </tr>
</table>

- **Step 1**: Run the command below:
    ```
    python _4_decomposition.py
    ```

