# @package _global_

constants:
  seed: 42
  raise_train_error: true   # Whether the code should raise an error if it crashes during training
  entity: multitask-gnn
  datacache_path: "/XXXX-3/XXXX-4/XXXX-2/.cache/graphium/data"
  name: "test"

accelerator:
  type: gpu  # cpu or ipu or gpu
  config_override:
    datamodule:
      args:
        batch_size_training: 64
        batch_size_inference: 256
    trainer:
      trainer:
        precision: 32
        accumulate_grad_batches: 1

datamodule:
  module_type: "MultitaskFromSmilesDataModule"
  # module_type: "FakeDataModule"  # Option to use generated data
  args: # Matches that in the test_multitask_datamodule.py case.
    task_specific_args:   # To be replaced by a new class "DatasetParams"
      l1000_vcap:
        df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv
        # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
        # or set path as the URL directly
        smiles_col: "SMILES"
        label_cols: geneID-*  # geneID-* means all columns starting with "geneID-"
        # sample_size: 2000 # use sample_size for test
        task_level: graph
        splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/l1000_vcap_random_splits.pt  # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
        epoch_sampling_fraction: 1.0

      l1000_mcf7:
        df: null
        df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv
        # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
        # or set path as the URL directly
        smiles_col: "SMILES"
        label_cols: geneID-*  # geneID-* means all columns starting with "geneID-"
        # sample_size: 2000 # use sample_size for test
        task_level: graph
        splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/l1000_mcf7_random_splits.pt  # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
        epoch_sampling_fraction: 1.0

      pcba_1328:
        df: null
        df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/PCBA_1328_1564k.parquet
        # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
        # or set path as the URL directly
        smiles_col: "SMILES"
        label_cols: assayID-*  # assayID-* means all columns starting with "assayID-"
        sample_size: 300000 # use sample_size for test
        task_level: graph
        splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/pcba_1328_random_splits.pt  # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
        epoch_sampling_fraction: 1.0

      homolumo:
        df: null
        task_level: "graph"
        df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/pcqm4mv2-20k.csv
        # wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv
        # or set path as https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv directly
        smiles_col: "cxsmiles"
        label_cols: ["homo_lumo_gap"]
        # sample_size: 6000 # use sample_size for test
        splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/split_dict_v2.pt  # Download with `wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
        # graphium/data/PCQM4Mv2/split_dict.pt
        # graphium/data/PCQM4Mv2/pcqm4m_split.csv
        split_names: ["train", "valid", "test-dev"]

      # pcqm4m_g25:
      #   df: null
      #   df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/PCQM4M_G25_N4.parquet
      #   # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
      #   # or set path as the URL directly
      #   smiles_col: "ordered_smiles"
      #   label_cols: graph_*  # graph_* means all columns starting with "graph_"
      #   # sample_size: 10000 # use sample_size for test
      #   task_level: graph
      #   splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/pcqm4m_g25_n4_random_splits.pt  # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
      #   label_normalization:
      #     normalize_val_test: True
      #     method: "normal"
      #   epoch_sampling_fraction: 1.0

      # pcqm4m_n4:
      #   df: null
      #   df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/PCQM4M_G25_N4.parquet
      #   # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
      #   # or set path as the URL directly
      #   smiles_col: "ordered_smiles"
      #   label_cols: node_* # node_* means all columns starting with "node_"
      #   # sample_size: 2000 # use sample_size for test
      #   task_level: node
      #   splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/pcqm4m_g25_n4_random_splits.pt  # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
      #   seed: 42
      #   label_normalization:
      #     normalize_val_test: True
      #     method: "normal"
      #   epoch_sampling_fraction: 1.0

    # Featurization
    prepare_dict_or_graph: pyg:graph
    featurization_n_jobs: 48
    featurization_progress: True
    featurization_backend: "loky"
    dataloading_from: disk
    processed_graph_data_path: "/XXXX-3/XXXX-4/XXXX-2/.cache/graphium/data"
    featurization:
    # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
    # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
    # 'num_chiral_centers (not included yet)']
      atom_property_list_onehot: [atomic-number, group, period, total-valence]
      atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
      # OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
      edge_property_list: [bond-type-onehot, stereo, in-ring]
      add_self_loop: False
      explicit_H: False # if H is included
      use_bonds_weights: False
      pos_encoding_as_features: # encoder dropout 0.18
        pos_types:
          lap_eigvec:
            pos_level: node
            pos_type: laplacian_eigvec
            num_pos: 8
            normalization: "none" # nomrlization already applied on the eigen vectors
            disconnected_comp: True # if eigen values/vector for disconnected graph are included
          lap_eigval:
            pos_level: node
            pos_type: laplacian_eigval
            num_pos: 8
            normalization: "none" # nomrlization already applied on the eigen vectors
            disconnected_comp: True # if eigen values/vector for disconnected graph are included
          rw_pos: # use same name as pe_encoder
            pos_level: node
            pos_type: rw_return_probs
            ksteps: 16

    num_workers: 48 # -1 to use all
    persistent_workers: True # if use persistent worker at the start of each epoch.
    # Using persistent_workers false might make the start of each epoch very long.


