from regawa.rl.ppo_gnn import setup, Args
from regawa import GNNParams, ActionMode
from torch import nn
from regawa.rddl import register_env

env_id = register_env()
args = Args(
    env_id=env_id,
    total_timesteps=4000,
    num_steps=40,
    domain="RecSim_ippc2023",
    instance=1,
    agent_config=GNNParams(
        layers=3,
        embedding_dim=8,
        activation=nn.Mish(),
        aggregation="sum",
        action_mode=ActionMode.ACTION_THEN_NODE,
    ),
)
setup(args)
