# Requirements

First, install PyTorch meeting your environment (at least 1.7, recommmended 1.10):
```bash
pip3 install torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio==0.10.0+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
```

Then, use the following command to install the rest of the libraries:
```bash
pip3 install tqdm ninja h5py kornia matplotlib pandas sklearn scipy seaborn wandb PyYaml click requests pyspng imageio-ffmpeg prdc torchdiffeq
```

# Quick Start

Before starting, users should login wandb using their personal API key.

```bash
wandb login PERSONAL_API_KEY
```

## Dataset

* CIFAR10 will automatically download the dataset once you execute ``main.py``.

* Load All Data in Main Memory (``-hdf5 -l``)
  ```bash
  CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -t -hdf5 -l -cfg ./SPI-GAN/src/configs/CIFAR10/StyleGAN2.yaml -data DATA_PATH -save ./SPI-GAN
  ```

## Train

* Train (``-t``) and evaluate IS, FID, Prc, Rec, Dns, Cvg (``-metrics is fid prdc``) of the model defined in ``CONFIG_PATH`` using GPU ``0``.
```bash
CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -hdf5 -l -metrics is fid prdc -mpc -cfg ./SPI-GAN/src/configs/CIFAR10/StyleGAN2.yaml -data DATA_PATH -save ./SPI-GAN
```

* Train (``-t``) and evaluate FID of the model defined in ``CONFIG_PATH`` through ``DataParallel`` using GPUs ``(0, 1, 2, 3)``. Evaluation of FID does not require (``-metrics``) argument!
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -hdf5 -l -metrics is fid prdc -mpc -cfg ./SPI-GAN/src/configs/CIFAR10/StyleGAN2.yaml -data DATA_PATH -save ./SPI-GAN
```

Try ``python3 src/main.py`` to see available options.


## Test
* Install the checkpoints from [here](https://drive.google.com/drive/folders/1cNOyaNzyBsJsDS1FCbh9vXkL9uCLFheC)

* Evaluate the model from the stored checkpoint(``-ckpt``).
```bash
CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -cfg ./SPI-GAN/src/configs/CIFAR10/StyleGAN2.yaml -ckpt ./SPI-GAN/checkpoints/SPIGAN_CIFAR10  -save ./SPI-GAN -data DATA_PATH -metrics is fid prdc
```

## Visualization
* Visualization(``-v``) the model from the stored checkpoint(``-ckpt``).
```bash
CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -v -cfg ./SPI-GAN/src/configs/CIFAR10/StyleGAN2.yaml -ckpt ./SPI-GAN/checkpoints/SPIGAN_CIFAR10  -save ./SPI-GAN -data DATA_PATH 
```
