import numpy as np
import gym
from typing import Dict, Optional, Sequence

from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.layers import SkipConnection
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType, List, ModelConfigDict, Tuple

tf1, tf, tfv = try_import_tf()


class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
    def __init__(self,
                 out_dim: int,
                 num_heads: int,
                 head_dim: int,
                 input_layernorm: bool = False,
                 output_activation: Optional["tf.nn.activation"] = None,
                 **kwargs):
        """Initializes a RelativeMultiHeadAttention keras Layer object.

        Args:
            out_dim (int): The output dimensions of the multi-head attention
                unit.
            num_heads (int): The number of attention heads to use.
                Denoted `H` in [2].
            head_dim (int): The dimension of a single(!) attention head within
                a multi-head attention unit. Denoted as `d` in [3].
            input_layernorm (bool): Whether to prepend a LayerNorm before
                everything else. Should be True for building a GTrXL.
            output_activation (Optional[tf.nn.activation]): Optional tf.nn
                activation function. Should be relu for GTrXL.
            **kwargs:
        """
        super().__init__(**kwargs)
        
        self.num_heads = num_heads
        self.head_dim = head_dim
        # 3=Query, key, and value inputs.
        self.q_layer = tf.keras.layers.Dense(num_heads * head_dim, use_bias=True)
        self.k_layer = tf.keras.layers.Dense(num_heads * head_dim, use_bias=True)
        self.v_layer = tf.keras.layers.Dense(num_heads * head_dim, use_bias=True)
        
        self.linear_layer = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                out_dim, use_bias=False, activation=output_activation))
        
    def call(self, inputs: TensorType,
             memory: Optional[TensorType] = None) -> TensorType:
        T = tf.shape(inputs)[1]  # length of segment (time)
        H = self.num_heads  # number of attention heads
        d = self.head_dim  # attention head dimension
        
        queries = self.q_layer(inputs)
        keys = self.k_layer(inputs)
        values = self.v_layer(inputs)
        
        queries = tf.reshape(queries, [-1, T, H, d])
        keys = tf.reshape(keys, [-1, T, H, d])
        values = tf.reshape(values, [-1, T, H, d])
        
        score = tf.einsum("bihd,bjhd->bijh", queries, keys)
        score = score / d**0.5
        
        wmat = tf.nn.softmax(score, axis=2)

        out = tf.einsum("bijh,bjhd->bihd", wmat, values)
        out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * d]), axis=0))
        return self.linear_layer(out)   
        

