#!/usr/bin/env python 
# -*- coding: utf-8 -*- 
# @Time : 2020/12/3 16:57 
# @Author : wangzhaorong
# @Site :  
# @File : qnet.py 
# @Software: PyCharm
import tensorflow as tf
import tensorflow.contrib.layers as layers
import numpy as np

TINY = 1e-6


class Posterior(object):
    def __init__(self, args):
        self.state_dim = 4
        self.encode_dim = args.num_units
        self.hidden = [256, 128]
        self.lr = args.lr
        self.batch_size = args.batch_size
        self.noptepochs = args.noptepochs

    def construct_model(self, gpu=-1):
        self.sess, device = self.get_session(gpu)
        with tf.device(device):
            self.create_placeholder()
            self.build_network()
            self.compute_loss()
            self.build_saver()
            self.build_update_target_params_op()

    def get_session(self, gpu):
        if gpu == -1:
            device = '/cpu:0'
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
        else:  # use GPU
            device = '/gpu:' + str(gpu)
            sess_config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True,
                                         intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
            sess_config.gpu_options.allow_growth = True
            sess_config.gpu_options.visible_device_list = '0, 1'
        sess = tf.Session(graph=tf.get_default_graph(), config=sess_config)
        return sess, device

    def create_placeholder(self):
        self.input_state = tf.placeholder(tf.float32, [None, self.state_dim], name='input_state')
        self.target_encode = tf.placeholder(dtype=tf.float32, shape=[None, self.encode_dim], name='target_encode')
        self.taken_action = tf.placeholder(dtype=tf.float32, shape=[None, ], name='taken_action')
        self.input_state_action = tf.concat([self.input_state, tf.expand_dims(self.taken_action / 10, -1)], -1)

    def build_network(self):
        with tf.variable_scope('qsa'):
            self.mu, self.log_sigma = self.qsa_network()
        with tf.variable_scope('qsa_target'):
            # self.qsa_target = self.qsa_network()
            self.mu_target, self.log_sigma_target = self.qsa_network()

    def qsa_network(self):
        fc0 = layers.fully_connected(inputs=self.input_state_action, num_outputs=self.hidden[0],
                                     activation_fn=tf.nn.leaky_relu,
                                     scope='q_hidden_0')
        fc1 = layers.fully_connected(inputs=fc0, num_outputs=self.hidden[1], activation_fn=tf.nn.leaky_relu,
                                     scope='q_hidden_1')
        # fc2 = layers.fully_connected(inputs=fc1, num_outputs=self.hidden[2], activation_fn=tf.nn.leaky_relu,
        #                             scope='q_hidden_2')
        # qsa = layers.fully_connected(inputs=fc1, num_outputs=self.encode_dim, activation_fn=tf.nn.leaky_relu,
        #                              scope='q_layer')
        mu = layers.fully_connected(inputs=fc1, num_outputs=self.encode_dim, activation_fn=tf.nn.tanh,
                                    scope='q_mu')
        log_sigma = layers.fully_connected(inputs=fc1, num_outputs=self.encode_dim, activation_fn=tf.identity,
                                           scope='q_std')
        # return qsa
        return mu, log_sigma

    def compute_loss(self):  #################
        with tf.name_scope('loss'):
            # self.qsa_prob = (tf.distributions.Normal(loc=self.mu, scale=self.sigma).prob(0)+1e-10) / (1.0 + 1e-10 * self.encode_dim) #
            # self.entropy = tf.reduce_mean(-tf.reduce_sum(self.qsa_prob * tf.log(self.qsa_prob), axis=1))
            # cal the loss
            epsilon = (self.target_encode - self.mu) / (tf.exp(self.log_sigma) + TINY)

            # self.loss = tf.reduce_mean(
            #     tf.nn.softmax_cross_entropy_with_logits(labels=self.target_encode, logits=self.qsa))

            # self.loss = tf.reduce_sum(
            #     self.log_sigma + 0.5 * tf.square(epsilon),
            #     reduction_indices=1,
            # )
            self.loss = tf.reduce_mean(self.log_sigma + 0.5 * tf.square(epsilon))
            self.train_op = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
            # self.qsa_target_prob = (tf.distributions.Normal(loc=self.mu_target, scale=self.sigma_target).prob(0)+1e-10) / (1.0 + 1e-10 * self.encode_dim)  # 该函数定义了一个正态分布

    def build_saver(self):
        self.saver = tf.train.Saver()

    def build_update_target_params_op(self):
        qsa_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa')
        qsa_target_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa_target')
        self.target_replace_op = [tf.assign(t, 0.5 * t + 0.5 * q) for t, q in zip(qsa_target_params, qsa_params)]

    def sync_params_op(self):
        qsa_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa')
        qsa_target_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='qsa_target')
        sync_params_op = [tf.assign(t, q) for t, q in zip(qsa_target_params, qsa_params)]
        self.sess.run(sync_params_op)

    def predict(self, states, actions, encodes):  ##########
        mu_target, log_sigma_target = self.sess.run([self.mu_target, self.log_sigma_target], feed_dict={self.input_state: states, self.taken_action: actions})
        # epsilon = (encodes - mu_target) / (tf.exp(log_sigma_target) + TINY)
        epsilon = (encodes - mu_target) / (np.exp(log_sigma_target) + TINY)
        info_reward = -np.mean(log_sigma_target + 0.5 * np.square(epsilon), axis=1)
        # info_reward = -tf.reduce_mean(
        #     log_sigma_target + 0.5 * tf.square(epsilon),
        #     reduction_indices=1,
        # )
        # info_reward = -log_sigma_target - 0.5 * tf.square(epsilon)

        return info_reward

    def update_model(self, observations, actions, encodes):
        num_total_sizes = observations.shape[0]
        train_val_ratio = 0.8
        num_train = int(num_total_sizes * train_val_ratio)

        observations_train = observations[:num_train]
        actions_train = actions[:num_train]
        encodes_train = encodes[:num_train]
        observations_val = observations[num_train:]
        actions_val = actions[num_train:]
        encodes_val = encodes[num_train:]

        batch_step = max(int(num_train / self.batch_size), 1)
        entropy_list, loss_list, val_loss_list = [], [], []
        inds = np.arange(num_train)

        for _ in range(self.noptepochs):
            np.random.shuffle(inds)
            start = 0
            for i in range(batch_step):
                end = start + self.batch_size
                ninds = inds[start:end]
                observationsl = observations_train[ninds]
                actionsl = actions_train[ninds]
                encodesl = encodes_train[ninds]
                entropy_cur, loss_cur, _ = self.sess.run([self.loss, self.loss, self.train_op],
                                                         feed_dict={self.input_state: observationsl,
                                                                    self.target_encode: encodesl,
                                                                    self.taken_action: actionsl})
                self.sess.run(self.target_replace_op)
                cross_entropy = -np.average(self.predict(observations_val, actions_val, encodes_val))
                entropy_list.append(entropy_cur)
                loss_list.append(loss_cur)
                val_loss_list.append(cross_entropy)
        entropy = np.mean(entropy_list)
        loss = np.mean(loss_list)
        val_loss = np.mean(val_loss_list)
        return entropy, loss, val_loss
