from regawa.rl.ppo_gnn import setup, Args
from regawa import GNNParams, ActionMode
from torch import nn
from regawa.rddl import register_env

env_id = register_env()
args = Args(
    anneal_lr=False,  # for consistency with sb3 defaults
    env_id=env_id,
    debug=True,
    total_timesteps=500000,
    domain="Elevators_MDP_ippc2011",
    remove_false=True,
    instance=1,
    num_steps=2048,
    num_minibatches=2,
    update_epochs=3,
    vf_coef=1.0,
    gamma=0.99,
    gae_lambda=0.95,
    ent_coef=0.01,
    max_grad_norm=0.5,
    norm_adv=True,
    clip_coef=0.2,
    learning_rate=3e-4,
    num_envs=1,
    agent_config=GNNParams(
        layers=4,
        embedding_dim=16,
        activation=nn.Mish(),
        aggregation="max",
        action_mode=ActionMode.ACTION_THEN_NODE,
    ),
)
setup(args)
