from args import Args
from sb3_contrib.common.maskable.evaluation import evaluate_policy
import utils

def main():
    args = Args().parse_args()

    if not args.no_wandb:
        import wandb
        from wandb.integration.sb3 import WandbCallback

        wandb.init(
            entity="", # fill in wandb
            project="",
            config=vars(args),
            sync_tensorboard=True,
    )
    vec_env = utils.get_vec_env(args)
    eval_env = utils.get_vec_env(args, 1)
    model, callbacks, path = utils.create_model(args, vec_env, eval_env)
    model.learn(total_timesteps=args.total_timesteps, callback=callbacks)
    utils.save_args(args, "results/ppo_mask")
    model.save("results/ppo_mask")
    eval = evaluate_policy(model, vec_env, n_eval_episodes=1000, warn=False)
    print("Evaluation (mean/std):", eval)

if __name__ == "__main__":
    main()
