#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/11/30 20:53 
# @Author : wangzhaorong
# @Site :  
# @File : main.py 
# @Software: PyCharm
import os
import sys
from datetime import datetime

from matplotlib import pyplot as plt, animation

path = os.path.abspath(os.path.dirname('./../__file__'))
sys.path.append(path)
import time
import random
import copy
import numpy as np
import argparse
from baselines import logger
import tensorflow as tf
import common.grid_world as grid_world
from queue import Queue
from common.config import Config
from ppo import Policy
from ppo import ReplayBuffer
from seq2seq import Seq2seq
from qnet_mod import Posterior

os.environ['PYTHONHASHSEED'] = '0'


# run python main_mod.py   --num_units   --add_reward --seed

def create_posterior(args):
    posterior = Posterior(args=args)
    posterior.construct_model(gpu=args.gpu)
    posterior.sess.run(tf.global_variables_initializer())
    posterior.sync_params_op()
    return posterior


def create_seq2seq(args):
    # create policy
    seq2seq = Seq2seq(args.max_encode_length, args.max_decode_length,
                      args.vocabulary_dim, args.num_layers, args.num_units, mode='train')
    args.seq2seq_model_dir = 'seq2seq/model_' + str(args.num_units) + '/'
    ckpt = tf.train.get_checkpoint_state(args.seq2seq_model_dir)
    assert tf.train.checkpoint_exists(ckpt.model_checkpoint_path)
    print('Reloading seq2seq model parameters..')
    seq2seq.saver.restore(seq2seq.sess, ckpt.model_checkpoint_path)
    return seq2seq


def create_base_path(base_path, args, **kwargs):
    # create dir
    time_int = int(time.time())
    time_str = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    seed = kwargs["seed"]
    if "gail" in args.algo:
        args_path = {"algo": "algo", "add_reward": "ad", "encode_dim": "nu", "max_sample_num": "msn", "random_reset": "rr", "test": "test"}
    elif "stylebc" in args.algo:
        args_path = {"algo": "algo", "learning_rate": "lr", "encode_dim": "nu", "max_sample_num": "msn", "random_reset": "rr", "test": "test", "batch_size": "bs", "style_num": "sn"}
    else:
        args_path = {"algo": "algo", "add_reward": "ad", "num_units": "nu", "max_sample_num": "msn", "random_reset": "rr", "test": "test"}
    tuple_attrs = [(v, getattr(args, k)) if hasattr(args, k) else (v, 0) for k, v in args_path.items()]
    list_attrs = []
    for idx, k in enumerate(tuple_attrs):
        if idx != 0:
            list_attrs.append(k[0])
        list_attrs.append(k[1])
    for k, v in kwargs.items():
        list_attrs.append(k)
        list_attrs.append(v)
    path_list_str = "_".join(["{:.2e}".format(k) if isinstance(k, float) else str(k) for k in list_attrs])
    base_path = f"{base_path}/{path_list_str}_{time_str}/"
    return {"base_path": base_path, "tb_path": base_path + "tb_event/", "logger_path": base_path + "logger/", "save_path": base_path + "checkpoint/", "save_style_dir": base_path + "style/",
            "test_success": base_path + "test_success/", "test_fail": base_path + "test_fail/"}


