<h1 align="center">K-DCT DDPM</h1>

This repo contains PyTorch implementation of the submitted paper "**Improved denoising diffusion probabilistic models with efficient non-diagonal covariance modeling**"


> The sampling process of Denoising Diffusion Probabilistic Models (DDPMs) can be accelerated by leveraging second-order information in the form of approximations to the denoising posterior covariance. Previous attempts at using such information have used drastic (e.g. diagonal) simplifications of the covariance. These do not do justice to the peculiar statistical structure of natural images, which exhibit strong non-diagonal correlations between pixels and color channels, and a slow-decaying power-law frequency spectrum. Here, we develop a novel covariance model that captures these features. Our Kronecker-DCT (K-DCT) model uses a Kronecker-factored decomposition of inter-color  covariances and spatial covariances modeled in the frequency domain using the Discrete Cosine Transform (DCT). The use of the DCT reduces the computational complexity from quadratic to log-linear, resulting in negligible computational and memory overhead in the sampling process. By learning K-DCT-structured amortizations of the denoising posterior covariance using pre-trained score models on CIFAR-10, Celeb-A, and ImageNet datasets, we show improved performance both in terms of FID and likelihoods compared to previous approximations.

## Installation

Our implementation is based on the [Extended Analytic-DPM](https://github.com/baofff/Extended-Analytic-DPM) and [OCM-DPM](https://github.com/J-zin/OCM_DPM) repository. To set up the environment, please follow the installation instructions provided in that repository. The main functionality of our code closely mirrors the original repo, and we provide detailed usage instructions below.

## Training

To train the model, you can use the following command:

```bash
python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams
```

* `pretrained_path` is the path to a pretrained diffusion probabilistic model (DPM). Here are the links to the pretrained models:
  [CIFAR10 (LS)](https://drive.google.com/file/d/1rhZBWUDK3_q37Iac3sXq6WnxR_OHhPyI/view), [CIFAR10 (CS)](https://drive.google.com/file/d/1ONNLpqPDLr4NesC0TfVZ3dCyaVBu7Xw0/view), [CelebA64](https://drive.google.com/file/d/1bGQGTsFOnqQ2z3FN5rdkj1FPN1_5nYF4/view), [ImageNet64](https://drive.google.com/file/d/1evlXbMOg55y2BIjiALcD6Smbm07k7XGW/view).
* `dataset` represents the training dataset, one of <`cifar10`|`celeba64`|`imagenet64`>.
* `workspace` is the place to put training outputs, e.g., logs and checkpoints.
* `train_hparams` specify other hyperparameters used in training. 

We provide the `train_hparams` used in training for our models on each dataset:

  * CIFAR10 (LS): `--method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc`
  * CIFAR10 (CS): `--method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc --schedule cosine_1000`
  * CelebA64: `--method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc_complex`
  * ImageNet64: `--method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc_complex`

where `<obj>` can be either be `hes` if training with the OCM objective or `epsc` if training with the NPR objective. For example, to train the CIFAR10 (LS) model with the NPR objective, you can run:

```bash
python run_train.py 
  --pretrained_path path/to/pretrained_dpm \
  --dataset cifar10 \
  --workspace path/to/working_directory \
  --method pred_eps_epsc_blockcirc_pretrained \
  --mode blockcirc
```

## Evaluation

To evaluate the model, you can use the following command:

```bash
python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory --phase phase --sample_steps sample_steps --batch_size batch_size --method pred_eps_hes_pretrained $eval_hparams
```
* `pretrained_path` is the path to a model to evaluate. We will provide all checkpoints trained with the proposed K-DCT parametrization after acceptance.
* `dataset` represents the dataset the model is trained on, one of <`cifar10`|`celeba64`|`imagenet64`>.
* `workspace` is the place to put evaluation outputs, e.g., logs, samples and bpd values.
* `phase` specifies running FID or likelihood evaluation, one of <`sample4test`|`nll4test`>.
* `sample_steps` is the number of steps to run during inference, the samller this value the faster the inference.
* `batch_size` is the batch size, e.g., 500.
* `eval_hparams` specifies other optional hyperparameters used in evaluation.

We provide `eval_hparams` for the FID and NLL results in this paper.
- FID Evaluation (DDPM)
  * CIFAR10 (LS): `--mode blockcirc --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2`
  * CIFAR10 (CS): `--mode blockcirc --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000`
  * CelebA64: `--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --mode blockcirc_complex`
  * ImageNet64: `--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode blockcirc_complex`
- NLL Evaluation
  * CIFAR10 (LS): `--mode blockcirc --rev_var_type optimal`
  * CIFAR10 (CS): `--mode blockcric --rev_var_type optimal --schedule cosine_1000`
  * CelebA64: `--rev_var_type optimal --mode blockcirc_complex`
  * ImageNet64: `--rev_var_type optimal --mode blockcirc_complex`

This [link](https://drive.google.com/drive/folders/1aqSXiJSFRqtqHBAsgUw4puZcRqrqOoHx?usp=sharing) provides precalculated FID statistics on CIFAR10, CelebA64, ImageNet64.

As an example, to evaluate the FID (DDPM) result of the CIFAR10 (LS) model, you can run:
```bash
python run_eval.py \
  --pretrained_path path/to/pretrained_dpm 
  --dataset dataset \
  --workspace path/to/working_directory \
  --phase sample4test \
  --sample_steps sample_steps \
  --batch_size batch_size \
  --mode blockcirc \
  --method pred_eps_epsc_blockcirc_pretrained \
  --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2
```