"""An example of training PPO against OpenAI Gym Envs.
This script is an example of training a PPO agent against OpenAI Gym envs.
Both discrete and continuous action spaces are supported.
To solve CartPole-v0, run:
    python train_ppo_gym.py --env CartPole-v0
"""
import argparse
import os
import sys
import csv

import chainer
from chainer import functions as F
import gym

import numpy as np

import chainerrl
from chainerrl.agents import a3c
from chainerrl.agents import PPO
from chainerrl import experiments
from chainerrl import links
from chainerrl import misc
from chainerrl.optimizers.nonbias_weight_decay import NonbiasWeightDecay
from chainerrl import policies


def obs2state(obs, status):
    return np.hstack((obs, np.array([status])))

class A3CFFSoftmax(chainer.ChainList, a3c.A3CModel):
    """An example of A3C feedforward softmax policy."""

    def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
        self.pi = policies.SoftmaxPolicy(
            model=links.MLP(ndim_obs, n_actions, hidden_sizes))
        self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
        super().__init__(self.pi, self.v)

    def pi_and_v(self, state):
        return self.pi(state), self.v(state)

class A3CFFGaussian(chainer.Chain, a3c.A3CModel):
    """An example of A3C feedforward Gaussian policy."""

    def __init__(self, obs_size, action_space,
                 n_hidden_layers=2, n_hidden_channels=64,
                 bound_mean=None):
        assert bound_mean in [False, True]
        super().__init__()
        hidden_sizes = (n_hidden_channels,) * n_hidden_layers
        with self.init_scope():
            self.pi = policies.FCGaussianPolicyWithStateIndependentCovariance(
                obs_size, action_space.low.size,
                n_hidden_layers, n_hidden_channels,
                var_type='diagonal', nonlinearity=F.tanh,
                bound_mean=bound_mean,
                min_action=action_space.low, max_action=action_space.high,
                mean_wscale=1e-2)
            v_hidden_sizes = (100, 50, 25)
            self.v = links.MLP(obs_size, 1, hidden_sizes=v_hidden_sizes)

    def pi_and_v(self, state):
        return self.pi(state), self.v(state)


def main():
    import logging

    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=-1)
    parser.add_argument('--env', type=str, default='CartPole-v1')
    parser.add_argument('--bound-mean', default=True)
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--outdir', type=str, default='results_ppo_traj',
                        help='Directory path to save output files.'
                             ' If it does not exist, it will be created.')
    parser.add_argument('--steps', type=int, default=10 ** 6)
    parser.add_argument('--eval-interval', type=int, default=10000)
    parser.add_argument('--eval-n-runs', type=int, default=10)
    parser.add_argument('--reward-scale-factor', type=float, default=1e-2)
    parser.add_argument('--standardize-advantages', action='store_true')
    parser.add_argument('--render', action='store_true', default=False)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--weight-decay', type=float, default=0.0)
    parser.add_argument('--demo', action='store_true', default=True)
    parser.add_argument('--load', type=str, default='./best_exp')
    parser.add_argument('--logger-level', type=int, default=logging.INFO)
    parser.add_argument('--monitor', action='store_true')
    parser.add_argument('--update-interval', type=int, default=2048)
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--entropy-coef', type=float, default=0.0)
    args = parser.parse_args()

    logging.basicConfig(level=args.logger_level)

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu,))

    args.outdir = experiments.prepare_output_dir(args, args.outdir)

    def make_env(test):
        env = gym.make(args.env)
        # Use different random seeds for train and test envs
        env_seed = 2 ** 32 - 1 - args.seed if test else args.seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        if args.monitor:
            env = chainerrl.wrappers.Monitor(env, args.outdir)
        if not test:
            # Scale rewards (and thus returns) to a reasonable range so that
            # training is easier
            env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
        if args.render:
            env = chainerrl.wrappers.Render(env)
        return env

    sample_env = gym.make(args.env)

    obs_space = sample_env.observation_space
    action_space = sample_env.action_space

    # Normalize observations based on their empirical mean and variance
    obs_normalizer = chainerrl.links.EmpiricalNormalization(
        obs_space.low.size, clip_threshold=5)

    # Switch policy types accordingly to action space types
    model = A3CFFSoftmax(obs_space.low.size, action_space.n)

    opt = chainer.optimizers.SMORMS3()
    opt.setup(model)
    if args.weight_decay > 0:
        opt.add_hook(NonbiasWeightDecay(args.weight_decay))
    agent = PPO(model, opt,
                obs_normalizer=obs_normalizer,
                gpu=args.gpu,
                update_interval=args.update_interval,
                minibatch_size=args.batchsize, epochs=args.epochs,
                clip_eps_vf=None, entropy_coef=args.entropy_coef,
                standardize_advantages=args.standardize_advantages,
                gamma=0.99,
                )

    agent.load(args.load)

    # Make dirs if not exist
    if os.path.exists("./dataset_train") != True:
        os.mkdir("./dataset_train")

    if os.path.exists("./dataset_val") != True:
        os.mkdir("./dataset_val")


    # Record expert trajectory for train
    f = open("./dataset_train/exp.csv", "w")
    writer = csv.writer(f, lineterminator="\n")

    n_episodes = 1000

    env = make_env(test=True)

    for i in range(n_episodes):
        obs = env.reset()
        writer.writerow(obs2state(obs, 0))
        done = False
        reward_sum = 0

        while not done:
            act = agent.act(obs)
            obs, reward, done, _ = env.step(act)
            reward_sum += reward
            if done:
                writer.writerow(obs2state(obs, 1))
            else:
                writer.writerow(obs2state(obs, 0))

        print("Reward sum: ", reward_sum)

    # Record expert trajectory for validation
    f = open("./dataset_val/exp.csv", "w")
    writer = csv.writer(f, lineterminator="\n")

    env = make_env(test=True)

    for i in range(n_episodes):
        obs = env.reset()
        writer.writerow(obs2state(obs, 0))
        done = False
        reward_sum = 0

        while not done:
            act = agent.act(obs)
            obs, reward, done, _ = env.step(act)
            reward_sum += reward
            if done:
                writer.writerow(obs2state(obs, 1))
            else:
                writer.writerow(obs2state(obs, 0))

        print("Reward sum: ", reward_sum)


if __name__ == '__main__':
    main()
