"""A training script of TD3 on OpenAI Gym Mujoco environments.
This script follows the settings of http://arxiv.org/abs/1802.09477 as much
as possible.
"""

import argparse
import logging
import os
import sys
import csv

import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers
import gym
import gym.wrappers
import numpy as np

import chainerrl
from chainerrl import experiments
from chainerrl import explorers
from chainerrl import misc
from chainerrl import replay_buffer

def obs2state(obs, status):
    return np.hstack((obs, np.array([status])))

def concat_obs_and_action(obs, action):
    """Concat observation and action to feed the critic."""
    return F.concat((obs, action), axis=-1)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir', type=str, default='results_td3_traj',
                        help='Directory path to save output files.'
                             ' If it does not exist, it will be created.')
    parser.add_argument('--env', type=str, default='BipedalWalker-v2',
                        help='OpenAI Gym MuJoCo env to perform algorithm on.')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--gpu', type=int, default=0,
                        help='GPU to use, set to -1 if no GPU.')
    parser.add_argument('--load', type=str, default='./best_exp',
                        help='Directory to load agent from.')
    parser.add_argument('--steps', type=int, default=10 ** 6,
                        help='Total number of timesteps to train the agent.')
    parser.add_argument('--eval-n-runs', type=int, default=10,
                        help='Number of episodes run for each evaluation.')
    parser.add_argument('--eval-interval', type=int, default=10000,
                        help='Interval in timesteps between evaluations.')
    parser.add_argument('--replay-start-size', type=int, default=10000,
                        help='Minimum replay buffer size before ' +
                             'performing gradient updates.')
    parser.add_argument('--batch-size', type=int, default=1000,
                        help='Minibatch size')
    parser.add_argument('--render', action='store_true',
                        help='Render env states in a GUI window.')
    parser.add_argument('--demo', action='store_true', default=False,
                        help='Just run evaluation, not training.')
    parser.add_argument('--load-pretrained', action='store_true',
                        default=False)
    parser.add_argument('--pretrained-type', type=str, default="best",
                        choices=['best', 'final'])
    parser.add_argument('--monitor', action='store_true',
                        help='Wrap env with gym.wrappers.Monitor.')
    parser.add_argument('--logger-level', type=int, default=logging.INFO,
                        help='Level of the root logger.')
    parser.add_argument('--gamma', type=int, default=0.99,
                        help='The discount rate for MDP.')
    args = parser.parse_args()

    logging.basicConfig(level=args.logger_level)

    args.outdir = experiments.prepare_output_dir(
        args, args.outdir, argv=sys.argv)
    print('Output files are saved in {}'.format(args.outdir))

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu,))

    def make_env(test):
        env = gym.make(args.env)
        # Unwrap TimeLimit wrapper
        #assert isinstance(env, gym.wrappers.TimeLimit)
        #env = env.env
        # Use different random seeds for train and test envs
        env_seed = 2 ** 32 - 1 - args.seed if test else args.seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        if args.monitor:
            env = chainerrl.wrappers.Monitor(env, args.outdir)
        if args.render and test:
            env = chainerrl.wrappers.Render(env)
        return env

    env = make_env(test=False)
    timestep_limit = env.spec.max_episode_steps
    obs_space = env.observation_space
    action_space = env.action_space
    print('Observation space:', obs_space)
    print('Action space:', action_space)

    action_size = action_space.low.size

    winit = chainer.initializers.LeCunUniform(3 ** -0.5)

    policy = chainer.Sequential(
        L.Linear(None, 400, initialW=winit),
        F.relu,
        L.Linear(None, 300, initialW=winit),
        F.relu,
        L.Linear(None, action_size, initialW=winit),
        F.tanh,
        chainerrl.distribution.ContinuousDeterministicDistribution,
    )
    policy_optimizer = optimizers.Adam().setup(policy)

    def make_q_func_with_optimizer():
        q_func = chainer.Sequential(
            concat_obs_and_action,
            L.Linear(None, 400, initialW=winit),
            F.relu,
            L.Linear(None, 300, initialW=winit),
            F.relu,
            L.Linear(None, 1, initialW=winit),
        )
        q_func_optimizer = optimizers.Adam().setup(q_func)
        return q_func, q_func_optimizer

    q_func1, q_func1_optimizer = make_q_func_with_optimizer()
    q_func2, q_func2_optimizer = make_q_func_with_optimizer()

    # Draw the computational graph and save it in the output directory.
    fake_obs = chainer.Variable(
        policy.xp.zeros_like(obs_space.low, dtype=np.float32)[None],
        name='observation')
    fake_action = chainer.Variable(
        policy.xp.zeros_like(action_space.low, dtype=np.float32)[None],
        name='action')
    chainerrl.misc.draw_computational_graph(
        [policy(fake_obs)], os.path.join(args.outdir, 'policy'))
    chainerrl.misc.draw_computational_graph(
        [q_func1(fake_obs, fake_action)], os.path.join(args.outdir, 'q_func1'))
    chainerrl.misc.draw_computational_graph(
        [q_func2(fake_obs, fake_action)], os.path.join(args.outdir, 'q_func2'))

    rbuf = replay_buffer.ReplayBuffer(10 ** 6)

    explorer = explorers.AdditiveGaussian(
        scale=0.1, low=action_space.low, high=action_space.high)

    def burnin_action_func():
        """Select random actions until model is updated one or more times."""
        return np.random.uniform(
            action_space.low, action_space.high).astype(np.float32)

    # Hyperparameters in http://arxiv.org/abs/1802.09477
    agent = chainerrl.agents.TD3(
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        rbuf,
        gamma=args.gamma,
        soft_update_tau=5e-3,
        explorer=explorer,
        replay_start_size=args.replay_start_size,
        gpu=args.gpu,
        minibatch_size=args.batch_size,
        burnin_action_func=burnin_action_func,
    )

    agent.load(args.load)

    if os.path.exists("./dataset_train") != True:
        os.mkdir("./dataset_train")

    # if os.path.exists("./dataset_val") != True:
    #     os.mkdir("./dataset_val")


    f = open("./dataset_train/exp.csv", "w")
    writer = csv.writer(f, lineterminator="\n")

    n_episodes = 1000

    env = make_env(test=True)

    for i in range(n_episodes):
        obs = env.reset()
        writer.writerow(obs2state(obs, 0))
        done = False
        reward_sum = 0

        while not done:
            act = agent.act(obs)
            obs, reward, done, _ = env.step(act)
            reward_sum += reward
            if done:
                writer.writerow(obs2state(obs, 1))
            else:
                writer.writerow(obs2state(obs, 0))

        print("Reward sum: ", reward_sum)

    # f = open("./dataset_val/exp.csv", "w")
    # writer = csv.writer(f, lineterminator="\n")
    #
    # env = make_env(test=True)
    #
    # for i in range(n_episodes):
    #     obs = env.reset()
    #     writer.writerow(obs2state(obs, 0))
    #     done = False
    #     reward_sum = 0
    #
    #     while not done:
    #         act = agent.act(obs)
    #         obs, reward, done, _ = env.step(act)
    #         reward_sum += reward
    #         if done:
    #             writer.writerow(obs2state(obs, 1))
    #         else:
    #             writer.writerow(obs2state(obs, 0))
    #
    #     print("Reward sum: ", reward_sum)



if __name__ == '__main__':
    main()
