#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/11/30 20:53 
# @Author : wangzhaorong
# @Site :  
# @File : ppo.py 
# @Software: PyCharm

import tensorflow as tf
import tensorflow.contrib.layers as layers
import numpy as np


class Policy(object):
    def __init__(self, args):
        self.action_dim = 4
        self.state_dim = 4
        self.encode_dim = args.num_units
        self.hidden = [256, 128, 128]

        self.gamma = args.gamma
        self.lammbda = args.lammbda
        self.lr = args.lr
        self.epsilon = args.epsilon
        self.var_beta = args.var_beta
        self.v_coef = args.v_coef
        self.nsteps_per_train = 10
        self.batch_size = args.batch_size
        self.noptepochs = args.noptepochs

    def action_network(self, input_state, input_encode):
        state_encode = tf.concat([input_state, input_encode], -1)
        fc0 = layers.fully_connected(inputs=state_encode, num_outputs=self.hidden[0], activation_fn=tf.nn.leaky_relu,
                                     scope='action_hidden_0')
        fc1 = layers.fully_connected(inputs=fc0, num_outputs=self.hidden[1], activation_fn=tf.nn.leaky_relu,
                                     scope='action_hidden_1')
        fc2 = layers.fully_connected(inputs=fc1, num_outputs=self.hidden[2], activation_fn=tf.nn.leaky_relu,
                                     scope='action_hidden_2')
        action = layers.fully_connected(inputs=fc2, num_outputs=self.action_dim, activation_fn=tf.nn.leaky_relu,
                                        scope='action_layer')
        return action

    def value_network(self, input_state, input_encode):
        state_encode = tf.concat([input_state, input_encode], 1)
        fc0 = layers.fully_connected(inputs=state_encode, num_outputs=self.hidden[0], activation_fn=tf.nn.leaky_relu,
                                     scope='value_hidden_0')
        fc1 = layers.fully_connected(inputs=fc0, num_outputs=self.hidden[1], activation_fn=tf.nn.leaky_relu,
                                     scope='value_hidden_1')
        fc2 = layers.fully_connected(inputs=fc1, num_outputs=self.hidden[2], activation_fn=tf.nn.leaky_relu,
                                     scope='value_hidden_2')
        value = layers.fully_connected(inputs=fc2, num_outputs=1, activation_fn=None, scope='value_layer')
        return value

    def construct_model(self, gpu=-1):
        self.sess, device = self.get_session(gpu)
        with tf.device(device):
            self.create_placeholder()
            self.build_network()
            self.compute_loss()
            self.build_saver()

    def get_session(self, gpu):
        if gpu == -1:
            device = '/cpu:0'
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
        else:  # use GPU
            device = '/gpu:' + str(gpu)
            sess_config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True,
                                         intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
            sess_config.gpu_options.allow_growth = True
            sess_config.gpu_options.visible_device_list = '0, 1'
        sess = tf.Session(graph=tf.get_default_graph(), config=sess_config)
        return sess, device

    def create_placeholder(self):
        self.input_state = tf.placeholder(tf.float32, [None, self.state_dim], name='input_state')
        self.encodes = tf.placeholder(tf.float32, [None, self.encode_dim], name='input_encode')
        self.taken_action = tf.placeholder(dtype=tf.int32, shape=[None, ], name='taken_action')
        self.adv = tf.placeholder(dtype=tf.float32, shape=[None, ], name='adv')
        self.y_r = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y_r')
        self.old_prob = tf.placeholder(tf.float32, [None, self.action_dim], name='old_prob')

    def build_network(self):
        with tf.variable_scope('action'):
            self.action = self.action_network(self.input_state, self.encodes)
        with tf.variable_scope('value'):
            self.value = self.value_network(self.input_state, self.encodes)

    def compute_loss(self):
        with tf.name_scope('loss'):
            self.prob = (tf.nn.softmax(self.action - tf.reduce_max(self.action, axis=1, keepdims=True), axis=1,
                                       name='prob') + 1e-10) / (1.0 + 1e-10 * self.action_dim)
            log_p = tf.log(
                tf.reduce_sum(self.prob * layers.one_hot_encoding(self.taken_action, self.action_dim), axis=1))
            old_log_p = tf.log(
                tf.reduce_sum(self.old_prob * layers.one_hot_encoding(self.taken_action, self.action_dim), axis=1))
            ratio = tf.exp(log_p - old_log_p)
            cost_policy = tf.reduce_mean(tf.minimum(ratio * self.adv, tf.clip_by_value(ratio, 1 - self.epsilon,
                                                                                       1 + self.epsilon) * self.adv))
            cost_entropy = tf.reduce_mean(-tf.reduce_sum(self.prob * tf.log(self.prob), axis=1))
            cost_p = -(cost_policy + self.var_beta * cost_entropy)
            cost_v = self.v_coef * tf.reduce_mean(tf.square(self.value - self.y_r))
            self.loss = cost_p + cost_v
            self.cost_p = cost_p
            self.cost_v = cost_v
            self.cost_entropy = cost_entropy
            self.train_op = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)

    def build_saver(self):
        self.saver = tf.train.Saver()

    def sample_action(self, state, encode):
        prob, value = self.sess.run([self.prob, self.value], feed_dict={self.input_state: [state],
                                                                        self.encodes: [encode]})
        prediction = prob / np.sum(prob)
        action = np.random.choice(list(range(self.action_dim)), p=prediction[0])
        return action, prob[0], value[0][0]

    def test_action(self, state, encode):
        prob = self.sess.run(self.prob, feed_dict={self.input_state: [state],
                                                   self.encodes: [encode]})
        action = np.argmax(prob)
        return action

    def update_model(self, observations, actions, returns, values, old_probs, encodes):
        num_total_sizes = observations.shape[0]
        batch_step = max(int(num_total_sizes / self.batch_size), 1)
        inds = np.arange(num_total_sizes)
        loss_list, cost_p_list, cost_v_list, cost_entropy_list = [], [], [], []
        for _ in range(self.noptepochs):
            np.random.shuffle(inds)
            start = 0
            for i in range(batch_step):
                end = start + self.batch_size
                ninds = inds[start:end]
                observationsl = observations[ninds]
                encodesl = encodes[ninds]
                actionsl = actions[ninds]
                returnsl = returns[ninds]
                old_probsl = old_probs[ninds]
                advs = returns[ninds] - values[ninds]
                advs = (advs - advs.mean()) / (advs.std() + 1e-5)
                loss_cur, cost_p, cost_v, cost_entropy, _ = self.sess.run([self.loss, self.cost_p, self.cost_v,
                                                                           self.cost_entropy, self.train_op],
                                                                          feed_dict={self.input_state: observationsl,
                                                                                     self.encodes: encodesl,
                                                                                     self.taken_action: actionsl,
                                                                                     self.adv: advs,
                                                                                     self.old_prob: old_probsl,

                                                                                     self.y_r: returnsl[:, np.newaxis]})
                loss_list.append(loss_cur)
                cost_p_list.append(cost_p)
                cost_v_list.append(cost_v)
                cost_entropy_list.append(cost_entropy)
        loss = np.mean(loss_list)
        cost_p = np.mean(cost_p_list)
        cost_v = np.mean(cost_v_list)
        cost_entropy = np.mean(cost_entropy_list)
        return loss, cost_p, cost_v, cost_entropy

    # def get_summary(self, observations, actions, returns, values, old_probs, encodes):
    #     advs = returns - values
    #     advs = (advs - advs.mean()) / (advs.std() + 1e-5)
    #     return self.sess.run(self.merged, feed_dict={self.input_state: observations,
    #                                                  self.encodes: encodes,
    #                                                  self.taken_action: actions,
    #                                                  self.adv: advs,
    #                                                  self.old_prob: old_probs,
    #                                                  self.y_r: returns[:, np.newaxis]})

    def compute_gae(self, next_obs, next_encode, rewards, values, dones):
        next_value = self.sess.run(self.value, feed_dict={self.input_state: [next_obs],
                                                          self.encodes: [next_encode]})[0][0]
        gae = 0
        returns = np.zeros_like(values)
        for step in reversed(range(rewards.shape[0])):
            td_delta = rewards[step] + self.gamma * (1. - dones[step]) * next_value - values[step]
            next_value = values[step]
            gae = self.gamma * self.lammbda * (1. - dones[step]) * gae + td_delta
            returns[step] = gae + values[step]
        return returns


