seed: null

data:
  channels: 1
  dataset: mnist
  image_size: 32 # resize, must be multiple of 32
  num_workers: 4 # number of workers for data loading
  random_flip: true
  bs: 64 # batch size

noising_process: 'pdmp' # pdmp, diffusion, None

pdmp:
  sampler: HMC # ZigZag, BPS, HMC
  time_horizon: 10
  refresh_rate: 1.
  add_losses: 
  #- hyvarinen
  - ml
  #- logistic


diffusion:
  alpha: 2.0
  clamp_a: null
  clamp_eps: null
  reverse_steps: 100 # here our diffusion models are trained with the uniforma distribution over [1, diffusion_steps] instead of [0, 1]: thus it is part of the training parameters
  LIM: false
  isotropic: true
  rescale_timesteps: true
  mean_predict: EPSILON
  var_predict: FIXED
  loss_type: LP_EPS_LOSS


eval:
  data_to_generate: 128
  batch_size: 128
  real_data: 128 # in case of images, number of real images to store and to compare to

  pdmp:
    backward_scheme: splitting # euler
    reverse_steps: 50 # reverse_steps are only part of evaluation, not of the training
    get_sample_history: false

    #clip_denoised: false
    #new_time_spacing
  # for diffusion
  diffusion:
    reverse_steps: 100
    clip_denoised: false
    ddim: false
    eta: 1.0
    #new_time_spacing

model:
  # this is for 2d ZigZag, and 2d diffusion
  mlp:
    a_emb_size: 32
    a_pos_emb: false
    act: silu
    compute_gamma: false
    dropout_rate: 0.0
    group_norm: true
    nblocks: 2
    no_a: true
    nunits: 64
    skip_connection: true
    time_emb_size: 8
    time_emb_type: learnable
    use_a_t: false
  unet:
    model_type: "ddpm"
    attn_resolutions: [2, 4]
    channel_mult: [1, 2, 2, 2]
    compute_gamma: false
    dropout: 0.1
    model_channels: 32
    num_heads: 4
    num_res_blocks: 2
  # this is for the other samplers; normalizing flows
  normalizing_flow:
#    transforms: 24
#    hidden_width: 1024 # 2048
#    hidden_depth: 4
#    time_emb_size: 32
#    time_emb_type: learnable
#    x_emb_size: 32
#    x_emb_type: mlp # mlp, unet
#  vae: false
    transforms: 24
    hidden_width: 128 # 2048
    hidden_depth: 3
    time_emb_size: 16
    time_emb_type: learnable
    x_emb_size: 16
    x_emb_type: mlp # mlp, unet
  vae: false

training:
  pdmp:
    ema_rates:
    - 0.99
    grad_clip: null #1.0
    subsamples: 5 # for ZigZag
    train_type: VAE # assert train_type in ['VAE', 'RATIO', 'NORMAL', 'NORMAL_WITH_VAE']
    train_alternate: false
  diffusion:
    ema_rates: 
    - 0.99
    grad_clip: null #1.0
    loss_monte_carlo: mean # loss to apply on batch of M number of a's. can be mean or median
    lploss: 2.
    monte_carlo_steps: 1 # for each t, x_0, z_t, number of different a_t_1, a_t' to generate
    monte_carlo_groups: 1 # number groups: will take median of means of monte_carlo_groups of a's


optim:
  optimizer: adamw
  schedule: null #steplr
  lr: 0.0005
  warmup: 50 #100
  lr_steps: 300000
  lr_step_size: 150
  lr_gamma: 0.99

run:
  epochs: 50
  eval_freq: null
  checkpoint_freq: 25
  progress: true # print progress bar