experiment_name: 'NTXentMultiplePositives_620000'

models_to_save:
  - 5
  - 10
  - 20
  - 35
  - 50
  - 65
  - 80
  - 100
  - 125
  - 150
  - 200
  - 300
  - 400
dataset: qmugs
num_epochs: 1000
batch_size: 500
log_iterations: 50
patience: 35
loss_func: NTXentMultiplePositives
loss_params:
  tau: 0.1
num_train: 620000

required_data:
  - dgl_graph
  - conformations
metrics:
  - positive_similarity
  - negative_similarity
  - contrastive_accuracy
  - true_negative_rate
  - true_positive_rate
  - uniformity
  - alignment
  - batch_variance
  - dimension_covariance
main_metric: loss
collate_function: conformer_collate
num_conformers: 3


optimizer: Adam
optimizer_params:
  lr: 8.0e-5


scheduler_step_per_batch: False
lr_scheduler: WarmUpWrapper
lr_scheduler_params:
  warmup_steps: [700]
  # parameters of scheduler to run after warmup
  wrapped_scheduler: ReduceLROnPlateau
  cooldown: 20
  factor: 0.6
  patience: 25
  min_lr: 1.0e-6
  threshold: 1.0e-4
  mode: 'min'
  verbose: True



# Model parameters
model_type: 'PNA'
model_parameters:
  target_dim: 256
  hidden_dim: 200
  mid_batch_norm: True
  last_batch_norm: True
  readout_batchnorm: True
  # e^(log(forgetfulness ~0.001)/number of steps per epoch) = batch_norm_momentum   =>  e^(log(0.001)/100) = 0.970
  batch_norm_momentum: 0.93
  readout_hidden_dim: 200
  readout_layers: 2
  dropout: 0.0
  propagation_depth: 7
  aggregators:
    - mean
    - max
    - min
    - std
  scalers:
    - identity
    - amplification
    - attenuation
  readout_aggregators:
    - min
    - max
    - mean
  pretrans_layers: 2
  posttrans_layers: 1
  residual: True

# Model parameters
model3d_type: 'Net3D'
model3d_parameters:
  target_dim: 256
  hidden_dim: 20
  hidden_edge_dim: 20
  node_wise_output_layers: 0
  message_net_layers: 1
  update_net_layers: 1
  reduce_func: 'mean'
  fourier_encodings: 4
  propagation_depth: 1
  dropout: 0.0
  batch_norm: True
  readout_batchnorm: True
  # e^(log(forgetfulness ~0.001)/number of steps per epoch) = batch_norm_momentum   =>  e^(log(0.001)/100) = 0.970
  batch_norm_momentum: 0.93
  readout_hidden_dim: 20
  readout_layers: 1
  readout_aggregators:
    - min
    - max
    - mean





# continue training from checkpoint:
#checkpoint: runs/PNAReadout_2_layer_03-04_15-29-07/last_checkpoint.pt