def main(args):
    set_random_seed(args.seed)

    # create env
    env, config = get_env_config(args)
    seq2seq = create_seq2seq(args)
    agent = Policy(args=args)
    agent.construct_model(gpu=args.gpu)  # goujian
    posterior = create_posterior(args)

    # create replay_buffer
    replay_buffer = ReplayBuffer(num_total_sizes=args.num_total_sizes, act_dims=env.action_space.n,
                                 obs_dims=len(env.state), encode_dims=seq2seq.num_units)
    base_path = create_base_path("result", args)
    writer = tf.summary.FileWriter(base_path["tb_path"], tf.get_default_graph())
    logger.configure(dir=base_path["logger_path"])
    save_path = base_path["save_path"]
    if args.ppo_load_path is not None:
        agent.saver.restore(agent.sess, save_path)
    else:
        agent.sess.run(tf.global_variables_initializer())
    # os.makedirs(base_path + "log/", exist_ok=True)
    # log_file = os.path.join(base_path + "log/", "log.txt")
    # file = open(log_file, 'a')

    train_step = 0
    sample_cost = 0
    ep_cur = 0
    # Train
    while sample_cost < args.max_sample_num:
        state = env.random_reset()
        state = [e / env.n_height for e in state]
        encode_input_state = init_encode_input(seq2seq, state)
        # state = [e / env.n_height for e in state]
        ep_length, ep_total_reward, ep_env_reward, ep_q_reward = 0, 0, 0, 0
        while True:
            ep_length += 1
            sample_cost += 1
            encode_state = get_encode_state(encode_input_state, seq2seq)
            encode = list(encode_state[1])[0]
            # encode = get_encode(encode_input_state, seq2seq)
            action, prob, value = agent.sample_action(state, encode)
            next_state, reward, done, _ = env.random_step(action)
            ep_env_reward += reward
            # env.render()
            cross_entropy = posterior.predict([state], [action], [encode])
            # print(cross_entropy)
            ep_q_reward += args.add_reward * cross_entropy
            reward += args.add_reward * cross_entropy
            replay_buffer.store_data(cur_obs=state, cur_action=action, reward=reward,
                                     done=done, old_prob=prob, value=value, encode=encode)
            update_encode_input(state, action, encode_input_state)
            next_state = [e / env.n_height for e in next_state]
            state = next_state
            if replay_buffer.enough_data:
                observations, actions = replay_buffer.observations, replay_buffer.actions
                rewards, dones = replay_buffer.rewards, replay_buffer.dones,
                values, old_probs = replay_buffer.values, replay_buffer.old_probs
                encodes = replay_buffer.encodes
                replay_buffer.clear_data()
                encode_state = get_encode_state(encode_input_state, seq2seq)
                next_encode = list(encode_state[1])[0]
                returns = agent.compute_gae(next_obs=next_state, next_encode=next_encode, rewards=rewards,
                                            values=values, dones=dones)
                ppo_loss, cost_p, cost_v, cost_entropy = agent.update_model(observations=observations, actions=actions,
                                                                            returns=returns, values=values,
                                                                            old_probs=old_probs, encodes=encodes)
                posterior_entropy, posterior_loss, posterior_val_loss = posterior.update_model(
                    observations=observations, actions=actions, encodes=encodes)
                # print("--posterior_loss %.2f" % (posterior_loss))
                log_data = {
                    "ppo/cost_all": ppo_loss,
                    "ppo/cost_p": cost_p,
                    "ppo/cost_v": cost_v,
                    "ppo/cost_entropy": cost_entropy,
                    "posterior/posterior_loss": posterior_loss,
                }
                train_step += 1
                if train_step % args.train_log_interval == 0:
                    update_logger(log_data)
                    update_tb(writer, sample_cost, log_data)
                if train_step % args.save_step_count == 0:
                    if not save_path:
                        os.makedirs(save_path)
                    save_name = save_path + str(train_step) + "_" + str(sample_cost)
                    agent.saver.save(agent.sess, save_name)
                    print('Model saved %s' % save_name)
                    test_log_data = env_test(agent, seq2seq, None, args)
                    update_logger(test_log_data)
                    update_tb(writer, sample_cost, test_log_data)
            ep_total_reward += reward
            if done or ep_length >= config.solving_criteria[0]:
                ep_cur += 1
                break
        log_ep_data = {
            "train/ep_total_reward": ep_total_reward,
            "train/ep_env_reward": ep_env_reward,
            "train/ep_length": ep_length,
            "train/ep_index": ep_cur,
            "train/sample_cost": sample_cost,
            "train/ep_q_reward": ep_q_reward,
            "train/train_step": train_step,
        }
        update_logger(log_ep_data)
        logger.dump_tabular()
        if ep_cur % args.train_log_interval == 0 and ep_cur != 0:
            update_tb(writer, sample_cost, log_ep_data)
        # print("Episode: {}, Path length:{}, Reward:{},add_entropy:{}".format(ep_cur, ep_length, ep_reward, cross_entropy))
    writer.close()


