# GAS: Improving Discretization of Diffusion ODEs via Generalized Adversarial Solver

<br>**GAS: Improving Discretization of Diffusion ODEs via Generalized Adversarial Solver**<br>

## Table of Contents
- [Setup Environment](#setup-environment)
- [Download Pretrained Models and FID Reference Sets](#download-pretrained-models-and-fid-reference-sets)
- [Generating Teachers Data](#generating-teachers-data)
- [Calculating FID](#calculating-fid)
- [Training GAS](#training-gas)
- [Inference with trained GS](#inference-with-trained-gs)

---

## Setup Environment

See [requirements.yml](requirements.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:

```.bash
conda env create -f requirements.yml -n gas
conda activate gas
```


## Download Pretrained Models and FID Reference Sets

All necessary data will be automatically downloaded by the script. Note that this process may take some time. If you wish to skip certain downloads, you can comment out the corresponding lines in the script.

```.bash
bash scripts/downloads.sh
```

## Generating Teachers Data 

Before training **GS/GAS**, we first need to generate teacher data. To get a batch of images using a teacher solver, run:

```.bash
# Generate 64 images and save them as out/*.png
python generate.py --config=configs/edm/cifar10.yaml \
	--outdir=out \
	--seeds=00000-63 \
	--batch=64
```

Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`:

```.bash
# Generate 50000 images using 2 GPUs
torchrun --standalone --nproc_per_node=2 generate.py \
	--config=configs/edm/cifar10.yaml \
	--outdir=data/teachers/cifar10 \
	--seeds=00000-49999 \
	--batch=1024
```
You can use option `--create_dataset=True` for create subdirs with dataset's chunks. The script `teachers.sh` handles this process. If you wish to skip certain datasets, you can comment out the corresponding lines in the script.

```.bash
bash scripts/teachers.sh
```

**We recommend using different seeds to generate teachers data and to evaluate student quality. For teachers, we use seeds 0-49999 for all datasets except MS-COCO, where we set the seeds to 0-29999.**

## Calculating FID

To compute Fréchet inception distance (FID) for a given solver, first generate the required number of random images and then compare them against the dataset reference statistics using `fid.py`:


```.bash
torchrun --standalone --nproc_per_node=1 fid.py calc \
	--images=data/teachers/cifar10 \
	--ref=fid-refs/edm/cifar10-32x32.npz
```

It is common to use 50k images to calculate FID in most setups, whereas the MS-COCO dataset is usually evaluated on 30k images. 

The command can be parallelized across multiple GPUs by adjusting `--nproc_per_node`. The `fid.py calc` typically takes 1-3 minutes in practice. See python `fid.py --help` for the full list of options.

### Teachers FID

For instance, run `scripts/teachers_fid.sh` to compute FID scores across all teacher solvers.

```.bash
bash scripts/teachers_fid.sh
```

## Training GAS

After generating the teacher data, you can train **GS/GAS** using `main.py`. Below is an example of training **GS** on **CIFAR-10** with four sampling steps:

```.bash
python main.py --config=configs/edm/cifar10.yaml \
	--loss_type=GS --student_step=4
```

Below is an example of training **GAS** in the same setup:

```.bash
python main.py --config=configs/edm/cifar10.yaml \
	--loss_type=GAS --student_step=4
```

The training settings can be controlled through command-line options; see `python main.py --help` for more information. 
For training **GAS**, we recommend using the `--train_size=5000` for all datasets except CIFAR10.

```
# FFHQ
python main.py --config=configs/edm/ffhq.yaml \
	--loss_type=GAS --student_step=4 --train_size=5000

# AFHQv2
python main.py --config=configs/ldm/afhqv2.yaml \
	--loss_type=GAS --student_step=4 --train_size=5000

# LSUN
python main.py --config=configs/ldm/lsun_beds256.yaml \
	--loss_type=GAS --student_step=4 --train_size=5000

# ImageNet
python main.py --config=configs/ldm/cin256-v2.yaml \
	--loss_type=GAS --student_step=4 --train_size=5000

# MS-COCO
python main.py --config=configs/sd/coco.yaml \
	--loss_type=GAS --student_step=4 --train_size=5000
```


## Inference with trained GS

If you set the option `--checkpoint_path=path`, images are generated from the **GS** checkpoint; otherwise, they are generated from the teacher solver. See `python generate.py --help` for more information. Below is an example of generating images from a trained **GS** checkpoint on **CIFAR-10** with four sampling steps:

```.bash
# Generate 50000 images using 2 GPUs and a checkpoint from checkpoint_path
torchrun --standalone --nproc_per_node=2 generate.py \
	--config=configs/edm/cifar10.yaml \
	--outdir=data/teachers/cifar10 \
	--seeds=50000-99999 \
	--batch=1024 \
	--steps=4 \
	--checkpoint_path=checkpoint_path
```

**For a fair comparison and to avoid leakage of test seeds into the training dataset, we recommend using seeds 50000-99999 for all datasets except MS-COCO, which should use seeds 30000-59999.**
