from util import Environment, default_circuit
from q_func import QFunction, RPQFunction, RPDQN
import numpy as np
import chainer
import chainerrl
from estimator import ThreatEstimator, RecoNet
from chainerrl.wrappers import ScaleReward, CastObservationToFloat32
from statistics import mean
import argparse


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        prog='run.py',
        description='run learning',
        add_help=True
    )

    parser.add_argument('--gpu', dest='gpu', action='store_true')
    parser.add_argument('--seed', dest='seed', default=0, type=int)
    parser.add_argument('--load', dest='load', default='', type=str)

    parser.add_argument('--normal', dest='normal', action='store_true')

    parser.add_argument('--adameps', dest='adam_eps', default=1e-2, type=float)
    parser.add_argument('--adamalpha', dest='adam_alpha',
                        default=1e-3, type=float)
    parser.add_argument('--gamma', dest='gamma', default=0.90, type=float)
    parser.add_argument('--alllog', dest='all_log', action='store_true')
    parser.add_argument('--lmd', dest='lmd', default=200, type=int)
    parser.add_argument('--scale', dest='scale', default=1.0, type=float)

    parser.add_argument('--firsteps', dest='firsteps', default=1.0, type=float)
    parser.add_argument('--step', dest='step', default=3 * 10 ** 6, type=int)

    parser.add_argument('--demo', dest='demo', action='store_true')

    parser.add_argument('--render', dest='ren', action='store_true')
    parser.add_argument('--eval', dest='eval', type=str, default='')
    parser.add_argument('-t', dest='times', default=100, type=int)
    args = parser.parse_args()

    gpus = (0,) if args.gpu else ()

    chainerrl.misc.set_random_seed(args.seed, gpus)

    circuit = default_circuit()
    rand = False if args.demo else True
    env = Environment(circuit=circuit,
                      random_init=rand, file='crash_train.log', all_log=args.all_log,
                      lmd=args.lmd, render=args.ren)

    n_actions = len(env.agent.action_list)

    env = ScaleReward(env, args.scale)

    reconet = RecoNet()
    estimator = ThreatEstimator(
        reconet, 'circuit/threat.model', args.gpu)

    danger_limit = 1e-3
    step = args.step

    if args.normal:
        q_func = QFunction(n_actions)
    else:
        q_func = RPQFunction(n_actions, estimator,
                             danger_limit)

    optimizer = chainer.optimizers.Adam(
        eps=args.adam_eps, alpha=args.adam_alpha)
    optimizer.setup(q_func)

    explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
        args.firsteps, 0.05, step, random_action_func=lambda: np.random.randint(n_actions))

    replay_buffer = chainerrl.replay_buffer.PrioritizedReplayBuffer(1e6)

    if args.normal:
        agent = chainerrl.agents.DoubleDQN(
            q_func, optimizer, replay_buffer, args.gamma, explorer, clip_delta=False,
            replay_start_size=600, update_interval=1,
            target_update_interval=1e3)
    else:
        agent = RPDQN(
            q_func, optimizer, replay_buffer, args.gamma, explorer, clip_delta=False,
            replay_start_size=600, update_interval=1,
            target_update_interval=1e3)

    env.unwrapped.result_agent = agent

    if args.demo:
        if args.load:
            agent.load(args.load)

        for i in range(args.times):
            obs = env.reset()
            done = False
            total = 0
            st = 0

            while not done:
                action = agent.act(obs)
                obs, r, done, _ = env.step(action)
                env.unwrapped.render()
                total += r
                st += 1
                num = '%03d' % st
                if st >= 200:
                    break
            print('Reward:', total)

    elif args.eval:
        def gen_dir_name(jobid):
            times = step // 10**5
            yield ''
            dirname = 'agents/'+args.eval+'/'
            for i in range(times - 1):
                yield dirname+'agent'+str(i+1)
            yield dirname + str(int(step)) + '_finish'

        crash_ratio = []
        reward_list = []
        steps = np.arange(0, step+1, 10**5)

        for agent_dir_name in gen_dir_name(args.eval):
            if agent_dir_name:
                agent.load(agent_dir_name)
            print('agent:', agent_dir_name)

            env = Environment(circuit=circuit,
                              random_init=True, file='crash_train.log', all_log=args.all_log,
                              lmd=args.lmd)

            total_episode_reward = []

            for i in range(args.times):
                obs = env.reset()
                done = False
                total = 0
                st = 0

                while not done:
                    action = agent.act(obs)
                    obs, r, done, _ = env.step(action)
                    total += r
                    st += 1
                    num = '%03d' % st
                    if st >= 200:
                        break

                if not env.crashed:
                    total_episode_reward.append(total)

            ave_reward = mean(total_episode_reward) if len(
                total_episode_reward) > 0 else np.nan
            ratio = env.crash_cnt / args.times

            print('result: crash_cnt ', ratio,
                  ' pure_reward ', ave_reward, end='\n\n')
            crash_ratio.append(ratio)
            reward_list.append(ave_reward)

        crash_ratio = np.array(crash_ratio)
        reward_list = np.array(reward_list)
        data = np.vstack((steps, crash_ratio))
        data2 = np.vstack((steps, reward_list))
        print(data)
        np.save('results/crash.npy', data)
        print(data2)
        np.save('results/reward.npy', data2)

    else:
        if args.load:
            agent.load(args.load)

        chainerrl.experiments.train_agent_with_evaluation(
            agent, env, steps=step, eval_n_steps=None, eval_n_episodes=1,
            train_max_episode_len=200, eval_interval=1e4, outdir='results',
            eval_env=Environment(circuit=circuit,  file='crash_test.log', all_log=True, lmd=200))
