# @package _global_

defaults:
  - override /model: l2d.yaml
  - override /callbacks: default.yaml
  - override /trainer: default.yaml
  - override /logger: wandb.yaml

env:
  _target_: rl4co.envs.TSPEnv4PPO
  generator_params:
    num_loc: 20

logger:
  wandb:
    project: "rl4co"
    tags: ["am-stepwise-ppo", "${env.name}"]
    group: ${env.name}${env.generator_params.num_loc}
    name: ppo-${env.name}${env.generator_params.num_loc}

trainer:
  max_epochs: 10
  precision: 32-true

embed_dim: 256
num_heads: 8
model:
  _target_: rl4co.models.StepwisePPO
  policy:
    _target_: rl4co.models.L2DPolicy4PPO
    decoder:
      _target_: rl4co.models.zoo.l2d.decoder.L2DDecoder
      env_name: ${env.name}
      embed_dim: ${embed_dim}
      feature_extractor:
        _target_: rl4co.models.zoo.am.encoder.AttentionModelEncoder
        embed_dim: ${embed_dim}
        num_heads: ${num_heads}
        num_layers: 4
        normalization: "batch"
        env_name: "tsp"
      actor:
        _target_: rl4co.models.zoo.l2d.decoder.AttnActor
        embed_dim: ${embed_dim}
        num_heads: ${num_heads}
        env_name: ${env.name}
    embed_dim: ${embed_dim}
    env_name: ${env.name}
    het_emb: False
  batch_size: 512
  mini_batch_size: 512
  train_data_size: 20000
  val_data_size: 1_000
  test_data_size: 1_000
  reward_scale: scale
  optimizer_kwargs:
    lr: 1e-4