program: trainer.py
method: bayes
command:
  - python
  - trainer.py
  - ${args_no_hyphens}
metric:
  goal: maximize
  name: val_epitope_f1

parameters:
  mode:
    value: "val"
  wandb.project:
    value: "m3epi_v3"
  dataset.split.method:
    value: "epitope_group"
  dataset.graph_type:
    value: "raad-plm"
  dataset.plm_type:
    value: "esm2_650m"
  dataset.graph_num_relations:
    value: 4
  num_threads:
    value: 3
  resume:
    value: false
  model.name:
    value: "epiformer"
  model.enable_pretraining:
    value: false
  callbacks.early_stopping.patience:
    value: 10

  # encoder architecture
  model.epiformer.residue_layers:
    values: [3,4,5]
  model.epiformer.residue_dim:
    value: 128
  model.epiformer.residue_hidden_dim:
    value: 128
  model.epiformer.plm_dim:
    value: 128
  model.epiformer.n_heads:
    value: 8
  model.epiformer.use_layer_norm:
    value: true
  model.epiformer.use_gradient_checkpointing:
    value: false
  model.epiformer.ag_feature_fusion_type:
    value: "concat"
  model.epiformer.ab_feature_fusion_type:
    value: "gated"
  model.epiformer.activation:
    value: "silu"
  model.epiformer.dropout:
    distribution: log_uniform_values
    min: 0.05
    max: 0.5
  model.dropout_rates.decoder:
    distribution: log_uniform_values
    min: 0.05
    max: 0.5


  # decoder config
  model.decoder.type:
    value: "cross_attention"
  model.decoder.num_rbf:
    value: 16
  model.decoder.d_k:
    value: 64
  model.decoder.d_ff:
    value: 128
  model.decoder.d_model:
    value: 128
  model.decoder.n_heads:
    value: 8
  model.decoder.decoder_layers:
    values: [2,3,4]
  model.decoder.sampling_strat:
    value: "top_k_mean_2"


  model.epi_threshold:
    value: 0.3
  model.para_threshold:
    value: 0.3
  model.use_layer_norm:
    value: true
  callbacks.early_stopping.patience:
    value: 10


  # loss configurations

  # edge prediction
  loss.edge_prediction.enabled:
    value: true
  loss.edge_prediction.weight:
    value: 1.0
  loss.edge_prediction.pos_weight:
    distribution: log_uniform_values
    min: 30
    max: 150

  # node prediction
  loss.node_prediction.enabled:
    value: true
  loss.node_prediction.name:
    value: "bce"
  loss.node_prediction.task:
    value: "epi_only"
  loss.node_prediction.weight:
    distribution: uniform
    min: 0.05
    max: 0.5
  loss.node_prediction.bce_weight:
    distribution: uniform
    min: 2
    max: 10
  loss.node_prediction.epi_pos_weight:
    distribution: log_uniform_values
    min: 10
    max: 60

  # node prediction regulizers
  loss.node_prediction.count_regularizer_enabled:
    value: true
  loss.count_regularizer.per_graph_matching:
    value: true
  loss.count_regularizer.epitope_weight:
    distribution: uniform
    min: 0.05
    max: 1.0
  loss.node_prediction.dice_enabled:
    value: true
  loss.node_prediction.dice_weight:
    distribution: uniform
    min: 0.1
    max: 3.0

  model.decoder.predict_distances:
    value: true
  loss.auxiliary_distance.enabled:
    value: true
  loss.auxiliary_distance.weight:
    distribution: uniform
    min: 0.05
    max: 0.3
  loss.auxiliary_distance.distance_weighting:
    value: true
  loss.auxiliary_distance.class_balancing:
    value: true
  loss.auxiliary_distance.max_distance:
    value: 32.0

  loss.node_prediction.smoothness_enabled:
    value: false
  loss.node_prediction.edge_node_consistency_enabled:
    value: false



  # contrastive
  loss.contrastive.enabled:
    value: false
  # loss.contrastive.name:
  #   value: "infonce"
  # loss.contrastive.weight:
  #   distribution: log_uniform_values
  #   min: 0.005
  #   max: 0.5
  # loss.contrastive.temperature:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7
  # loss.contrastive.inter_weight:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7
  # loss.contrastive.intra_weight:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7
  loss.walle.enabled:
    value: false  
  loss.force.enabled:
    value: false


  # training hparams
  hparams.train.num_epochs:
    value: 130
  hparams.train.regularization.use_l2_reg:
    value: false
  hparams.train.learning_rate:
    distribution: log_uniform_values
    min: 0.00001
    max: 0.0001
  hparams.train.batch_size:
    value: 8
  hparams.train.weight_decay:
    distribution: log_uniform_values
    min: 0.00001
    max: 0.0001
    



count: 80
early_terminate:
  type: hyperband
  max_iter: 10