from lib2to3.pgen2 import token
import gym
import os
import numpy as np
from typing import Dict, Optional, Sequence

from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.utils.typing import TensorType, ModelConfigDict



from models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.layers import SkipConnection
from utils.framework import try_import_tf
from utils.torch_utils import FLOAT_MIN
from utils.typing import ModelConfigDict

tf1, tf, tfv = try_import_tf()

class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
    def __init__(self,
                 out_dim: int,
                 num_heads: int,
                 head_dim: int,
                 input_layernorm: bool = False,
                 output_activation: Optional["tf.nn.activation"] = None,
                 **kwargs):
        """Initializes a RelativeMultiHeadAttention keras Layer object.

        Args:
            out_dim (int): The output dimensions of the multi-head attention
                unit.
            num_heads (int): The number of attention heads to use.
                Denoted `H` in [2].
            head_dim (int): The dimension of a single(!) attention head within
                a multi-head attention unit. Denoted as `d` in [3].
            input_layernorm (bool): Whether to prepend a LayerNorm before
                everything else. Should be True for building a GTrXL.
            output_activation (Optional[tf.nn.activation]): Optional tf.nn
                activation function. Should be relu for GTrXL.
            **kwargs:
        """
        super().__init__(**kwargs)
        
        self.num_heads = num_heads
        self.head_dim = head_dim
        # 3=Query, key, and value inputs.
        self.q_layer = tf.keras.layers.Dense(num_heads * head_dim, use_bias=True)
        self.k_layer = tf.keras.layers.Dense(num_heads * head_dim, use_bias=True)
        self.v_layer = tf.keras.layers.Dense(num_heads * head_dim, use_bias=True)
        
        self.linear_layer = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                out_dim, use_bias=False, activation=output_activation))
        
    def call(self, inputs: TensorType,
             memory: Optional[TensorType] = None) -> TensorType:
        T = tf.shape(inputs)[1]  # length of segment (time)
        H = self.num_heads  # number of attention heads
        d = self.head_dim  # attention head dimension
        
        queries = self.q_layer(inputs)
        keys = self.k_layer(inputs)
        values = self.v_layer(inputs)
        
        queries = tf.reshape(queries, [-1, T, H, d])
        keys = tf.reshape(keys, [-1, T, H, d])
        values = tf.reshape(values, [-1, T, H, d])
        
        score = tf.einsum("bihd,bjhd->bijh", queries, keys)
        score = score / d**0.5
        
        wmat = tf.nn.softmax(score, axis=2)

        out = tf.einsum("bijh,bjhd->bihd", wmat, values)
        out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * d]), axis=0))
        return self.linear_layer(out) 

