defaults:
  - _self_



hydra:
  sweep:
    dir: ./hydra_logs/pcdm_train/drug/${rdm_args.exp_name}_${now:%Y-%m-%d_%H-%M-%S}/
  run:
    dir: ./hydra_logs/pcdm_train/drug/${rdm_args.exp_name}_${now:%Y-%m-%d_%H-%M-%S}/

pcdm_args:
  # logging
  exp_name: 'drug_encoder_finetune_batch64'
  wandb_usr: ???
  no_wandb: false # Disable wandb
  online: true # true = wandb online -- false = wandb offline
  save_model: true # save model
  n_report_steps: 1000





  # About models
  pcdm_model: 'egnn_dynamics' # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
  probabilistic_model: 'diffusion' # diffusion

  diffusion_steps: 1000
  diffusion_noise_schedule: 'polynomial_2' # learned, cosine
  diffusion_noise_precision: 1e-5
  diffusion_loss_type: 'l2' # vlb, l2

  conditioning: [] # arguments : homo | lumo | alpha | gap | mu | Cv

  use_latent: false
  latent_nf: 32
  encoder_n_layers: 3
  decoder_n_layers: 9
  train_diffusion: true
  trainable_ae: true
  # ae_path: "./outputs/latentpcdm_batch128_norm_hadarmardsum (vae part) (with latent RDM) "
  ae_path: null
  kl_weight: 0.01


  # EGNN Config
  mode: egnn
  n_layers: 4 # number of layers
  inv_sublayers: 1 # number of layers
  nf: 256 # number of layers
  tanh: true # use tanh in the coord_mlp
  attention: true # use attention in the EGNN
  norm_constant: 1.0 # diff/(|diff| + norm_constant)
  sin_embedding: false # whether using or not the sin embedding
  aggregation_method: 'sum' # "sum" or "mean"
  normalization_factor: 1.0 # Normalize the sum aggregation of EGNN
  cross_attn: true
  attn_dropout: 0.1

  # Training
  n_epochs: 3000
  batch_size: 16
  inference_batch_size: 16
  lr: 1e-4
  break_train_epoch: false # true | false
  condition_time: true # true | false
  clip_grad: true # true | false
  ode_regularization: 1e-3
  num_workers: 4 # Number of worker for the dataloader
  test_epochs: 1
  noise_sigma: 0.2
  
  resume: null

  start_epoch: 0
  ema_decay: 0.9999 # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
  augment_noise: 0.0
  n_stability_samples: 250 # Number of samples to compute the stability
  normalize_factors: [1, 4, 10] # normalize factors for [x, categorical, integer]
  remove_h: false
  include_charges: false # include atom charge or not
  visualize_every_batch: 10000 # Can be used to visualize multiple times per epoch

  dp: false # true | false


  rep_align_loss: 0.


  # classifier-free guidance
  cfg: 1.0 
  rep_dropout_prob: 0.1


  # dataset
  data_file: './data/geom/geom_drugs_30.npy'
  sequential: False
  filter_molecule_size: null
  data_augmentation: False
  dataset: geom

  # RDM Sampling. 
  # Note this can be changed after training during evaluation. We set it here just to monitor the trajectory.
  sampler: GtSampler # ["PCSampler", "GtSampler", "DDIMSampler"]
  #     For DDIMSampler and PCSampler
  rdm_ckpt: './checkpoints/rdm_ckpts/rdm_diffusion_finetuned.pth' # resume from checkpoint
  #     For DDIMSampler
  step_num: 250
  eta: 1.0
  #     For PCSampler
  inv_temp: 1.0
  n_steps: 1
  #     For GtSampler
  Gt_dataset: train # ["train", "test", "valid"]


  # Encoder
  encoder_type: unimol # unimol or frad
  encoder_path: "./checkpoints/encoder_ckpts/unimol_drug.pt" 
  rep_nf: 512

  finetune_encoder: false
  encoder_lr: 1e-4
  encoder_weight_decay: 0.
  encoder_factor: 0.5
  encoder_patience: 1
  encoder_min_lr: 1e-6
  light_n_layers: 2
  light_nf: 128
  finetune_save_path: "./checkpoints/encoder_ckpts/Drug_diffusion_finetune_l3_nf128.ckpt"

  # DEBUG
  debug: false



  
