from seq2seq import Seq2seq
import argparse
import tensorflow as tf
import common.grid_world as grid_world
from common.config import Config
from queue import Queue
import copy
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def init_encode_input(seq2seq):
    encode_input_state = Queue(maxsize=seq2seq.max_encode_length)
    for i in range(seq2seq.max_encode_length):
        encode_input_state.put([0., 0., 0.9, 0.9, 0.0])
    return encode_input_state

def update_encode_input(state, action, encode_input_state):
    encode_input_tail = copy.deepcopy(state)
    encode_input_tail.append(action / 10)
    if encode_input_state.full():
        encode_input_state.get()
    encode_input_state.put(encode_input_tail)


def get_encode_state(encode_input_state, seq2seq):
    input_list = list(encode_input_state.queue)
    input_len = len(input_list)
    feed = {"encode_input": [input_list],
            "encode_sequence_length": [input_len]}
    encode_state = seq2seq.predict_encode_state(feed)
    return encode_state


def get_env_config(args):
    if args == 'two_way_grid_world':
        env = grid_world.TwoWayGridWorld()
        config = Config(env, 'two_way_grid_world')
    return env, config

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--algo', type=str, default='ppo', help='algo name')
    parser.add_argument('--seq2seq_model_dir', type=str, default='seq2seq/model/', help='algo name')
    parser.add_argument('--ppo_load_path', type=str, default=None, help='algo name')
    parser.add_argument('--env', type=str, default='two_way_grid_world', help='env')
    parser.add_argument('--gpu', type=int, default=-1, help='running on a specify gpu, -1 indicates using cpu')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--num_episodes', type=int, default=100000, help='num_episodes')
    # for algo
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--epsilon', type=float, default=0.2, help='epsilon')
    parser.add_argument('--var_beta', type=float, default=0.2, help='var_beta')
    parser.add_argument('--v_coef', type=float, default=1.0, help='v_coef')
    parser.add_argument('--gamma', type=float, default=0.99, help='ed factor')
    parser.add_argument('--lammbda', type=float, default=0.95, help='lammbda')
    parser.add_argument('--save_step_count', type=int, default=50, help='save step')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_total_sizes', type=int, default=1024, help='num_total_sizes')
    parser.add_argument('--noptepochs', type=int, default=10, help='noptepochs')
    # for seq2seq
    parser.add_argument('--max_encode_length', type=int, default=4, help='max_encode_length')
    parser.add_argument("--max_decode_length", type=int, default=4, help="max_decode_length")
    parser.add_argument("--vocabulary_dim", type=int, default=5, help="vocabulary_dim")
    parser.add_argument("--num_layers", default=1, type=int, help="num_layers")
    parser.add_argument("--num_units", default=128, type=int, help="num_units")
    args = parser.parse_args()

    #load seq2seq model
    seq2seq = Seq2seq(args.max_encode_length, args.max_decode_length,
                      args.vocabulary_dim, args.num_layers, args.num_units, mode='train')
    ckpt = tf.train.get_checkpoint_state(args.seq2seq_model_dir)
    assert tf.train.checkpoint_exists(ckpt.model_checkpoint_path)
    print('Reloading seq2seq model parameters..')
    seq2seq.saver.restore(seq2seq.sess, ckpt.model_checkpoint_path)

    env, config = get_env_config(args.env)

    trajectories = []
    act_list = []
    file = open('grid_word_demo.txt','r').readlines()
    #print (file)
    act_ind = 1
    for ind in range(len(file)):
        if ind == act_ind:
            act = file[ind].strip().split(':')[1]
            act_list.append(int(act))
            if file[ind+1].strip().split(':')[1] == 'True':
                trajectories.append(act_list)
                act_list = []
            act_ind += 3

    embeds_x = []
    embeds_y = []
    count = 0
    for tra in trajectories:
        state = env.reset()
        encode_input_state = init_encode_input(seq2seq)
        state = [e / env.n_height for e in state]
        if count < 3: ### first 3 expert trajectories. lower gate
            label = 0
        else:
            label = 1
        count += 1
        for act in tra:
            encode_state = get_encode_state(encode_input_state, seq2seq)
            encode = list(encode_state[1])[0]
            embeds_x.append(encode)
            embeds_y.append(label)
            next_state, reward, done, _ = env.step(act)
            #env.render()
            update_encode_input(state, act, encode_input_state)
            next_state = [e / env.n_height for e in next_state]
            state = next_state

    print (embeds_y)
    pca=PCA(n_components=2)
    reduced_x=pca.fit_transform(embeds_x)
    print (len(reduced_x))

    red_x, red_y = [], []
    blue_x, blue_y = [], []

    for i in range(len(reduced_x)):
        if embeds_y[i] == 0:
            red_x.append(reduced_x[i][0])
            red_y.append(reduced_x[i][1])

        else:
            blue_x.append(reduced_x[i][0])
            blue_y.append(reduced_x[i][1])

    plt.scatter(red_x, red_y, c='r', marker='x', label = 'lower_gate')
    plt.scatter(blue_x, blue_y, c='b', marker='D', label = 'upper_gate')
    plt.legend()
    plt.show()
