# @package _global_

# example hyperparameter optimization of some experiment with Optuna:
# python run.py -m hparams_search=mnist_optuna experiment=example_simple
# python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30
# python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb

defaults:
  - tpe.yaml  # random, tpe
  - override /callbacks: checkpoint.yaml # If set no checkpoint, it cannot use the best model path

# choose metric which will be optimized by Optuna
optimized_metric: "valid/micro_f1"  # valid/acc

# If set no, it will consume all your filesystem
remove_best_model_ckpt: True

# The number of runs to get average of metrics
num_averaging: 10

# Set optuna specific logger dirs
logger:
  tensorboard:
    save_dir: "${project_dir}/logs_multi_tensorboard/"
  csv:
    save_dir: "${project_dir}/logs_multi_csv/"

hydra:
  # here we define Optuna hyperparameter search
  # it optimizes for value returned from function with @hydra.main decorator
  # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
  sweeper:
    study_name: ${experiment_name}

    # 'minimize' or 'maximize' the objective
    direction: maximize

    # number of experiments that will be executed
    n_trials: 100

    # define range of hyperparameters
    # https://hydra.cc/docs/plugins/optuna_sweeper/
    # https://github.com/facebookresearch/hydra/blob/02495f9c781615e2fe7ae5588e16f596fcec010c/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py#L49
    search_space:
      datamodule.post_edge_normalize_arg_1:
        type: float
        low: 1.0
        high: 4.0
        step: 0.25
      datamodule.post_edge_normalize_arg_2:
        type: float
        low: 0.5
        high: 2.0
        step: 0.5
      model.learning_rate:
        type: float
        log: True
        low: 5e-4
        high: 1e-2
      model._gradient_clip_val:
        type: float
        low: 0.0
        high: 0.5
        step: 0.1
      model.weight_decay:
        type: float
        log: True
        low: 1e-9
        high: 1e-6
      model.dropout_channels:
        type: float
        low: 0.0
        high: 0.5
        step: 0.1
      model.dropout_edges:
        type: float
        low: 0.0
        high: 0.5
        step: 0.1
      model.num_layers:
        type: int
        low: 1
        high: 2

      model.layer_kwargs.eps:
        type: float
        low: 0.1
        high: 0.5
        step: 0.1

      model.sub_node_num_layers:
        type: int
        low: 1
        high: 2
      model.sub_node_encoder_layer_kwargs.eps:
        type: float
        low: 0.1
        high: 0.5
        step: 0.1