seed: null

data:
  dataset: gmm_grid     # gaussian mixture grid
  d: 2                  # data dimension
  alpha: 2.0            
  isotropic: true
  n: 3                  # n*n mixture components
  
  n_samples: 32000      # number of samples in the dataset
  std: 0.1
  theta: 3.0
  weights:
  # - 0.01
  # - 0.99
  - 0.01
  - 0.1
  - 0.3
  - 0.2
  - 0.02
  - 0.15
  - 0.02
  - 0.15
  - 0.05

method: dlpm # LIM

lim:
  alpha: 1.8 
  clamp_a: null
  clamp_eps: null
  reverse_steps: 100
  isotropic: true
  rescale_timesteps: true
  # quadractic_timesteps: false

dlpm:
  alpha: 2.0 # becomes DDPM
  reverse_steps: 100 # number of reverse steps used during training
  isotropic: true
  mean_predict: EPSILON # See dlpm.ModelMeanType
  rescale_timesteps: true # rescale the timesteps to [0, 1]
  var_predict: FIXED # see dlpm.ModelVarType
  scale: 'scale_preserving'
  input_scaling: false


eval:
  data_to_generate: 5000
  real_data: 5000
  batch_size: 1024 

  dlpm:
    deterministic: false # set to true for dlim
    reverse_steps: 100
    clip_denoised: false # clip denoised data to [0, 1]
    clamp_a: null
    clamp_eps: null
  
  lim: 
    deterministic: false # set to True for ODE
    reverse_steps: 100


model:
  architecture: 'mlp'
  mlp:
    act: silu
    dropout_rate: 0.0
    group_norm: true
    nblocks: 4
    nunits: 64
    skip_connection: true
    time_emb_size: 32
    time_emb_type: learnable

    out_channel_mult : 1 # multiplier for the number of output channels in the model

    # experimented with inputting a_t_0, a_t_1 in the model: unsuccessful
    no_a: true
    use_a_t: false # use only a_t_1 instead of (a_t_0, a_t_1)
    a_emb_size: 32
    a_pos_emb: false
    
    # output two channels after additional MLP blocks, to learn the variance. 
    learn_variance: false 
    
    softplus: false # apply softplus to the output of the model
    beta: null # beta for the softplus function
    threshold: null # threshold for the softplus function
    


training:
  batch_size: 1024
  num_workers: 0

  dlpm:
    refresh_data: true # resample the dataset at each epoch. Dataset object must implement refresh method 

    ema_rates:
    #- 0.9
    grad_clip: null #1.0 #1.0

    loss_monte_carlo: mean # mean or median. aggregate function to apply to batch loss with M samples of a_{1:T}
    loss_type: EPS_LOSS # other loss types to reimplment
    lploss: 2. # p in LP loss. DLPM loss: 2.0, can also try smooth L1 loss (p = 1), or MSE loss (p =-1).
    monte_carlo_inner: 1 # number of samples for inner expetation approximation in the loss of Proposition (9)
    monte_carlo_outer: 1 # number of samples for outer expectation approximation in the loss of Proposition (9)

    clamp_a: null
    clamp_eps: null

  lim:
    ema_rates:
    #- 0.9
    grad_clip: null #1.0

    clamp_a: null
    clamp_eps: null

optim:
  optimizer: adamw
  schedule: null # linear, steplr
  lr: 0.005
  warmup: 0 #1000
  lr_steps: 2000
  lr_step_size: 400
  lr_gamma: 0.99

run:
  epochs: 20
  eval_freq: null
  checkpoint_freq: null
  progress: false # print progress bar