"""Example of a custom experiment wrapped around an RLlib trainer."""
import argparse

import ray
from ray import tune
from src.rllib.agents import ppo

parser = argparse.ArgumentParser()
parser.add_argument("--train-iterations", type=int, default=10)


def experiment(config):
    iterations = config.pop("train-iterations")
    train_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
    checkpoint = None
    train_results = {}

    # Train
    for i in range(iterations):
        train_results = train_agent.train()
        if i % 2 == 0 or i == iterations - 1:
            checkpoint = train_agent.save(tune.get_trial_dir())
        tune.report(**train_results)
    train_agent.stop()

    # Manual Eval
    config["num_workers"] = 0
    eval_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
    eval_agent.restore(checkpoint)
    env = eval_agent.workers.local_worker().env

    obs = env.reset()
    done = False
    eval_results = {"eval_reward": 0, "eval_eps_length": 0}
    while not done:
        action = eval_agent.compute_single_action(obs)
        next_obs, reward, done, info = env.step(action)
        eval_results["eval_reward"] += reward
        eval_results["eval_eps_length"] += 1
    results = {**train_results, **eval_results}
    tune.report(results)


if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=3)
    config = ppo.DEFAULT_CONFIG.copy()
    config["train-iterations"] = args.train_iterations

    config["env"] = "CartPole-v0"

    tune.run(
        experiment,
        config=config,
        resources_per_trial=ppo.PPOTrainer.default_resource_request(config))
