program: trainer.py
method: bayes
command:
  - python
  - trainer.py
  - ${args_no_hyphens}
metric:
  goal: maximize
  name: val_epitope_f1

parameters:
  mode:
    value: val
  wandb.project:
    value: "m3epi_v3"
  dataset.split.method:
    value: "epitope_ratio"
  dataset.tensor:
    value: "epiformer_dataset.pkl"
  num_threads:
    value: 3
  resume:
    value: false
  model.enable_pretraining:
    value: false
  model.graph_type:
    value: "raad"
  callbacks.early_stopping.patience:
    values: [20, 40]

  # antibody encoder architecture
  model.ab_encoder.resmp_enabled:
    value: true
  model.ab_encoder.resmp_type:
    value: "egnn"
  model.ab_encoder.edgemp_enabled:
    value: false
  model.ab_encoder.atommp_enabled:
    value: false
  model.ab_encoder.residue_layers:
    values: [ 3, 4]
  model.ab_encoder.feature_fusion_type:
    values: ["concat", "gated"]

  # antigen encoder architecture
  model.ag_encoder.resmp_enabled:
    value: true
  model.ag_encoder.resmp_type:
    value: "egnn"
  model.ag_encoder.edgemp_enabled:
    value: false
  model.ag_encoder.atommp_enabled:
    value: false
  model.ag_encoder.residue_layers:
    values: [ 3, 4]
  model.ag_encoder.feature_fusion_type:
    values: ["concat", "gated"]

  # decoder configurations
  model.decoder.type:
    values: ["dot_product", "cross_attention"]
  model.decoder.sampling_strat:
    values: ["max_row", "top_k_mean"]
  model.decoder.decoder_layers:
    values: [ 3, 4]
  model.decoder.d_ff:
    values: [256, 512]
  model.decoder.n_heads:
    value: 8
  model.decoder.predict_distances:
    value: false
    
  model.use_layer_norm:
    values: [false, true]
  model.dropout:
    distribution: uniform
    min: 0.01
    max: 0.3
  model.dropout_rates.res_mp:
    distribution: uniform
    min: 0.01
    max: 0.3
  model.dropout_rates.decoder:
    distribution: uniform
    min: 0.01
    max: 0.3
  model.dropout_rates.projections:
    distribution: uniform
    min: 0.01
    max: 0.3


  # loss configurations
  
  # node prediction
  loss.node_prediction.enabled:
    value: true
  loss.node_prediction.name:
    value: "bce"
  loss.node_prediction.task:
    value: "epi_only"
  loss.node_prediction.weight:
    distribution: uniform
    min: 0.3
    max: 5.0
  loss.node_prediction.bce_weight:
    distribution: uniform
    min: 0.1
    max: 5.0
  loss.node_prediction.epi_pos_weight:
    distribution: log_uniform_values
    min: 5.0
    max: 200.0

  # node prediction regulizers
  loss.node_prediction.count_regularizer_enabled:
    values: [true, false]
  loss.count_regularizer.per_graph_matching:
    value: true
  loss.count_regularizer.epitope_weight:
    distribution: uniform
    min: 0.01
    max: 0.7
  loss.count_regularizer.paratope_weight:
    distribution: uniform
    min: 0.05
    max: 0.7
  loss.node_prediction.dice_enabled:
    values: [true, false]
  loss.node_prediction.dice_weight:
    distribution: uniform
    min: 0.01
    max: 0.7
  loss.node_prediction.smoothness_enabled:
    values: [true, false]
  loss.node_prediction.smoothness_weight:
    distribution: uniform
    min: 0.01
    max: 0.7
  loss.node_prediction.edge_node_consistency_enabled:
    value: false
  loss.node_prediction.consistency_weight:
    distribution: uniform
    min: 0.01
    max: 0.7

  # edge prediction
  loss.edge_prediction.enabled:
    values: [true, false]
  loss.edge_prediction.weight:
    distribution: log_uniform_values
    min: 0.005
    max: 0.5
  loss.edge_prediction.pos_weight:
    distribution: log_uniform_values
    min: 5.0
    max: 50

  # contrastive
  loss.contrastive.enabled:
    values: [false, true]
  loss.contrastive.name:
    value: "infonce"
  loss.contrastive.weight:
    distribution: log_uniform_values
    min: 0.005
    max: 0.5
  loss.contrastive.temperature:
    distribution: uniform
    min: 0.01
    max: 0.7
  loss.contrastive.inter_weight:
    distribution: uniform
    min: 0.01
    max: 0.7
  loss.contrastive.intra_weight:
    distribution: uniform
    min: 0.01
    max: 0.7

  loss.walle.enabled:
    value: false
    
  loss.force.enabled:
    value: false

  # training hparams
  hparams.train.num_epochs:
    values: [100, 130]
  hparams.train.regularization.use_l2_reg:
    value: false
  hparams.train.learning_rate:
    distribution: log_uniform_values
    min: 0.0001
    max: 0.0005
  hparams.train.batch_size:
    values: [8, 16]
  hparams.train.weight_decay:
    distribution: log_uniform_values
    min: 0.0001
    max: 0.01



count: 150
early_terminate:
  type: hyperband
  max_iter: 10