import tensorflow as tf
from transformers import OpenAIGPTConfig
from transformers.activations_tf import get_tf_activation
from transformers.modeling_tf_utils import (
    TFConv1D,
    TFSharedEmbeddings,
    get_initializer,
    shape_list,
)


class TFAttention(tf.keras.layers.Layer):
    def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
        super().__init__(**kwargs)

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert (
            n_state % config.n_head == 0
        ), f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}"
        self.n_ctx = n_ctx
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
        self.output_attentions = config.output_attentions

        self.c_attn = TFConv1D(
            n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn"
        )
        self.c_proj = TFConv1D(
            n_state, nx, initializer_range=config.initializer_range, name="c_proj"
        )
        self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
        self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        pass

    @staticmethod
    def causal_attention_mask(nd, ns):
        """
        1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
        -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        i = tf.range(nd)[:, None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return m

    def _attn(
        self, q, k, v, attention_mask, head_mask, output_attentions, training=False
    ):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        if self.scale:
            dk = tf.cast(shape_list(k)[-1], dtype=w.dtype)  # scale attention_scores
            w = w / tf.math.sqrt(dk)

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w * b - 1e4 * (1 - b)

        if attention_mask is not None:
            # Apply the attention mask
            attention_mask = tf.cast(attention_mask, dtype=w.dtype)
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
        w = self.attn_dropout(w, training=training)

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

        outputs = [tf.matmul(w, v)]
        if output_attentions:
            outputs.append(w)
        return outputs

    def merge_heads(self, x):
        x = tf.transpose(x, [0, 2, 1, 3])
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
        return tf.reshape(x, new_x_shape)

    def split_heads(self, x):
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)

    def call(self, x, attention_mask, head_mask, output_attentions, training=False):
        x = self.c_attn(x)
        query, key, value = tf.split(x, 3, axis=2)
        query = self.split_heads(query)
        key = self.split_heads(key)
        value = self.split_heads(value)

        attn_outputs = self._attn(
            query,
            key,
            value,
            attention_mask,
            head_mask,
            output_attentions,
            training=training,
        )
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a, training=training)

        outputs = [a] + attn_outputs[1:]
        return outputs  # a, (attentions)


class TFMLP(tf.keras.layers.Layer):
    def __init__(self, n_state, config, **kwargs):
        super().__init__(**kwargs)
        nx = config.n_embd
        self.c_fc = TFConv1D(
            n_state, nx, initializer_range=config.initializer_range, name="c_fc"
        )
        self.c_proj = TFConv1D(
            nx, n_state, initializer_range=config.initializer_range, name="c_proj"
        )
        self.act = get_tf_activation("gelu")
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)

    def call(self, x, training=False):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        h2 = self.dropout(h2, training=training)
        return h2


class TFBlock(tf.keras.layers.Layer):
    def __init__(self, n_ctx, config, scale=False, **kwargs):
        super().__init__(**kwargs)
        nx = config.n_embd
        self.attn = TFAttention(nx, n_ctx, config, scale, name="attn")
        self.ln_1 = tf.keras.layers.LayerNormalization(
            epsilon=config.layer_norm_epsilon, name="ln_1"
        )
        self.mlp = TFMLP(4 * nx, config, name="mlp")
        self.ln_2 = tf.keras.layers.LayerNormalization(
            epsilon=config.layer_norm_epsilon, name="ln_2"
        )

    def call(self, x, attention_mask, head_mask, output_attentions, training=False):
        output_attn = self.attn(
            x, attention_mask, head_mask, output_attentions, training=training
        )
        a = output_attn[0]  # output_attn: a, (attentions)

        n = self.ln_1(x + a)
        m = self.mlp(n, training=training)
        h = self.ln_2(n + m)

        outputs = [h] + output_attn[1:]
        return outputs  # x, (attentions)


class TFPositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

        self.config = config
        self.n_embd = config.n_embd
        self.n_positions = config.n_positions
        self.initializer_range = config.initializer_range

    def build(self, input_shape):
        with tf.name_scope("positions_embed"):
            self.positions_embed = self.add_weight(
                name="embeddings",
                shape=[self.n_positions, self.n_embd],
                initializer=get_initializer(self.initializer_range),
            )

        super().build(input_shape)

    def call(self, position_ids):
        position_embeds = tf.gather(self.positions_embed, position_ids)
        return position_embeds


class TFTokenAndPositionEmbedding(tf.keras.layers.Layer):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

        self.n_positions = config.n_positions

        self.tokens_embed = TFSharedEmbeddings(
            config.vocab_size,
            config.n_embd,
            initializer_range=config.initializer_range,
            name="tokens_embed",
        )
        self.positions_embed = TFPositionalEmbedding(config)

    def call(self, x):
        positions = tf.expand_dims(tf.range(self.n_positions), axis=0)
        x = self.tokens_embed(x, mode="embedding") + self.positions_embed(positions)
        return x


def create_huggingface_gpt1(
    embed_dim,
    n_layers,
    n_heads=None,
    vocab_size=20000,
    seq_len=256,
):
    if n_heads is None:
        assert embed_dim % 64 == 0
        n_heads = embed_dim // 64

    config = OpenAIGPTConfig(
        vocab_size=vocab_size,
        n_positions=seq_len,
        n_embd=embed_dim,
        n_layer=n_layers,
        n_head=n_heads,
    )

    # set up layers
    embed = TFTokenAndPositionEmbedding(config)
    drop = tf.keras.layers.Dropout(config.embd_pdrop)
    blocks = [
        TFBlock(config.n_ctx, config, scale=True, name=f"h_._{i}")
        for i in range(config.n_layer)
    ]

    input_ids = tf.keras.layers.Input((seq_len,), dtype=tf.int32)
    hidden_states = embed(input_ids)
    hidden_states = drop(hidden_states)
    for i, block in enumerate(blocks):
        outputs = block(
            hidden_states,
            attention_mask=None,
            head_mask=None,
            output_attentions=False,
        )
        hidden_states = outputs[0]

    logits = embed.tokens_embed(hidden_states, mode="linear")

    model = tf.keras.Model(input_ids, logits, name=f"hf-GPT1-{embed_dim}-{n_layers}")
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
    return model
