# A Geometric Perspective on Diffusion Models

## Requirements

To install all packages in this codebase, run
```sh
pip install -e .
```

## Usage
To generate npz stats for Figure 2a (top), run:

```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='normTraj_collect'
```

To generate npz stats for Figure 2b, run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='deviation_collect'
```
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='trajDistance_collect'
```

To generate trajectories shown in Figure 3 (top), run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='all_trajs'
```

To generate npz stats shown in Figure 4 (left and middle), run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='monitor_denoiser_std_collect'
```

To generate interpolation trajectories in three strategies (Linear, N-Linear, Slerp) shown in Figure 6 (right), run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='interpolation'
```

To generate interpolation trajectories for FID evaluation shown in Figure 6 (left), run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='interpolation_generate'
```

To generate trajectories for FID evaluation shown in Figure 8a, run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='traj_generate'
```

To generate npz stats for cosine evaluation shown in Figure 10, run:
```
torchrun --standalone --nproc_per_node=1 main.py --data='path/to/cifar10-32x32.zip' --num_steps=18 --outdir='./assets' --network='edm_ddpmpp_uncond' --func='cos_collect'
```

## Pre-trained Models
The checkpoints used for CIFAR-10 experiments are from 
https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/

Before running the command, you should place 'edm-cifar10-32x32-uncond-vp.pkl' in the directory './edms'.

### Supported checkpoints are:

edm_ddpmpp_uncond: [edm-cifar10-32x32-uncond-vp.pkl](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl)

edm_ddpmpp_cond: [edm-cifar10-32x32-cond-vp.pkl](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl)

edm_ncsnpp_uncond: [edm-cifar10-32x32-uncond-ve.pkl](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-ve.pkl)

vp_uncond: [baseline-cifar10-32x32-uncond-vp.pkl](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline-cifar10-32x32-uncond-vp.pkl)
    
