# OT-VAE (ICLR 2023 Submission)

This repository contains Pytorch implementation of OT-VAE to reproduce experiments on CelebA and CelebAMask

## Dependencies

This code has been tested under `torch1.7.1+cu110` and `torchvision0.8.2+cu110`, to install the package.
```
pip3 install -r requirements.txt
```

## Datasets 

### CelebA

Download the align and cropped images (named `img_align_celeba.zip`) from this [link](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg&usp=sharing)

Put `img_align_celeba.zip` to `./data/CelebA/celeba` and unzip it

The data structure should be organized as :
```
./data/CelebA/celeba
├── identity_CelebA.txt
├── list_attr_celeba.txt
├── list_bbox_celeba.txt
├── list_eval_partition.txt
├── list_landmarks_align_celeba.txt
└── img_align_celeba
```

### CelebAMask
Preprocessing the CelebAMask-HQ refer to this [link](https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing)

The data structure should be organized as :
```
./data/CelebAMask-HQ/
├── train_img
├── train_label
├── train_list.txt
├── val_img
├── val_label
├── val_list.txt
├── test_img
├── test_label
└──test_list.txt
```


## Training
The training of a model can be done by calling main.py with the corresponding yaml file. The list of yaml files can be found below.
Please refer to main.py (or execute 'python main.py --help') for the usage of extra arguments.

### Setup steps before training of a model
* Set the CHECKPOINT path "_C.path" (/configs/defaults.py:4) 
* Set the DATASET path, "_c.path_dataset" (/configs/defaults.py:5).


### Reproducing experiments on CelebA

Experiments with `lambda = 1e-3` and `log(1/tau) = 1`, repeating experiments three times :
```
python3 main.py -c celeba_lambda0001_tau1.yaml --save --dbg --seed 0
python3 main.py -c celeba_lambda0001_tau1.yaml --save --dbg --seed 123
python3 main.py -c celeba_lambda0001_tau1.yaml --save --dbg --seed 456
```

### Reproducing experiments on CelebAMask-HQ

Experiments with `lambda = 1e-3` and `log(1/tau) = 1`, repeating experiments three times :
```
python3 main.py -c celebamask_lambda0001_tau1.yaml --save --dbg --seed 0
python3 main.py -c celebamask_lambda0001_tau1.yaml --save --dbg --seed 123
python3 main.py -c celebamask_lambda0001_tau1.yaml --save --dbg --seed 456
```


