# Uni-Instruct

## Experiment results available so far
- [x] CIFAR10 Forward-KL FID=1.44 Reverse-KL(SiDA) FID=1.45 jeffrey-KL FID=1.43
- [x] ImageNet64 Forward-KL FID=1.34 Reverse-KL(SiDA) FID=1.35 jeffrey-KL FID=1.32

## Experiments available for training
- [ ] ImageNet512 Forward-KL Reverse-KL(SiDA) jeffrey-KL (Weijian)
```
sbatch run_fsim_edm2.sh 'imagenet512-xxl'
```
--- 
- [ ] ImageNet64 funetune SiD using jeffrey-KL (Weijian)
```
sbatch run_fsim.sh 'imagenet64-cond-sid'
```
--- 
- [ ] CIFAR-10 funetune SiD using jeffrey-KL (Yifei)

## Experiments under construction 
- [ ] 3D Forward-KL Reverse-KL jeffrey-KL

### Experiments not that important
- [ ] ImageNet64 finetune DMD2 Forward-KL Reverse-KL jeffrey-KL
- [ ] f-distill ImageNet64 [Forward-KL](https://github.com/a-little-hoof/Uni-Instruct/blob/main/run_fdistill_forward-kl.sh) [Reverse-KL](https://github.com/a-little-hoof/Uni-Instruct/blob/main/run_fdistill_reverse-kl.sh) [jeffrey-KL](https://github.com/a-little-hoof/Uni-Instruct/blob/main/run_fdistill_JS.sh)

## Environment Setup
```
conda env create -f environment.yaml
conda activate sida
```

## Prepare Dataset
Follow the instructions of [EDM](https://github.com/NVlabs/edm). We also provided a Google Drive version, see below.

*Important!* We split imagenet512-sd.zip(~150GB) into 20 subsets to avoid OOM problem.
First, run shell script:
```
bash split.sh
``` 
After that, before training, you might need to adjust the dataset path at line 997 and line 344.

## Training Script
The training script is [run_fsim.sh](https://github.com/a-little-hoof/Uni-Instruct/blob/main/run_fsim.sh) 
Here is an example command, there are a few lines you should replace with your own path.
```
if [ "$dataset" = 'cifar10-cond' ]; then
    torchrun --standalone --nproc_per_node=4 fsim_train.py \
    --cond 1 \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 256 \
    --batch-gpu 32 \
    --data './datasets/cifar10-32x32.zip'  \
    --outdir './image_experiment/fsim-train-runs/cifar10-cond-chi-square' \
    --divergence 'Chi-Square' \
    --resume "/ailab/user/wangyifei/fsim/image_experiment/fsim-train-runs/cifar10-cond-chi-square/00000-cifar10-32x32-SiDA-cond-ddpmpp-edm-glr1e-05-lr1e-05-ls1.0_lsg100.0_lsd1.0_lsg_gan0.01-initsigma2.5-gpus4-batch256-tmax800-fp32batchgpu32/training-state-024576.pt" \ ### resume previous experiments, you can simply remove this line if you're training from scratch.
    --nosubdir 0 \
    --arch ddpmpp \
    --edm_model '/ailab/user/wangyifei/SiD-main/checkpoints/edm-cifar10-32x32-cond-vp.pkl' \ ### pretrained model, downloaded from EDM
    --detector_url '/ailab/user/wangyifei/SiD-main/checkpoints/inception-2015-12-05.pt' \ ### pretrained model that is used to calculate FID and IS, also downloaded from EDM 
    --tick 10 \
    --snap 50 \
    --dump 200 \
    --lr 1e-5 \
    --glr 1e-5 \
    --fp16 0 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 300 \
    --data_stat '/ailab/user/wangyifei/SiD-main/cifar10-32x32.npz' \ ### data statistics, downloaded from EDM
    --use_gan 1 \
    --metrics fid50k_full \
    --save_best_and_last 1 \
```
Update the following 5 parameters: 
```
--resume ... (delete this line)
--outdir PATH/TO/THE/DIRECTORY/OF/THE/OUTPUTS \
--edm_model 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl' \
--detector_url 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' \
--data_stat 'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz' \
```
Here are several divergences that the code supported:
```
--divergence "Forward-KL" \
--divergence "Reverse-KL" \
--divergence "Jeffrey-KL" \
--divergence "Chi-Square" \
```
We also reimplement the divergence mentioned in f-distill, which is an integral version of our distillation method:
```
--divergence "f-distill-Forward-KL" \
--divergence "f-distill-Reverse-KL" \
--divergence "f-distill-Jensen-Shannon" \
```

## Model Weights and Datasets
- ImageNet64 dataset: [ImageNet 64*64](https://drive.google.com/file/d/1UYnWH40Ed9uSWzl6fdXpim33MO7uzluk/view?usp=sharing).
- CIFAR10 pretrained EDM model: [EDM-cifar10-cond](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl)
- ImageNet64 pretrained EDM model: [EDM-ImageNet64-cond](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl)

## Resume training
To resume training from previous experiments, download the .pt files below and add --resume argument to the training script:
```
--resume /PATH/TO/THE/DOWNLOADED/.PT/FILE \
```
Resume the training process of Forward-KL, Reverse-KL, and jeffrey-KL on ImageNet64: [Model Weights](https://disk.pku.edu.cn/link/AA1C01BF2D551748748927920652F8C5B2).

## Finetune other generative model
We also finetune other types of generative model and achieve superior results.

DMD2 is a type of one step generative model trained with Diff-Instruct loss. We start with a pretrained DMD2 checkpoint whcih has a FID score of 2.61 and finetune the model with Uni-Instruct loss.

To run our experiments, please refer to this [script](https://github.com/a-little-hoof/Uni-Instruct/blob/main/DMD2-code/main).
