# Testing the mpnn only model with the PCQMv2 dataset on IPU.
constants:
  name: &name pcqm4mv2_mpnn_16layer
  seed: &seed 42
  raise_train_error: true   # Whether the code should raise an error if it crashes during training

accelerator:
  type: gpu  # cpu or ipu or gpu
  config_override:
    datamodule:
      args:
        batch_size_training: 64
        batch_size_inference: 256
    trainer:
      trainer:
        precision: 32


datamodule:
  module_type: "MultitaskFromSmilesDataModule"
  # module_type: "FakeDataModule"  # Option to use generated data
  args: # Matches that in the test_multitask_datamodule.py case.
    task_specific_args:   # To be replaced by a new class "DatasetParams"
      # homolumo:
      #   df: null
      #   task_level: "graph"
      #   df_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/pcqm4mv2-20k.csv
      #   # wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv
      #   # or set path as https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv directly
      #   smiles_col: "cxsmiles"
      #   label_cols: ["homo_lumo_gap"]
      #   # sample_size: 6000 # use sample_size for test
      #   # splits_path: /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/split_dict_v2.pt  # Download with `wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
      #   # # graphium/data/PCQM4Mv2/split_dict.pt
      #   # # graphium/data/PCQM4Mv2/pcqm4m_split.csv
      #   # split_names: ["train", "valid", "test"]
      #   split_val: 0.1
      #   split_test: 0.1

      pcqm4m_g25:
        df: null
        df_path:  /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/PCQM4M_G25_N4.parquet
        # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
        # or set path as the URL directly
        smiles_col: "ordered_smiles"
        label_cols: graph_*  # graph_* means all columns starting with "graph_"
        # sample_size: 2000 # use sample_size for test
        task_level: "graph"
        splits_path:  /XXXX-3/XXXX-4/datasets/graphium/data/large-dataset/pcqm4m_g25_n4_random_splits.pt  # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
        # split_names: [train, val, test_seen]
        label_normalization:
          normalize_val_test: True
          method: "normal"
        epoch_sampling_fraction: 1.0


    prepare_dict_or_graph: pyg:graph
    featurization_n_jobs: 30
    featurization_progress: True
    featurization_backend: "loky"
    processed_graph_data_path: /XXXX-3/XXXX-4/datasets/graphium/data/featurized_data
    dataloading_from: disk
    num_workers: 4 # -1 to use all
    persistent_workers: False # if use persistent worker at the start of each epoch.
    # Using persistent_workers false might make the start of each epoch very long.
    featurization:
    # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
    # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
    # 'num_chiral_centers (not included yet)']
      atom_property_list_onehot: [atomic-number, group, period, total-valence]
      atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
      # OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
      edge_property_list: [bond-type-onehot, stereo, in-ring]
      add_self_loop: False
      explicit_H: False # if H is included
      use_bonds_weights: False
      pos_encoding_as_features: # encoder dropout 0.18
        pos_types:
          lap_eigvec:
            pos_level: node
            pos_type: laplacian_eigvec
            num_pos: 8
            normalization: "none" # nomrlization already applied on the eigen vectors
            disconnected_comp: True # if eigen values/vector for disconnected graph are included
          lap_eigval:
            pos_level: node
            pos_type: laplacian_eigval
            num_pos: 8
            normalization: "none" # nomrlization already applied on the eigen vectors
            disconnected_comp: True # if eigen values/vector for disconnected graph are included
          rw_pos: # use same name as pe_encoder
            pos_level: node
            pos_type: rw_return_probs
            ksteps: 16

