# Weakly-Supervised Disentangled Representation Learning via Filter-Based Adaptive Swapping


## How to use

### Environment Dependencies
```
python 3.9.7
pytorch 1.10.2
visdom 0.2.4
tqdm 4.65.0
numpy 1.26.4
matplotlib 3.8.0
pandas 2.2.1
spflow 0.0.41
opencv-python 4.9.0.80
```

### 1. Download dataset

#### Traffic Sign dataset
* Download datasets from Google Drive: [TrafficSign](https://drive.google.com/drive/folders/18ifcsyuXwWztMPT2LKtDXUIA5tK2uOhM?usp=sharing) and [TrafficSignRealWorld](https://drive.google.com/drive/folders/14dSUtgxYFMhBDw6RcNa1NrOc33Bd9a4F?usp=sharing).
* Set the dataset as follow structure:
    ```
        traffic_10x8x100
        ├── train
        │   ├── deerCrossing
        │   │   └── xxx.png
        │   ├── handicappedCrossing
        │   ├── leftCurve
        │   ├── oneway
        │   ├── speedlimit25mph
        │   ├── stopsign
        │   ├── warning
        │   └── workersAhead
        ├── class_label.csv
        └── traffic_8000_3x128x128.npz
    ```
    ```
        real_data
        ├── deercrossing
        │   └── xxx.png
        ├── leftcurve
        ├── oneway
        ├── pedestrian
        ├── speedlimit25mph
        ├── stop
        ├── warning
        ├── workersahead
        └── class_label.csv

    ```

* ```traffic_8000_3x128x128.npz``` description: 

    This dataset uses 3 latents, controlling the color, shape, and orientation of a traffic sign with latent size of [10, 8, 100].  All possible combinations of the latents are presented with the ordering of ndarrays. Each image is presented using dataset['imgs'](0, 0, 1).

    Keys in dataset:
    ```
    ['imgs', 'latent_sizes']
    ```

### 2. Run Experiments

#### 2.1 Disentanglement
```
cd ControlVAE
```

1. Unsupervised ControlVAE:
    ```
    ./run_model128.sh
    ```

2. Swapped ControlVAE (supervised pairwise inputs):
    ```
    ./run_model128_swap.sh
    ```
* Visualization of disentangled latent feature **z** will be generated in ```./outputs/[viz_name]``` through **viz_traverse** function.
  
* dim-wise KL values over training are recored in ```./checkpoints/[viz_name]/train.kl```.

* Losses are recorded in ```./checkpoints/[viz_name]/train.log```. total_kld is expected to be equal to exp_kld as training goes on.

#### 2.2 Classification
1. Save z embeddings into .csv files. Set ```--Train False``` and run
    ```
    ./run_model128_swap.sh
    ```
2. Generate attacked images and save their z embeddings:
    ```
    cd Attack
    ./run_attack_traffic.sh
    ```
3. Run SPN classification. Need to specify the saved embedding file path ```train_filename```, ```test_filename``` and relevant latent index ```feature_indices```.
    ```
    cd ../../SPN
    ./main.py
    ```


#### 2.3 Baselines
- KEMLP:

    Train: Specify dataset ('traffic' or 'real_data') and run
    ```
    cd ControlVAE/Attack/KEMLP
    python adv_train.py # KEMLP(AdvTrain)
    python doa_train.py # KEMLP(DOA)
    ```
    Test: specify the saved checkpoint path and run
    ```
    python eval.py
    ```

- FSR:

    Train: run
    ```
    cd ControlVAE/Attack/FSR
    python train.py --dataset 'traffic' (or 'realdata')
    ```
    Test: specify the saved checkpoint path run
    ```
    python test.py --dataset 'traffic' (or 'realdata') -bs 1
    ```

---

## Acknowledgments

We thank the following repos/papers providing baselines and helpful functions.

* [ControlVAE](https://github.com/shj1987/ControlVAE-ICML2020/tree/master)
* [SWVAE](https://arxiv.org/abs/2006.09882)
* [KEMLP](https://github.com/AI-secure/Knowledge-Enhanced-Machine-Learning-Pipeline)
* [FSR](https://github.com/wkim97/FSR/tree/main)

