# Telegram Bot
telegram_config_file: telegram_config.yml


# Dataset and Splits
storage_folder: DATA
dataset_class: dataset.WMT14_EN_DE
data_splits_file:  DATA_SPLITS//WMT14_EN_DE/WMT14_EN_DE_outer1_inner1.splits


# Hardware
device:  cuda
max_cpus:  64
max_gpus: 4
gpus_per_task:  1
# gpus_subset: 0,1,2,3,4,5,6,7,8


# Data Loading
dataset_getter: mlwiz.data.provider.DataProvider
data_loader:
  class_name: torch.utils.data.DataLoader  # dataset is not a graph, but we use general PyG data utilities
  args:
    num_workers : 4
    pin_memory: True


# Reproducibility
seed: 42

# Experiment
result_folder: RESULTS/GRID/TRANSFORMER/
exp_name: transformer_awn
experiment: mlwiz.experiment.Experiment
higher_results_are_better: False  # classification loss
evaluate_every: 1
model_selection_training_runs: 1
risk_assessment_training_runs: 1


grid:

  model: model.AWN
  checkpoint: True
  shuffle: True
  batch_size: 32
  epochs: 2

  # Model specific arguments #

  num_hidden_layers:
    - 12  # 6 encoding layers + 6 decoding layers. "Patch" to make Transformer fit easily into AWN

  share_width_distribution:
    - False

  # treat the minibatch ELBO as if we had run an entire pass over the dataset
  # essentially it rescales the minibatch gradient of the classification term
  # therefore avoiding that other terms dominate the loss
  n_observations: 1000000  # approximate training set size

  quantile: 0.9

  dynamic_architecture: dynamic_transformer.DynamicTransformer

  truncated_distribution:
    - class_name: distribution.TruncatedDistribution
      args:
        discretized_distribution:
          - class_name: distribution.DiscretizedDistribution
            args:
              base_distribution:
                - class_name: distribution.Exponential
                  args:  # initial values for the distribution
                    rate:
                      - 0.004  # ~512 neurons for 0.9 quantile

  # scale for p(theta)
  theta_prior_scale:
    - 1.
#      - 3.

  # mean for p(alpha)
  alpha_prior_mean:
    - 0.0001

  # scale for p(alpha)
  alpha_prior_scale:
    - null

  init_type:
    - gaussian

  # classical activation of neurons
  activation: torch.nn.functional.relu

  # apply after the importance renormalization
  activation_outer: null # torch.nn.functional.relu


  embedding_dim:
    - 256
    - 512

  num_attention_heads: 8
  num_enc_dec_layers: 6
  dropout: 0.1

  # ------------------------ #

  # Optimizer
  optimizer:
    - class_name: mlwiz.training.callback.optimizer.Optimizer
      args:
        optimizer_class_name: torch.optim.Adam
        lr: 0.01
        weight_decay: 0.0005
        eps: 0.000000005

  # Scheduler (optional)
  scheduler:
    - class_name: mlwiz.training.callback.scheduler.MetricScheduler
      args:
        use_loss: True
        monitor: training_Multiclass Classification
        scheduler_class_name: torch.optim.lr_scheduler.ReduceLROnPlateau
        factor: 0.9

  # Loss metric (with an example of Additive Loss)
  loss:
    - class_name: metric.MachineTranslationMulticlassClassification
      args:
        accumulate_over_epoch: False  # reduces memory consumption

  # Score metric
  scorer:
    - class_name: mlwiz.training.callback.metric.MultiScore
      args:
        # TODO CHANGE TO BLEU
#        main_scorer: metric.BLEU
        accumulate_over_epoch: False   # reduces memory consumption
        main_scorer: metric.MachineTranslationMulticlassClassification


  # Training engine
  engine: mlwiz.training.engine.TrainingEngine

  # Gradient clipper (optional)
  gradient_clipper: null

  # Early stopper (optional, with an example of "patience" early stopping on the validation score)
  early_stopper:
    - class_name:
        - mlwiz.training.callback.early_stopping.PatienceEarlyStopper
      args:
        patience:
          - 1
        # SYNTAX: (train_,validation_)[name_of_the_scorer_or_loss_to_monitor] -> we can use MAIN_LOSS or MAIN_SCORE
        monitor: validation_main_loss
        mode: min  # is best the `max` or the `min` value we are monitoring?
        checkpoint: True  # store the best checkpoint

  # Plotter of metrics
  plotter: plotter.MiniBatchPlotter
