rdm_args:
  # Training parameters
  batch_size: 128 # Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)
  epochs: 2000
  accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
  device: 'cuda' # device to use for training / testing
  seed: 0
  rdm_ckpt: null  # resume from checkpoint
  start_epoch: 0 # start epoch
  num_workers: 10
  pin_mem: true # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
  dp: false

  # logging
  wandb_usr: null
  no_wandb: false # Disable wandb
  online: true # true = wandb online -- false = wandb offline
  # exp_name: 'rdm_batch128_conditioningTrainable_latent'
  exp_name: 'rdm_batch128_pretrainedEncCrossAttn_LUMO'

  # Optimizer parameters
  weight_decay: 0.005 # weight decay (default: 0.05)
  lr: null # learning rate (absolute lr)
  blr: 1e-6 # base learning rate: absolute_lr = base_lr * total_batch_size
  min_lr: 0.0 # lower lr bound for cyclic schedulers that hit 0
  cosine_lr: true # Use cosine lr scheduling.
  warmup_epochs: 0

  # Dataset parameters
  dataset: 'qm9_second_half' # qm9 | qm9_second_half (train only on the last 50K samples of the training dataset)
  datadir: './data/' # qm9 directory
  filter_n_atoms: null # When set to an integer value, QM9 will only contain molecules of that amount of atoms
  remove_h: false
  include_charges: true # include atom charge or not

  encoder_type: frad
  encoder_path: "./checkpoints/encoder_ckpts/QM9.ckpt"

  # About latent diffusion 
  train_diffusion: true 
  trainable_ae: true
  vae_ckpt: null

  # Debug
  debug: false


model_args:
  target: models.rdm.models.diffusion.ddpm.RDM
  params:
    use_ema: false
    linear_start: 0.0015
    linear_end: 0.0195
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: image
    cond_stage_key: [node_num, lumo]
    image_size: 1
    channels: 256
    class_cond: true
    cond_stage_trainable: true
    conditioning_key: crossattn
    parameterization: x0
    # latent_config:
    #   vae_encoder_dim_sequence: [256, 192, 192, 128]
    #   vae_kl_weight: 0.01

    unet_config:
      target: models.rdm.modules.diffusionmodules.latentmlp.SimpleMLP
      params:
        in_channels: 256
        time_embed_dim: 256
        model_channels: 1536
        bottleneck_channels: 1536
        out_channels: 256
        num_res_blocks: 18
        use_context: true
        context_channels: 256
    cond_stage_config:
      cond_num: 2
      target1: models.rdm.modules.encoders.modules.ClassEmbedder
      params1:
        embed_dim: 256
        n_classes: 30
        key: node_num
      target2: models.rdm.modules.encoders.modules.ContinuousEmbedder
      params2:
        embed_dim: 256
        input_dim: 1
        key: lumo
