from malsim_groundings import register_env
from regawa.rl.ppo_gnn import setup, Args
from regawa import GNNParams, ActionMode
from torch import nn


# scenario = "/storage/GitHub/malsim-jakob-groundings/tests/data/scenario4-demo2/demo2_scenario_future_shutdown_defender.yml"
scenario = "/storage/GitHub/malsim-jakob-groundings/tests/data/scenario1/scenario.yml"

env_id = register_env()
args = Args(
    env_id=env_id,
    total_timesteps=10000,
    num_steps=40,
    domain="",
    instance=scenario,
    update_epochs=10,
    agent_config=GNNParams(
        layers=3,
        embedding_dim=8,
        activation=nn.Mish(),
        aggregation="sum",
        action_mode=ActionMode.ACTION_THEN_NODE,
    ),
)
setup(args)