architecture:
  model_type: FullGraphMultiTaskNetwork
  mup_base_path: null
  pre_nn:   # Set as null to avoid a pre-nn network
    out_dim: 256
    hidden_dims: 1024
    depth: 2
    activation: relu
    last_activation: none
    dropout: &dropout 0.18
    normalization: &normalization layer_norm
    last_normalization: *normalization
    residual_type: none

  pre_nn_edges:   # Set as null to avoid a pre-nn network
    out_dim: 128
    hidden_dims: 512
    depth: 2
    activation: relu
    last_activation: none
    dropout: *dropout
    normalization: *normalization
    last_normalization: *normalization
    residual_type: none

  pe_encoders:
    out_dim: 32
    pool: "sum" #"mean" "max"
    last_norm: None #"batch_norm", "layer_norm"
    encoders: #la_pos |  rw_pos
      la_pos:  # Set as null to avoid a pre-nn network
        encoder_type: "laplacian_pe"
        input_keys: ["laplacian_eigvec", "laplacian_eigval"]
        output_keys: ["feat"]
        hidden_dim: 64
        out_dim: 32
        model_type: 'DeepSet' #'Transformer' or 'DeepSet'
        num_layers: 2
        num_layers_post: 1 # Num. layers to apply after pooling
        dropout: 0.1
        first_normalization: "none" #"batch_norm" or "layer_norm"
      rw_pos:
        encoder_type: "mlp"
        input_keys: ["rw_return_probs"]
        output_keys: ["feat"]
        hidden_dim: 64
        out_dim: 32
        num_layers: 2
        dropout: 0.1
        normalization: "layer_norm" #"batch_norm" or "layer_norm"
        first_normalization: "layer_norm" #"batch_norm" or "layer_norm"



  gnn:  # Set as null to avoid a post-nn network
    out_dim: 256
    hidden_dims: 256
    depth: 4
    activation: gelu
    last_activation: none
    dropout: 0.1
    normalization: "layer_norm"
    last_normalization: *normalization
    residual_type: simple
    virtual_node: 'none'
    layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
    layer_kwargs:  # Parameters for the model itself. You could define dropout_attn: 0.1
      node_residual: false
      mpnn_type: 'pyg:mpnnplus'
      mpnn_kwargs:
        in_dim: 256
        out_dim: 256
        in_dim_edges: 128
        out_dim_edges: 128
      attn_type: "none" # "full-attention", "none"
      # biased_attention: false
      attn_kwargs: null

  graph_output_nn:
    graph:
      pooling: [sum]
      out_dim: 256
      hidden_dims: 256
      depth: 1
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none

  task_heads:
    l1000_vcap:
      task_level: graph
      out_dim: 2934
      hidden_dims: 128
      depth: 2
      activation: none
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    l1000_mcf7:
      task_level: graph
      out_dim: 2934
      hidden_dims: 128
      depth: 2
      activation: none
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    pcba_1328:
      task_level: graph
      out_dim: 1328
      hidden_dims: 64
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    homolumo:
      task_level: graph
      out_dim: 1
      hidden_dims: 256
      depth: 2                          # Not needed if we have hidden_dims
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    # pcqm4m_g25:
    #   task_level: graph
    #   out_dim: 25
    #   hidden_dims: 32
    #   depth: 2
    #   activation: relu
    #   last_activation: none
    #   dropout: *dropout
    #   normalization: *normalization
    #   last_normalization: "none"
    #   residual_type: none
    # pcqm4m_n4:
    #   task_level: node
    #   out_dim: 4
    #   hidden_dims: 32
    #   depth: 2
    #   activation: relu
    #   last_activation: none
    #   dropout: *dropout
    #   normalization: *normalization
    #   last_normalization: "none"
    #   residual_type: none

