"""
OfflineLight
"""

from tensorflow.keras.layers import Input, Dense, Reshape,  Lambda,  Activation, Embedding, Conv2D, concatenate, add,\
    multiply, MultiHeadAttention
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
import numpy as np
import random
import tensorflow as tf
from tensorflow.keras.losses import MeanSquaredError
import copy
from .actor_critic_network_agent import ActorCriticNetworkAgent
import scipy.stats as st 

class GeneralAgent_Offline(ActorCriticNetworkAgent):

    def build_hidden(self, ins0, ins1):
        cur_phase_emb = Activation('sigmoid')(Embedding(2, 4, input_length=8)(ins1))
        cur_phase_emb = Reshape((2, 4, 4))(cur_phase_emb)
        cur_phase_feat = Lambda(lambda x: K.sum(x, axis=1), name="feature_as_phase")(cur_phase_emb)
        feat1 = Reshape((12, self.num_feat, 1))(ins0)
        feat_emb = Dense(4, activation='sigmoid', name="feature_embedding")(feat1)
        feat_emb = Reshape((12, self.num_feat*4))(feat_emb)
        lane_feat_s = tf.split(feat_emb, 12, axis=1)
        Sum1 = Lambda(lambda x: K.sum(x, axis=1, keepdims=True))
        
        phase_feats_map_2 = []
        for i in range(self.num_phases):
            tmp_feat_1 = tf.concat([lane_feat_s[idx] for idx in self.phase_map[i]], axis=1)
            tmp_feat_3 = Sum1(tmp_feat_1)
            phase_feats_map_2.append(tmp_feat_3)

        # embedding
        phase_feat_all = tf.concat(phase_feats_map_2, axis=1)
        phase_feat_all = concatenate([phase_feat_all, cur_phase_feat])

        att_encoding = MultiHeadAttention(4, 8, attention_axes=1)(phase_feat_all, phase_feat_all)
        hidden = Dense(20, activation="relu")(att_encoding)
        hidden = Dense(20, activation="relu")(hidden)

        hidden = Dense(1, activation="linear", name="final_critic_score")(hidden)
        hidden = Reshape((4,))(hidden)
        return hidden

    def build_q_network(self):
        ins0 = Input(shape=(12, self.num_feat), name="input_total_features")
        ins1 = Input(shape=(8, ), name="input_cur_phase")
        ins2 = Input(shape=(4, ), name="input_cur_action_q")

        hidden = self.build_hidden(ins0, ins1)
       # action_q = Dense(1, activation='sigmoid')(ins2)
        #hidden = concatenate([hidden, action_q])
        #q_values = Dense(4, activation="linear")(hidden)
        #q_values1 = Dense(4, activation='linear')(hidden)
        # network = Model(inputs=[ins0, ins1, ins2],
        #                 outputs=[q_values,q_values1])
        network = Model(inputs=[ins0, ins1, ins2],
                        outputs=hidden)
        network.compile()
        network.summary()
        return network   


    def build_a_network(self):
        ins0 = Input(shape=(12, self.num_feat), name="input_total_features")
        ins1 = Input(shape=(8, ), name="input_cur_phase")
        action_q_values = self.build_hidden(ins0, ins1)
        network = Model(inputs=[ins0, ins1],
                        outputs=action_q_values)
        network.compile()
        network.summary()
        return network   


    def choose_action(self, states):
        dic_state_feature_arrays = {}
        cur_phase_info = []
        used_feature = copy.deepcopy(self.dic_traffic_env_conf["LIST_STATE_FEATURE"])
        # print(used_feature)
        for feature_name in used_feature:
            dic_state_feature_arrays[feature_name] = []
        
        for s in states:
            for feature_name in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
                if feature_name == "new_phase":
                    cur_phase_info.append(s[feature_name])
                else:
                    dic_state_feature_arrays[feature_name].append(s[feature_name])

        used_feature.remove("new_phase")
        state_input = [np.array(dic_state_feature_arrays[feature_name]).reshape(len(states), 12, -1) for feature_name in
                       used_feature]
        state_input = np.concatenate(state_input, axis=-1)

        q_values = self.a_network.predict([state_input, np.array(cur_phase_info)])

        action = np.argmax(q_values, axis=1)
        return action

    def prepare_samples(self, memory):
        """
        [state, action, next_state, final_reward, average_reward]
        """
        state, action, next_state, p_reward, ql_reward = memory
        used_feature = self.dic_traffic_env_conf["LIST_STATE_FEATURE"]
        
        memory_size = len(action)
        
        _state = [[], None]
        _next_state = [[], None]
        for feat_name in used_feature:
            if feat_name == "new_phase":
                _state[1] = np.array(state[feat_name])
                _next_state[1] = np.array(next_state[feat_name])
            else:
                _state[0].append(np.array(state[feat_name]).reshape(memory_size, 12, -1))
                _next_state[0].append(np.array(next_state[feat_name]).reshape(memory_size, 12, -1))
                
        # ========= generate reaward information ===============
        if "pressure" in self.dic_traffic_env_conf["DIC_REWARD_INFO"].keys():
            my_reward = p_reward
            # list(-np.absolute(np.sum(next_state["traffic_movement_pressure_queue_efficient"], axis=-1))/4)
        else:
            my_reward = ql_reward
            # list(-np.sum(next_state["lane_queue_vehicle_in"], axis=-1)/4)
        
        return [np.concatenate(_state[0], axis=-1), _state[1]], action, [np.concatenate(_next_state[0], axis=-1), _next_state[1]], my_reward

    def update_target(self, target_weights, weights, tau):
        for (a, b) in zip(target_weights, weights):
            a.assign(b * tau + a * (1 - tau))

    def train_network(self, memory):
        _state, _action, _next_state, _reward = self.prepare_samples(memory)
        
        # ==== shuffle the samples ============
        
        random_index = np.random.permutation(len(_action))
        _state[0] = _state[0][random_index, :, :]
        _state[1] = _state[1][random_index, :]
        _action = np.array(_action)[random_index]
        _next_state[0] = _next_state[0][random_index, :, :]
        _next_state[1] = _next_state[1][random_index, :]
        _reward = np.array(_reward)[random_index]

        # set epochs=1
        epochs = self.dic_agent_conf["EPOCHS"]
        batch_size = min(self.dic_agent_conf["BATCH_SIZE"], len(_action))
        num_batch = int(np.floor((len(_action) / batch_size)))

        
        loss_fn = MeanSquaredError()
        critic_optimizer = Adam(lr=self.dic_agent_conf["LEARNING_RATE"])
        actor_optimizer = Adam(lr=self.dic_agent_conf["LEARNING_RATE"])
        std = 2e-2
        tau = 0.01
        c = 0.2
        freq_q = 10
        freq_a = 10
        #min_q_weight = 0.1
        for epoch in range(epochs):

            for ba in range(int(num_batch)):
                noise = tf.convert_to_tensor(np.random.normal(loc=0, scale=std), dtype=tf.float32)
                action_noise = tf.clip_by_value(noise, -c, c)
                batch_feature_Xs1 = _state[0][ba*batch_size:(ba+1)*batch_size, :, :]
                batch_phase_Xs1 = _state[1][ba*batch_size:(ba+1)*batch_size, :]
            
                batch_feature_Xs2 = _next_state[0][ba*batch_size:(ba+1)*batch_size, :, :]
                batch_phase_Xs2 =  _next_state[1][ba*batch_size:(ba+1)*batch_size, :]
                batch_r = _reward[ba*batch_size:(ba+1)*batch_size]
                batch_a = _action[ba*batch_size:(ba+1)*batch_size]
                k = tf.keras.losses.KLDivergence()
                
                alpha = 0.1
                
               
                # forward
                with tf.GradientTape() as tape:
                    tape.watch(self.q_network.trainable_weights)
                    # calcualte basic loss
                    batch_action_Xs1 = self.a_network([batch_feature_Xs1, batch_phase_Xs1]).numpy()
                    batch_action_Xs2 = self.a_network_bar([batch_feature_Xs2, batch_phase_Xs2]).numpy()
                    batch_action_Xs2_noise = batch_action_Xs2 + action_noise

                    #tmp_cur_q, tmp_cur_q1 = self.q_network([batch_feature_Xs1, batch_phase_Xs1, batch_action_Xs1])
                    tmp_cur_q = self.q_network([batch_feature_Xs1, batch_phase_Xs1, batch_action_Xs1])
                    #tmp_cur_q = tf.math.minimum(tmp_cur_q, tmp_cur_q1)
                    #tmp_next_q,tmp_next_q1 = self.q_network_bar([batch_feature_Xs2, batch_phase_Xs2, batch_action_Xs2_noise])
                    tmp_next_q = self.q_network_bar([batch_feature_Xs2, batch_phase_Xs2, batch_action_Xs2_noise])
                    #tmp_next_q = tf.math.minimum(tmp_next_q, tmp_next_q1)
                    tmp_target = np.copy(tmp_cur_q)
                    for i in range(batch_size):
                        tmp_target[i, batch_a[i]] = batch_r[i] / self.dic_agent_conf["NORMAL_FACTOR"] + \
                                                    self.dic_agent_conf["GAMMA"] * \
                                                    np.max(tmp_next_q[i, :])

                    base_loss = loss_fn(tmp_target, tmp_cur_q) 

                    # calculate CQL loss
                    replay_action_one_hot = tf.one_hot(batch_a, 4, 1., 0., name='action_one_hot')
                    replay_chosen_q = tf.reduce_sum(tmp_cur_q * replay_action_one_hot, axis=1)
                    #replay_chosen_q = tf.reduce_sum(tmp_cur_q , axis=1)
                    dataset_expec = tf.reduce_mean(replay_chosen_q)

                    negative_sampling = tf.reduce_mean(tf.reduce_logsumexp(tmp_cur_q, 1))
                    min_q_loss = (negative_sampling - dataset_expec)
                    min_q_loss = min_q_loss * self.min_q_weight
                    # final loss
                    tmp_loss = base_loss + min_q_loss
                    #tmp_loss = base_loss
                    critic_grads = tape.gradient(tmp_loss, self.q_network.trainable_weights)
                    critic_optimizer.apply_gradients(zip(critic_grads, self.q_network.trainable_weights))
                    print("===== Epoch {} | Batch {} / {} |  Critic Loss {}".format(epoch, ba, num_batch, tmp_loss))


                 # every 10 ba update actor
                if ba % 10 == 0:
                    with tf.GradientTape() as tape:
                        tape.watch(self.a_network.trainable_weights)
                        # calculate actor loss
                        pi_theta_action = self.a_network([batch_feature_Xs1, batch_phase_Xs1])
                        actor_loss1 = self.q_network([batch_feature_Xs1, batch_phase_Xs1, pi_theta_action])
                        #actor_loss1 = tf.math.minimum(actor_loss1, actor_loss11)
                        pi_beta = tf.one_hot(batch_a, 4, 1., 0., name='one_hot_action')
                        
                        #actor_loss2 = tf.expand_dims(alpha * k(pi_theta_action, pi_beta), -1)
                        #lmbda = alpha/tf.reduce_mean(tf.math.abs(actor_loss1))
                        actor_loss2 = MeanSquaredError()(pi_beta, pi_theta_action)
                        actor_loss = - tf.reduce_mean(actor_loss1) + tf.reduce_mean(actor_loss2)
                        #actor_loss = -tf.reduce_mean(actor_loss1) 
                        actor_grads = tape.gradient(actor_loss, self.a_network.trainable_weights)
                        actor_optimizer.apply_gradients(zip(actor_grads, self.a_network.trainable_weights))
                        print("===== Epoch {} | Batch {} / {} |  Actor Loss {}".format(epoch, ba, num_batch, actor_loss))
                self.update_target(self.a_network_bar.variables, self.a_network.variables, tau)
                self.update_target(self.q_network_bar.variables, self.q_network.variables, tau)
               