# ROBUST DIFFUSION GAN USING SEMI-UNBALANCED OPTIMAL TRANSPORT #

Diffusion models, a type of generative model, have demonstrated great potential for synthesizing highly detailed images. By integrating with GAN, advanced diffusion models like DDGAN could approach real-time performance for expansive practical applications. While DDGAN has effectively addressed the challenges of generative modeling, namely producing high-quality samples, covering different data modes, and achieving faster sampling, it remains susceptible to performance drops caused by datasets that are corrupted with outlier samples. This work introduces a robust training technique based on semi-unbalanced optimal transport to mitigate the impact of outliers effectively. Through comprehensive evaluations, we demonstrate that our robust diffusion GAN (RDGAN) outperforms vanilla DDGAN in terms of the aforementioned generative modeling criteria, i.e., image quality, mode coverage of distribution, and inference speed, and exhibits improved robustness when dealing with both clean and corrupted datasets.

## Set up datasets ##
We trained on several datasets, including CIFAR10, LSUN Church Outdoor 256, CelebA HQ 256, MNIST. 

## Training Denoising Diffusion GANs ##
We use the following commands for training robust diffusion GANs.

#### CIFAR-10 pertubed by MNIST (5%) ####

We train Denoising Diffusion GANs using 1 32-GB V100 GPU. 
```
python3 train_ddgan.py --dataset cifar10 --exp ddgan_cifar10_exp1 --num_channels 3 --num_channels_dae 128 --num_timesteps 4 --num_res_blocks 2 --batch_size 64 --num_epoch 1800 --ngf 64 --nz 100 --z_emb_dim 256 --n_mlp 4 --embedding_type positional --use_ema --ema_decay 0.9999 --r1_gamma 0.02 --lr_d 1.25e-4 --lr_g 1.6e-4 --lazy_reg 15 --num_process_per_node 1 --ch_mult 1 2 2 2 --save_content --version ba64 --master_port 6026 --perturb_dataset mnist --perturb_percent 5
```

We train Denoising Diffusion GANs using 1 32-GB V100 GPU. 
```
python3 train_rdgan.py --dataset cifar10 --exp ddgan_cifar10_exp1 --num_channels 3 --num_channels_dae 128 --num_timesteps 4 --num_res_blocks 2 --batch_size 64 --num_epoch 1800 --ngf 64 --nz 100 --z_emb_dim 256 --n_mlp 4 --embedding_type positional --use_ema --ema_decay 0.9999 --r1_gamma 0.02 --lr_d 1.25e-4 --lr_g 1.6e-4 --lazy_reg 15 --num_process_per_node 1 --ch_mult 1 2 2 2 --save_content --version ba64 --master_port 6023 --phi1 softplus --phi2 softplus --perturb_dataset mnist --perturb_percent 5
```

## Evaluation ##
After training, samples can be generated by calling ```test.py```. 
Below, we use `--epoch_id` to specify the checkpoint saved at a particular epoch.
Specifically, for models trained by above commands, the scripts for generating samples on CIFAR-10 (DDGAN) is
```
python3 test.py --dataset cifar10 --num_channels 3 --num_channels_dae 128 --num_timesteps 4 --batch_size 1800 --num_res_blocks 2 --nz 100 --z_emb_dim 256 --n_mlp 4 --ch_mult 1 2 2 2 --version bs256 --master_port 6038 --compute_fid --epoch_start 1100 --epoch_end 1800 --epoch_jump 25
```
Note: add --phi1 for test the model.

We use the [PyTorch](https://github.com/mseitzer/pytorch-fid) implementation to compute the FID scores, and in particular, codes for computing the FID are adapted from [FastDPM](https://github.com/FengNiMa/FastDPM_pytorch).

To compute FID, run the same scripts above for sampling, with additional arguments ```--compute_fid```.

For Inception Score, save samples in a single numpy array with pixel values in range [0, 255] and simply run 
```
python ./pytorch_fid/inception_score.py --sample_dir /path/to/sampled_images
```
where the code for computing Inception Score is adapted from [here](https://github.com/tsc2017/Inception-Score).

For Improved Precision and Recall, follow the instruction [here](https://github.com/kynkaat/improved-precision-and-recall-metric).


## License ##