import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (
    LSTM,
    Activation,
    Dense,
    Embedding,
    Input,
    LayerNormalization,
    MultiHeadAttention,
    Reshape,
)
from tensorflow.keras.losses import SparseCategoricalCrossentropy

from microsoft_nlp.gpt_models import (
    GptBlock_minimal,
    TokenAndPositionEmbedding,
    create_model_gpt1,
    create_model_gpt2,
)
from microsoft_nlp.huggingface_gpt1 import create_huggingface_gpt1
from microsoft_nlp.layers import (
    FFN,
    LMUD,
    LMUE,
    LMUContraction,
    Scaling,
    de_embedding,
    lmu_layer,
)


def create_lmud(
    vocab_size,
    seq_len,
    n_layers,
    activation,
    order,
    theta,
    theta_factor,
    embed_dim,
    n_filters,
    share_filters,
    lmud_order,
    post_ffn,
    pre_ffn,
    n_heads,
    weight_tying,
    option2,
    lmue_like,
    eqn11,
    order_increment,
):

    inputs = Input((seq_len,))

    embed_layer = Embedding(seq_len, vocab_size, embed_dim)
    x = embed_layer(inputs)

    for n in range(n_layers):
        theta_i = int(round(theta * theta_factor ** n))
        order_i = order + order_increment * n
        assert theta_i >= order, "Silly to have theta < order"
        if eqn11 == 3:
            if n_heads is None:
                x = Scaling()(x) + FFN(
                    embed_dim, inner_ratio=pre_ffn, activation=activation
                )(x)
            else:
                x = Scaling()(x) + GptBlock_minimal(embed_dim, n_heads)(x)
        skip_x = x
        x = lmu_layer(order_i, theta_i)(x)
        x = Reshape((-1, embed_dim, order_i))(x)
        x = LMUD(
            d=embed_dim,
            order=int(lmud_order * order_i),
            activation=activation,
            # post_ffn=post_ffn,
            n_filters=n_filters,
            share_filters=share_filters,
            second_gate=option2,
            lmue_like=lmue_like,
            eqn11=eqn11,
        )(x)
        x = Scaling()(skip_x) + x
        x = Scaling()(x) + FFN(embed_dim, inner_ratio=post_ffn, activation=activation)(
            x
        )
        x = LayerNormalization()(x)
    x = Dense(embed_dim)(x)
    out = de_embedding(
        inputs=x,
        embed_layer=embed_layer,
        vocab_size=vocab_size,
        weight_tying=weight_tying,
    )
    # TODO: Auto-generate name.
    name = f"lmud-{embed_dim}-{n_filters}-{n_layers}-{order}-{lmud_order}-{theta}"
    model = Model(inputs, out, name=name)
    model.compile(
        optimizer="adam", loss=SparseCategoricalCrossentropy(from_logits=True)
    )
    return model


def create_lmue(
    vocab_size,
    seq_len,
    n_layers,
    activation,
    order,
    theta,
    theta_factor,
    embed_dim,
    ff_dim,
    n_filters,
    share_filters,
    layernorm,
    positional_embed,
    lmud_order,
    pre_ffn,
    post_ffn,
    weight_tying,
    option2=False,
):
    """LMUD variant with both pre and post FFNs"""
    inputs = Input((seq_len,))

    if positional_embed:
        embed_layer = TokenAndPositionEmbedding(
            seq_len, vocab_size, embed_dim, dropout_rate=0
        )
    else:
        embed_layer = Embedding(vocab_size, embed_dim)

    x = embed_layer(inputs)

    for n in range(n_layers):
        skip_x = x

        if layernorm:
            x = LayerNormalization()(x)

        if pre_ffn > 0:
            x = FFN(ff_dim, inner_ratio=pre_ffn, activation=activation)(x)

        theta_i = int(round(theta * theta_factor ** n))
        assert theta_i >= order, "Silly to have theta < order"
        x = lmu_layer(order, theta_i)(x)
        x = Reshape((-1, ff_dim, order))(x)
        x = LMUE(
            d=ff_dim,
            out_d=embed_dim,
            order=int(lmud_order * order),
            activation=activation,
            post_ffn=post_ffn,
            n_filters=n_filters,
            share_filters=share_filters,
            second_gate=option2,
        )(x)
        x = x + skip_x

    out = de_embedding(
        inputs=x,
        embed_layer=embed_layer,
        vocab_size=vocab_size,
        weight_tying=weight_tying,
    )
    # TODO: Auto-generate name.
    name = f"lmue-{embed_dim}-{ff_dim}-{n_filters}-{n_layers}-{order}-{theta}"
    model = Model(inputs, out, name=name)
    model.compile(
        optimizer="adam", loss=SparseCategoricalCrossentropy(from_logits=True)
    )
    return model


def create_lmu_mlp(
    vocab_size,
    seq_len,
    n_layers,
    activation,
    order,
    theta,
    theta_factor,
    embed_dim,
    ff_dim,
    n_filters,
    share_filters,
    # gating,
    layernorm,
    weight_tying,
):
    """Creates an LMU motivated by 'Pay Attention to MLPs' paper."""
    inputs = Input((seq_len,))
    embed_layer = Embedding(vocab_size, embed_dim)
    x = embed_layer(inputs)

    for n in range(n_layers):
        skip_x = x

        # dense layer across channels
        if layernorm:
            x = LayerNormalization()(x)
        x = Dense(ff_dim)(x)
        x = Activation(activation)(x)

        # lmu across time points
        pre_x = x
        if layernorm:
            x = LayerNormalization()(x)

        theta_i = int(round(theta * theta_factor ** n))
        assert theta_i >= order, "Silly to have theta < order"
        x = lmu_layer(order, theta_i)(x)
        x = LMUContraction(
            ff_dim,
            order,
            n_filters,
            share_filters=share_filters,
            activation=None,
            has_dense_output=False,
        )(x)
        pre_x = tf.expand_dims(pre_x, axis=-1)
        x = pre_x * x

        # dense layer across channels
        x = Reshape((seq_len, ff_dim * n_filters))(x)
        x = Dense(embed_dim)(x)

        x = x + skip_x

    out = de_embedding(
        inputs=x,
        embed_layer=embed_layer,
        vocab_size=vocab_size,
        weight_tying=weight_tying,
    )

    name = f"lmu-mlp-{embed_dim}-{ff_dim}-{n_filters}-{n_layers}-{order}-{theta}"
    model = Model(inputs, out, name=name)
    model.compile(
        optimizer="adam", loss=SparseCategoricalCrossentropy(from_logits=True)
    )
    return model


def create_lstm(
    vocab_size,
    seq_len,
    n_layers,
    embed_dim,
    hidden_dim,
    weight_tying,
):
    inputs = Input((seq_len,))
    embed_layer = Embedding(vocab_size, embed_dim)
    x = embed_layer(inputs)
    for i in range(n_layers):
        x = LSTM(hidden_dim, return_sequences=True)(x)
    if hidden_dim != embed_dim:
        x = Dense(embed_dim)(x)
    out = de_embedding(x, embed_layer, vocab_size, weight_tying)
    model = Model(inputs, out, name=f"lstm-{embed_dim}-{hidden_dim}-{n_layers}")
    model.summary()
    model.compile(
        optimizer="adam", loss=SparseCategoricalCrossentropy(from_logits=True)
    )
    return model
