import os.path

import numpy as np
import tensorflow as tf


class BC:
    def __init__(self, name: str, state_dim, act_dim, encode_dim, sample_num, **kwargs):
        """
        :param name: string
        :param env: gym env
        """
        self.overall_scope_name = name
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.encode_dim = encode_dim
        self.sample_num = sample_num
        self.hidden = [512, 256, 128]
        self.var_learning_rate = kwargs.get("learning_rate", 1e-4)
        self.soft_update_tau = kwargs.get("soft_update_tau", 0.01)
        self.save_dir = kwargs.get("save_path", "./model")
        self.save_style_dir = kwargs.get("save_style_dir", "./style")
        self.tb_dir = kwargs.get("tb_dir", "./tb_event")
        self.max_to_keep = kwargs.get("max_to_keep", 5)
        self.activate = tf.nn.leaky_relu
        self._init_graph()

    def create_graph(self):
        self._create_network()
        self._create_vars()
        self._create_loss()
        self._create_train_op()

    def _create_network(self):
        with tf.variable_scope(self.overall_scope_name):
            # self.obs = tf.placeholder(dtype=tf.float32, shape=[None] + list(ob_space.shape), name='obs')
            self.obs = tf.placeholder(dtype=tf.float32, shape=[None, self.state_dim], name='state')
            self.action = tf.placeholder(dtype=tf.float32, shape=[None, self.act_dim], name='action')
            self.sample_index = tf.placeholder(dtype=tf.int64, shape=[None], name='sample_index')
            self.global_step = tf.Variable(0, trainable=False, name='step')
            self.style_global_step = tf.Variable(0, trainable=False, name='style_step')
            with tf.variable_scope('share'):
                layer_1 = tf.layers.dense(inputs=self.obs, units=self.hidden[0], activation=self.activate, name="layer1")
                layer_2 = tf.layers.dense(inputs=layer_1, units=self.hidden[1], activation=self.activate, name="layer2")
            with tf.variable_scope('style_share'):
                style_layer_1 = tf.layers.dense(inputs=self.obs, units=self.hidden[0], activation=self.activate, name="layer1")
                style_layer_2 = tf.layers.dense(inputs=style_layer_1, units=self.hidden[1], activation=self.activate, name="layer2")
            with tf.variable_scope("target_policy"):
                self.target_logits = tf.layers.dense(inputs=layer_2, units=self.act_dim, activation=tf.identity, name="target_layer1")
                self.target_probs = tf.nn.softmax(self.target_logits, axis=1)
            with tf.variable_scope("style"):
                self.total_normal_style = tf.Variable(tf.random_normal([self.sample_num, self.encode_dim]), name="total_normal_style")  #sample_num   sytle nums
                self.total_softmax_style = tf.nn.softmax(self.total_normal_style, name="total_softmax_style")
                self.normal_style = tf.gather(self.total_normal_style, self.sample_index, name="normal_style")
                assert len(self.normal_style.get_shape().as_list()) == 2 and self.normal_style.get_shape().as_list()[
                    1] == self.encode_dim, f"Wrong normal_style shape {self.normal_style.get_shape().as_list()}, expected (None, {self.encode_dim})."
                self.softmax_style = tf.nn.softmax(self.normal_style, axis=1, name="softmax_style")
                # self.softmax_style_layer1 = tf.layers.dense(inputs=self.softmax_style, units=self.hidden[1], activation=tf.nn.tanh, name="softmax_style_layer1")
                #print(self.softmax_style)
                self.style_state = tf.concat([style_layer_2, self.softmax_style], axis=1)
                self.weighted_style = tf.layers.dense(inputs=self.style_state, units=self.hidden[1], activation=tf.nn.tanh, name="weighted_style")
            with tf.variable_scope("policy"):
                self.logits = tf.layers.dense(inputs=self.weighted_style, units=self.act_dim, activation=tf.identity)
                self.probs = tf.nn.softmax(self.logits, axis=1)

            self.act_stochastic = tf.multinomial(tf.log(self.probs), num_samples=1)
            self.act_stochastic = tf.reshape(self.act_stochastic, shape=[-1])
            self.act_deterministic = tf.argmax(self.probs, axis=1)

            self.target_act_stochastic = tf.multinomial(tf.log(self.target_probs), num_samples=1)
            self.target_act_stochastic = tf.reshape(self.target_act_stochastic, shape=[-1])
            self.target_act_deterministic = tf.argmax(self.target_probs, axis=1)

    def _create_vars(self):
        self.vars = self.get_trainable_variables()
        print(f"All trainable vars:")
        for idx, var in enumerate(self.vars):
            print(var.name, var.value)
        self.target_policy_trainable_vars = self.get_trainable_variables_by_name([f"{self.overall_scope_name}/{k}" for k in ["share", "target_policy"]])
        print(f"Target trainable vars:")
        for idx, var in enumerate(self.target_policy_trainable_vars):
            print(var.name, var.value)
        self.target_policy_assign_vars = self.get_trainable_variables_by_name([f"{self.overall_scope_name}/{k}" for k in ["target_policy"]])
        self.policy_assign_vars = self.get_trainable_variables_by_name([f"{self.overall_scope_name}/{k}" for k in ["policy"]])

        self.target_share_assign_vars = self.get_trainable_variables_by_name([f"{self.overall_scope_name}/{k}" for k in ["share"]])
        self.share_assign_vars = self.get_trainable_variables_by_name([f"{self.overall_scope_name}/{k}" for k in ["style_share"]])

        self.style_trainable_vars = self.get_trainable_variables_by_name([f"{self.overall_scope_name}/{k}" for k in ["style", "policy"]])
        print(f"Style trainable vars:")
        for idx, var in enumerate(self.style_trainable_vars):
            print(var.name, var.value)

    def _create_loss(self):
        self.target_policy_loss = tf.losses.softmax_cross_entropy(self.action, self.target_logits)
        self.policy_loss = tf.losses.softmax_cross_entropy(self.action, self.logits)
        self.style_loss = self.policy_loss - self.target_policy_loss
        # self.style_loss = self.policy_loss

    def _create_train_op(self):
        self.opt = tf.train.AdamOptimizer(learning_rate=self.var_learning_rate)
        self.train_op_policy = self.opt.minimize(self.target_policy_loss, global_step=self.global_step, var_list=self.target_policy_trainable_vars)
        self.train_op_style = self.opt.minimize(self.style_loss, global_step=self.style_global_step, var_list=self.style_trainable_vars)

    def _create_saver(self, var_dict):
        self.saver = tf.train.Saver(var_dict, max_to_keep=0)

    def _create_tb(self, graph):
        self.log_writer = tf.summary.FileWriter(self.tb_dir, graph=graph)

    def _init_graph(self, **kwargs):
        self.graph = tf.Graph()
        self.tf_conf = {}
        with self.graph.as_default() as g:
            self.create_graph()
            self.sess = tf.Session(
                graph=self.graph,
                config=tf.ConfigProto(
                    inter_op_parallelism_threads=self.tf_conf.get("inter_op_parallelism_threads", 0),
                    intra_op_parallelism_threads=self.tf_conf.get("intra_op_parallelism_threads", 0),
                    allow_soft_placement=self.tf_conf.get("allow_soft_placement", True),
                    log_device_placement=self.tf_conf.get("log_device_placememt", False),
                    gpu_options=tf.GPUOptions(allow_growth=True)))
            self.sess.run(tf.global_variables_initializer())
            self.policy_update_op = []
            for old_var, new_var in zip(self.policy_assign_vars, self.target_policy_assign_vars):
                new_var = self.soft_update_tau * new_var + (1 - self.soft_update_tau) * old_var
                self.policy_update_op.append(tf.assign(old_var, new_var))
            self.share_update_op = []
            for old_var, new_var in zip(self.share_assign_vars, self.target_share_assign_vars):
                new_var = self.soft_update_tau * new_var + (1 - self.soft_update_tau) * old_var
                self.share_update_op.append(tf.assign(old_var, new_var))
            self._create_saver({var.name: var for var in self.vars})
            self._create_tb(g)

    def train(self, states, actions, sample_index):
        feed_dict = {
            self.obs: states,
            self.action: actions,
            self.sample_index: sample_index
        }
        train_step, policy_loss, target_loss, _ = self.sess.run([self.global_step, self.policy_loss, self.target_policy_loss, self.train_op_policy], feed_dict=feed_dict)
        # self.replace_policy()
        self.replace_share()
        style_step, style_loss, _ = self.sess.run([self.style_global_step, self.style_loss, self.train_op_style], feed_dict=feed_dict)
        # style_step, style_loss, _ = 0, 0, 0
        return train_step, target_loss, style_step, style_loss, policy_loss

    def replace_policy(self):
        self.sess.run(self.policy_update_op)

    def replace_share(self):
        self.sess.run(self.share_update_op)

    def save(self, model_step):
        save_path = os.path.join(self.save_dir, f"model_{model_step}.ckpt")
        self.saver.save(self.sess, save_path)

    def load(self, model_path, model_step):
        filename = os.path.join(model_path, f"model_{model_step}.ckpt")
        self.saver.restore(self.sess, filename)
        print('restore checkpoint', filename)

    def save_styles(self, model_step):
        style = self.sess.run(self.total_normal_style)
        softmax_style = self.sess.run(self.total_softmax_style)
        os.makedirs(self.save_style_dir, exist_ok=True)
        filename = os.path.join(self.save_style_dir, f"style_{model_step}.txt")
        np.savetxt(filename, np.array(style))
        filename = os.path.join(self.save_style_dir, f"style_softmax_{model_step}.txt")
        np.savetxt(filename, np.array(softmax_style))

    def act(self, obs, softmax_style, stochastic=True):
        if stochastic:
            return self.sess.run([self.act_stochastic, self.probs], feed_dict={self.obs: obs, self.softmax_style: softmax_style})
        else:
            return self.sess.run([self.act_deterministic, self.probs], feed_dict={self.obs: obs, self.softmax_style: softmax_style})

    def get_action(self, obs, sample_index, stochastic=True):
        if stochastic:
            return self.sess.run([self.act_stochastic, self.probs], feed_dict={self.obs: obs, self.sample_index: sample_index})
        else:
            return self.sess.run([self.act_deterministic, self.probs], feed_dict={self.obs: obs, self.sample_index: sample_index})

    def get_target_action(self, obs, stochastic=True):
        if stochastic:
            return self.sess.run([self.target_act_stochastic, self.target_probs], feed_dict={self.obs: obs})
        else:
            return self.sess.run([self.target_act_deterministic, self.target_probs], feed_dict={self.obs: obs})

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    def get_trainable_variables_by_name(self, var_scopes):
        var_list = []
        for var_scope in var_scopes:
            var_list.extend(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=var_scope))
        return var_list