#Task-specific
predictor:
  metrics_on_progress_bar:
    l1000_vcap: []
    l1000_mcf7: []
    pcba_1328: []
    # pcqm4m_g25: []
    # pcqm4m_n4: []
    homolumo: []
  metrics_on_training_set:
    l1000_vcap: []
    l1000_mcf7: []
    pcba_1328: []
    # pcqm4m_g25: []
    # pcqm4m_n4: []
    homolumo: ["pearsonr"]
  loss_fun:
    l1000_vcap:
      name: hybrid_ce_ipu
      n_brackets: 3
      alpha: 0.5
    l1000_mcf7:
      name: hybrid_ce_ipu
      n_brackets: 3
      alpha: 0.5
    pcba_1328: bce_logits_ipu
    homolumo: mse

    # pcqm4m_g25: mae
    # pcqm4m_n4: mae_ipu
  random_seed: 42
  optim_kwargs:
    lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
    # weight_decay: 1.e-7
  torch_scheduler_kwargs:
    module_type: WarmUpLinearLR
    max_num_epochs: &max_epochs 100
    warmup_epochs: 10
    verbose: False
  scheduler_kwargs:
  #  monitor: &monitor qm9/mae/train
  #  mode: min
  #  frequency: 1
  target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
  multitask_handling: flatten # flatten, mean-per-label

# Task-specific
metrics:
  l1000_vcap: &classif_metrics
    - name: auroc
      metric: auroc
      num_classes: 3
      task: multiclass
      target_to_int: True
      target_nan_mask: -1000
      ignore_index: -1000
      multitask_handling: mean-per-label
      threshold_kwargs: null
    - name: avpr
      metric: averageprecision
      num_classes: 3
      task: multiclass
      target_to_int: True
      target_nan_mask: -1000
      ignore_index: -1000
      multitask_handling: mean-per-label
      threshold_kwargs: null
  l1000_mcf7: *classif_metrics
  pcba_1328:
  # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
    - name: auroc
      metric: auroc
      task: binary
      multitask_handling: mean-per-label
      target_nan_mask: ignore
      threshold_kwargs: null
    - name: avpr
      metric: averageprecision
      task: binary
      multitask_handling: mean-per-label
      target_nan_mask: ignore
      threshold_kwargs: null
  # pcqm4m_g25: &pcqm_metrics
  #   - name: mae
  #     metric: mae_ipu
  #     target_nan_mask: null
  #     multitask_handling: mean-per-label
  #     threshold_kwargs: null
  #   - name: pearsonr
  #     metric: pearsonr_ipu
  #     threshold_kwargs: null
  #     target_nan_mask: null
  #     multitask_handling: mean-per-label
  #   - name: r2
  #     metric: r2_score_ipu
  #     threshold_kwargs: null
  #     target_nan_mask: null
  #     multitask_handling: mean-per-label
  # pcqm4m_n4: *pcqm_metrics
  homolumo:
    - name: mae
      metric: mae_ipu
      target_nan_mask: null
      multitask_handling: flatten
      threshold_kwargs: null
    - name: pearsonr
      metric: pearsonr_ipu
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label

trainer:
  seed: 42
  logger:
    save_dir: /XXXX-3/XXXX-4/XXXX-2/cell_painting/results/largemix/logs/
    name: test
    project: test
  model_checkpoint:
    dirpath: /XXXX-3/XXXX-4/XXXX-2/cell_painting/results/largemix/models/test/
    filename: test
    # monitor: *monitor
    # mode: *mode
    # save_top_k: 1
    save_last: True
  trainer:
    max_epochs: 50
    min_epochs: 1
    check_val_every_n_epoch: 20
