data:
  data_name: plbind
  data_dir: ../data
  neg_aug_p: 0.1
  use_rdkit_coords: true
  use_pocket_label: true
  pocket_cutoff: 15  # either 15 or 13 is good
  use_contact_label: false
  nearby_res_cutoff: 0
  bin_thres: 10

  feature_type: only_x # only_pos or only_x or both_x_pos for FeatEncoder, Equivariant models will use pos by default
  n_categorical_feat_to_use_lig: 1  # in {-1, 0, 1}, -1 means all categorical features
  n_scalar_feat_to_use_lig: -1  # in {-1, 0}, -1 means all scalar features
  n_categorical_feat_to_use: 1  # in {-1, 0, 1}, -1 means all categorical features
  n_scalar_feat_to_use: -1  # in {-1, 0}, -1 means all scalar features
  use_lig_info: true

logging:
  tensorboard: false
  topk: [20, 40, 60]
  # if we assume
    # 1/2 labeled amino acids are critial, then each sample on average have 60 critical amino acids
    # 1/3 labeled amino acids are critial, then each sample on average have 40 critical amino acids
    # 1/6 labeled amino acids are critial, then each sample on average have 20 critical amino acids

optimizer:
  batch_size: 128
  wp_lr: 1.0e-3
  wp_wd: 1.0e-5
  attn_lr: 1.0e-3
  attn_wd: 1.0e-5
  emb_lr: 1.0e-3
  emb_wd: 1.0e-5
  do_adamw: false

model:
  egnn:
    n_layers: 4
    hidden_size: 64
    dropout_p: 0.2
    norm_type: batch
    act_type: relu
    pool: add
  pointtrans:
    n_layers: 4
    hidden_size: 64
    dropout_p: 0.2
    norm_type: batch
    act_type: relu
    pool: add
  dgcnn:
    n_layers: 4
    hidden_size: 64
    dropout_p: 0.2
    norm_type: batch
    act_type: relu
    pool: add


# Default hyperparameters are tuned for dgcnn.
# If the tuned hyperparameters for other backbone models are different, they will be commented on the right.

lri_bern:
  epochs: 50
  warmup: 50
  final_r: 0.5
  info_loss_coef: 0.01         # pointtrans: 1.0, egnn: 0.1
  one_encoder: false
  attn_constraint: smooth_min

  temperature: 1.0
  decay_interval: 10
  decay_r: 0.1
  init_r: 0.9
  pred_loss_coef: 1.0
  pred_lr: 1.0e-3
  pred_wd: 1.0e-5
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

lri_gaussian:
  epochs: 50
  warmup: 50
  pos_coef: 1.2            # pointtrans: 0.9
  info_loss_coef: 1.0      # pointtrans: 0.1, egnn: 0.01
  kr: 7
  one_encoder: false
  attn_constraint: smooth_min

  covar_dim: 3
  pred_loss_coef: 1.0
  pred_lr: 1.0e-3
  pred_wd: 1.0e-5
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

gradcam:
  epochs: 1
  warmup: 100
  pred_lr: 0
  pred_wd: 0

gradgeo:
  epochs: 1
  warmup: 100
  gradgeo: true
  pred_lr: 0
  pred_wd: 0

bernmask_p:
  epochs: 50
  warmup: 100
  size_loss_coef: 0.1             # egnn: 0.01
  mask_ent_loss_coef: 1.0         # pointtrans: 0.1, egnn: 0.1

  temp: [1.0, 1.0]
  pred_loss_coef: 1.0
  pred_lr: 0
  pred_wd: 0
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

bernmask:
  epochs: 1
  warmup: 100
  size_loss_coef: 0.01            # pointtrans: 1.0, egnn: 0.1
  mask_ent_loss_coef: 0.01        # egnn: 0.1
  iter_lr: 1.0e-1

  iter_per_sample: 500
  pred_loss_coef: 1.0
  pred_lr: 0
  pred_wd: 0
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

pointmask:
  epochs: 100
  warmup: 0
  t: 0.2                          # egnn: 0.8
  kl_loss_coef: 0.01              # egnn: 0.1

  one_encoder: false
  covar_dim: 2  # MLP output 2 dims for mu and sigma
  pred_loss_coef: 1.0
  pred_lr: 1.0e-3
  pred_wd: 1.0e-5
  dropout_p: 0.2
  norm_type: batch
  act_type: relu
