defaults:
  - _self_



hydra:
  sweep:
    dir: ./hydra_logs/rdm_train/drug/${rdm_args.exp_name}_${now:%Y-%m-%d_%H-%M-%S}/
  run:
    dir: ./hydra_logs/rdm_train/drug/${rdm_args.exp_name}_${now:%Y-%m-%d_%H-%M-%S}/



rdm_args:
  # Training parameters
  batch_size: 32 # Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)
  epochs: 2000
  accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
  device: 'cuda' # device to use for training / testing
  seed: 0
  rdm_ckpt: null
  start_epoch: 0 # start epoch
  num_workers: 10
  pin_mem: true # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
  dp: false



  # logging
  wandb_usr: null
  no_wandb: false # Disable wandb
  online: true # true = wandb online -- false = wandb offline
  # exp_name: 'rdm_batch128_conditioningTrainable_latent'
  exp_name: 'drug_batch128'

  # Optimizer parameters
  weight_decay: 0.005 # weight decay (default: 0.05)
  lr: null # learning rate (absolute lr)
  blr: 1e-6 # base learning rate: absolute_lr = base_lr * total_batch_size
  min_lr: 0.0 # lower lr bound for cyclic schedulers that hit 0
  cosine_lr: true # Use cosine lr scheduling.
  warmup_epochs: 0

  # Dataset parameters
  filter_n_atoms: null # When set to an integer value, QM9 will only contain molecules of that amount of atoms
  remove_h: false
  include_charges: false # include atom charge or not

  data_file: './data/geom/geom_drugs_30.npy'
  sequential: False
  filter_molecule_size: null
  dataset: geom

  encoder_type: unimol
  encoder_path: "./checkpoints/encoder_ckpts/unimol_drug.pt" 

  # About latent diffusion
  train_diffusion: true # vae and rdm
  trainable_ae: true
  vae_ckpt: null

  # Debug
  debug: false

  



model_args:
  target: models.rdm.models.diffusion.ddpm.RDM
  params:
    use_ema: false
    linear_start: 0.0015
    linear_end: 0.0195
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: image
    cond_stage_key: node_num
    image_size: 1
    channels: 256
    class_cond: true
    cond_stage_trainable: true
    conditioning_key: crossattn
    parameterization: x0
    # latent_config:
    #   vae_encoder_dim_sequence: [256, 192, 192, 128]
    #   vae_kl_weight: 1.

    unet_config:
      target: models.rdm.modules.diffusionmodules.latentmlp.SimpleMLP
      params:
        in_channels: 512
        time_embed_dim: 512
        model_channels: 1536
        bottleneck_channels: 1536
        out_channels: 512
        num_res_blocks: 18
        use_context: true
        context_channels: 512
    cond_stage_config:
      target: models.rdm.modules.encoders.modules.ClassEmbedder
      params:
        embed_dim: 512
        n_classes: 200
        key: node_num

