#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/11/30 20:53 
# @Author : wangzhaorong
# @Site :  
# @File : main.py 
# @Software: PyCharm
import os
import sys

path = os.path.abspath(os.path.dirname('./../__file__'))
sys.path.append(path)
import time
import random
import copy
import numpy as np
import argparse
from baselines import logger
import tensorflow as tf
import common.grid_world as grid_world
from queue import Queue
from common.config import Config
from ppo import Policy
from ppo import ReplayBuffer
from seq2seq import Seq2seq
from qnet import Posterior

os.environ['PYTHONHASHSEED'] = '0'


def main(args):
    set_random_seed(args.seed)

    # create env
    env, config = get_env_config(args)

    # create policy
    seq2seq = Seq2seq(args.max_encode_length, args.max_decode_length,
                      args.vocabulary_dim, args.num_layers, args.num_units, mode='train')
    ckpt = tf.train.get_checkpoint_state(args.seq2seq_model_dir)
    assert tf.train.checkpoint_exists(ckpt.model_checkpoint_path)
    print('Reloading seq2seq model parameters..')
    seq2seq.saver.restore(seq2seq.sess, ckpt.model_checkpoint_path)
    agent = Policy(args=args)
    agent.construct_model(gpu=args.gpu)
    posterior = Posterior(args=args)
    posterior.construct_model(gpu=args.gpu)
    posterior.sess.run(tf.global_variables_initializer())
    posterior.sync_params_op()

    # create replay_buffer
    replay_buffer = ReplayBuffer(num_total_sizes=args.num_total_sizes, act_dims=env.action_space.n,
                                 obs_dims=len(env.state), encode_dims=seq2seq.num_units)

    # create dir
    ts = int(time.time())
    base_path = 'ppo/' + args.algo + '_seed_' + str(args.seed) + '_ts_' + str(ts) + '/'
    # summary_writer = tf.summary.FileWriter(base_path + "tb_event/", tf.get_default_graph())
    logger.configure(dir=base_path + "tb_event/")
    save_path = base_path + "checkpoint/"
    if args.ppo_load_path is not None:
        agent.saver.restore(agent.sess, save_path)
    else:
        agent.sess.run(tf.global_variables_initializer())
    # os.makedirs(base_path + "log/", exist_ok=True)
    # log_file = os.path.join(base_path + "log/", "log.txt")
    # file = open(log_file, 'a')

    train_step = 0
    # Train
    for ep_cur in range(args.num_episodes):
        state = env.reset()
        encode_input_state = init_encode_input(seq2seq)
        state = [e / env.n_height for e in state]
        ep_length, ep_reward = 0, 0
        while True:
            ep_length += 1
            encode_state = get_encode_state(encode_input_state, seq2seq)
            encode = list(encode_state[1])[0]
            # encode = get_encode(encode_input_state, seq2seq)
            action, prob, value = agent.sample_action(state, encode)
            next_state, reward, done, _ = env.step(action)
            env.render()
            cross_entropy = posterior.predict([state], [action], [encode])
            reward -= 0.1 * cross_entropy
            replay_buffer.store_data(cur_obs=state, cur_action=action, reward=reward,
                                     done=done, old_prob=prob, value=value, encode=encode)
            update_encode_input(state, action, encode_input_state)
            next_state = [e / env.n_height for e in next_state]
            state = next_state
            if replay_buffer.enough_data:
                observations, actions = replay_buffer.observations, replay_buffer.actions
                rewards, dones = replay_buffer.rewards, replay_buffer.dones,
                values, old_probs = replay_buffer.values, replay_buffer.old_probs
                encodes = replay_buffer.encodes
                replay_buffer.clear_data()
                encode_state = get_encode_state(encode_input_state, seq2seq)
                next_encode = list(encode_state[1])[0]
                returns = agent.compute_gae(next_obs=next_state, next_encode=next_encode, rewards=rewards,
                                            values=values, dones=dones)
                ppo_loss, cost_p, cost_v, cost_entropy = agent.update_model(observations=observations, actions=actions,
                                                                            returns=returns, values=values,
                                                                            old_probs=old_probs, encodes=encodes)
                posterior_entropy, posterior_loss, posterior_val_loss = posterior.update_model(
                    observations=observations, actions=actions, encodes=encodes)
                update_logger(ppo_loss, cost_p, cost_v, cost_entropy, posterior_entropy, posterior_loss,
                              posterior_val_loss)
                train_step += 1
                if train_step % args.save_step_count == args.save_step_count - 1:
                    if not save_path:
                        os.makedirs(save_path)
                    save_name = save_path + str(train_step)
                    agent.saver.save(agent.sess, save_name)
                    print('Model saved %s' % save_name)
            ep_reward += reward
            if done or ep_length >= config.solving_criteria[0]:
                break
        logger.logkv("ep_reward", ep_reward)
        logger.dump_tabular()
        print("Episode: {}, Path length:{}, Reward:{}.".format(ep_cur, ep_length, ep_reward))


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--algo', type=str, default='ppo', help='algo name')
    parser.add_argument('--seq2seq_model_dir', type=str, default='seq2seq/model/', help='algo name')
    parser.add_argument('--ppo_load_path', type=str, default=None, help='algo name')
    parser.add_argument('--env', type=str, default='two_way_grid_world', help='env')
    parser.add_argument('--gpu', type=int, default=-1, help='running on a specify gpu, -1 indicates using cpu')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--num_episodes', type=int, default=100000, help='num_episodes')
    # for algo
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--epsilon', type=float, default=0.2, help='epsilon')
    parser.add_argument('--var_beta', type=float, default=0.2, help='var_beta')
    parser.add_argument('--v_coef', type=float, default=1.0, help='v_coef')
    parser.add_argument('--gamma', type=float, default=0.99, help='ed factor')
    parser.add_argument('--lammbda', type=float, default=0.95, help='lammbda')
    parser.add_argument('--save_step_count', type=int, default=50, help='save step')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_total_sizes', type=int, default=1024, help='num_total_sizes')
    parser.add_argument('--noptepochs', type=int, default=10, help='noptepochs')
    # for seq2seq
    parser.add_argument('--max_encode_length', type=int, default=4, help='max_encode_length')
    parser.add_argument("--max_decode_length", type=int, default=4, help="max_decode_length")
    parser.add_argument("--vocabulary_dim", type=int, default=5, help="vocabulary_dim")
    parser.add_argument("--num_layers", default=1, type=int, help="num_layers")
    parser.add_argument("--num_units", default=128, type=int, help="num_units")
    return parser.parse_args()


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)


