# Dataset and Splits
data_root: DATA
dataset_class: pydgn.data.dataset.PlanetoidDatasetInterface
dataset_name: cora
data_splits_file: DATA_SPLITS/cora/cora_outer1_inner1.splits


# Hardware
device:  cuda
max_cpus:  8
max_gpus: 1
gpus_per_task:  1


# Data Loading
dataset_getter: provider.SingleGraphNodeDataProvider
data_loader:
  class_name: torch_geometric.loader.DataLoader
  args:
    num_workers : 1
    pin_memory: True


# Reproducibility
seed: 42


# Experiment
result_folder: GSPN_RESULTS/UNSUPERVISED
exp_name: unsupervised_embedding_generation_bernoulli
experiment: pydgn.experiment.supervised_task.SupervisedTask
higher_results_are_better: True
evaluate_every: 1
final_training_runs: 5

grid:
  supervised_config:
    model: model_mask.GSPN
    checkpoint: True
    shuffle: True
    batch_size: 32  # does not matter with single graph
    epochs: 1000

    # Model specific arguments #

    convolution_class: model_mask.GSPNBaseLogConv
    emission_class: model_mask.GSPNBernoulliEmission
    weight_features:
      - False
      #- True
    avg_parameters_across_layers:
      - False
      - True
    num_mixtures:
      - 4
      - 8
    num_hidden_neurons: 0
    num_layers:
      - 2
      - 3
      - 10

    # ------------------------ #

    # Optimizer
    optimizer:
      - class_name: pydgn.training.callback.optimizer.Optimizer
        args:
          optimizer_class_name: torch.optim.Adam
          lr:
            - 0.1
          accumulate_gradients: False  # do/do not accumulate gradient across mini-batches

    # Scheduler
    scheduler: null

    # Loss metric
    loss:
      - class_name: metric.GSPNNodeLogLikelihood  # Do not consider masked nodes in the optimization
        args:
          use_nodes_batch_size: True

    # Score metric
    scorer:
      - class_name: pydgn.training.callback.metric.MultiScore
        args:
          main_scorer:
            - class_name: metric.GSPNNodeLogLikelihood
              args:
                use_nodes_batch_size: True
          masked_micro_f1:
            - class_name: metric.MaskedNodeMicroF1
              args:
                use_nodes_batch_size: True
          masked_macro_f1:
            - class_name: metric.MaskedNodeMacroF1
              args:
                use_nodes_batch_size: True
          masked_mse:
            - class_name: metric.MaskedNodeMSE
              args:
                use_nodes_batch_size: True
          masked_top10recall:
            - class_name: metric.MaskedTopKNodeRecall
              args:
                use_nodes_batch_size: True
                topK: 10
          masked_top10NGDC:
            - class_name: metric.MaskedTopKNodeNGDC
              args:
                use_nodes_batch_size: True
                topK: 10
          masked_binary_ece:
            - class_name: metric.MaskedBinaryECE
              args:
                use_nodes_batch_size: True
          masked_binary_mce:
            - class_name: metric.MaskedBinaryMCE
              args:
                use_nodes_batch_size: True
    # Readout (optional)
    readout: null

    # Training engine
    engine: pydgn.training.engine.TrainingEngine

    # Gradient clipper (optional)
    gradient_clipper: null

    # Early stopper
    early_stopper:
      - class_name: pydgn.training.callback.early_stopping.PatienceEarlyStopper
        args:
          patience: 150
          monitor: validation_main_loss
          mode: min
          checkpoint: True

    # Plotter of metrics
    plotter: pydgn.training.callback.plotter.Plotter