class ReplayBuffer(object):
    def __init__(self, num_total_sizes, act_dims, obs_dims, encode_dims):
        self.num_total_sizes = num_total_sizes
        self.observations, self.actions, self.rewards = np.zeros([num_total_sizes, obs_dims]), \
                                                        np.zeros([num_total_sizes]), np.zeros([num_total_sizes])
        self.old_probs, self.values, self.dones = np.zeros([num_total_sizes, act_dims]), np.zeros([num_total_sizes]), \
                                                  np.zeros([num_total_sizes])
        self.encodes = np.zeros([num_total_sizes, encode_dims])
        self.cur_index = 0

    def store_data(self, cur_obs, cur_action, reward, done, old_prob=None, value=None, encode=None):
        """
        cur_obs:                   numpy.array                (obs_dims, )
        cur_action:                numpy.array                (act_dims, )
        reward:                   numpy.array                 (1,        )
        done:                     numpy.array                 (1,        )
        old_log_prob:             numpy.array                 (1,        )
        value:                    numpy.array                 (1,        )
        """
        self.observations[self.cur_index] = cur_obs
        self.actions[self.cur_index] = cur_action
        self.rewards[self.cur_index] = reward
        self.old_probs[self.cur_index] = old_prob
        self.dones[self.cur_index] = done
        self.values[self.cur_index] = value
        self.encodes[self.cur_index] = encode
        self.cur_index += 1

    def clear_data(self):
        self.cur_index = 0

    @property
    def enough_data(self):
        return self.cur_index == self.num_total_sizes
