# Here, we are finetuning a FullGraphMultitaskNetwork
# trained on ToyMix. We finetune from the zinc task-head
# (graph-level) on the TDC dataset lipophilicity_astraceneca

# Here are the changes to the architecture:
#   Change zinc task-head:
#     depth:        2 -> 2 - 1 + 2 = 3
#     out_dim:      3 -> 8
#
#   Add finetuning head
#     model_type:   FeedForwardNN
#     out_dim:      1
#     hidden_dims:  8
#     depth:        2
#
# Finetuning training:
#   after 1 epochs, unfreeze all layers


###################################################
########### How to combine information  ###########
###################################################


###########################
### FINETUNING-SPECIFIC ###
###########################

finetuning:
  # New task
  task: lipophilicity_astrazeneca
  level: graph

  # Pretrained model
  pretrained_model: dummy-pretrained-model
  finetuning_module: task_heads  
  sub_module_from_pretrained: zinc # optional
  new_sub_module: lipophilicity_astrazeneca # optional
  # keep_modules_after_finetuning_module: # optional

  # Changes to finetuning_module                                                  
  drop_depth: 1
  new_out_dim: 8
  added_depth: 2

  # Optional finetuning head appended to model after finetuning_module
  finetuning_head: # none
    task: lipophilicity_astrazeneca
    previous_module: task_heads
    incoming_level: graph
    model_type: mlp
    in_dim: 8
    out_dim: 1
    hidden_dims: 8
    depth: 2
    last_layer_is_readout: true

  # Finetuning training
  unfreeze_pretrained_depth: 0
  epoch_unfreeze_all: 1

constants:
  seed: 42
  max_epochs: 2

accelerator:
  float32_matmul_precision: medium
  type: cpu

predictor:
  random_seed: ${constants.seed}
  optim_kwargs:
    lr: 4.e-5
  scheduler_kwargs: null
  target_nan_mask: null
  multitask_handling: flatten # flatten, mean-per-label
  
  torch_scheduler_kwargs:
    module_type: WarmUpLinearLR
    max_num_epochs: 2
    warmup_epochs: 1
    verbose: False
  
  metrics_on_progress_bar:
    lipophilicity_astrazeneca: ["mae"]
  loss_fun:
    lipophilicity_astrazeneca: mae

metrics:
  lipophilicity_astrazeneca:
    - name: mae
      metric: mae
      target_nan_mask: null
      multitask_handling: flatten
      threshold_kwargs: null
    - name: spearman
      metric: spearmanr
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label
    - name: pearson
      metric: pearsonr
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label
    - name: r2_score
      metric: r2_score
      target_nan_mask: null
      multitask_handling: mean-per-label
      threshold_kwargs: null

trainer:
  seed: ${constants.seed}
  trainer:
    precision: 32
    max_epochs: 2
    min_epochs: 1
    check_val_every_n_epoch: 1
    accumulate_grad_batches: 1
  
##################
### DATAMODULE ###
##################

datamodule:

### FROM FINETUNING ###

  module_type: "ADMETBenchmarkDataModule"
  args:
    # TDC specific
    tdc_benchmark_names: [lipophilicity_astrazeneca]
    tdc_train_val_seed: ${constants.seed}
    
    batch_size_training: 200
    batch_size_inference: 200
    featurization_n_jobs: 0
    num_workers: 0

    prepare_dict_or_graph: pyg:graph
    featurization_progress: True
    featurization_backend: "loky"
    persistent_workers: False




