#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/11/26 16:25 
# @Author : wangzhaorong
# @Site :  
# @File : seq2seq.py 
# @Software: PyCharm

import tensorflow as tf
import numpy as np
import argparse
import os
import ast
import copy
from tensorflow.python.ops.rnn_cell import LSTMStateTuple


class Seq2seq(object):
    def __init__(self, max_encode_length, max_decode_length, vocabulary_dim, num_layers, num_units, mode='train'):
        # input
        self.mode = mode
        self.max_encode_length = max_encode_length
        self.max_decode_length = max_decode_length
        self.vocabulary_dim = vocabulary_dim
        self.start_char = [-1., -1., -1., -1., -1.]

        # self.num_sampled_softmax = 0

        # network
        self.num_layers = num_layers
        self.dtype = tf.float32
        self.num_units = num_units
        self.max_gradient_norm = 5.0
        self.learing_rate = 0.0001

        # build graph
        self.sess, device = self.get_session()
        with tf.device(device):
            self.get_encoder_input()
            self.build_graph()
            self.build_saver()
            self.build_summary()

    def get_session(self, gpu=-1):
        if gpu == -1:
            device = '/cpu:0'
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
        else:  # use GPU
            device = '/gpu:' + str(gpu)
            sess_config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True,
                                         intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
            sess_config.gpu_options.allow_growth = True
            sess_config.gpu_options.visible_device_list = '0, 1'
        sess = tf.Session(graph=tf.get_default_graph(), config=sess_config)
        return sess, device

    def get_encoder_input(self):
        self.encoder_input = tf.placeholder(tf.float32, [None, self.max_encode_length, self.vocabulary_dim], name='encoder_input')
        self.decoder_input = tf.placeholder(tf.float32, [None, self.max_decode_length, self.vocabulary_dim], name='decoder_input')
        self.encode_sequence_length = tf.placeholder(tf.int32, [None], name='encode_sequence_length')
        self.decode_sequence_length = tf.placeholder(tf.int32, [None], name='decode_sequence_length')
        self.batch_size = tf.placeholder(tf.int32, name='batch_size')
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

        self.infer_c = tf.placeholder(tf.float32, [None, self.num_units], name='infer_c')
        self.infer_h = tf.placeholder(tf.float32, [None, self.num_units], name='infer_h')
        self.infer_decoder_input = tf.placeholder(tf.float32, [self.vocabulary_dim], name='infer_decoder_input')

    def build_graph(self):
        with tf.variable_scope("dynamic_seq2seq", dtype=self.dtype):
            self.encoder_outputs, self.encoder_state = self.build_encoder()
            self.decoder_logits, self.decoder_cell_outputs, self.final_context_state = self.build_decoder()
            self.loss, self.train_op = self.compute_loss()

    def build_encoder(self):
        with tf.variable_scope("encoder") as scope:
            cell_list = []
            for i in range(self.num_layers):
                cell_layer = tf.contrib.rnn.BasicLSTMCell(self.num_units)
                cell_layer = tf.contrib.rnn.DropoutWrapper(cell=cell_layer, input_keep_prob=self.keep_prob)
                cell_list.append(cell_layer)
            if len(cell_list) == 1:
                cell = cell_list[0]
            else:
                cell = tf.contrib.rnn.MultiRNNCell(cell_list)
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell, self.encoder_input, dtype=tf.float32, sequence_length=self.encode_sequence_length)
            # add softmax to encoder_state
            # encoder_state_list = list(encoder_state)
            # # encoder_state_list[0] = tf.nn.softmax(encoder_state_list[0], axis=-1, name='c_prob')
            # encoder_state_list[1] = tf.nn.softmax(encoder_state_list[1], axis=-1, name='h_prob')
            # encoder_state = LSTMStateTuple(encoder_state_list[0], encoder_state_list[1])
        return encoder_outputs, encoder_state

    def build_decoder(self):
        with tf.variable_scope("decoder") as scope:
            cell_list = []
            for i in range(self.num_layers):
                cell_layer = tf.contrib.rnn.BasicLSTMCell(self.num_units)
                cell_layer = tf.contrib.rnn.DropoutWrapper(cell=cell_layer, input_keep_prob=self.keep_prob)
                cell_list.append(cell_layer)
            if len(cell_list) == 1:
                cell = cell_list[0]
            else:
                cell = tf.contrib.rnn.MultiRNNCell(cell_list)
            if self.mode == 'train':
                decoder_initial_state = self.encoder_state
                start_tensor = tf.expand_dims(tf.expand_dims(tf.constant(self.start_char, dtype=tf.float32), axis=0), axis=0)
                start_tensor = tf.tile(start_tensor, [self.batch_size, 1, 1])
                concat_decoder_input = tf.concat([start_tensor, self.decoder_input], 1)

                helper = tf.contrib.seq2seq.TrainingHelper(concat_decoder_input, self.decode_sequence_length, name='training_helper')
                my_decoder = tf.contrib.seq2seq.BasicDecoder(cell=cell, helper=helper, initial_state=decoder_initial_state)
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=scope)
                # sample_id = outputs.sample_id
                decoder_cell_outputs = outputs.rnn_output
                self.output_layer = tf.layers.Dense(self.vocabulary_dim, use_bias=False, name="output_projection")
                decoder_logits = self.output_layer(outputs.rnn_output)
                return decoder_logits, decoder_cell_outputs, final_context_state
            elif self.mode == 'infer':
                decoder_initial_state = LSTMStateTuple(self.infer_c, self.infer_h)
                infer_encoder_input_tensor = tf.expand_dims(tf.expand_dims(self.infer_decoder_input, axis=0), axis=0)
                helper = tf.contrib.seq2seq.TrainingHelper(infer_encoder_input_tensor, self.decode_sequence_length,
                                                           name='infer_helper')
                my_decoder = tf.contrib.seq2seq.BasicDecoder(cell=cell, helper=helper,
                                                             initial_state=decoder_initial_state)
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=scope)
                # sample_id = outputs.sample_id
                decoder_cell_outputs = outputs.rnn_output
                self.output_layer = tf.layers.Dense(self.vocabulary_dim, use_bias=False, name="output_projection")
                decoder_logits = self.output_layer(outputs.rnn_output)
                return decoder_logits, decoder_cell_outputs, final_context_state

            # elif self.mode == 'infer':
                # def initialize_fn():
                #     finished = tf.tile([False], [self.batch_size])
                #     start_tensor = tf.expand_dims(tf.expand_dims(tf.constant(self.start_char, dtype=tf.float32), axis=0), axis=0)
                #     # start_tensor = tf.tile(start_tensor, [self.batch_size, 1, 1])
                #     return (finished, start_tensor)
                #
                # def sample_fn(time, outputs, state):
                #     return tf.constant([0])
                #
                # def next_inputs_fn(time, outputs, state, sample_ids):
                #     finished = time >= self.decode_sequence_length
                #     next_inputs = outputs
                #     return (finished, next_inputs, state)

                # helper = tf.contrib.seq2seq.CustomHelper(initialize_fn=initialize_fn, sample_fn=sample_fn,
                #                                          next_inputs_fn=next_inputs_fn)
                # my_decoder = tf.contrib.seq2seq.BasicDecoder(cell=cell, helper=helper, initial_state=decoder_initial_state)
                #
                # outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=scope)
                # # sample_id = outputs.sample_id
                # decoder_cell_outputs = outputs.rnn_output
                # self.output_layer = tf.layers.Dense(self.vocabulary_dim, use_bias=False, name="output_projection")
                # decoder_logits = self.output_layer(outputs.rnn_output)
                # return decoder_logits, decoder_cell_outputs

    def build_saver(self):
        self.saver = tf.train.Saver()

    def save(self, save_path, step):
        if not os.path.isdir(save_path):
            os.makedirs(save_path)
        save_path = save_path + str(step)
        self.saver.save(self.sess, save_path)
        print('Model saved %s' % save_path)

    def load(self, save_path, step):
        save_path = save_path + str(step)
        self.saver.restore(self.sess, save_path)
        print('Model load %s' % save_path)

    def build_summary(self):
        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()

    def compute_loss(self):
        loss = tf.contrib.losses.mean_squared_error(predictions=self.decoder_logits, labels=self.decoder_input)
        trainable_params = tf.trainable_variables()
        gradients = tf.gradients(loss, trainable_params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm)
        # Optimization
        optimizer = tf.train.AdamOptimizer(self.learing_rate)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, trainable_params))
        return loss, train_op

    def create_feed_dict(self, mode, feed):
        if mode == "predict_encode_state":
            feed_dict = {self.encoder_input: feed["encode_input"],
                         self.encode_sequence_length: feed["encode_sequence_length"],
                         self.keep_prob: 1.0,
                         }
        elif mode == "predict_decode_out":
            feed_dict = {self.infer_decoder_input: feed["infer_encoder_input_tensor"],
                         self.infer_c: feed["infer_c"],
                         self.infer_h: feed["infer_h"],
                         self.decode_sequence_length: [1],
                         self.keep_prob: 1.0,
                         }
        else:
            batch_size = len(feed["encode_input"])
            feed_dict = {self.encoder_input: feed["encode_input"],
                         self.decoder_input: feed["decode_input"],
                         self.encode_sequence_length: feed["encode_sequence_length"],
                         self.decode_sequence_length: feed["decode_sequence_length"],
                         self.batch_size: batch_size}
        if mode == "train":
            feed_dict[self.keep_prob] = 0.5
        elif mode == "eval":
            feed_dict[self.keep_prob] = 1.0
        return feed_dict

    def train(self, feed):
        feed_dict = self.create_feed_dict(mode="train", feed=feed)
        _, loss, summary = self.sess.run([self.train_op, self.loss, self.summary_op], feed_dict=feed_dict)
        return loss, summary

    def eval(self, feed):
        feed_dict = self.create_feed_dict(mode="eval", feed=feed)
        loss, summary = self.sess.run([self.loss, self.summary_op], feed_dict=feed_dict)
        return loss, summary

    def predict_encode_state(self, feed):
        feed_dict = self.create_feed_dict(mode="predict_encode_state", feed=feed)
        encode_state = self.sess.run(self.encoder_state, feed_dict=feed_dict)
        return encode_state

    def predict_decode_out(self, feed):
        feed_dict = self.create_feed_dict(mode="predict_decode_out", feed=feed)
        decode_out, decode_state_tuple = self.sess.run([self.decoder_logits, self.final_context_state], feed_dict=feed_dict)
        return decode_out, decode_state_tuple

    #
    # def predict(self, feed):
    #     feed_dict = self.create_feed_dict(mode="predict", feed=feed)
    #     encoder_state = self.sess.run(self.encoder_state, feed_dict=feed_dict)
    #     decoder_state = self.sess.run(self.final_context_state, feed_dict=feed_dict)
    #     return encoder_state, decoder_state


