# Multi-scale Conditional Generative Model

This repository is for the paper "Multi-scale Conditional Generative Model". As shown by the figure below, it consists of two parts: the Brownian Bridge Diffusion Process (BBDP) at the coarsest scale, and the multi-scale GAN for the following wavelet scales.

![Multi-scale Conditional Generative Model](imgs/arch.png)


## Installation

First, clone this repository and navigate to it in your terminal. Then run:

```
pip install -e .
```

This should install the `improved_diffusion` python package that the scripts depend on.

## Preparing Data

The training code reads images from a directory of image files. In the `improved_diffusion/image_datasets`, we have prepared several dataloaders for paired images in image format, `.mat` and `.npy` format.

For creating your own dataset, please refer to `Improved_Diffusion` repository folder, we have provided instructions/scripts for preparing these directories for various datasets.

The images will automatically be scaled and center-cropped by the data-loading pipeline. However, we do recommend implement upscaling as a pre-processing step to accelerate the training process as the interpolation is relatively time-consuming. Simply pass your data directories to the training script, and it will take care of the rest.

## Training

1. The BBDP part is mainly modified from [Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672). To train your model, you should first decide some hyperparameters. The hyperparameters can be adjusted in the training scripts. For example, the training script for [natural image super-resolution](scripts/super_res_train.py) provides a list of default settings for hyperparameters and can be freely modified. Please refer to the scripts for more details about each hyperparameter. Once you have setup your hyper-parameters, you can run an experiment like so:

    ```
    python scripts/image_train.py
    ```

    You may also want to train in a distributed manner. In this case, run the same command with `mpiexec`:

    ```
    mpiexec -n $NUM_GPUS python scripts/image_train.py
    ```

    When training in a distributed manner, you must manually divide the `--batch_size` argument by the number of ranks. In lieu of distributed training, you may use `--microbatch 16` (or `--microbatch 1` in extreme memory-limited cases) to reduce memory usage.

    The logs and saved models will be written to `log` folder. And checkpoints saved to `models` folder.

2. The multi-scale GAN part is modified from [Simple Baselines for Image Restoration](https://arxiv.org/abs/2204.04676). Similarly, you should first set hyperparameters in [parameters](NAFNet_GAN/GAN/parameter.py) and [train_gan.py](NAFNet_GAN/train_gan.py). The training script [train_gan.py](NAFNet_GAN/train_gan.py) can be used for all tasks presented in our manuscript and contains hyperparameters of input and target image directory setups. Once the hyperparameters have been setup, the training script can be run directly:
    ```
    python NAFNet_GAN/train_gan.py
    ```
    The logs and saved models will be written to `NAFNet_GAN/log` folder. And checkpoints saved to `NAFNet_GAN/checkpoint` folder.

## Sampling

Simialar to training, the sampling process consists of two steps: sampling from the BBDP at the coarsest scale and sampling the rest high-frequency coefficients using multi-scale GAN.

1. The above training script saves checkpoints to `.pt` files in the logging directory. These checkpoints will have names like `ema_0.9999_200000.pt` and `model200000.pt`. We suggest sampling from the EMA models, since those produce much better samples.

    Similar to the training scripts, hyperparameters including data directory, batch size, and path to the saved model, etc., can be adjusted within the sampling script.
    Once you have set up the path to your model, you can generate a large batch of samples using:

    ```
    python scripts/image_sample.py
    ```

    Again, this will save results to `outputs`. Samples are saved as a large `npz` file, where `arr_0` in the file is a large batch of samples.

    Just like for training, you can run sampling scripts, e.g., `super_res_sample.py` through MPI to use multiple GPUs and machines. The sampling step for both training and sampling should be set as 1000. The sampled images will be saved to `outputs` folder.

2. Next, the sampling based on multi-scale GAN can be implemented by
    ```
    python NAFNet_GAN/sample_mscgm.py
    ```
    Please be aware that you should set up the directory for sampled low-frequency images from BBDP in [NAFNet_GAN/sample_mscgm.py](NAFNet_GAN/sample_mscgm.py). There is another sampling script [NAFNet_GAN/sample_gan.py](NAFNet_GAN/sample_gan.py) to solely evaluate the trained GAN model from ground truth low-frequency images.
    
    The results will be saved to `NAFNet_GAN/outputs` folder.


## Demo
The demo results for microscopy images of nanobeads and HeLa cells can be downloaded by
```
wget https://drive.google.com/file/d/1Lpqc1pZ5AN3QG5-x0L2V6i3AYn4pnXn9/view?usp=sharing
```