def env_test(agent, seq2seq, discriminator, args, **kwargs):
    result = {}
    test_count = args.test_count
    test_env, test_config = get_env_config(args)
    success_num = 0
    sub_test_label = kwargs.get("test_label", None)
    style_num = kwargs.get("style_num", 1)
    # save_gif = args.save_gif
    save_gif = True if args.test else False
    save_gif_interval = args.save_gif_interval
    render = args.render
    base_path = kwargs.get("base_path", None)
    train_step = kwargs.get("train_step", 0)
    if save_gif_interval > test_count:
        save_gif_interval = test_count - 1
    if sub_test_label is not None:
        test_label = f"test_{sub_test_label}"
    else:
        test_label = f"test"
    for idx in range(test_count):
        # sample_index = random.randint(0, style_num - 1)
        for sample_index in range(style_num):
            if save_gif and (idx + 1) % save_gif_interval == 0 and base_path is not None:
                gif_label = True
            else:
                gif_label = False
            frames = []
            # state = test_env.reset(high=idx)
            state = test_env.reset()
            state = [e / test_env.n_height for e in state]
            if args.algo == "ppo_mod":
                encode_input_state = init_encode_input(seq2seq, state)
            elif args.algo == "info-gail":
                encode_zero = np.zeros((1, args.encode_dim), dtype=np.float32)
                encode_zero[0, random.randint(0, args.encode_dim - 1)] = 1
                encode_input_state = encode_zero
            else:
                encode_input_state = None
            ep_length, ep_reward, ep_reward_d = 0, 0, 0
            while True:
                if gif_label:
                    frames.append(test_env.render(mode="rgb_array"))
                #render=True
                if render:
                    test_env.render(mode="human")
                if args.algo == "ppo_mod":
                    encode_state = get_encode_state(encode_input_state, seq2seq)
                    encode = list(encode_state[1])[0]
                    action = agent.test_action(state, encode)
                    update_encode_input(state, action, encode_input_state)
                elif "gail" in args.algo:
                    obs = np.stack([state]).astype(dtype=np.float32)  # prepare to feed placeholder Policy.obs
                    act, _ = agent.act(obs=obs, encode=encode_input_state, stochastic=False)
                    action = act.item()
                    d_rewards = discriminator.get_rewards(agent_s=obs, agent_a=[action]).item()
                    ep_reward_d += d_rewards
                elif "stylebc" in args.algo:
                    obs = np.stack([state]).astype(dtype=np.float32)
                    if sub_test_label == "target":
                        action, _ = agent.get_target_action(obs, stochastic=False)
                    else:
                        # if sample_index==0:
                        #     softmax_style=[0.87,0.13]
                        # elif sample_index==1:
                        #     softmax_style = [0.88, 0.12]
                        # elif sample_index==2:
                        #     softmax_style = [0.89, 0.11]
                        # elif sample_index==3:
                        #     softmax_style = [0.90, 0.01]
                        # elif sample_index==4:
                        #     softmax_style = [0.91, 0.09]
                        # else:
                        #     softmax_style = [0.991, 0.009]
                        # action, _ = agent.act(obs, [softmax_style], stochastic=False)  #   test the styles
                        action, _ = agent.get_action(obs, [sample_index], stochastic=False)
                    action = action.item()
                else:
                    raise NotImplementedError(f"Unknown algo types {args.algo}.")
                next_state, reward, done, _ = test_env.step(action)
                next_state = [e / test_env.n_height for e in next_state]
                state = next_state
                ep_length += 1
                ep_reward += reward
                if done or ep_length >= test_config.solving_criteria[0]:
                    break
            if done:
                success_num += 1
            if gif_label:
                gif_result = "success" if done else "fail"
                gif_path = base_path[f"test_{gif_result}"]
                gif_name = f"{test_label}_{gif_result}_{train_step}_style-{sample_index}_test-{idx}.gif"
                save_frames_as_gif(frames, gif_path, gif_name)
            if gif_label or render:
                test_env.render(mode='human', close=True)
            result.setdefault(f"{test_label}/ep_length", []).append(ep_length)
            result.setdefault(f"{test_label}/ep_reward", []).append(ep_reward)
            result.setdefault(f"{test_label}/ep_style_index", []).append(sample_index)
            if "gail" in args.algo:
                result.setdefault(f"{test_label}/ep_reward_d", []).append(ep_reward_d)
    ret = {k: (sum(v) / test_count) for k, v in result.items()}
    ret.update({f"{test_label}/success_rate": success_num / test_count if test_count > 0 else 0})
    return ret


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)