def get_data(trajectory_file):
    tra_list, tra = [], []
    f = open(trajectory_file, "r")
    for line in f:
        if line.startswith("state"):
            s = ast.literal_eval(line.split(":")[1])
        elif line.startswith("action"):
            a = int(line.split(":")[1])
        elif line.startswith("done"):
            tra.append([s, a])
            if "True" in line:
                tra_list.append(tra)
                tra = []
    return tra_list


def get_train_data(tra_list, split_length):
    train_data = []
    for tra in tra_list:
        start_idx = 0
        while start_idx <= len(tra) - split_length:
            frag = tra[start_idx: start_idx + split_length]
            frag_seq = []
            for item in frag:
                s_a = copy.deepcopy(item[0])
                s_a.append(item[1] / 10)
                frag_seq.append(s_a)
            train_data.append(frag_seq)
            start_idx += 1
    return train_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--trajectory_file", default="./random_demo_0701.txt", type=str, help="Path for source data")  #10条轨迹
    parser.add_argument("--model_dir", default='seq2seq/model_64/', type=str, help="Path to save model checkpoints")
    parser.add_argument("--summary_dir", default='seq2seq/summary/', type=str, help="mode name")
    parser.add_argument("--steps_per_checkpoint", default=100, type=int, help="steps_per_checkpoint")
    # configurations for model
    parser.add_argument("--num_epochs", default=1001, type=int, help="Maximum # of training epochs")
    parser.add_argument("--batch_size", default=32, type=int, help="batch_size")
    parser.add_argument("--max_sequence_length", default=8, type=int, help="max_encode_length")
    parser.add_argument("--max_encode_length", default=4, type=int, help="max_encode_length")
    parser.add_argument("--max_decode_length", default=4, type=int, help="max_decode_length")
    parser.add_argument("--vocabulary_dim", default=5, type=int, help="vocabulary_dim")
    parser.add_argument("--num_layers", default=1, type=int, help="num_layers")
    parser.add_argument("--num_units", default=64, type=int, help="num_units")

    args = parser.parse_args()

    tra_list = get_data(args.trajectory_file)
    train_data = get_train_data(tra_list, args.max_sequence_length)
    idx = [i for i in range(len(train_data))]
    np.random.shuffle(idx)

    model = Seq2seq(args.max_encode_length, args.max_decode_length,
                    args.vocabulary_dim, args.num_layers, args.num_units, mode='train')
    train_writer = tf.summary.FileWriter(args.summary_dir + 'train', graph=model.sess.graph)
    ckpt = tf.train.get_checkpoint_state(args.model_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print('Reloading model parameters..')
        model.saver.restore(model.sess, ckpt.model_checkpoint_path)
    else:
        print('Created new model parameters..')
        model.sess.run(tf.global_variables_initializer())
    # train_current_step = 0
    # dev_current_step = 0
    start = 0
    for i in range(args.num_epochs):
        if start + args.batch_size >= len(idx):
            start = 0
            np.random.shuffle(idx)
        end = start + args.batch_size
        ninds = idx[start: end]
        start = end
        batch_data = np.array(train_data)[np.array(ninds)]
        encode_input = batch_data[:, : model.max_encode_length, :]
        decode_input = batch_data[:, model.max_encode_length:, :]
        encode_sequence_length = [model.max_encode_length] * args.batch_size
        decode_sequence_length = [model.max_decode_length] * args.batch_size
        # print ("encode:", encode_input)
        # print ("decode:", decode_input)
        feed = {"encode_input": encode_input, "decode_input": decode_input,
                "encode_sequence_length": encode_sequence_length, "decode_sequence_length": decode_sequence_length}
        train_loss, train_summary = model.train(feed)
        train_writer.add_summary(train_summary, i)
        print("-- train step %d --train_loss %.2f" % (i, train_loss))
        if i % args.steps_per_checkpoint == 0:
           model.save(save_path=args.model_dir, step=i)










