# Telegram Bot
# telegram_config_file: telegram_config.yml

# Dataset and Splits
data_root: DATA
dataset_class: pydgn.data.dataset.TUDatasetInterface
dataset_name:  benzene
data_splits_file:  DATA_SPLITS/CHEMICAL/benzene/benzene_outer1_inner1.splits


# Hardware
device:  cuda
max_cpus:  24
max_gpus: 1
gpus_per_task:  0.25  # this multiplies GPU parallelism by 4, so be sure to provide enough CPUs


# Data Loading
dataset_getter: pydgn.data.provider.DataProvider
data_loader:
  class_name: torch_geometric.loader.DataLoader
  args:
    num_workers : 1
    pin_memory: True


# Reproducibility
seed: 42


# Experiment
result_folder: GSPN_RESULTS/UNSUPERVISED
exp_name: unsupervised_embedding_regression_mlp_weaksup_0001
experiment: unsupervised_embedding_classification.ClassificationTask  # the name is ambiguous, don't look at this
higher_results_are_better: False  # MSE
evaluate_every: 1
final_training_runs: 10

grid:
  weak_supervision_percentage: 0.001

  unsupervised_config:
    model: model.GSPN
    checkpoint: True
    shuffle: True
    batch_size: 1024
    epochs: 100

    # Model specific arguments #

    embeddings_folder: 'UNSUP_GSPN_EMBEDDINGS_BASIC/'

    convolution_class: model.GSPNBaseConv
    emission_class: model.GSPNMultiCategoricalEmission
    num_mixtures:
      - 10
      - 5
      - 20
    num_hidden_neurons: 0  # not used at the moment
    num_layers:
      - 5
      - 10
      - 20

    avg_parameters_across_layers:
      - True
      - False

    # ------------------------ #

    # Optimizer
    optimizer:
      - class_name: pydgn.training.callback.optimizer.Optimizer
        args:
          optimizer_class_name: torch.optim.Adam
          lr: 0.1
          accumulate_gradients: False  # do/do not accumulate gradient across mini-batches

    # Scheduler
    scheduler: null

    # Loss metric
    loss:
      - class_name: metric.GSPNNodeLogLikelihood

    # Score metric
    scorer:
      - class_name: pydgn.training.callback.metric.MultiScore
        args:
          main_scorer:
            - class_name: metric.GSPNNodeLogLikelihood

    # Readout (optional)
    readout: null

    # Training engine
    engine: pydgn.training.engine.TrainingEngine

    # Gradient clipper (optional)
    gradient_clipper: null

    # Early stopper
    early_stopper:
      - class_name:
          - pydgn.training.callback.early_stopping.PatienceEarlyStopper
        args:
          patience:
            - 50
          monitor: validation_main_score
          mode: min
          checkpoint: True

    # Plotter of metrics
    plotter: pydgn.training.callback.plotter.Plotter

  supervised_config:
    model: readout.MLPGraphClassifier_GlobalReadout
    checkpoint: True
    shuffle: True
    batch_size: 1024
    epochs: 1000

    # Model specific arguments #

    embeddings_folder: 'UNSUP_GSPN_EMBEDDINGS_BASIC/'

    global_pooling:
      - sum
      - mean

    hidden_units:
      - 64
      - 8
      - 32
    # ------------------------ #

    # Optimizer
    optimizer:
      -
        class_name: pydgn.training.callback.optimizer.Optimizer
        args:
          optimizer_class_name: torch.optim.Adam
          lr:
            - 0.01
          weight_decay:
            - 0.
            - 0.0001

    # Scheduler
    scheduler: null

    # Loss metric
    loss: pydgn.training.callback.metric.MeanAverageError

    # Score metric
    scorer:
      - class_name: pydgn.training.callback.metric.MultiScore
        args:
          main_scorer: pydgn.training.callback.metric.MeanAverageError



    # Readout (optional)
    readout: null

    # Training engine
    engine: pydgn.training.engine.TrainingEngine

    # Gradient clipper (optional)
    gradient_clipper: null

    # Early stopper
    early_stopper:
      -
        class_name: pydgn.training.callback.early_stopping.PatienceEarlyStopper
        args:
          patience: 500
          monitor: validation_main_loss
          mode: min
          checkpoint: True

    # Plotter of metrics
    plotter: pydgn.training.callback.plotter.Plotter
