seed: null

data:
    dataset: cifar10
    image_size: 32
    channels: 3
    random_flip: true
    num_workers: 4 # 2* num_GPU
    #num_classes: 2

eval:
  data_to_generate: 4096
  ddim: false
  eval_eta: 1.0
  reduce_timesteps: 4.0 # 1.0
  clip_denoised: false


model:
    model_type: "ddpm"
    #in_channels: 3
    #out_ch: 3
    attn_resolutions: [16,]
    channel_mult: [1, 2, 2, 2]
    compute_gamma: false
    dropout: 0.1
    model_channels: 128
    num_heads: 4
    num_res_blocks: 2

diffusion:
    LIM: false
    alpha: 1.8 #1.9
    clamp_a: null 
    clamp_eps: null 
    diffusion_steps: 4000
    isotropic: true
    mean_predict: EPSILON
    rescale_timesteps: true
    var_predict: FIXED

training:
    bs: 64
    ema_rates:
    #- 0.999
    - 0.9999
    grad_clip: null #1.0
    loss_monte_carlo: mean # loss to apply on batch of M number of a's. can be mean or median
    loss_type: LP_EPS_LOSS
    lploss: 2.
    monte_carlo_steps: 1 # for each t, x_0, z_t, number M of different a_t_1, a_t' to generate
    monte_carlo_groups: 1 # number groups: will take median of means of monte_carlo_groups of a's

optim:
  optimizer: adamw
  schedule: steplr #steplr
  lr: 0.0002
  warmup: 500 #100
  lr_steps: 300000
  lr_step_size: 1000
  lr_gamma: 0.99

run:
  epochs: 400 #10000
  eval_freq: null
  checkpoint_freq: 50
  progress: true # print progress bar