defaults:
  - _self_
  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  - /data: rgfn_graph
  - /model: semla
  - /strategy: ddp
  - /lr_scheduler: constant_warmup
  - /noise@discrete_noise: linear
  - /noise@spatial_noise: brownianbridge

mode: train  # train / ppl_eval / sample_eval
denoise_discrete: False
denoise_coordinates: False
true_edge_weight: 1.0
backbone: semla
diffusion: absorbing_state
parameterization: subs
self_conditioning: False
time_conditioning: False
T: 0
subs_masking: False
seed: 1

spatial:
  do_noise: False
  overfit: True
  n_overfit: 100
  sample_conformer: True
  align: True
  center: True
  normalize: False
  rotate: True
  translate: True
  equivariant_ot: True
  coord_mask_value: 0.0
  prior: gaussian # uniform, gaussian, ones
  bond_loss: [bond_length, pairwise_distance, smooth_lddt] # bond_length, pairwise_distance, smooth_lddt
  mse_coef: 1.0
  bond_length_coef: 1.0
  pairwise_distance_coef: 1.0
  smooth_lddt_coef: 1.0
  square_time_weight: False
  square_bond_loss: False
  pairwise_threshold: 5.0
  scale_noise: False
  bond_loss_time_threshold: 1.0 # 0.25
  scale_noise_factor: 0.2 # from Semla paper
  pharmacophore_conditioning: False
  pharmacophore_subset: 40
  pharm_cond_mol: None

loader:
  global_batch_size: 512
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: ddpm  # ddpm, path_planning
  steps: 128
  boltz_inference_align: False
  noise_removal: True
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  conf_out_dir: ${paths.output_dir}/samples
  constrain_edge_sampling: False
  refine_coordinates_steps: 0
  path_planning:
    tau: 1.0
    eta: 1.0
    score_type: confidence
  spatial:
    guidance: None
    cond_mol_path: ""
    inference_annealing: False
    inference_annealing_coef: 10.0
    integrator: euler # euler
    stochastic: False
    churn: 1
    tmin: 0.0
    tmax: 0.2

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False

eval:
  checkpoint_path: ''  # Used to evaluate a checkpoint after training.
  disable_ema: False
  generate_samples: True

rgfn:
  reassembly_logging: False
  compatibility_mask: False

optim:
  weight_decay: 1e-5
  lr: 3e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: 'bf16'
  num_sanity_val_steps: 2
  max_steps: 1_000_000
  log_every_n_steps: 10
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
  check_val_every_n_epoch: 1

wandb:
  project: training runs
  notes: Masked diffusion graph modeling for synthesizable molecules!
  group: ""
  job_type: null
  log_model: False
  entity: syncogen
  save_dir: "${paths.output_dir}"
  mode: online
  id: null
  tags:
    - ${discrete_noise.type}
    - ${spatial_noise.type}
    - ${data.train}
    - ${data.valid}

hydra:
  run:
    dir: ${paths.log_dir}/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${paths.output_dir}
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt

paths:
  root_dir: .
  data_dir: ${paths.root_dir}/data
  log_dir: ${paths.root_dir}/logs
  output_dir: ${hydra:runtime.output_dir}
  work_dir: ${hydra:runtime.cwd}
  use_lmdb: False
