import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.agents.ppo import PPOTrainer
import argparse
from ray.tune.logger import pretty_print

# ASDF: run in local for easier debugging? https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py
LOCAL_MODE = True
TUNE = False

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--run", type=str, default="PPO")
    parser.add_argument("--env", type=str, default="CartPole-v0")
    parser.add_argument("--n_itr", type=int, default=50)

    args = parser.parse_args()

    config = {
        "env": args.env,
        "num_workers": 1,
        "framework": "torch",
    }

    stop = {
        "training_iteration": args.n_itr,
    }

    if TUNE:
        # automated run with Tune and grid search and TensorBoard
        print("Training automatically with Ray Tune")
        results = tune.run(args.run, config=config, stop=stop)

    else:
        # manual training with train loop using PPO and fixed learning rate
        if args.run != "PPO":
            raise ValueError("Only support --run PPO with --no-tune.")
        ppo_config = ppo.DEFAULT_CONFIG.copy()
        ppo_config.update(config)
        # use fixed learning rate instead of grid search (needs tune)
        ppo_config["lr"] = 1e-3
        trainer = PPOTrainer(config=ppo_config, env=args.env)
        # run manual training loop and print results after each iteration
        for _ in range(args.n_itr):
            result = trainer.train()
            print(pretty_print(result))


    ray.shutdown()