# @package _global_

defaults:
  - scheduling/base

logger:
  wandb:
    tags: ["am-ppo", "${env.name}"]
    name: "am-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m"

embed_dim: 256
num_heads: 8

model:
  _target_: rl4co.models.StepwisePPO
  policy:
    _target_: rl4co.models.L2DPolicy4PPO
    decoder:
      _target_: rl4co.models.zoo.l2d.decoder.L2DDecoder
      env_name: ${env.name}
      embed_dim: ${embed_dim}
      feature_extractor:
        _target_: rl4co.models.zoo.matnet.matnet_w_sa.Encoder
        embed_dim: ${embed_dim}
        num_heads: ${num_heads}
        num_layers: 4
        normalization: "batch"
        init_embedding:
          _target_: rl4co.models.nn.env_embeddings.init.FJSPMatNetInitEmbedding
          embed_dim: ${embed_dim}
          scaling_factor: ${scaling_factor}
      actor:
        _target_: rl4co.models.zoo.l2d.decoder.L2DAttnActor
        embed_dim: ${embed_dim}
        num_heads: ${num_heads}
        env_name: ${env.name}
        scaling_factor:  ${scaling_factor}
        stepwise: True
    env_name: ${env.name}
    embed_dim: ${embed_dim}
    scaling_factor: ${scaling_factor}
    het_emb: True
  batch_size: 128
  val_batch_size: 512
  test_batch_size: 64
  train_data_size: 2000
  mini_batch_size: 512

env:
  stepwise_reward: True