#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/12/3 16:57 
# @Author : wangzhaorong
# @Site :  
# @File : qnet.py 
# @Software: PyCharm
import tensorflow as tf
import tensorflow.contrib.layers as layers
import numpy as np


class Posterior(object):
    def __init__(self, args):
        self.state_dim = 4
        self.encode_dim = 128
        self.hidden = [256, 128, 128]
        self.lr = args.lr
        self.batch_size = args.batch_size
        self.noptepochs = args.noptepochs

    def construct_model(self, gpu=-1):
        self.sess, device = self.get_session(gpu)
        with tf.device(device):
            self.create_placeholder()
            self.build_network()
            self.compute_loss()
            self.build_saver()
            self.build_update_target_params_op()

    def get_session(self, gpu):
        if gpu == -1:
            device = '/cpu:0'
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
        else:  # use GPU
            device = '/gpu:' + str(gpu)
            sess_config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True,
                                         intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
            sess_config.gpu_options.allow_growth = True
            sess_config.gpu_options.visible_device_list = '0, 1'
        sess = tf.Session(graph=tf.get_default_graph(), config=sess_config)
        return sess, device

    def create_placeholder(self):
        self.input_state = tf.placeholder(tf.float32, [None, self.state_dim], name='input_state')
        self.target_encode = tf.placeholder(dtype=tf.float32, shape=[None, self.encode_dim], name='target_encode')
        self.taken_action = tf.placeholder(dtype=tf.float32, shape=[None, ], name='taken_action')
        self.input_state_action = tf.concat([self.input_state, tf.expand_dims(self.taken_action / 10, -1)], -1)

    def build_network(self):
        with tf.variable_scope('qsa'):
            self.qsa = self.qsa_network()
        with tf.variable_scope('qsa_target'):
            self.qsa_target = self.qsa_network()

    def qsa_network(self):
        fc0 = layers.fully_connected(inputs=self.input_state_action, num_outputs=self.hidden[0],
                                     activation_fn=tf.nn.leaky_relu,
                                     scope='q_hidden_0')
        fc1 = layers.fully_connected(inputs=fc0, num_outputs=self.hidden[1], activation_fn=tf.nn.leaky_relu,
                                     scope='q_hidden_1')
        fc2 = layers.fully_connected(inputs=fc1, num_outputs=self.hidden[2], activation_fn=tf.nn.leaky_relu,
                                     scope='q_hidden_2')
        qsa = layers.fully_connected(inputs=fc2, num_outputs=self.encode_dim, activation_fn=tf.nn.leaky_relu,
                                     scope='q_layer')
        return qsa

    def compute_loss(self):
        with tf.name_scope('loss'):
            self.qsa_prob = (tf.nn.softmax(self.qsa - tf.reduce_max(self.qsa, axis=-1, keepdims=True), axis=-1,
                                           name='qsa_prob') + 1e-10) / (1.0 + 1e-10 * self.encode_dim)
            self.entropy = tf.reduce_mean(-tf.reduce_sum(self.qsa_prob * tf.log(self.qsa_prob), axis=1))
            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=self.target_encode, logits=self.qsa))
            self.train_op = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
            self.qsa_target_prob = (tf.nn.softmax(self.qsa_target - tf.reduce_max(self.qsa_target, axis=-1, keepdims=True),
                                                  axis=-1, name='qsa_target_prob') + 1e-10) / (1.0 + 1e-10 * self.encode_dim)

    def build_saver(self):
        self.saver = tf.train.Saver()

    def build_update_target_params_op(self):
        qsa_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa')
        qsa_target_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa_target')
        self.target_replace_op = [tf.assign(t, 0.5 * t + 0.5 * q) for t, q in zip(qsa_target_params, qsa_params)]

    def sync_params_op(self):
        qsa_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa')
        qsa_target_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa_target')
        sync_params_op = [tf.assign(t, q) for t, q in zip(qsa_target_params, qsa_params)]
        self.sess.run(sync_params_op)

    def predict(self, states, actions, encodes):
        qsa_target_prob = self.sess.run(self.qsa_target_prob, feed_dict={self.input_state: states, self.taken_action: actions})
        cross_entropy = -np.average(np.sum(np.log(qsa_target_prob) * encodes, axis=1))
        return cross_entropy

    def update_model(self, observations, actions, encodes):
        num_total_sizes = observations.shape[0]
        train_val_ratio = 0.8
        num_train = int(num_total_sizes * train_val_ratio)

        observations_train = observations[:num_train]
        actions_train = actions[:num_train]
        encodes_train = encodes[:num_train]
        observations_val = observations[num_train:]
        actions_val = actions[num_train:]
        encodes_val = encodes[num_train:]

        batch_step = max(int(num_train / self.batch_size), 1)
        entropy_list, loss_list, val_loss_list = [], [], []
        inds = np.arange(num_train)

        for _ in range(self.noptepochs):
            np.random.shuffle(inds)
            start = 0
            for i in range(batch_step):
                end = start + self.batch_size
                ninds = inds[start:end]
                observationsl = observations_train[ninds]
                actionsl = actions_train[ninds]
                encodesl = encodes_train[ninds]
                entropy_cur, loss_cur, _ = self.sess.run([self.entropy, self.loss, self.train_op],
                                                         feed_dict={self.input_state: observationsl,
                                                                    self.target_encode: encodesl,
                                                                    self.taken_action: actionsl})
                self.sess.run(self.target_replace_op)
                cross_entropy = self.predict(observations_val, actions_val, encodes_val)
                entropy_list.append(entropy_cur)
                loss_list.append(loss_cur)
                val_loss_list.append(cross_entropy)
        entropy = np.mean(entropy_list)
        loss = np.mean(loss_list)
        val_loss = np.mean(val_loss_list)
        return entropy, loss, val_loss