# TODO: (sven) obsolete this class once we only support native keras models.
class TransformerNetwork(TFModelV2):
    """Generic fully connected network implemented in ModelV2 API."""

    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str, **kwargs):
        super(TransformerNetwork, self).__init__(
            obs_space, action_space, num_outputs, model_config, name)

        input_split_shape = list(model_config["custom_model_config"].get(
            "input_split_shape", [[obs_space.shape], ]))
        token_dim = kwargs.get("token_dim", 256)
        num_heads = kwargs.get("num_heads", 4)
        head_dim = kwargs.get("head_dim", 256)
        self.num_att_layers = 1
        hiddens = list(model_config.get("fcnet_hiddens", [])) + \
            list(model_config.get("post_fcnet_hiddens", []))
        activation = model_config.get("fcnet_activation")
        if not model_config.get("fcnet_hiddens", []):
            activation = model_config.get("post_fcnet_activation")
        activation = get_activation_fn(activation)
        no_final_linear = model_config.get("no_final_linear")
        vf_share_layers = model_config.get("vf_share_layers")
        free_log_std = model_config.get("free_log_std")

        # Generate free-floating bias variables for the second half of
        # the outputs.
        if free_log_std:
            assert num_outputs % 2 == 0, (
                "num_outputs must be divisible by two", num_outputs)
            num_outputs = num_outputs // 2
            self.log_std_var = tf.Variable(
                [0.0] * num_outputs, dtype=tf.float32, name="log_std")

        # We are using obs_flat, so take the flattened shape as input.
        inputs = tf.keras.layers.Input(
            shape=(int(np.product(obs_space.shape)), ), name="observations")

        # split input by input_split_shape for different units
        def flatten_list(irregular_list): return [element for item in irregular_list for element in
                                                  flatten_list(item)] if type(irregular_list) is list else [irregular_list]
        data_split_shape = flatten_list(input_split_shape)
        agent_feature_data_list = tf.split(inputs, data_split_shape, axis=-1)

        agent_feature_data_list[0] = tf.concat([agent_feature_data_list[0], agent_feature_data_list.pop(-1)], axis=-1)

        # get different representations for differernt types of units
        agent_feature_out_list = []
        fc_ego_feature_dense = tf.keras.layers.Dense(
            token_dim, name="fc_ego_feature", activation=activation, kernel_initializer=normc_initializer(1.0))
        fc_enemy_feature_dense = tf.keras.layers.Dense(
            token_dim, name="fc_enemy_feature", activation=activation, kernel_initializer=normc_initializer(1.0))
        fc_ally_feature_dense = tf.keras.layers.Dense(
            token_dim, name="fc_ally_feature", activation=activation, kernel_initializer=normc_initializer(1.0))
        for i, agent_feature_data in enumerate(agent_feature_data_list):
            if i == 0:
                agent_feature_out_list.append(
                    fc_ego_feature_dense(agent_feature_data))
            elif i <= len(input_split_shape[1]):
                agent_feature_out_list.append(
                    fc_enemy_feature_dense(agent_feature_data))
            else:
                agent_feature_out_list.append(
                    fc_ally_feature_dense(agent_feature_data))

        # [batch_size, length, token_size]
        agent_feature_out = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=-2))(agent_feature_out_list)
        
        # building transformer
        # Last hidden layer output (before logits outputs).
        last_layer = agent_feature_out
        
        for _ in range(self.num_att_layers):
            last_layer = SkipConnection(
                    MultiHeadAttention(
                        out_dim=head_dim,
                        num_heads=num_heads,
                        head_dim=head_dim,
                        input_layernorm=False,
                        output_activation=None),
                    fan_in_layer=None)(last_layer)

        # avgpool
        last_layer = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=1))(last_layer)

        # The action distribution outputs.
        logits_out = None

        # The last layer is adjusted to be of size num_outputs, but it's a
        # layer with activation.
        if no_final_linear and num_outputs:
            logits_out = tf.keras.layers.Dense(
                num_outputs,
                name="fc_out",
                activation=activation,
                kernel_initializer=normc_initializer(1.0))(last_layer)
        # Finish the layers with the provided sizes (`hiddens`), plus -
        # iff num_outputs > 0 - a last linear layer of size num_outputs.
        else:
            if len(hiddens) > 0:
                last_layer = tf.keras.layers.Dense(
                    hiddens[-1],
                    name="fc_{}".format(i),
                    activation=activation,
                    kernel_initializer=normc_initializer(1.0))(last_layer)
            if num_outputs:
                logits_out = tf.keras.layers.Dense(
                    num_outputs,
                    name="fc_out",
                    activation=None,
                    kernel_initializer=normc_initializer(0.01))(last_layer)
            # Adjust num_outputs to be the number of nodes in the last layer.
            else:
                self.num_outputs = (
                    [int(np.product(obs_space.shape))] + hiddens[-1:])[-1]

        # Concat the log std vars to the end of the state-dependent means.
        if free_log_std and logits_out is not None:

            def tiled_log_std(x):
                return tf.tile(
                    tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])

            log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
            logits_out = tf.keras.layers.Concatenate(axis=1)(
                [logits_out, log_std_out])

        last_vf_layer = None
        if not vf_share_layers:
            # Build a parallel set of hidden layers for the value net.
            last_vf_layer = inputs
            i = 1
            for size in hiddens:
                last_vf_layer = tf.keras.layers.Dense(
                    size,
                    name="fc_value_{}".format(i),
                    activation=activation,
                    kernel_initializer=normc_initializer(1.0))(last_vf_layer)
                i += 1

        value_out = tf.keras.layers.Dense(
            1,
            name="value_out",
            activation=None,
            kernel_initializer=normc_initializer(0.01))(
                last_vf_layer if last_vf_layer is not None else last_layer)

        self.base_model = tf.keras.Model(
            inputs, [(logits_out
                      if logits_out is not None else last_layer), value_out])

    def forward(self, input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
        model_out, self._value_out = self.base_model(input_dict["obs_flat"])
        return model_out, state

    def value_function(self) -> TensorType:
        return tf.reshape(self._value_out, [-1])

    def load_weights(self, dir_path):

        # king
        attention_names = [["fc_encoder_layer{}_q_weight".format(i), "fc_encoder_layer{}_q_bias".format(i),
                            "fc_encoder_layer{}_k_weight".format(i), "fc_encoder_layer{}_k_bias".format(i),
                            "fc_encoder_layer{}_v_weight".format(i), "fc_encoder_layer{}_v_bias".format(i),
                            #"fc_encoder_layer{}_ffn1_weight".format(i), "fc_encoder_layer{}_ffn1_bias".format(i),
                            #"fc_encoder_layer{}_ffn2_weight".format(i), "fc_encoder_layer{}_ffn2_bias".format(i)\
                          ] for i in range(self.num_att_layers)]
        """
        # smac
        attention_names = [["multi_head_attention/dense/kernel","multi_head_attention/dense/bias",
            "multi_head_attention/dense_1/kernel","multi_head_attention/dense_1/bias",
            "multi_head_attention/dense_2/kernel","multi_head_attention/dense_2/bias"]]
        """

        def extend(a):
            out = []
            for sublist in a:
                out.extend(sublist)
            return out
        attention_names = extend(attention_names)
        #print("attention_names", attention_names)
        #print("~"*20)
        svars = tf.trainable_variables()
        #print("variables", svars)
        #print("~"*20)
        variables_in_checkpoint = [name for name, shape in tf.train.list_variables(dir_path)]
        #print("variables_in_checkpoint", variables_in_checkpoint)
        #print("~"*20)
        load_dict = {}
        for attention_name in attention_names:
            var_name_list = []
            for name in variables_in_checkpoint:
                if name.endswith(attention_name):
                    var_name_list.append(name)
            #print("var_name", var_name_list)
            #print("~"*20)
            var_name = var_name_list[0]
            #var_name = [var_name for var_name in variables_in_checkpoint if
            #       var_name.endswith(attention_name)][0]
            var_list = []
            for tmp in svars:
                print(tmp.name)
                #if attention_name in tmp.name:
                #    var_list.append(tmp)

                # from king
                if "multi_head_attention/dense/kernel" in tmp.name and "fc_encoder_layer0_q_weight" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense/bias" in tmp.name and "fc_encoder_layer0_q_bias" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_1/kernel" in tmp.name and "fc_encoder_layer0_k_weight" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_1/bias" in tmp.name and "fc_encoder_layer0_k_bias" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_2/kernel" in tmp.name and "fc_encoder_layer0_v_weight" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_2/bias" in tmp.name and "fc_encoder_layer0_v_bias" in attention_name:
                    var_list.append(tmp)

                """
                # from smac
                if "multi_head_attention/dense/kernel" in tmp.name and "multi_head_attention/dense/kernel" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense/bias" in tmp.name and "multi_head_attention/dense/bias" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_1/kernel" in tmp.name and "multi_head_attention/dense_1/kernel" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_1/bias" in tmp.name and "multi_head_attention/dense_1/bias" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_2/kernel" in tmp.name and "multi_head_attention/dense_2/kernel" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_2/bias" in tmp.name and "multi_head_attention/dense_2/bias" in attention_name:
                    var_list.append(tmp)
                """

            #print("var", var_list)
            #print("~"*20)
            var = var_list[0]
            #var = [var for var in svars if attention_name in var.name][0]
            load_dict.update({var_name: var})
        #for k, v in load_dict.items():
        #    print("### load var_name: {}, var: {}".format(k,v))
        #self.saver = tf.train.Saver(load_dict, max_to_keep=0)
        #self.saver.restore(sess, dir_path)
        #print("restore rl model successfully")

        """
        # load all network
        svars = tf.trainable_variables()
        load_dict = {}
        for name, shape in tf.train.list_variables(dir_path):
            for tmp in svars:
                if name in tmp.name:
                    load_dict[name] = tmp
        print("~~~~~~~~~~~"*20)
        print(load_dict)
        """

        return load_dict
