# @package _global_

# Override defaults by taking another config
defaults:
  - rf_base_100.yaml
  - _self_

env:
  generator_params:
    variant_preset: "all"

  val_file: [cvrp/val/100.npz, vrptw/val/100.npz, ovrp/val/100.npz, vrpl/val/100.npz,
             vrpb/val/100.npz, ovrptw/val/100.npz, vrpbl/val/100.npz, vrpbltw/val/100.npz,
             vrpbtw/val/100.npz, vrpltw/val/100.npz, ovrpb/val/100.npz, ovrpbl/val/100.npz,
             ovrpbltw/val/100.npz, ovrpbtw/val/100.npz, ovrpl/val/100.npz, ovrpltw/val/100.npz,
             cvrp/val/50.npz, vrptw/val/50.npz, # generalization
  ]

  val_dataloader_names: [cvrp100, vrptw100, ovrp100, vrpl100,
                         vrpb100, ovrptw100, vrpbl100, vrpbltw100,
                         vrpbtw100, vrpltw100, ovrpb100, ovrpbl100,
                         ovrpbltw100, ovrpbtw100, ovrpl100, ovrpltw100,
                         cvrp50, vrptw50,]

  test_file: [cvrp/test/100.npz, vrptw/test/100.npz, ovrp/test/100.npz, vrpl/test/100.npz,
              vrpb/test/100.npz, ovrptw/test/100.npz, vrpbl/test/100.npz, vrpbltw/test/100.npz,
              vrpbtw/test/100.npz, vrpltw/test/100.npz, ovrpb/test/100.npz, ovrpbl/test/100.npz,
              ovrpbltw/test/100.npz, ovrpbtw/test/100.npz, ovrpl/test/100.npz, ovrpltw/test/100.npz,
              cvrp/test/50.npz, vrptw/test/50.npz, # generalization
  ]


  test_dataloader_names: ${env.val_dataloader_names}

model:
  _target_: models.model.CadaModel
  policy:
    _target_: models.policy.CadaPolicy
    normalization: "rms"
    encoder_use_prenorm: false
    encoder_use_post_layers_norm: false
    parallel_gated_kwargs:
      mlp_activation: "silu"
    attn_sparse_ratio: 0.5
    sparse_applied_to_score: true
    prompt_embedding:
      _target_: models.env_embeddings.mtvrp.init.MTVRPPromptEmbedding
      normalization: null
    lora_modules_ckpt_path: null


# Logging: we use Wandb in this case
logger:
  wandb:
    project: ""
    tags: []
    group: ""
    name: ""
    entity: ""


callbacks:
  model_checkpoint:
    monitor: "val/reward/vrptw100"