# @package _global_


######################################################################
# NOTE: do not use this directly, as it is the base config
######################################################################

# Override defaults: take configs from relative path
defaults:
  - override /model: base_rf.yaml
  - override /env: mtvrp.yaml
  - override /callbacks: default.yaml
  - override /trainer: default.yaml
  # - override /logger: null # comment this line to enable logging
  - override /logger: wandb.yaml

seed: 69420

env:
  generator_params:
    num_loc: 100
    variant_preset: "cvrp" # NOTE: original is "single_feat", but we use all in our setting

  val_file: [ "${env.generator_params.variant_preset}/val/${env.generator_params.num_loc}.npz", ]

  val_dataloader_names: [ "${env.generator_params.variant_preset}${env.generator_params.num_loc}", ]

  test_file: [ "${env.generator_params.variant_preset}/test/${env.generator_params.num_loc}.npz", ]

  test_dataloader_names: ${env.val_dataloader_names}


# Logging: we use Wandb in this case
logger:
  wandb:
    project: ""
    tags: []
    group: ""
    name: ""
    entity: ""

# Note that we use 100k per epoch, so total is 1000 epochs instead of 10k
# However, due to resource constraints (and time), we consider training to 100 epochs
# https://github.com/FeiLiu36/MTNCO/blob/c5b3b2b8158a2262cc61238b26041ece1594e7e7/MTPOMO/POMO/train_n100.py#L66
model:
  batch_size: 256
  # note: batch size is a list corresponding to num of datasets
  val_batch_size: 128
  test_batch_size: ${model.val_batch_size}
  train_data_size: 100_000
  # note: data size is a list corresponding to num of datasets
  val_data_size: 64
  test_data_size: 64 # NOTE: unused if provided by env
  optimizer_kwargs:
    lr: 3e-4 # NOTE: we will be using 3e-4 from now on
    weight_decay: 1e-6
  lr_scheduler:
    "MultiStepLR"
  lr_scheduler_kwargs:
    milestones: [270, 295]
    gamma: 0.1
  normalize_reward: "exponential"
  norm_operation: "div"
  alpha: 0.25
  _target_: models.model.RouteFinderBase
  policy:
    _target_: models.policy.RouteFinderPolicy
    normalization: "rms"
    encoder_use_prenorm: True
    encoder_use_post_layers_norm: True
    parallel_gated_kwargs:
      mlp_activation: "silu"
    lora_modules_ckpt_path: null


trainer:
  # max_epochs: 1000 (full run as per the paper)
  max_epochs: 300 #100 # 100 epochs ~ 8hrs on 1x3090, so we allow 24 hrs

# Easier default under logs/ directory
callbacks:
  model_checkpoint:
    dirpath: ${paths.log_dir}/${logger.wandb.name}/${now:%Y-%m-%d}_${now:%H-%M-%S}/checkpoints
    #monitor: "val/reward/cvrp100"