class TFNMMOModel(TFModelV2):
    """Tensorflow version of above ActionMaskingModel."""

    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        **kwargs
    ):
        super(TFNMMOModel, self).__init__(
            obs_space, action_space, num_outputs, model_config, name)
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (
            isinstance(orig_space, gym.spaces.Dict)
            and "terrain" in orig_space.spaces
            and "camp" in orig_space.spaces
            and "entity" in orig_space.spaces
            and "va" in orig_space.spaces
        )

        token_dim = kwargs.get("token_dim", 256)
        num_heads = kwargs.get("num_heads", 4)
        head_dim = kwargs.get("head_dim", 256)
        self.num_att_layers = 1
        
        activation = get_activation_fn(
            kwargs.get("conv_activation"), framework="tf")
        input_shape = kwargs.get("input_shape")
            
        activation = model_config.get("activation")
        activation = get_activation_fn(activation)

        terrain_input = tf.keras.layers.Input(shape=input_shape, name="terrain_input")
        camp_input = tf.keras.layers.Input(shape=input_shape, name="camp_input")
        entity_input = tf.keras.layers.Input(shape=[7,]+input_shape, name="entity_input")
        entity_input_tmp = tf.transpose(entity_input, [0, 2, 3, 1])
        
        entity_input_list = tf.split(entity_input_tmp, num_or_size_splits=7, axis=-1)
        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False

        feature_output_list = []
        
        # make sure: flatten dim == token_dim
        terrain_input_tmp = tf.one_hot(tf.cast(terrain_input, tf.int32), depth=6)
        terrain_feature = tf.keras.layers.Conv2D(
            16,
            (6, 6),
            strides=4,
            activation=activation,
            padding="same",
            data_format="channels_last",
            name="conv_terrain_0")(terrain_input_tmp)
        feature_output_list.append(tf.keras.layers.Flatten(
                    data_format="channels_last")(terrain_feature))
        
        camp_input_tmp = tf.one_hot(tf.cast(camp_input, tf.int32), depth=4)
        camp_feature = tf.keras.layers.Conv2D(
            16,
            (6, 6),
            strides=4,
            activation=activation,
            padding="same",
            data_format="channels_last",
            name="conv_camp_0")(camp_input_tmp)
        feature_output_list.append(tf.keras.layers.Flatten(
                    data_format="channels_last")(camp_feature))
        
        entity_layer = tf.keras.layers.Conv2D(
                16,
                (6, 6),
                strides=4,
                activation=activation,
                padding="same",
                data_format="channels_last",
                name="conv_entity_0")
        for each_entity_input in entity_input_list:
            entity_feature = entity_layer(each_entity_input)
            feature_output_list.append(tf.keras.layers.Flatten(
                    data_format="channels_last")(entity_feature))

        for feature in feature_output_list:
            assert feature.shape[-1].value == token_dim
            
        # [batch_size, length, token_size]
        feature_out = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=-2))(feature_output_list)
        last_layer = feature_out
        
        for _ in range(self.num_att_layers):
            last_layer = SkipConnection(
                    MultiHeadAttention(
                        out_dim=head_dim,
                        num_heads=num_heads,
                        head_dim=head_dim,
                        input_layernorm=False,
                        output_activation=None),
                    fan_in_layer=None)(last_layer)

        # avgpool
        last_layer = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=1))(last_layer)

        # The action distribution outputs.
        logits_out = None  
        
        logits_out = tf.keras.layers.Dense(
            num_outputs,
            name="fc_out",
            activation=None,
            kernel_initializer=normc_initializer(0.01))(last_layer)

        value_out = tf.keras.layers.Dense(
            1,
            name="value_out",
            activation=None,
            kernel_initializer=normc_initializer(0.01))(last_layer)

        self.base_model = tf.keras.Model(
            [terrain_input, camp_input, entity_input], [logits_out, value_out])
        
        # disable action masking --> will likely lead to invalid actions
        self.no_masking = False
        if "no_masking" in model_config["custom_model_config"]:
            self.no_masking = model_config["custom_model_config"]["no_masking"]

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["va"]

        # Compute the unmasked logits.
        logits, self._value_out = self.base_model([input_dict["obs"]["terrain"],
                                     input_dict["obs"]["camp"],
                                     input_dict["obs"]["entity"]])

        # If action masking is disabled, directly return unmasked logits
        if self.no_masking:
            return logits, state

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = tf.clip_by_value(tf.math.log(
            action_mask), clip_value_min=FLOAT_MIN, clip_value_max=0.0)
        masked_logits = logits + inf_mask

        # Return masked logits.
        return masked_logits, state

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    # TODO: add partial import code
    def import_partially(self, dir_path):
        # from king
        # attention_names = [["fc_encoder_layer{}_q_weight".format(i), "fc_encoder_layer{}_q_bias".format(i),
        #                     "fc_encoder_layer{}_k_weight".format(i), "fc_encoder_layer{}_k_bias".format(i),
        #                     "fc_encoder_layer{}_v_weight".format(i), "fc_encoder_layer{}_v_bias".format(i),
        #                     #"fc_encoder_layer{}_ffn1_weight".format(i), "fc_encoder_layer{}_ffn1_bias".format(i),
        #                     #"fc_encoder_layer{}_ffn2_weight".format(i), "fc_encoder_layer{}_ffn2_bias".format(i)\
        #                   ] for i in range(self.num_att_layers)]
        
        # from smac
        attention_names = [["multi_head_attention/dense/kernel","multi_head_attention/dense/bias",
            "multi_head_attention/dense_1/kernel","multi_head_attention/dense_1/bias",
            "multi_head_attention/dense_2/kernel","multi_head_attention/dense_2/bias"]]
        
        # from nmmo
        # conv_names = [["conv_camp_0/kernel", "conv_camp_0/bias",
        # "conv_entity_0/bias", "conv_entity_0/kernel", 
        # "conv_terrain_0/bias", "conv_terrain_0/kernel"]]

        def extend(a):
            out = []
            for sublist in a:
                out.extend(sublist)
            return out
        attention_names = extend(attention_names)
        #print("attention_names", attention_names)
        #print("~"*20)
        svars = tf.trainable_variables()
        #print("variables", svars)
        #print("~"*20)
        weight_dir = os.path.join(dir_path, 'model')
        print("~"*20, weight_dir)
        variables_in_checkpoint = [name for name, shape in tf.train.list_variables(weight_dir)]
        #print("variables_in_checkpoint", variables_in_checkpoint)
        #print("~"*20)
        load_dict = {}
        for attention_name in attention_names:
            var_name_list = []
            for name in variables_in_checkpoint:
                if name.endswith(attention_name):
                    var_name_list.append(name)
            #print("var_name", var_name_list)
            #print("~"*20)
            var_name = var_name_list[0]
            #var_name = [var_name for var_name in variables_in_checkpoint if
            #       var_name.endswith(attention_name)][0]
            var_list = []
            for tmp in svars:
                print(tmp.name)
                #if attention_name in tmp.name:
                #    var_list.append(tmp)

                # from king
                # if "multi_head_attention/dense/kernel" in tmp.name and "fc_encoder_layer0_q_weight" in attention_name:
                #     var_list.append(tmp)
                # if "multi_head_attention/dense/bias" in tmp.name and "fc_encoder_layer0_q_bias" in attention_name:
                #     var_list.append(tmp)
                # if "multi_head_attention/dense_1/kernel" in tmp.name and "fc_encoder_layer0_k_weight" in attention_name:
                #     var_list.append(tmp)
                # if "multi_head_attention/dense_1/bias" in tmp.name and "fc_encoder_layer0_k_bias" in attention_name:
                #     var_list.append(tmp)
                # if "multi_head_attention/dense_2/kernel" in tmp.name and "fc_encoder_layer0_v_weight" in attention_name:
                #     var_list.append(tmp)
                # if "multi_head_attention/dense_2/bias" in tmp.name and "fc_encoder_layer0_v_bias" in attention_name:
                #     var_list.append(tmp)

                
                # from smac
                if "multi_head_attention/dense/kernel" in tmp.name and "multi_head_attention/dense/kernel" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense/bias" in tmp.name and "multi_head_attention/dense/bias" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_1/kernel" in tmp.name and "multi_head_attention/dense_1/kernel" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_1/bias" in tmp.name and "multi_head_attention/dense_1/bias" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_2/kernel" in tmp.name and "multi_head_attention/dense_2/kernel" in attention_name:
                    var_list.append(tmp)
                if "multi_head_attention/dense_2/bias" in tmp.name and "multi_head_attention/dense_2/bias" in attention_name:
                    var_list.append(tmp)
                
                # if "conv_camp_0/kernel" in tmp.name and "conv_camp_0/kernel" in attention_name:
                #     var_list.append(tmp)
                # if "conv_camp_0/bias" in tmp.name and "conv_camp_0/bias" in attention_name:
                #     var_list.append(tmp)
                # if "conv_entity_0/kernel" in tmp.name and "conv_entity_0/kernel" in attention_name:
                #     var_list.append(tmp)
                # if "conv_entity_0/bias" in tmp.name and "conv_entity_0/bias" in attention_name:
                #     var_list.append(tmp)
                # if "conv_terrain_0/kernel" in tmp.name and "conv_terrain_0/kernel" in attention_name:
                #     var_list.append(tmp)
                # if "conv_terrain_0/bias" in tmp.name and "conv_terrain_0/bias" in attention_name:
                #     var_list.append(tmp)
                

            print("var", var_list)
            print("~"*20)
            var = var_list[0]
            #var = [var for var in svars if attention_name in var.name][0]
            load_dict.update({var_name: var})
        #for k, v in load_dict.items():
        #    print("### load var_name: {}, var: {}".format(k,v))
        #self.saver = tf.train.Saver(load_dict, max_to_keep=0)
        #self.saver.restore(sess, dir_path)
        #print("restore rl model successfully")

        """
        # load all network
        svars = tf.trainable_variables()
        load_dict = {}
        for name, shape in tf.train.list_variables(dir_path):
            for tmp in svars:
                if name in tmp.name:
                    load_dict[name] = tmp
        print("~~~~~~~~~~~"*20)
        print(load_dict)
        """

        return load_dict
