run_name: drugflow_preference_alignment

checkpoint: ./reference.ckpt # TODO: specify reference checkpoint
dpo_mode: single_dpo_comp_v3

pocket_representation: CA+
virtual_nodes: [0, 10]
flexible: False
flexible_bb: False

train_params:
  logdir: ./runs # symlink to any location you like
  datadir: ./processed_crossdocked # symlink to the dataset location
  enable_progress_bar: True
  num_sanity_val_steps: 0
  batch_size: 64
  accumulate_grad_batches: 2
  lr: 5.0e-5
  n_epochs: 500
  num_workers: 0
  gpus: 1
  clip_grad: True
  gnina: gnina # path to gnina binary
  sample_from_clusters: False
  sharded_dataset: False

wandb_params:
  mode: online  # disabled, offline, online
  entity:
  group: crossdocked

loss_params:
  discrete_loss: VLB  # VLB or CE
  lambda_x: 1.0
  lambda_h: 500
  dpo_lambda_h: 2500
  lambda_e: 500
  dpo_lambda_e: 2500
  lambda_chi: 0.5  # only effective if flexible=True
  lambda_trans: 1.0  # only effective if flexible_bb=True
  lambda_rot: 0.1  # only effective if flexible_bb=True
  lambda_clash: null
  timestep_weights: null  # sigmoid_a=1_b=10  # null, sigmoid_a=?_b=?
  dpo_beta: 100.0
  dpo_beta_schedule: 't'
  dpo_lambda_w: 1.0
  dpo_lambda_l: 0.2
  clamp_dpo: False

simulation_params:
  n_steps: 5000
  prior_h: marginal # uniform, marginal
  prior_e: uniform # uniform, marginal
  predict_final: False
  predict_confidence: False

eval_params:
  eval_epochs: 4
  n_eval_samples: 1
  n_sampling_steps: 500
  eval_batch_size: 16
  visualize_sample_epoch: 1
  n_visualize_samples: 10
  visualize_chain_epoch: 1
  keep_frames: 100
  sample_with_ground_truth_size: True

predictor_params:
  heterogeneous_graph: True
  backbone: gvp
  num_rbf_time: 16
  edge_cutoff_ligand: null
  edge_cutoff_pocket: 10.0
  edge_cutoff_interaction: 10.0
  cycle_counts: True
  spectral_feat: False
  reflection_equivariant: False
  num_rbf: 16
  d_max: 15.0
  self_conditioning: True
  augment_residue_sc: False
  augment_ligand_sc: False
  normal_modes: False
  add_chi_as_feature: False
  angle_act_fn: null
  add_all_atom_diff: False

  gvp_params:
    n_layers: 5
    node_h_dim: [ 128, 32 ]  # (s, V)
    edge_h_dim: [ 128, 32 ]
    dropout: 0.0
    vector_gate: True
