seed: null

data:
  channels: 1
  dataset: mnist
  image_size: 32 # resize, must be multiple of 32
  num_workers: 8 # number of workers for data loading
  random_flip: true

diffusion:
  LIM: false
  alpha: 1.8 #1.8
  clamp_a: null 
  clamp_eps: null 
  diffusion_steps: 1000
  isotropic: true
  mean_predict: EPSILON
  rescale_timesteps: true
  var_predict: FIXED

eval:
  data_to_generate: 1024
  ddim: false
  eval_eta: 0.0
  reduce_timesteps: 1.0
  clip_denoised: false

model:
  model_type: "ddpm"
  attn_resolutions: [2, 4]
  channel_mult: [1, 2, 2, 2]
  compute_gamma: false
  dropout: 0.1
  model_channels: 32
  num_heads: 4
  num_res_blocks: 2

training:
  bs: 256
  ema_rates:
  - 0.99
  #- 0.999
  grad_clip: null #1.0
  loss_monte_carlo: mean # loss to apply on batch of M number of a's. can be mean or median
  loss_type: LP_EPS_LOSS
  lploss: 2.0
  monte_carlo_steps: 1 # for each t, x_0, z_t, number M of different a_t_1, a_t' to generate
  monte_carlo_groups: 1 # number groups: will take median of means of monte_carlo_groups of a's

optim:
  lr: 0.0005
  lr_steps: 300000
  optimizer: adamw
  schedule: null
  warmup: 0 #100

run:
  epochs: 150
  eval_freq: null
  checkpoint_freq: 25
  progress: true # print progress bar