# VQ-WAE
## Training
The training of a model can be done by calling main.py with the corresponding yaml file. The list of yaml files can be found below.
Please refer to main.py (or execute 'python main.py --help') for the usage of extra arguments.

### Setup steps before training of a model
* Set the checkpoint path "_C.path" (/configs/defaults.py:4) 
* Set the dataset path, "_c.path_dataset" (/configs/defaults.py:5).


### Train a model
Example 1: Gaussian SQ-VAE (I) on CIFAR10
```
python main.py -c "cifar10_sqvae_C512.yaml" --save
```
Example 3: VQ-WAE on CIFAR10
```
python main.py -c "cifar10_vqwae_C512.yaml" --save
```

### Test a trained model
Example 1: VQ-WAE on CIFAR10
```
python main.py -c "cifar10_vqwae_C512.yaml" --save -timestame resnet_seed0_0916_0610
```



### Where to find the checkpoints
If the trainning is successful, checkpoint folders will be generated under the folder (cfgs represents the yaml file specified when calling main.py):
```
configs.defaults._C.path + '/' + cfgs.path_spcific
```


### List of yaml files: models work on continuous/discrete data distributions
| Config file | Description |
|---|---|
| cifar10_sqvae_C512.yaml | Gaussian SQ-VAE (I) on CIFAR10 with codebook size of 512 |
| celeba_sqvae_C512.yaml | Gaussian SQ-VAE (I) on CelebA with codebook size of 512  |
| mnist_sqvae_C512.yaml | Gaussian SQ-VAE (I) on MNIST with codebook size of 512 |
| svhn_sqvae_C512.yaml | Gaussian SQ-VAE (I) on SVHN with codebook size of 512  |
| cifar10_vqwae_C512.yaml | VQ-WAE on CIFAR10 with codebook size of 512 |
| celeba_vqwae_C512.yaml | VQ-WAE on CelebA with codebook size of 512  |
| mnist_vqwae_C512.yaml | VQ-WAE on MNIST with codebook size of 512 |
| svhn_vqwae_C512.yaml | VQ-WAE on SVHN with codebook size of 512  |
| celeba_fast_vqwae_C512.yaml | VQ-WAE (with entropic semi-discrete dual OT: for fast computation) from on CelebA with codebook size of 512  |




## Experiments
"[checkpoint_foldername_with_timestep]" means the folder names under the path "[configs.defaults._C.path + '/' + cfgs.path_spcific]".
These folder names are consist of the model names, the seed indices and the timestamps.

## Dependencies
numpy
scipy
torch
torchvision
PIL
ot

## Acknowledgements
Codes are adapted from https://github.com/sony/sqvae/tree/main/vision. We thank them for their excellent projects.

