import tensorflow as tf
from numpy.core.fromnumeric import transpose
from tensorflow.keras import Sequential
from tensorflow.keras.layers import (
    Activation,
    Dense,
    Layer,
    LayerNormalization,
    MultiHeadAttention,
    Reshape,
    Softmax,
    TimeDistributed,
)

from microsoft_nlp.lmu_fft import LMUFFT


class Scaling(Layer):
    """Layer introduces a learnable scaling parameter that scales the inputs
    provided to the layer. Intended to be used with skip connections as
    follows: x + Scaling(Some_Layer(x))."""

    def __init__(self, initial_value=5.0, **kwargs):
        """
        inital_value: float, default is 5.0 because sigmoid(5.0) ~ 1
        """
        super().__init__(**kwargs)
        self.initializer = tf.keras.initializers.Constant(
            value=tf.cast(initial_value, tf.float32)
        )

        self.scaling_parameter = self.add_weight(
            name="scaling parameter",
            shape=(1,),
            initializer=self.initializer,
            dtype="float32",
            trainable=True,
        )

    def call(self, inputs):
        return tf.keras.activations.sigmoid(self.scaling_parameter) * inputs


class DenseTranspose(Layer):
    """Dense layer that applies the transpose of a tied layer's weights."""

    # https://medium.com/@lmayrandprovencher/building-an-autoencoder-with-tied-weights-in-keras-c4a559c529a2

    def __init__(
        self, tied_layer, activation=None, use_bias=True, shape_out=None, **kwargs
    ):
        super().__init__(**kwargs)
        self.tied_layer = tied_layer
        self.activation = Activation(activation)
        self.use_bias = use_bias
        # shape_out is needed for some tied layers, such as embedding layers, since
        #  they can have a different input shape (e.g., scalar input). Note that
        #  shape_out has no effect if use_bias == False.
        self.shape_out = tied_layer.input_shape[-1] if shape_out is None else shape_out

    def build(self, batch_input_shape):
        if self.use_bias:
            self.biases = self.add_weight(
                name="bias", shape=self.shape_out, initializer="zeros"
            )
        super().build(batch_input_shape)

    def call(self, inputs):
        z = tf.matmul(inputs, self.tied_layer.weights[0], transpose_b=True)
        if self.use_bias:
            z += self.biases
        return self.activation(z)


def de_embedding(inputs, embed_layer, vocab_size, weight_tying):
    """Typically the final output layer to convert word embeddings back to one-hot.

    If ``weight_tying == True`` then it uses the input embedding matrix to compute
    logits. Otherwise, it uses a regular dense layer (with bias).
    """
    if weight_tying:
        # assert len(embed_layer.weights) == 1
        # TODO: Should use_bias be True or False?
        out = DenseTranspose(embed_layer, use_bias=False, shape_out=vocab_size)(inputs)
    else:
        out = Dense(vocab_size)(inputs)
    return out


def lmu_layer(order, theta):
    """Delay network (also known as the LMU or Legendre Delay Network)."""
    return LMUFFT(
        order=order, theta=theta, input_to_hidden=False, return_sequences=True
    )


class LMUContraction(Layer):
    """Efficiently reduces the rank of the memory matrix from the LMU by one."""

    def __init__(
        self,
        d,
        order,
        n_filters,
        share_filters=False,
        activation="gelu",
        has_dense_output=True,
    ):
        super().__init__()
        self.d = d
        self.order = order
        self.n_filters = n_filters
        self.share_filters = share_filters

        self.activation = None if activation is None else Activation(activation)
        self.reshape_input = Reshape((-1, d, order))
        self.reshape_filters = Reshape((-1, d * n_filters))
        self.dense_output = Dense(d) if has_dense_output else None
        if share_filters:
            self.filters = Dense(n_filters, activation)

    def build(self, input_shape):
        assert (
            input_shape[-1] == self.d * self.order
            or input_shape[-1] == self.order
            and input_shape[-2] == self.d
        )
        if not self.share_filters:
            # Custom Dense layer to be applied along order axis.
            self.W = self.add_weight(
                name="weight",
                shape=(self.d, self.order, self.n_filters),
                initializer="glorot_uniform",
                trainable=True,
            )
            self.b = self.add_weight(
                name="bias",
                shape=(self.d, self.n_filters),
                initializer="zeros",
                trainable=True,
            )

    def call(self, x):
        x = self.reshape_input(x)
        if self.share_filters:
            x = self.filters(x)
        else:
            # This sums across the order axis (j) for each dimension (i) and each
            #  filter (k) independently. Thus learning a different filter per dimension.
            #  tf.reduce_sum is more efficient than tf.einsum in this case.
            # x = tf.einsum("...ij,...ijk->...ik", x, self.W) + self.b
            x = tf.reduce_sum(tf.expand_dims(x, axis=-1) * self.W, axis=-2) + self.b

        x = x if self.activation is None else self.activation(x)

        if self.dense_output is not None:
            x = self.reshape_filters(x)
            x = self.dense_output(x)
            x = x if self.activation is None else self.activation(x)

        return x


class FFN(Layer):
    def __init__(self, size, inner_ratio, activation, out_size=None):
        super().__init__()

        out_size = size if out_size is None else out_size

        self.dense1 = Dense(int(inner_ratio * size), activation=activation)
        self.dense2 = Dense(out_size)

    def call(self, x):
        x = self.dense1(x)
        x = self.dense2(x)
        return x


