"""
OfflineLight
"""

from tensorflow.keras.layers import Input, Dense, Reshape,  Lambda,  Activation, Embedding, Conv2D, concatenate, add,\
    multiply, MultiHeadAttention
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
import numpy as np
import random
import tensorflow as tf
from tensorflow.keras.losses import MeanSquaredError
import copy
from .actor_critic_network_agent import ActorCriticNetworkAgent
import scipy.stats as st 
import pickle
import os 

class GeneralAgent_BEAR(ActorCriticNetworkAgent):

    def build_hidden(self, ins0, ins1):
        cur_phase_emb = Activation('sigmoid')(Embedding(2, 4, input_length=8)(ins1))
        cur_phase_emb = Reshape((2, 4, 4))(cur_phase_emb)
        cur_phase_feat = Lambda(lambda x: K.sum(x, axis=1), name="feature_as_phase")(cur_phase_emb)
        feat1 = Reshape((12, self.num_feat, 1))(ins0)
        feat_emb = Dense(4, activation='sigmoid', name="feature_embedding")(feat1)
        feat_emb = Reshape((12, self.num_feat*4))(feat_emb)
        lane_feat_s = tf.split(feat_emb, 12, axis=1)
        Sum1 = Lambda(lambda x: K.sum(x, axis=1, keepdims=True))
        
        phase_feats_map_2 = []
        for i in range(self.num_phases):
            tmp_feat_1 = tf.concat([lane_feat_s[idx] for idx in self.phase_map[i]], axis=1)
            tmp_feat_3 = Sum1(tmp_feat_1)
            phase_feats_map_2.append(tmp_feat_3)

        # embedding
        phase_feat_all = tf.concat(phase_feats_map_2, axis=1)
        phase_feat_all = concatenate([phase_feat_all, cur_phase_feat])

        att_encoding = MultiHeadAttention(4, 8, attention_axes=1)(phase_feat_all, phase_feat_all)
        hidden = Dense(20, activation="relu")(att_encoding)
        hidden = Dense(20, activation="relu")(hidden)

        hidden = Dense(1, activation="linear", name="final_critic_score")(hidden)
        hidden = Reshape((4,))(hidden)
        return hidden

    def build_q_network(self):
        ins0 = Input(shape=(12, self.num_feat), name="input_total_features")
        ins1 = Input(shape=(8, ), name="input_cur_phase")
        ins2 = Input(shape=(4, ), name="input_cur_action_q")

        hidden = self.build_hidden(ins0, ins1)
        #hidden = concatenate([hidden, ins2])
        q_values = Dense(1, activation="linear")(hidden)
        q_values1 = Dense(1, activation="linear")(hidden)
        network = Model(inputs=[ins0, ins1, ins2],
                        outputs=[q_values, q_values1])
        network.compile()
        network.summary()
        return network   


    def build_a_network(self):
        ins0 = Input(shape=(12, self.num_feat), name="input_total_features")
        ins1 = Input(shape=(8, ), name="input_cur_phase")
        action_q_values = self.build_hidden(ins0, ins1)
        network = Model(inputs=[ins0, ins1],
                        outputs=action_q_values)
        network.compile()
        network.summary()
        return network   


    def choose_action(self, states):
        dic_state_feature_arrays = {}
        cur_phase_info = []
        used_feature = copy.deepcopy(self.dic_traffic_env_conf["LIST_STATE_FEATURE"])
        # print(used_feature)
        for feature_name in used_feature:
            dic_state_feature_arrays[feature_name] = []
        
        for s in states:
            for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
                if feature_name == "new_phase":
                    cur_phase_info.append(s[feature_name])
                else:
                    dic_state_feature_arrays[feature_name].append(s[feature_name])

        used_feature.remove("new_phase")
        state_input = [np.array(dic_state_feature_arrays[feature_name]).reshape(len(states), 12, -1) for feature_name in
                       used_feature]
        state_input = np.concatenate(state_input, axis=-1)

        q_values = self.a_network.predict([state_input, np.array(cur_phase_info)])

        action = np.argmax(q_values, axis=1)
        return action

    def prepare_samples(self, memory):
        """
        [state, action, next_state, final_reward, average_reward]
        """
        state, action, next_state, p_reward, ql_reward = memory
        used_feature = self.dic_traffic_env_conf["LIST_STATE_FEATURE"]
        
        memory_size = len(action)
        
        _state = [[], None]
        _next_state = [[], None]
        for feat_name in used_feature:
            if feat_name == "new_phase":
                _state[1] = np.array(state[feat_name])
                _next_state[1] = np.array(next_state[feat_name])
            else:
                _state[0].append(np.array(state[feat_name]).reshape(memory_size, 12, -1))
                _next_state[0].append(np.array(next_state[feat_name]).reshape(memory_size, 12, -1))
                
        # ========= generate reaward information ===============
        if "pressure" in self.dic_traffic_env_conf["DIC_REWARD_INFO"].keys():
            my_reward = p_reward
            # list(-np.absolute(np.sum(next_state["traffic_movement_pressure_queue_efficient"], axis=-1))/4)
        else:
            my_reward = ql_reward
            # list(-np.sum(next_state["lane_queue_vehicle_in"], axis=-1)/4)
        
        return [np.concatenate(_state[0], axis=-1), _state[1]], action, [np.concatenate(_next_state[0], axis=-1), _next_state[1]], my_reward

    def update_target(self, target_weights, weights, tau):
        for (a, b) in zip(target_weights, weights):
            a.assign(b * tau + a * (1 - tau))

    def train_network(self, memory):
        _state, _action, _next_state, _reward = self.prepare_samples(memory)
        
        # ==== shuffle the samples ============
        
        random_index = np.random.permutation(len(_action))
        _state[0] = _state[0][random_index, :, :]
        _state[1] = _state[1][random_index, :]
        _action = np.array(_action)[random_index]
        _next_state[0] = _next_state[0][random_index, :, :]
        _next_state[1] = _next_state[1][random_index, :]
        _reward = np.array(_reward)[random_index]

        # set epochs=1
        epochs = self.dic_agent_conf["EPOCHS"]
        batch_size = min(self.dic_agent_conf["BATCH_SIZE"], len(_action))
        num_batch = int(np.floor((len(_action) / batch_size)))

        
        loss_fn = MeanSquaredError()
        critic_optimizer = Adam(lr=1e-6)
        actor_optimizer = Adam(lr=1e-6)
        if self.cnt_round == 0:
            lagrange_multiplier = tf.Variable(initial_value=2.)
        else:
            lagrange_multiplier = self.load_lagrange("round_{0}_inter_{1}".format(self.cnt_round -1, 0))
        lagrange_optimizer = Adam(lr=1e-5)
        _epsilon = 5e-2
        tau = 0.01
        lmbda = 0.75
        p = 10
        for epoch in range(epochs):

            for ba in range(int(num_batch)):
                batch_feature_Xs1 = _state[0][ba*batch_size:(ba+1)*batch_size, :, :]
                batch_phase_Xs1 = _state[1][ba*batch_size:(ba+1)*batch_size, :]
            
                batch_feature_Xs2 = _next_state[0][ba*batch_size:(ba+1)*batch_size, :, :]
                batch_phase_Xs2 =  _next_state[1][ba*batch_size:(ba+1)*batch_size, :]
                batch_r = _reward[ba*batch_size:(ba+1)*batch_size]
                batch_a = _action[ba*batch_size:(ba+1)*batch_size]

                
                batch_feature_Xs2_rep = np.repeat(np.array(batch_feature_Xs2), p, axis=0)
                batch_phase_Xs2_rep = np.repeat(np.array(batch_phase_Xs2), p, axis=0)
                
                # forward
                with tf.GradientTape() as tape:
                    tape.watch(self.q_network.trainable_weights)
                    # target critic using p samples
                    # current critic using origin action and states
                    batch_action_Xs2_rep = self.a_network_bar([batch_feature_Xs2_rep, batch_phase_Xs2_rep]).numpy()

                    tmp_cur_q, tmp_cur_q1 = self.q_network([batch_feature_Xs1, batch_phase_Xs1, batch_a])
                    tmp_next_q, tmp_next_q1 = self.q_network_bar([batch_feature_Xs2_rep, batch_phase_Xs2_rep, batch_action_Xs2_rep])
                    # soft double q-learning
                    soft_tmp_next_q = lmbda * tf.math.minimum(tmp_next_q, tmp_next_q1) + (1 - lmbda) * tf.math.maximum(tmp_next_q, tmp_next_q1)
                    tmp_next_q = tf.math.reduce_max(tf.reshape(soft_tmp_next_q, shape=(batch_size, -1)), axis=1, keepdims=True)
                    tmp_target = np.copy(tmp_cur_q)
                    for i in range(batch_size):
                        tmp_target[i, 0] = batch_r[i] / self.dic_agent_conf["NORMAL_FACTOR"] + \
                                                    self.dic_agent_conf["GAMMA"] * \
                                                    tmp_next_q[i, :]

                    base_loss = loss_fn(tmp_target, tmp_cur_q) + loss_fn(tmp_target, tmp_cur_q1)
                    # final loss
                    tmp_loss = base_loss 
                    critic_grads = tape.gradient(tmp_loss, self.q_network.trainable_weights)
                    critic_optimizer.apply_gradients(zip(critic_grads, self.q_network.trainable_weights))
               
                # sample actions from current actor
                mmd_samples = 5
                batch_feature_Xs1_mmd = np.repeat(np.array(batch_feature_Xs1), mmd_samples, axis=0)
                batch_phase_Xs1_mmd = np.repeat(np.array(batch_phase_Xs1), mmd_samples, axis=0)


                batch_feature_Xs2_mmd = np.repeat(np.array(batch_feature_Xs2), mmd_samples, axis=0)
                batch_phase_Xs2_mmd = np.repeat(np.array(batch_phase_Xs2), mmd_samples, axis=0)
                one_hot_batch_a = tf.one_hot(np.array(batch_a), 4, 1., 0., name='one_hot_action')
                raw_actions = tf.reshape(np.repeat(one_hot_batch_a, mmd_samples, axis=0), shape=(batch_size, mmd_samples, -1))

                with tf.GradientTape() as tape:
                    tape.watch(self.a_network.trainable_weights)
                    # calculate actor loss
                    pi_actions = tf.reshape(self.a_network([batch_feature_Xs1_mmd, batch_phase_Xs1_mmd]), shape=(batch_size, mmd_samples, -1))
                    mmd_loss = self._compute_laplacian_mmd(raw_actions, pi_actions, 20.)
                    pi_actions_mmd = tf.reshape(pi_actions, shape=(batch_size * mmd_samples, -1))
                    tmp_cur_q_mmd, tmp_cur_q1_mmd = self.q_network([batch_feature_Xs2_mmd, batch_phase_Xs2_mmd, pi_actions_mmd])
                    
                    tmp_cur_q_mmd = tf.reduce_mean(tf.reshape(tmp_cur_q_mmd,shape=(batch_size, mmd_samples, -1)), axis=1)
                    tmp_cur_q1_mmd = tf.reduce_mean(tf.reshape(tmp_cur_q1_mmd,shape=(batch_size, mmd_samples, -1)), axis=1)
                    
                    tmp_cur_q_min = tf.math.minimum(tmp_cur_q_mmd, tmp_cur_q1_mmd)

                    actor_loss = tf.reduce_mean(mmd_loss * lagrange_multiplier)
                    
                    # actor gradient
                    actor_grads = tape.gradient(actor_loss, self.a_network.trainable_weights)
                    actor_optimizer.apply_gradients(zip(actor_grads, self.a_network.trainable_weights))
                
                with tf.GradientTape() as tape:
                    lagrange_loss = -tf.reduce_mean(-tmp_cur_q_min +
                                tf.exp(lagrange_multiplier) * (mmd_loss - _epsilon))
                    # lagrange gradient
                    lagrange_grads = tape.gradient(lagrange_loss, lagrange_multiplier)
                    lagrange_optimizer.apply_gradients(zip([lagrange_grads], [lagrange_multiplier]))

                self.update_target(self.a_network_bar.variables, self.a_network.variables, tau)
                self.update_target(self.q_network_bar.variables, self.q_network.variables, tau)
                self.save_lagrange("round_{0}_inter_{1}".format(self.cnt_round, 0), lagrange_multiplier)
                print("===== Epoch {} | Batch {} / {} | Critic Loss {} | Actor Loss {}".format(epoch, ba, num_batch, tmp_loss, actor_loss))

    def _compute_laplacian_mmd(self, samples1, samples2, sigma=20.0):
        n = samples1.shape[1]
        m = samples2.shape[1]

        k_xx = tf.expand_dims(samples1, axis=2) - \
            tf.expand_dims(samples1, axis=1)
        sum_k_xx = tf.math.reduce_sum(
            tf.math.exp(-tf.math.reduce_sum(tf.abs(k_xx), axis=-1, keepdims=True) / (2.0 * sigma)), axis=(1, 2))

        k_xy = tf.expand_dims(samples1, axis=2) - \
            tf.expand_dims(samples2, axis=1)
        sum_k_xy = tf.math.reduce_sum(
            tf.math.exp(-tf.math.reduce_sum(tf.math.abs(k_xy), axis=-1, keepdims=True) / (2.0 * sigma)), axis=(1, 2))

        k_yy = tf.expand_dims(samples2, axis=2) - \
            tf.expand_dims(samples2, axis=1)
        sum_k_yy = tf.math.reduce_sum(
            tf.math.exp(-tf.math.reduce_sum(tf.math.abs(k_yy), axis=-1, keepdims=True) / (2.0 * sigma)), axis=(1, 2))

        mmd_squared = \
            sum_k_xx / (n * n) - 2.0 * sum_k_xy / (m * n) + sum_k_yy / (m * m)
        return tf.math.sqrt(mmd_squared + 1e-6)

    def save_lagrange(self, file_name, obj):
        file = open(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_lagrange.pickle" % file_name), 'wb')
        pickle.dump(obj, file)
        file.close()

    def load_lagrange(self, file_name, file_path=None):
        if file_path is None:
            file_path = self.dic_path["PATH_TO_MODEL"]
        file_name = os.path.join(file_path, "%s_lagrange.pickle" % file_name)
        with open(file_name, 'rb') as file:
            return pickle.load(file)