def get_env_config(args):
    if args.env == 'two_way_grid_world':
        env = grid_world.TwoWayGridWorld()
        config = Config(env, 'two_way_grid_world')
    return env, config


def init_encode_input(seq2seq):
    encode_input_state = Queue(maxsize=seq2seq.max_encode_length)
    for i in range(seq2seq.max_encode_length):
        encode_input_state.put([0., 0., 0.9, 0.9, 0.0])
    return encode_input_state


# def get_encode(encode_input_state, seq2seq):
#     input_list = list(encode_input_state.queue)
#     input_len = len(input_list)
#     feed = {"encode_input": [input_list],
#             "encode_sequence_length": [input_len]}
#     # encoder_h = [float(e) for e in list(seq2seq.predict(feed)[0])]
#     encoder_h = list(seq2seq.predict(feed)[0])
#     return encoder_h

def get_encode_state(encode_input_state, seq2seq):
    input_list = list(encode_input_state.queue)
    input_len = len(input_list)
    feed = {"encode_input": [input_list],
            "encode_sequence_length": [input_len]}
    encode_state = seq2seq.predict_encode_state(feed)
    return encode_state


def update_encode_input(state, action, encode_input_state):
    encode_input_tail = copy.deepcopy(state)
    encode_input_tail.append(action / 10)
    if encode_input_state.full():
        encode_input_state.get()
    encode_input_state.put(encode_input_tail)


def update_logger(ppo_loss, cost_p, cost_v, cost_entropy, posterior_entropy, posterior_loss, posterior_val_loss):
    logger.logkv("ppo_loss", ppo_loss)
    logger.logkv("cost_p", cost_p)
    logger.logkv("cost_v", cost_v)
    logger.logkv("cost_entropy", cost_entropy)
    logger.logkv("posterior_entropy", posterior_entropy)
    logger.logkv("posterior_loss", posterior_loss)
    logger.logkv("posterior_val_loss", posterior_val_loss)


if __name__ == "__main__":
    main(args_parse())
