project_name: &_project_name !ENV "${LOGNAME}_bfn_molopt"
exp_name: !SUB ${exp_name}
revision: &_revision !SUB ${revision}
debug: !SUB ${debug}
no_wandb: !SUB ${no_wandb}
wandb_resume_id: !SUB ${wandb_resume_id}
logging_level: !SUB ${logging_level}
seed: !SUB ${seed}
test_only: !SUB ${test_only}
empty_folder: !SUB ${empty_folder}


#model
dynamics:
  t_min: !SUB ${t_min}
  sigma1_coord: !SUB ${sigma1_coord}
  beta1: !SUB ${beta1}
  protein_atom_feature_dim: 27 #TODO
  ligand_atom_feature_dim: 13
  use_discrete_t: !SUB ${use_discrete_t}
  discrete_steps: !SUB ${discrete_steps}
  destination_prediction: !SUB ${destination_prediction}
  sampling_strategy: !SUB ${sampling_strategy}

  # no_diff_coord: !SUB ${no_diff_coord}
  # charge_discretised_loss: !SUB ${charge_discretised_loss}
  node_indicator: True
  time_emb_dim: !SUB ${time_emb_dim}
  time_emb_mode: !SUB ${time_emb_mode}
  center_pos_mode: protein
  pos_init_mode: !SUB ${pos_init_mode}
  net_config:
      name: "unio2net"
      num_blocks: 1
      num_layers: 9
      hidden_dim: 128
      n_heads: 16
      edge_feat_dim: 4  # edge type feat
      num_r_gaussian: 20
      knn: 32 # !
      num_node_types: 8
      act_fn: relu
      norm: True
      cutoff_mode: knn  # [radius, none]
      ew_net_type: global  # [r, m, none]
      num_x2h: 1
      num_h2x: 1
      r_max: 10.
      x2h_out_fc: False
      sync_twoup: False


accounting:
  logdir: &_logdir !PATHJOIN [!ENV "${HOME}/project/logs/", *_project_name, !SUB "${exp_name}", *_revision]
  dump_config_path: !PATHJOIN [*_logdir, "config.yaml"]
  wandb_logdir: *_logdir
  checkpoint_dir: !PATHJOIN [*_logdir, "checkpoints"]
  generated_mol_dir: !PATHJOIN [*_logdir, "generated_mol"]
  test_outputs_dir: !PATHJOIN [*_logdir, "test_outputs"]


data:
  name: pl_dcmp # pl, pl_tr
  path: /sharefs/share/sbdd_data/crossdocked_pocket10
  split: null
  transform:
    ligand_atom_mode: add_aromatic
  with_split: True
  atom_decoder: ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl']
  colors_dic: ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1', 'C4', 'C8', 'C9', ]
  radius_dic: [0.3, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6]
  normalizer_dict: 
    pos: !SUB ${pos_normalizer}
  

visual:
  save_mols: True
  visual_nums: 10
  visual_chain: !SUB ${visual_chain}


train:
  batch_size: !SUB ${batch_size}
  num_workers: 8
  pos_noise_std: !SUB ${pos_noise_std}
  random_rot: !SUB ${random_rot}
  val_freq: 1000
  epochs: !SUB ${epochs}
  resume: !SUB ${resume}
  v_loss_weight: !SUB ${v_loss_weight}
  ckpt_freq: 1
  max_grad_norm: !SUB ${max_grad_norm}
  optimizer: 
    type: "adam"
    lr: !SUB ${lr}
    weight_decay: !SUB ${weight_decay}
    beta1: 0.95
    beta2: 0.999
  scheduler: 
    type: !SUB ${scheduler}
    factor: 0.6
    patience: 10
    min_lr: 1.e-6
    max_iters: 20000
  ema_decay: 0.999


evaluation:
  batch_size: 5
  last_ckpt: !SUB ${last_ckpt}
  sample_steps: !SUB ${sample_steps}
  num_samples: !SUB ${num_samples}  # x-% test
  sample_num_atoms: !SUB ${sample_num_atoms}  # ['prior', 'ref']
  docking_config:
    mode: !SUB ${docking_mode} # ['qvina', 'vina_score', 'vina_dock']
    protein_root: /sharefs/share/sbdd_data/test_set
    exhaustiveness: 16
  ligand_path: !SUB ${ligand_path:-null}
  protein_path: !SUB ${protein_path:-null}
  guide_mode: !SUB ${guide_mode}
  objective: !SUB ${objective}
  pos_grad_weight: !SUB ${pos_grad_weight}
  type_grad_weight: !SUB ${type_grad_weight}
  save_traj: !SUB ${save_traj}
  change_scaffold: !SUB ${change_scaffold:-False}
  interpolate_coef: !SUB ${interpolate_coef:-0.5}
  cfg_coef: !SUB ${cfg_coef:-0.0}


