program: trainer.py
method: bayes
command:
  - python
  - trainer.py
  - ${args_no_hyphens}
metric:
  goal: maximize
  name: val_epitope_f1

parameters:
  mode:
    value: val
  wandb.project:
    value: "m3epi_v3"
  dataset.split.method:
    value: "epitope_ratio"
  dataset.tensor:
    value: "epiformer_dataset.pkl"
  num_threads:
    value: 3
  resume:
    value: false
  model.enable_pretraining:
    value: false
  model.graph_type:
    value: "raad"
  callbacks.early_stopping.patience:
    value: 15

  # antibody encoder architecture
  model.ab_encoder.resmp_enabled:
    value: true
  model.ab_encoder.resmp_type:
    value: "egnn"
  model.ab_encoder.edgemp_enabled:
    value: false
  model.ab_encoder.atommp_enabled:
    value: false
  model.ab_encoder.residue_layers:
    value: 4
  model.ab_encoder.feature_fusion_type:
    value: "gated"

  # antigen encoder architecture
  model.ag_encoder.resmp_enabled:
    value: true
  model.ag_encoder.resmp_type:
    value: "egnn"
  model.ag_encoder.edgemp_enabled:
    value: false
  model.ag_encoder.atommp_enabled:
    value: false
  model.ag_encoder.residue_layers:
    value: 4
  model.ag_encoder.feature_fusion_type:
    value: "concat"

  # decoder configurations
  model.decoder.type:
    value:  "cross_attention"
  model.decoder.sampling_strat:
    value: "top_k_mean"
  model.decoder.decoder_layers:
    value: 3
  model.decoder.d_ff:
    value: 512
  model.decoder.n_heads:
    value: 8
  model.decoder.predict_distances:
    value: false
    
  model.use_layer_norm:
    values: [true, false]
  model.dropout:
    value: 0.24
  model.dropout_rates.res_mp:
    value: 0.25
  model.dropout_rates.decoder:
    value: 0.025
  model.dropout_rates.projections:
    value: 0.025


  # loss configurations
  
  # node prediction
  loss.node_prediction.enabled:
    value: true
  loss.node_prediction.name:
    value: "bce"
  loss.node_prediction.task:
    value: "epi_only"
  loss.node_prediction.weight:
    value: 1.0
    # distribution: uniform
    # min: 0.3
    # max: 5.0
  loss.node_prediction.bce_weight:
    distribution: uniform
    min: 0.1
    max: 10.0
  loss.node_prediction.epi_pos_weight:
    distribution: log_uniform_values
    min: 5.0
    max: 100.0

  # node prediction regulizers
  loss.node_prediction.count_regularizer_enabled:
    value: true
  loss.count_regularizer.per_graph_matching:
    value: true
  loss.count_regularizer.epitope_weight:
    distribution: uniform
    min: 0.01
    max: 1.0
  # loss.count_regularizer.paratope_weight:
  #   distribution: uniform
  #   min: 0.05
  #   max: 0.7
  loss.node_prediction.dice_enabled:
    value: true
  loss.node_prediction.dice_weight:
    distribution: uniform
    min: 0.01
    max: 1.0
  loss.node_prediction.smoothness_enabled:
    value: false
  # loss.node_prediction.smoothness_weight:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7
  loss.node_prediction.edge_node_consistency_enabled:
    value: false
  # loss.node_prediction.consistency_weight:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7

  # edge prediction
  loss.edge_prediction.enabled:
    value: false
  # loss.edge_prediction.weight:
  #   distribution: log_uniform_values
  #   min: 0.005
  #   max: 0.5
  # loss.edge_prediction.pos_weight:
  #   distribution: log_uniform_values
  #   min: 5.0
  #   max: 50

  # contrastive
  loss.contrastive.enabled:
    value: false
  # loss.contrastive.name:
  #   value: "infonce"
  # loss.contrastive.weight:
  #   distribution: log_uniform_values
  #   min: 0.005
  #   max: 0.5
  # loss.contrastive.temperature:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7
  # loss.contrastive.inter_weight:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7
  # loss.contrastive.intra_weight:
  #   distribution: uniform
  #   min: 0.01
  #   max: 0.7

  loss.walle.enabled:
    value: false
    
  loss.force.enabled:
    value: false

  # training hparams
  hparams.train.num_epochs:
    value: 130
  hparams.train.regularization.use_l2_reg:
    value: false
  hparams.train.learning_rate:
    value: 0.0003
  hparams.train.batch_size:
    value: 8
  hparams.train.weight_decay:
    value: 0.0001
    



count: 50
early_terminate:
  type: hyperband
  max_iter: 10