defaults:
  - _self_

hydra:
  sweep:
    dir: ./hydra_logs/pcdm_train/qm9/${pcdm_args.exp_name}_${now:%Y-%m-%d_%H-%M-%S}/
  run:
    dir: ./hydra_logs/pcdm_train/qm9/${pcdm_args.exp_name}_${now:%Y-%m-%d_%H-%M-%S}/



pcdm_args:
  # logging
  exp_name: 'pcdm_batch128_noisy0.3_drop0.2_pretrain'
  wandb_usr: ???
  no_wandb: false # Disable wandb
  online: true # true = wandb online -- false = wandb offline
  save_model: true # save model
  n_report_steps: 250


  # About models
  pcdm_model: 'egnn_dynamics' # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
  probabilistic_model: 'diffusion' # diffusion

  diffusion_steps: 1000
  diffusion_noise_schedule: 'polynomial_2' # learned, cosine
  diffusion_noise_precision: 1e-5
  diffusion_loss_type: 'l2' # vlb, l2

  conditioning: [] # arguments : homo | lumo | alpha | gap | mu | Cv

  use_latent: false
  latent_nf: 32
  encoder_n_layers: 3
  decoder_n_layers: 9
  train_diffusion: true
  trainable_ae: true
  ae_path: null
  kl_weight: 0.01


  # EGNN Config
  mode: egnn
  n_layers: 9 # number of layers
  inv_sublayers: 1 # number of layers
  nf: 256 # number of layers
  tanh: true # use tanh in the coord_mlp
  attention: true # use attention in the EGNN
  norm_constant: 1.0 # diff/(|diff| + norm_constant)
  sin_embedding: false # whether using or not the sin embedding
  aggregation_method: 'sum' # "sum" or "mean"
  normalization_factor: 1.0 # Normalize the sum aggregation of EGNN
  cross_attn: true
  attn_dropout: 0.2

  # Training
  n_epochs: 3000
  batch_size: 128
  inference_batch_size: 128
  lr: 1e-4
  break_train_epoch: false # true | false
  condition_time: true # true | false
  clip_grad: true # true | false
  ode_regularization: 1e-3
  num_workers: 4 # Number of worker for the dataloader
  test_epochs: 100
  noise_sigma: 0.3
  
  resume: null

  start_epoch: 0
  ema_decay: 0.9999 # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
  augment_noise: 0.0
  n_stability_samples: 1000 # Number of samples to compute the stability
  normalize_factors: [1, 4, 10] # normalize factors for [x, categorical, integer]
  remove_h: false
  include_charges: true # include atom charge or not
  visualize_every_batch: 1e8 # Can be used to visualize multiple times per epoch

  dp: false # true | false


  rep_align_loss: 0.


  # classifier-free guidance
  cfg: 1.0 
  rep_dropout_prob: 0.1


  # dataset
  dataset: 'qm9' # qm9 | qm9_second_half (train only on the last 50K samples of the training dataset)
  datadir: './data' # qm9 directory
  filter_n_atoms: null # When set to an integer value, QM9 will only contain molecules of that amount of atoms
  data_augmentation: false # use attention in the EGNN



  # RDM Sampling. 
  # Note this can be changed after training during evaluation. We set it here just to monitor the trajectory.
  sampler: GtSampler # ["PCSampler", "GtSampler", "DDIMSampler"]
  #     For DDIMSampler and PCSampler
  rdm_ckpt: null # resume from checkpoint
  #     For DDIMSampler
  step_num: 250
  eta: 1.0
  #     For PCSampler
  inv_temp: 1.0
  n_steps: 5
  snr: 0.01
  #     For GtSampler
  Gt_dataset: train # ["train", "test", "valid"]


  # Encoder
  encoder_type: frad # unimol or frad
  encoder_path: "./checkpoints/encoder_ckpts/QM9.ckpt" 
  rep_nf: 256

  finetune_encoder: false
  encoder_lr: 1e-4
  encoder_weight_decay: 0.
  encoder_factor: 0.8
  encoder_patience: 5
  encoder_min_lr: 1e-6
  light_n_layers: 3
  light_nf: 128
  finetune_save_path: "./checkpoints/encoder_ckpts/NEW_QM9_diffusion_finetune_l3_nf128.ckpt"

  # DEBUG
  debug: false



  
