# Telegram Bot
# telegram_config_file: telegram_config.yml


# Dataset and Splits
storage_folder: DATA
dataset_class: dataset.REDDIT_BINARY
data_splits_file:  DATA_SPLITS/REDDIT_BINARY/REDDIT_BINARY_outer10_inner1.splits


# Hardware
device:  cuda
max_cpus:  64
max_gpus: 3
gpus_per_task:  0.5
gpus_subset: 1,2


# Data Loading
dataset_getter: mlwiz.data.provider.DataProvider
data_loader:
  class_name: torch_geometric.loader.DataLoader
  args:
    num_workers : 1
    pin_memory: True


# Reproducibility
seed: 42

# Experiment
result_folder: RESULTS/GRID/DGN/
exp_name: dgn_fixed
experiment: mlwiz.experiment.Experiment
higher_results_are_better: True  # classification
evaluate_every: 1
model_selection_training_runs: 1
risk_assessment_training_runs: 10


grid:
  
  model: model.DGN
  checkpoint: True
  shuffle: True
  batch_size:
    - 32
    - 128
  epochs: 1000

  # Model specific arguments #

  num_hidden_layers:  # PATCH: this will actually mean MLP + Graph Conv Layers, assuming 1-hidden layer MLPs for each conv layer
    - 2
    - 3
    - 5

  global_pooling:
    - mean

  num_hidden_neurons:
    - 32
    - 64

  # ------------------------ #

  # Optimizer
  optimizer:
    - class_name: mlwiz.training.callback.optimizer.Optimizer
      args:
        optimizer_class_name: torch.optim.Adam
        lr: 0.01

  # Scheduler (optional)
  scheduler: null

  # Loss metric
  loss:  mlwiz.training.callback.metric.MulticlassClassification


  # Score metric (with an example of Multi Score)
  scorer:
    - class_name: mlwiz.training.callback.metric.MultiScore
      args:
        main_scorer: mlwiz.training.callback.metric.MulticlassAccuracy


  # Training engine
  engine:
    - class_name: mlwiz.training.engine.TrainingEngine
      args:
        eval_training : True


  # Gradient clipper (optional)
  gradient_clipper: null

  # Early stopper (optional, with an example of "patience" early stopping on the validation score)
  early_stopper:
    - class_name:
        - mlwiz.training.callback.early_stopping.PatienceEarlyStopper
      args:
        patience:
          - 500
        # SYNTAX: (train_,validation_)[name_of_the_scorer_or_loss_to_monitor] -> we can use MAIN_LOSS or MAIN_SCORE
        monitor: validation_main_score
        mode: max  # is best the `max` or the `min` value we are monitoring?
        checkpoint: True  # store the best checkpoint

  # Plotter of metrics
  plotter:
    - class_name: mlwiz.training.callback.plotter.Plotter
      args:
        store_on_disk: True  # store evolution of metrics over time
