import keras_nlp
import tensorflow as tf
from tf_agents.networks import sequential
from tf_agents.specs import tensor_spec

from config import FC_LAYER_PARAMS


def init_q_net(env):
    action_tensor_spec = tensor_spec.from_spec(env.action_spec())
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    # Define a helper function to create Dense layers configured with the right
    # activation and kernel initializer.
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0, mode="fan_in", distribution="truncated_normal"
            ),
        )

    emb = tf.keras.layers.Embedding(8, 20, input_length=env.obs_shape)
    mean_emb = tf.keras.layers.Lambda(lambda x: tf.keras.backend.max(x, axis=1))
    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # its output.
    dense_layers = [dense_layer(num_units) for num_units in FC_LAYER_PARAMS]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(
            minval=-0.03, maxval=0.03
        ),
        bias_initializer=tf.keras.initializers.Constant(-0.2),
    )
    q_net = sequential.Sequential(dense_layers + [q_values_layer])
    return q_net


def init_transformer_network(env):
    action_tensor_spec = tensor_spec.from_spec(env.action_spec())
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    input_length = env.observation_spec().shape[0]
    vocab_size = env.observation_spec().maximum + 1
    hidden_dim = 64

    # emb =tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
    # pos_emb = keras_nlp.layers.PositionEmbedding(sequence_length=input_length)
    lin_hidden = tf.keras.layers.Dense(
        hidden_dim,
        activation="relu",
        kernel_initializer=tf.keras.initializers.VarianceScaling(
            scale=2.0, mode="fan_in", distribution="truncated_normal"
        ),
    )

    unsqueeze_layer = tf.keras.layers.Reshape((1, hidden_dim))

    encoder = keras_nlp.layers.TransformerEncoder(
        intermediate_dim=hidden_dim, num_heads=4
    )

    final_layer = tf.keras.layers.Dense(num_actions)

    squeeze_layer = tf.keras.layers.Reshape((num_actions,))

    q_net = sequential.Sequential(
        [lin_hidden] + [unsqueeze_layer] + [encoder] + [final_layer] + [squeeze_layer]
    )
    """
    B: batch size
    hidden: hidden dimension
    1. linear layer with activation function: input [B, P] -> sequence [B, hidden]
       P usually is the number of players
    2. TransformerEncoder: input [B, 1, hidden] -> output [B, 1, hidden]
    3. Squeeze layer: input [B, 1, hidden] -> output [B, hidden]
    4. Get action: input [B, hidden] -> [B, action_space]
    """
    return q_net


if __name__ == "__main__":
    encoder = keras_nlp.layers.TransformerEncoder(intermediate_dim=64, num_heads=8)

    q_values_layer = tf.keras.layers.Dense(
        1,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(
            minval=-0.03, maxval=0.03
        ),
        bias_initializer=tf.keras.initializers.Constant(-0.2),
    )
