backend: jax
name: 'diffusion'

defaults:
  - neural_net: 'score_mlp'
  - params_train: 'score'
  - loss_fn: 'dsm'
  - sde: 'vpsde'
  - sampler: 'em_gaus_auto_full'
