# Fast Sampling of Diffusion Models with Exponential Integrator

# Usage

```shell
# for pytorch user
pip install "jax[cpu]"
```

## If diffusion models are trained with continuous time

```py
import jax_deis as deis

def eps_fn(x_t, scalar_t):
    vec_t = jnp.ones(x_t.shape[0]) * scalar_t
    return eps_model(x_t, vec_t)

# pytorch
# import th_deis as deis
# def eps_fn(x_t, scalar_t):
#     vec_t = (th.ones(x_t.shape[0])).float().to(x_t) * scalar_t
#     with th.no_grad():
#         return eps_model(x_t, vec_t)

# mappings between t and alpha in VPSDE
# we provide popular linear and cos mappings
t2alpha_fn,alpha2t_fn = deis.get_linear_alpha_fns(beta_0=0.01, beta_1=20)

vpsde = deis.VPSDE(
    t2alpha_fn, 
    alpha2t_fn,
    sampling_eps, # sampling end time t_0
    sampling_T # sampling starting time t_T
)

sampler_fn = deis.get_sampler(
    # args for diffusion model
    vpsde,
    eps_fn,
    # args for timestamps scheduling
    ts_phase="t", # support "rho", "t", "log"
    ts_order=2.0,
    num_step=10,
    # deis choice
    method = "t_ab", # deis sampling algorithms: support "rho_rk", "rho_ab", "t_ab", "ipndm"
    ab_order= 3, # greater than 0, used for "rho_ab", "t_ab" algorithms, other algorithms will ignore the arg
    rk_method="3kutta" # used for "rho_rk" algorithms, other algorithms will ignore the arg
)

sample = sampler_fn(noise)
```

## If diffusion models are trained with discrete time

```py
#! by default the example assumes sampling 
#! from t=len(discrete_alpha) - 1 to t=0
#! totaly len(discrete_alpha) steps if we use delta_t = 1
vpsde = deis.DiscreteVPSDE(discrete_alpha)
```

# A short derivation for DEIS

<details>
<summary>Exponential integrator in diffusion model</summary>

The key insight of exponential integrator is taking advantage of all math structures present in ODEs. The goal is to reduce discretization error as small as possible. 

The math structure in diffusion models includes semilinear structure, the analytic formula for drift and diffusion coefficients.

Below we present a short derivation for applications of the exponential integrator in diffusion model.

## Forward SDE

$$
dx = F_tx dt + G_td\mathbf{w}
$$

## Backward ODE

$$
dx = F_tx dt + 0.5 G_tG_t^T L_t^{-T} \epsilon(x, t) dt
$$

where $L_t L_t^{T} = \Sigma_t$ 
and $\Sigma_t$ are variance of $p_{0t}(x_t | x_0)$.

## Exponential Integrator

We can get rid of semilinear structure with **Exponential Integrator** by introducing a new variable $y$

$$
y_t = \Psi(t) x_t \quad \Psi(t) = \exp{-\int_0^{t} F_\tau d \tau}
$$

And ODE is simplified into

$$
\dot{y}_t = 0.5 \Psi(t) G_t G_t^T L_t^{-T} \epsilon(x(y), t)
$$

where $x(y)$ maps $y_t$ to $x_t$.


## Time scaling

We can take one step further when $F_t, G_t$ are scalars by rescaling time

$$
\dot{v}_\rho = \epsilon(x(v), t(\rho))
$$

where 
$y_t = v_\rho$ 
and $d \rho = 0.5 \Psi(t) G_t G_t^T L_t^{-T} dt$. 
And $x(v)$ 
maps $v_\rho$ 
to $x_t$,
$t(\rho)$ 
maps $\rho$ 
to $t$.

## High order solver

By absorbing all math structure, we reach the following ODE

$$
\dot{v}_\rho = \epsilon(x(v), t(\rho))
$$

As RHS is a nerual network, we can not further simplify ODE unless we have knowledge for the black-box function.
Then we can use well-established ODE solvers, such as multistep and runge kutta.
</details>

# Demo

- [continuous vpsde](demo/continuous_cifar/deis.ipynb) Based on [score_sde codebase](https://github.com/yang-song/score_sde). CIFAR10 images in 7 steps
- [discrete vpsde](demo/discrete_celeba) Based on [PNDM codebase](https://github.com/luping-liu/PNDM)

