# Here, we are finetuning a FullGraphMultitaskNetwork
# trained on ToyMix. We finetune from the gnn on the 
# TDC dataset lipophilicity_astraceneca

# Here are the changes to the architecture:
#   Change gnn:
#     depth:        4 -> 4 - 2 + 3 = 5
#
#   Keep modules after gnn and apply modifications
#     graph_output_nn/graph:
#       pooling: sum -> mean
#       depth: 1 -> 2
#     task_heads/zinc:
#       new_sub_module: lipophilicity_astrazeneca
#       out_dim: 3 -> 1
#
# Finetuning training:
#   unfreeze one additional layer of pretrained gnn
#   after 1 epochs, unfreeze all layers


###################################################
########### How to combine information  ###########
###################################################


###########################
### FINETUNING-SPECIFIC ###
###########################

finetuning:
  # New task
  task: lipophilicity_astrazeneca
  level: graph

  # Pretrained model
  pretrained_model: dummy-pretrained-model
  finetuning_module: gnn 

  # Changes to finetuning_module                                                  
  drop_depth: 2
  added_depth: 3

  keep_modules_after_finetuning_module: # optional
    graph_output_nn-graph:
      pooling: [mean]
      depth: 2
    task_heads-zinc:
      new_sub_module: lipophilicity_astrazeneca
      out_dim: 1

  # Finetuning training
  unfreeze_pretrained_depth: 1
  epoch_unfreeze_all: 1

constants:
  seed: 42
  max_epochs: 2

accelerator:
  float32_matmul_precision: medium
  type: cpu

predictor:
  random_seed: ${constants.seed}
  optim_kwargs:
    lr: 4.e-5
  scheduler_kwargs: null
  target_nan_mask: null
  multitask_handling: flatten # flatten, mean-per-label
  
  torch_scheduler_kwargs:
    module_type: WarmUpLinearLR
    max_num_epochs: 2
    warmup_epochs: 1
    verbose: False
  
  metrics_on_progress_bar:
    lipophilicity_astrazeneca: ["mae"]
  loss_fun:
    lipophilicity_astrazeneca: mae

metrics:
  lipophilicity_astrazeneca:
    - name: mae
      metric: mae
      target_nan_mask: null
      multitask_handling: flatten
      threshold_kwargs: null
    - name: spearman
      metric: spearmanr
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label
    - name: pearson
      metric: pearsonr
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label
    - name: r2_score
      metric: r2_score
      target_nan_mask: null
      multitask_handling: mean-per-label
      threshold_kwargs: null

trainer:
  seed: ${constants.seed}
  trainer:
    precision: 32
    max_epochs: 2
    min_epochs: 1
    check_val_every_n_epoch: 1
    accumulate_grad_batches: 1
  
##################
### DATAMODULE ###
##################

datamodule:

### FROM FINETUNING ###

  module_type: "ADMETBenchmarkDataModule"
  args:
    # TDC specific
    tdc_benchmark_names: [lipophilicity_astrazeneca]
    tdc_train_val_seed: ${constants.seed}
    
    batch_size_training: 200
    batch_size_inference: 200
    featurization_n_jobs: 0
    num_workers: 0

    prepare_dict_or_graph: pyg:graph
    featurization_progress: True
    featurization_backend: "loky"
    persistent_workers: False