import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import (
    Dense,
    Dropout,
    Embedding,
    Input,
    Layer,
    LayerNormalization,
    MultiHeadAttention,
)
from tensorflow.keras.losses import SparseCategoricalCrossentropy

from microsoft_nlp.layers import de_embedding


def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)


class GptBlock_minimal(Layer):
    def __init__(self, embed_dim, n_heads):
        super().__init__()

        assert (
            embed_dim % n_heads == 0
        ), "Embed dim must be divisible by the number of heads"

        self.initializer = RandomNormal(0, 0.02)
        self.att = MultiHeadAttention(
            n_heads, embed_dim // n_heads, kernel_initializer=self.initializer
        )

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)

        x_skip = inputs
        x = self.att(x_skip, x_skip, attention_mask=causal_mask)
        return x


class Gpt1Block(Layer):
    def __init__(self, embed_dim, n_heads, ff_dim, dropout_rate=0.1, act_fn="gelu"):
        super().__init__()

        assert (
            embed_dim % n_heads == 0
        ), "Embed dim must be divisible by the number of heads"

        self.initializer = RandomNormal(0, 0.02)
        self.att = MultiHeadAttention(
            n_heads, embed_dim // n_heads, kernel_initializer=self.initializer
        )
        self.ffn = Sequential(
            [
                Dense(ff_dim, activation=act_fn, kernel_initializer=self.initializer),
                Dense(embed_dim, kernel_initializer=self.initializer),
            ]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout_rate)
        self.dropout2 = Dropout(dropout_rate)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)

        x_skip = inputs
        x = self.att(x_skip, x_skip, attention_mask=causal_mask)
        x = self.dropout1(x)
        x_skip = self.layernorm1(x + x_skip)
        x = self.ffn(x_skip)
        x = self.dropout2(x)
        x = self.layernorm2(x + x_skip)
        return x


class Gpt2Block(Layer):
    def __init__(
        self, embed_dim, n_layers, n_heads, ff_dim, dropout_rate=0.1, act_fn="gelu"
    ):
        super().__init__()

        assert (
            embed_dim % n_heads == 0
        ), "Embed dim must be divisible by the number of heads"

        self.initializer = RandomNormal(
            0, 0.02 * tf.math.rsqrt(tf.cast(n_layers, tf.float32))
        )
        self.att = MultiHeadAttention(
            n_heads, embed_dim // n_heads, kernel_initializer=self.initializer
        )
        self.ffn = Sequential(
            [
                Dense(ff_dim, activation=act_fn, kernel_initializer=self.initializer),
                Dense(embed_dim, kernel_initializer=self.initializer),
            ]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout_rate)
        self.dropout2 = Dropout(dropout_rate)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)

        x_skip = inputs
        x = self.layernorm1(x_skip)
        x = self.att(x, x, attention_mask=causal_mask)
        x = self.dropout1(x)
        x_skip = x + x_skip
        x = self.layernorm2(x_skip)
        x = self.ffn(x)
        x = self.dropout2(x)
        x = x + x_skip
        return x


class TokenAndPositionEmbedding(Layer):
    def __init__(self, maxlen, vocab_size, embed_dim, dropout_rate=0.1):
        super().__init__()
        self.initializer = RandomNormal(0, 0.02)
        self.token_emb = Embedding(
            input_dim=vocab_size,
            output_dim=embed_dim,
            embeddings_initializer=self.initializer,
        )
        self.pos_emb = Embedding(
            input_dim=maxlen,
            output_dim=embed_dim,
            embeddings_initializer=self.initializer,
        )
        self.dropout = Dropout(dropout_rate)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x) + positions
        x = self.dropout(x)
        return x


def create_model_gpt1(
    embed_dim,
    n_layers,
    n_heads,
    ff_dim,
    vocab_size,
    seq_len,
    weight_tying,
):
    inputs = Input((seq_len,))
    embed_layer = TokenAndPositionEmbedding(seq_len, vocab_size, embed_dim)
    x = embed_layer(inputs)
    for i in range(n_layers):
        x = Gpt1Block(embed_dim, n_heads, ff_dim)(x)
    out = de_embedding(x, embed_layer, vocab_size, weight_tying)
    model = Model(inputs, out, name=f"gpt1-{embed_dim}-{n_layers}-{n_heads}-{ff_dim}")
    model.summary()
    model.compile(
        optimizer="adam", loss=SparseCategoricalCrossentropy(from_logits=True)
    )
    return model


def create_model_gpt2(
    embed_dim,
    n_layers,
    n_heads,
    ff_dim,
    vocab_size,
    seq_len,
    weight_tying,
):
    inputs = Input((seq_len,))
    embed_layer = TokenAndPositionEmbedding(seq_len, vocab_size, embed_dim)
    x = embed_layer(inputs)
    for i in range(n_layers):
        x = Gpt2Block(embed_dim, n_layers, n_heads, ff_dim)(x)
    x = LayerNormalization(epsilon=1e-6)(x)
    out = de_embedding(x, embed_layer, vocab_size, weight_tying)
    model = Model(inputs, out, name=f"gpt2-{embed_dim}-{n_layers}-{n_heads}-{ff_dim}")
    model.summary()
    model.compile(
        optimizer="adam", loss=SparseCategoricalCrossentropy(from_logits=True)
    )
    return model
