datamodule:
  module_type: "MultitaskFromSmilesDataModule"
  args: # Matches that in the test_multitask_datamodule.py case.
    task_specific_args:   # To be replaced by a new class "DatasetParams"
      SA:
        df: null
        task_level: "graph"
        df_path: "graphium/data/multitask/tiny_ZINC_SA.csv"
        smiles_col: "SMILES"
        label_cols: ["SA"]
        split_val: 0.2
        split_test: 0.2
        seed: 19
        splits_path: null                 # This may not always be provided
        sample_size: null                 # This may not always be provided
        idx_col: null                     # This may not always be provided
        weights_col: null                 # This may not always be provided
      logp:
        df: null
        task_level: "graph"
        df_path: "graphium/data/multitask/tiny_ZINC_logp.csv"
        smiles_col: "SMILES"
        label_cols: ["logp"]
        split_val: 0.2
        split_test: 0.2
        seed: 19
        splits_path: null
        sample_size: null
        idx_col: null
        weights_col: null
        weights_type: null
      score:
        df: null
        task_level: "graph"
        df_path: "graphium/data/multitask/tiny_ZINC_score.csv"
        smiles_col: "SMILES"
        label_cols: ["score"]
        split_val: 0.2
        split_test: 0.2
        seed: 19
        splits_path: null
        sample_size: null
        idx_col: null
        weights_col: null
        weights_type: null

    # Featurization
    featurization_n_jobs: 16
    featurization_progress: True
    featurization:
      atom_property_list_onehot: ["atomic-number", "degree"]
      atom_property_list_float: []
      edge_property_list: ["in-ring", "bond-type-onehot"]
      add_self_loop: False
      explicit_H: False
      use_bonds_weights: False

    # Data handling-related
    batch_size_training: 16
    batch_size_inference: 16

architecture:     # The parameters for the full graph network are taken from `config_micro_ZINC.yaml`
  model_type: FullGraphMultiTaskNetwork
  pre_nn:         # Set as null to avoid a pre-nn network
    out_dim: 32
    hidden_dims: 32
    depth: 1
    activation: relu
    last_activation: none
    dropout: &dropout 0.1
    normalization: &normalization "batch_norm"
    last_normalization: *normalization
    residual_type: none

  pre_nn_edges:   # Set as null to avoid a pre-nn network
    out_dim: 16
    hidden_dims: 16
    depth: 2
    activation: relu
    last_activation: none
    dropout: *dropout
    normalization: *normalization
    last_normalization: *normalization
    residual_type: none

  gnn:            # Set as null to avoid a post-nn network
    out_dim: 32
    hidden_dims: 32
    depth: 4
    activation: relu
    last_activation: none
    dropout: *dropout
    normalization: *normalization
    last_normalization: *normalization
    residual_type: random
    virtual_node: 'sum'
    layer_type: 'pyg:pna-msgpass'
    layer_kwargs:
      # num_heads: 3
      aggregators: [mean, max]
      scalers: [identity, amplification, attenuation]

  graph_output_nn:
    node:
      out_dim: 16
      hidden_dims: 32
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    graph:
      pooling: [sum, max]
      out_dim: 1
      hidden_dims: 32
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    edge:
      out_dim: 16
      hidden_dims: 32
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    nodepair:
      out_dim: 16
      hidden_dims: 32
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none

  task_heads:     # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams
    task_1:
      task_level: "node"
      out_dim: 5
      hidden_dims: [5, 6, 7]
      #depth: none                          # Not needed if we have hidden_dims
      activation: relu
      last_activation: none
      dropout: 0.2
      normalization: none
      residual_type: none
    task_2:
      task_level: "edge"
      out_dim: 3
      hidden_dims: [8, 9, 10]
      activation: relu
      last_activation: none
      dropout: 0.2
      normalization: none
      residual_type: none
    task_3:
      task_level: "graph"
      out_dim: 4
      hidden_dims: [2, 2, 2]
      activation: relu
      last_activation: none
      dropout: 0.2
      normalization: none
      residual_type: none
    task_4:
      task_level: "nodepair"
      out_dim: 4
      hidden_dims: [2, 2, 2]
      activation: relu
      last_activation: none
      dropout: 0.2
      normalization: none
      residual_type: none
accelerator:
  type: cpu