def get_env_config(args):
    if args.env == 'two_way_grid_world':
        env_config = {
            "random_reset": args.random_reset,
            "random_step": args.random_step,
            "end": tuple(int(k) for k in args.end.split("_"))
        }
        env = grid_world.MultiWayGridWorld(max_step_count=200, **env_config)
        config = Config(env, 'two_way_grid_world')
    else:
        raise NotImplementedError(f"Unknown env name {args.env}.")
    return env, config


def init_encode_input(seq2seq, state):
    encode_input_state = Queue(maxsize=seq2seq.max_encode_length)
    init_x = state[0]
    init_y = state[1]
    final_x = state[2]
    final_y = state[3]
    for i in range(seq2seq.max_encode_length):
        encode_input_state.put([init_x, init_y, final_x, final_y, 0.0])
    return encode_input_state


def save_frames_as_gif(frames, path='./', filename='gym_animation.gif'):
    # Mess with this to change frame size
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)

    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    os.makedirs(path, exist_ok=True)
    anim.save(path + filename, writer='imagemagick', fps=60)


# def get_encode(encode_input_state, seq2seq):
#     input_list = list(encode_input_state.queue)
#     input_len = len(input_list)
#     feed = {"encode_input": [input_list],
#             "encode_sequence_length": [input_len]}
#     # encoder_h = [float(e) for e in list(seq2seq.predict(feed)[0])]
#     encoder_h = list(seq2seq.predict(feed)[0])
#     return encoder_h

def get_encode_state(encode_input_state, seq2seq):
    input_list = list(encode_input_state.queue)
    input_len = len(input_list)
    feed = {"encode_input": [input_list],
            "encode_sequence_length": [input_len]}
    encode_state = seq2seq.predict_encode_state(feed)
    return encode_state


def update_encode_input(state, action, encode_input_state):
    encode_input_tail = copy.deepcopy(state)
    encode_input_tail.append(action / 10)
    if encode_input_state.full():
        encode_input_state.get()
    encode_input_state.put(encode_input_tail)


def update_logger(log_data):
    for k, v in log_data.items():
        logger.logkv(k, v)


def update_tb(writer, step, log_data):
    for k, v in log_data.items():
        writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=v)]), step)


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--algo', type=str, default='ppo_mod', help='algo name')
    parser.add_argument('--seq2seq_model_dir', type=str, default='seq2seq/model/', help='algo name')
    parser.add_argument('--ppo_load_path', type=str, default=None, help='algo name')
    parser.add_argument('--env', type=str, default='two_way_grid_world', help='env')
    parser.add_argument('--gpu', type=int, default=-1, help='running on a specify gpu, -1 indicates using cpu')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--num_episodes', type=int, default=2e6, help='num_episodes')
    parser.add_argument('--max_sample_num', type=int, default=2e6, help='max steps interaction with the environment')
    # for algo
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--epsilon', type=float, default=0.2, help='epsilon')
    parser.add_argument('--var_beta', type=float, default=0.2, help='var_beta')
    parser.add_argument('--v_coef', type=float, default=1.0, help='v_coef')
    parser.add_argument('--gamma', type=float, default=0.99, help='ed factor')
    parser.add_argument('--lammbda', type=float, default=0.95, help='lammbda')
    parser.add_argument('--train_log_interval', type=int, default=10, help='save step')
    parser.add_argument('--save_step_count', type=int, default=10, help='save step')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_total_sizes', type=int, default=1024, help='num_total_sizes')
    parser.add_argument('--noptepochs', type=int, default=10, help='noptepochs')
    # for seq2seq
    parser.add_argument('--max_encode_length', type=int, default=4, help='max_encode_length')
    parser.add_argument("--max_decode_length", type=int, default=4, help="max_decode_length")
    parser.add_argument("--vocabulary_dim", type=int, default=5, help="vocabulary_dim")
    parser.add_argument("--num_layers", default=1, type=int, help="num_layers")
    parser.add_argument("--num_units", default=32, type=int, help="num_units")
    parser.add_argument('--add_reward', type=float, default=0.1, help='var_cross_entropy')
    parser.add_argument('--test_count', type=float, default=10, help='test episode for each saved model')
    parser.add_argument('--random_reset', type=bool, default=False, help='whether to random reset the env')
    parser.add_argument('--random_step', type=bool, default=False, help='whether to act random steps in env')
    return parser.parse_args()


if __name__ == "__main__":
    main(args_parse())
# args=args_parse()
