
"""
Runs one instance of the environment and optimizes using the Soft Actor
Critic algorithm. Can use a GPU for the agent (applies to both sample and
train). No parallelism employed, everything happens in one python process; can
be easier to debug.

Requires OpenAI gym (and maybe mujoco).  If not installed, move on to next
example.

"""

from open_source.rlpyt.rlpyt.samplers.serial.sampler import SerialSampler
from open_source.rlpyt.rlpyt.envs.gym import make as gym_make
from open_source.rlpyt.rlpyt.algos.qpg.sac import SAC
from open_source.rlpyt.rlpyt.agents.qpg.sac_agent import SacAgent
from open_source.rlpyt.rlpyt.runners.minibatch_rl import MinibatchRl, MinibatchRlEval
from open_source.rlpyt.rlpyt.utils.logging.context import logger_context
from open_source.rlpyt.rlpyt.algos.qpg.ddpg import *
from open_source.rlpyt.rlpyt.agents.qpg.ddpg_agent import *
import datetime

def build_and_train(env_id="Hopper-v3", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=gym_make,
        env_kwargs=dict(id=env_id),
        eval_env_kwargs=dict(id=env_id),
        batch_T=1,  # One time-step per sampler iteration.
        batch_B=1,  # One environment (i.e. sampler Batch dimension).
        max_decorrelation_steps=1,
        eval_n_envs=5,
        eval_max_steps=int(2e3),
        eval_max_trajectories=5,
    )
    algo = DDPG(
        min_steps_learn=1000
    )  # Run with defaults.
    agent = DdpgAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=100000,
        log_interval_steps=10000,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(env_id=env_id)
    log_params = dict()
    name = "ddpg_" + env_id
    # log_dir = "example_2_gym"
    log_dir = "data/gym_example/{}".format(datetime.datetime.today().strftime("%Y%m%d_%H%M"))
    with logger_context(log_dir,
                        run_ID,
                        name,
                        # config,
                        log_params=log_params,
                        snapshot_mode="last",
                        use_summary_writer=True,
                        override_prefix=True):
        runner.train()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env_id', help='environment ID', default='Pendulum-v0')
    parser.add_argument('--run_ID', help='run identifier (logging)', type=int, default=0)
    parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=None)
    args = parser.parse_args()
    build_and_train(
        env_id=args.env_id,
        run_ID=args.run_ID,
        cuda_idx=args.cuda_idx,
    )
