"""In this example, the stable baselines package is used to train an RL agent.

The syntax for running this example is:

    python run_stable_baselines.py <domain> <instance> <method> [<steps>] [<learning_rate>]

where:
    <domain> is the name of a domain located in the /Examples directory
    <instance> is the instance number
    <method> is the algorithm to train (e.g. PPO, DQN etc.)
    <steps> is the number of trials to simulate for training
    <learning_rate> is the learning rate to use to train the agent
"""

import sys
from stable_baselines3 import *

import pyRDDLGym

from pyRDDLGym_rl.core.agent import StableBaselinesAgent
from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv

from rddlrepository import RDDLRepoManager

METHODS = {"a2c": A2C, "ddpg": DDPG, "dqn": DQN, "ppo": PPO, "sac": SAC, "td3": TD3}


def main(domain, instance, method, steps=200000, learning_rate=None):
    # set up the environment
    env = pyRDDLGym.make(
        domain,
        instance,
        base_class=SimplifiedActionRDDLEnv,
        enforce_action_constraints=True,
    )

    # train the PPO agent
    kwargs = {"verbose": 1}
    if learning_rate is not None:
        kwargs["learning_rate"] = learning_rate
    model = METHODS[method]("MultiInputPolicy", env, **kwargs)
    model.learn(total_timesteps=steps)

    model.save(f"{domain}_{instance}_{method}")

    # wrap the agent in a RDDL policy and evaluate
    ppo_agent = StableBaselinesAgent(model)
    ppo_agent.evaluate(env, episodes=1, verbose=True, render=True)

    env.close()


if __name__ == "__main__":
    args = sys.argv[1:]
    if len(args) < 3:
        print(
            "python run_stable_baselines.py <domain> <method> [<steps>] [<learning_rate>]"
        )
        exit(1)
    if args[2] not in METHODS:
        print(f"<method> in {set(METHODS.keys())}")
        exit(1)
    kwargs = {"domain": args[0], "method": args[2]}
    if len(args) >= 4:
        kwargs["steps"] = int(args[3])
    if len(args) >= 5:
        kwargs["learning_rate"] = float(args[4])

    manager = RDDLRepoManager()
    domain = manager.get_problem(kwargs["domain"])
    instances = domain.list_instances()
    for instance in instances:
        kwargs["instance"] = instance
        try:
            main(**kwargs)
        except Exception as e:
            print(f"Error in domain {kwargs['domain']} instance {instance}: {e}")
            continue
