# @package _global_


######################################################################
# NOTE: do not use this directly, as it is the base config
######################################################################

# Override defaults: take configs from relative path
defaults:
  - override /model: base_cada.yaml
  - override /env: mtvrp.yaml
  - override /callbacks: default.yaml
  - override /trainer: default.yaml
  # - override /logger: null # comment this line to enable logging
  - override /logger: wandb.yaml

seed: 69420

env:
  generator_params:
    num_loc: 50
    variant_preset: "cvrp"
    subsample: True

  val_file: ["${env.generator_params.variant_preset}/val/${env.generator_params.num_loc}.npz",]

  val_dataloader_names: ["${env.generator_params.variant_preset}${env.generator_params.num_loc}",]

  test_file: ["${env.generator_params.variant_preset}/test/${env.generator_params.num_loc}.npz",]

  test_dataloader_names: ${env.val_dataloader_names}


# Logging: we use Wandb in this case
logger:
  wandb:
    project: ""
    tags: []
    group: ""
    name: ""
    entity: ""

# Note that we use 100k per epoch, so total is 1000 epochs instead of 10k
# However, due to resource constraints (and time), we consider training to 100 epochs
# https://github.com/FeiLiu36/MTNCO/blob/c5b3b2b8158a2262cc61238b26041ece1594e7e7/MTPOMO/POMO/train_n100.py#L66
model:
  batch_size: 256
  # note: batch size is a list corresponding to num of datasets
  val_batch_size: 128
  test_batch_size: ${model.val_batch_size}
  train_data_size: 100_000
  # note: data size is a list corresponding to num of datasets
  val_data_size: 64
  test_data_size: 64 # NOTE: unused if provided by env
  optimizer_kwargs:
    lr: 3e-4 # NOTE: we will be using 3e-4 from now on
    weight_decay: 1e-6
  lr_scheduler:
    "MultiStepLR"
  lr_scheduler_kwargs:
    milestones: [270, 295]
    gamma: 0.1
  normalize_reward: "exponential"
  norm_operation: "div"
  alpha: 0.25
  _target_: models.model.CadaModel
  policy:
    _target_: models.policy.CadaPolicy
    normalization: "rms"
    encoder_use_prenorm: False
    encoder_use_post_layers_norm: False
    parallel_gated_kwargs:
      mlp_activation: "silu"
    attn_sparse_ratio: 0.5
    sparse_applied_to_score: true
    prompt_embedding:
      _target_: models.env_embeddings.mtvrp.init.MTVRPPromptEmbedding
      normalization: null
    lora_modules_ckpt_path: null


trainer:
  # max_epochs: 1000 (full run as per the paper)
  max_epochs: 300 #100 # 100 epochs ~ 8hrs on 1x3090, so we allow 24 hrs

# Easier default under logs/ directory
callbacks:
  model_checkpoint:
    dirpath: ${paths.log_dir}/${logger.wandb.name}/${now:%Y-%m-%d}_${now:%H-%M-%S}/checkpoints
    #monitor: "val/reward/cvrp50"