from regawa.rl.ppo_gnn import setup, Args
from regawa import GNNParams, ActionMode
from torch import nn
from regawa.rddl import register_shuffle_env
import tyro
import json


batch_size = 1024
num_envs = 16
distributed_steps = batch_size // num_envs

# profiles

p1_settings = {
    "learning_rate": 1e-3,
    "ent_coef": 0.1,
    "seed": 0,
}

p2_settings = {
    "learning_rate": 1e-4,
    "ent_coef": 0.001,
    "seed": 1,
}

p3_settings = {
    "learning_rate": 1e-5,
    "ent_coef": 0.0001,
    "seed": 2,
}

profiles = {
    "warmup": p1_settings,
    "anneal": p2_settings,
    "finetune": p3_settings,
}


def run(
    domain: str,
    batch_id: str,
    train_instances: str,
    eval_instances: str,
    profile: str,
    /,
    resume_from: str | None = None,
):
    env_id = register_shuffle_env()

    profile_settings = profiles[profile]

    args = Args(
        anneal_lr=False,  # for consistency with sb3 defaults
        env_id=env_id,
        total_timesteps=500000,
        resume_from=resume_from,
        domain=domain,
        remove_false=True,
        instance=[i + 1 for i in json.loads(train_instances)],
        eval_instance=[i + 1 for i in json.loads(eval_instances)],
        num_steps=distributed_steps,
        num_minibatches=16,
        update_epochs=10,
        vf_coef=1.0,
        gamma=0.99,
        gae_lambda=0.95,
        max_grad_norm=1.0,
        multiprocess=True,
        norm_adv=False,
        clip_coef=0.2,
        num_envs=num_envs,
        weight_decay=0.0,
        **profile_settings,
        agent_config=GNNParams(
            layers=4,
            embedding_dim=16,
            activation=nn.Tanh(),
            aggregation="max",
            action_mode=ActionMode.ACTION_THEN_NODE,
        ),
    )
    setup(args, batch_id)


def main():
    tyro.cli(run)


if __name__ == "__main__":
    main()