architecture:
  model_type: FullGraphMultiTaskNetwork
  mup_base_path: null
  pre_nn:   # Set as null to avoid a pre-nn network
    out_dim: 128
    hidden_dims: 256
    depth: 2
    activation: relu
    last_activation: none
    dropout: &dropout 0.1
    normalization: &normalization layer_norm
    last_normalization: *normalization
    residual_type: none

  pre_nn_edges:   # Set as null to avoid a pre-nn network
    out_dim: 64
    hidden_dims: 128
    depth: 2
    activation: relu
    last_activation: none
    dropout: *dropout
    normalization: *normalization
    last_normalization: *normalization
    residual_type: none

  pe_encoders:
    out_dim: 32
    pool: "sum" #"mean" "max"
    last_norm: None #"batch_norm", "layer_norm"
    encoders: #la_pos |  rw_pos
      la_pos:  # Set as null to avoid a pre-nn network
        encoder_type: "laplacian_pe"
        input_keys: ["laplacian_eigvec", "laplacian_eigval"]
        output_keys: ["feat"]
        hidden_dim: 64
        out_dim: 32
        model_type: 'DeepSet' #'Transformer' or 'DeepSet'
        num_layers: 2
        num_layers_post: 1 # Num. layers to apply after pooling
        dropout: 0.1
        first_normalization: "none" #"batch_norm" or "layer_norm"
      rw_pos:
        encoder_type: "mlp"
        input_keys: ["rw_return_probs"]
        output_keys: ["feat"]
        hidden_dim: 64
        out_dim: 32
        num_layers: 2
        dropout: 0.1
        normalization: "layer_norm" #"batch_norm" or "layer_norm"
        first_normalization: "layer_norm" #"batch_norm" or "layer_norm"


  gnn:  # Set as null to avoid a post-nn network
    out_dim: 512
    hidden_dims: 128
    depth: 16
    activation: gelu
    last_activation: none
    dropout: 0.0
    normalization: "layer_norm"
    last_normalization: *normalization
    residual_type: simple
    virtual_node: 'none'
    layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
    layer_kwargs:  # Parameters for the model itself. You could define dropout_attn: 0.1
      node_residual: false
      mpnn_type: 'pyg:mpnnplus'
      mpnn_kwargs:
        in_dim: 128
        out_dim: 128
        in_dim_edges: 64
        out_dim_edges: 64
      attn_type: "none" # "full-attention", "none"
      # biased_attention: false
      attn_kwargs: null

  graph_output_nn:
    graph:
      pooling: [sum]
      out_dim: 384
      hidden_dims: 512
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
  post_nn: null

  task_heads:
    pcqm4m_g25:
      task_level: graph
      out_dim: 25
      hidden_dims: 256
      depth: 2
      activation: relu
      last_activation: none
      dropout: *dropout
      normalization: *normalization
      last_normalization: "none"
      residual_type: none
    # homolumo:
    #   task_level: graph
    #   out_dim: 25
    #   hidden_dims: 256
    #   depth: 2                          # Not needed if we have hidden_dims
    #   activation: relu
    #   last_activation: none
    #   dropout: *dropout
    #   normalization: *normalization
    #   last_normalization: "none"
    #   residual_type: none

#Task-specific
predictor:
  # metrics_on_progress_bar:
  #   homolumo: ["mae", "pearsonr"]
  # loss_fun:
  #   homolumo: mse_ipu
  metrics_on_progress_bar:
    pcqm4m_g25: []
  metrics_on_training_set:
    pcqm4m_g25: []
  loss_fun:
    pcqm4m_g25: mae_ipu
  random_seed: *seed
  optim_kwargs:
    lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
    # weight_decay: 1.e-7
  torch_scheduler_kwargs:
    module_type: WarmUpLinearLR
    max_num_epochs: &max_epochs 40
    warmup_epochs: 10
    verbose: False
  scheduler_kwargs:
  #  monitor: &monitor homolumo/mae/train
  #  mode: min
  #  frequency: 1
  target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
  flag_kwargs:
    n_steps: 0 # 1
    alpha: 0.0 # 0.01

# Task-specific
metrics:
  # homolumo:
  #   - name: mae
  #     metric: mae_ipu
  #     target_nan_mask: null
  #     multitask_handling: flatten
  #     threshold_kwargs: null
  #   - name: pearsonr
  #     metric: pearsonr_ipu
  #     threshold_kwargs: null
  #     target_nan_mask: null
  #     multitask_handling: mean-per-label
  pcqm4m_g25:
    - name: mae
      metric: mae_ipu
      target_nan_mask: null
      multitask_handling: mean-per-label
      threshold_kwargs: null
    - name: pearsonr
      metric: pearsonr_ipu
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label
    - name: r2
      metric: r2_score_ipu
      threshold_kwargs: null
      target_nan_mask: null
      multitask_handling: mean-per-label

trainer:
  logger:
    save_dir: /XXXX-3/XXXX-4/XXXX-2/cell_painting/results/mpnn/logs/
    name: *name
    project: PCQMv2_mpnn
  #early_stopping:
  #  monitor: *monitor
  #  min_delta: 0
  #  patience: 10
  #  mode: &mode min
  model_checkpoint:
    dirpath: /XXXX-3/XXXX-4/XXXX-2/cell_painting/results/mpnn/models/
    filename: *name
    #monitor: *monitor
    #mode: *mode
    save_top_k: -1
    every_n_epochs: 1
    save_last: True
  trainer:
    precision: 32
    max_epochs: *max_epochs
    min_epochs: 1
    accumulate_grad_batches: 2
    check_val_every_n_epoch: 1
