#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/11/30 20:53
# @Author : wangzhaorong
# @Site :
# @File : main.py
# @Software: PyCharm
import os
import sys

path = os.path.abspath(os.path.dirname('./../__file__'))
sys.path.append(path)
import time
import random
import copy
import numpy as np
import argparse
from baselines import logger
import tensorflow as tf
import common.grid_world as grid_world
from queue import Queue
from common.config import Config
from ppo import Policy
from ppo import ReplayBuffer
from seq2seq import Seq2seq
from qnet import Posterior


os.environ['PYTHONHASHSEED'] = '0'


def main(args):
    set_random_seed(args.seed)
    # create env
    env, config = get_env_config(args)
    # create policy
    seq2seq = Seq2seq(args.max_encode_length, args.max_decode_length,
                      args.vocabulary_dim, args.num_layers, args.num_units, mode='infer')
    ckpt = tf.train.get_checkpoint_state(args.seq2seq_model_dir)
    assert tf.train.checkpoint_exists(ckpt.model_checkpoint_path)
    print('Reloading seq2seq model parameters..')
    seq2seq.saver.restore(seq2seq.sess, ckpt.model_checkpoint_path)
    agent = Policy(args=args)
    agent.construct_model(gpu=args.gpu)
    ackpt = tf.train.get_checkpoint_state(args.ppo_load_path)
    assert tf.train.checkpoint_exists(ackpt.model_checkpoint_path)
    print('Reloading agent model parameters..')
    agent.saver.restore(agent.sess, ackpt.model_checkpoint_path)
    #agent.saver.restore(agent.sess, args.ppo_load_path)
    # Eval
    for ep_cur in range(args.num_episodes):
        state = env.reset()
        encode_input_state = init_encode_input(seq2seq)
        state = [e / env.n_height for e in state]
        ep_length, ep_reward = 0, 0
        while True:
            ep_length += 1
            encode_state = get_encode_state(encode_input_state, seq2seq)
            encode = list(encode_state[1])[0]
            out_list = get_decode_out(args, encode_state, seq2seq)
            env.render(extra_info=out_list)
            action = agent.test_action(state, encode)
            next_state, reward, done, _ = env.step(action)
            print("state:{}， next_state：{}".format(state, next_state))

            time.sleep(0.5)
            update_encode_input(state, action, encode_input_state)
            next_state = [e / env.n_height for e in next_state]
            state = next_state
            ep_reward += reward
            if done or ep_length >= config.solving_criteria[0]:
                break
        print("Episode: {}, Path length:{}, Reward:{}.".format(ep_cur, ep_length, ep_reward))


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--algo', type=str, default='ppo', help='algo name')
    parser.add_argument('--seq2seq_model_dir', type=str, default='seq2seq/model/', help='algo name')
    parser.add_argument('--ppo_load_path', type=str, default='./ppo/ppo_seed_0_ts_1617641161/checkpoint/', help='algo name')  #
    parser.add_argument('--env', type=str, default='two_way_grid_world', help='env')
    parser.add_argument('--gpu', type=int, default=-1, help='running on a specify gpu, -1 indicates using cpu')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--num_episodes', type=int, default=5, help='num_episodes')
    # for algo
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--epsilon', type=float, default=0.2, help='epsilon')
    parser.add_argument('--var_beta', type=float, default=0.2, help='var_beta')
    parser.add_argument('--v_coef', type=float, default=1.0, help='v_coef')
    parser.add_argument('--gamma', type=float, default=0.99, help='ed factor')
    parser.add_argument('--lammbda', type=float, default=0.95, help='lammbda')
    parser.add_argument('--save_step_count', type=int, default=1000, help='save step')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_total_sizes', type=int, default=1024, help='num_total_sizes')
    parser.add_argument('--noptepochs', type=int, default=10, help='noptepochs')
    # for seq2seq
    parser.add_argument('--max_encode_length', type=int, default=4, help='max_encode_length')
    parser.add_argument("--max_decode_length", type=int, default=4, help="max_decode_length")
    parser.add_argument("--vocabulary_dim", type=int, default=5, help="vocabulary_dim")
    parser.add_argument("--num_layers", default=1, type=int, help="num_layers")
    parser.add_argument("--num_units", default=128, type=int, help="num_units")
    return parser.parse_args()


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)


def get_env_config(args):
    if args.env == 'two_way_grid_world':
        env = grid_world.TwoWayGridWorld()
        config = Config(env, 'two_way_grid_world')
    return env, config


def init_encode_input(seq2seq):
    encode_input_state = Queue(maxsize=seq2seq.max_encode_length)
    for i in range(seq2seq.max_encode_length):
        encode_input_state.put([0., 0., 0.9, 0.9, 0.1])
    return encode_input_state


def get_encode_state(encode_input_state, seq2seq):
    input_list = list(encode_input_state.queue)
    input_len = len(input_list)
    feed = {"encode_input": [input_list],
            "encode_sequence_length": [input_len]}
    # encoder_h = [float(e) for e in list(seq2seq.predict(feed)[0])]
    # encoder_h = list(seq2seq.predict_encode_state(feed))[0]
    encode_state = seq2seq.predict_encode_state(feed)
    return encode_state


def get_decode_out(args, encode_state, seq2seq):
    out_list = []
    feed = {"infer_encoder_input_tensor": [-1, -1, -1, -1, -1],
            "infer_c": encode_state[0],
            "infer_h": encode_state[1]}
    for i in range(args.max_decode_length):
        decode_out, decode_state_tuple = seq2seq.predict_decode_out(feed)
        out_list.append(copy.deepcopy(list(decode_out[0][0])))
        feed["infer_encoder_input_tensor"] = decode_out[0][0]
        feed["infer_c"] = decode_state_tuple[0]
        feed["infer_h"] = decode_state_tuple[1]
    for idx, item in enumerate(out_list):
        out_list[idx] = [round(e, 1) for e in item]
    print("decode out:", out_list)
    return out_list


def update_encode_input(state, action, encode_input_state):
    encode_input_tail = copy.deepcopy(state)
    encode_input_tail.append(action / 10)
    if encode_input_state.full():
        encode_input_state.get()
    encode_input_state.put(encode_input_tail)


if __name__ == "__main__":
    main(args_parse())