class LMUD(Layer):
    """A data dependent way of mixing tokens"""

    def __init__(
        self,
        d,
        order,
        activation,
        # post_ffn,
        n_filters,
        share_filters,
        second_gate,
        lmue_like,
        eqn11,
    ):
        super().__init__()
        self.d = d
        self.order = order
        self.activation = activation
        # self.post_ffn = post_ffn
        self.lmue_like = lmue_like
        self.eqn11 = eqn11
        self.dense_r = Dense(order, activation=activation)
        self.dense_r_gate1 = Dense(order, activation=activation)
        self.dense_r_gate2 = (
            Dense(order, activation=activation) if second_gate else None
        )
        lmud_d = int(d / 2) if eqn11 == 1 else d
        self.dense_l = Dense(lmud_d, activation=activation) if eqn11 < 2 else None
        self.dense_l_gate1 = Dense(lmud_d, activation=activation) if eqn11 < 3 else None
        self.dense_l_gate2 = (
            Dense(lmud_d, activation=activation) if second_gate and eqn11 < 3 else None
        )
        self.dense_adjust = (
            Dense(d, activation=activation) if eqn11 == 1 else None
        )  # lmud_order * order not always equal to d, so use this to scale up or down.

        # self.ffn = (
        #     FFN(
        #         d * n_filters,
        #         inner_ratio=post_ffn,
        #         activation=activation,
        #     )
        #     if post_ffn > 0
        #     else None
        # )
        self.n_filters = 2 * n_filters if eqn11 == 1 else n_filters
        self.contraction = LMUContraction(
            lmud_d, self.order, self.n_filters, share_filters, activation, False
        )
        self.ln = LayerNormalization()

        self.reshape_for_scaling = Reshape((-1, lmud_d * self.n_filters))
        self.reshape_output = Reshape((-1, d * n_filters))

    def call(self, x):
        # x has shape (d, order)
        if self.lmue_like:
            x = self.dense_r(x)

        g1 = self.dense_r_gate1(x)
        g1 = tf.transpose(g1, perm=[0, 1, 3, 2])
        if self.eqn11 < 3:
            g1 = self.dense_l_gate1(g1)

        if self.dense_r_gate2 is None:
            g2 = g1
        else:
            g2 = self.dense_r_gate2(x)
            g2 = tf.transpose(g2, perm=[0, 1, 3, 2])
            if self.eqn11 < 3:
                g2 = self.dense_l_gate2(g2)

        g = tf.matmul(g1, g2, transpose_b=True)
        g = Softmax()(g)

        if not self.lmue_like:
            x = self.dense_r(x)
        x = tf.transpose(x, perm=[0, 1, 3, 2])
        if self.eqn11 < 2:
            x = self.dense_l(x)

        x = tf.matmul(g, x) + x
        x = tf.transpose(x, perm=[0, 1, 3, 2])
        if self.eqn11 == 1:
            x = self.contraction(x)
            x = self.reshape_for_scaling(x)
            x = self.dense_adjust(x)
        else:
            x = self.contraction(x)
            x = self.reshape_output(x)
        return x


class LMUE(Layer):
    """A data dependent way of mixing tokens"""

    def __init__(
        self,
        d,
        order,
        activation,
        out_d=None,
        post_ffn=0,
        n_filters=1,
        share_filters=True,
        second_gate=False,
    ):
        super().__init__()
        self.d = d
        self.out_d = d if out_d is None else out_d
        self.order = order
        self.activation = activation
        self.post_ffn = post_ffn

        self.dense_r = Dense(order)
        self.dense_l = Dense(d, activation=activation)
        self.dense_r_gate1 = Dense(order)
        self.dense_l_gate1 = Dense(d, activation=activation)
        self.dense_r_gate2 = Dense(order) if second_gate else None
        self.dense_l_gate2 = Dense(d, activation=activation) if second_gate else None
        self.contraction = LMUContraction(
            d, self.order, n_filters, share_filters, activation, has_dense_output=False
        )
        self.ffn = (
            FFN(
                d * n_filters,
                inner_ratio=post_ffn,
                activation=activation,
                out_size=self.out_d,
            )
            if post_ffn > 0
            else None
        )

        self.reshape_output = Reshape((-1, d * n_filters))

    def call(self, x):
        # x has shape (d, order)
        x = self.dense_r(x)

        g1 = self.dense_r_gate1(x)
        g1 = tf.transpose(g1, perm=[0, 1, 3, 2])
        g1 = self.dense_l_gate1(g1)

        if self.dense_l_gate2 is None:
            g2 = g1
        else:
            g2 = self.dense_r_gate2(x)
            g2 = tf.transpose(g2, perm=[0, 1, 3, 2])
            g2 = self.dense_l_gate2(g2)

        g = tf.matmul(g1, g2, transpose_b=True)
        g = Softmax()(g)

        x = tf.transpose(x, perm=[0, 1, 3, 2])
        x = self.dense_l(x)

        x = tf.matmul(g, x) + x
        x = tf.transpose(x, perm=[0, 1, 3, 2])
        x = self.contraction(x)
        x = self.reshape_output(x)

        if self.ffn is not None:
            x = self.ffn(x)

        return x
