import jax
import jax.numpy as jnp
import flax.nnx as nnx


def act(x: jax.Array):
    y = jax.nn.elu(x) + 1
    y = y / y.sum(axis=1, keepdims=True)
    return y


def _step(
    W_qkv: jax.Array,
    W_beta: jax.Array,
    W: jax.Array,
    W_in: jax.Array,
    W_out: jax.Array,
    z_norm: jax.Array,
    x: jax.Array,
):
    # This has been shown to be an implementation of linear transformers with
    # attention normalization
    eps = 1e-6

    x = x @ W_in
    # W_qk: (input_dim, 3 * hidden_dim)
    # W: (batch_size, hidden_dim, hidden_dim)
    # x: (batch_size, input_dim)
    qkv = x @ W_qkv  # (batch_size, 3 * hidden_dim)
    q, k, v = jnp.split(
        qkv, 3, axis=1
    )  # (batch_size, hidden_dim), (batch_size, hidden_dim)

    k_norm = act(k)

    v_bar = jnp.einsum("bij,bj->bi", W, k_norm)  # (batch_size, hidden_dim)

    z_dot_k = (z_norm.squeeze(axis=2) * k_norm).sum(axis=1, keepdims=True)
    v_bar = v_bar / (z_dot_k + eps)

    beta = nnx.sigmoid(x @ W_beta)  # (batch_size, 1)

    v_new = beta * v + (1 - beta) * v_bar  # (batch_size, hidden_dim)

    # Expand the dimensions of v_new and k_norm to match the shape of W
    v_new = jnp.expand_dims(v_new, axis=1)  # (batch_size, 1, hidden_dim)
    v_bar = jnp.expand_dims(v_bar, axis=1)  # (batch_size, 1, hidden_dim)
    k_norm = jnp.expand_dims(k_norm, axis=2)  # (batch_size, hidden_dim, 1)

    W = W + jnp.matmul(v_new, k_norm) - jnp.matmul(v_bar, k_norm)

    # Update the z_norm
    z_norm = z_norm + k_norm  # (batch_size, hidden_dim, 1)

    q_norm = act(q)
    z_dot_q = (z_norm.squeeze(axis=2) * q_norm).sum(axis=1, keepdims=True)

    y = jnp.einsum("bij,bj->bi", W, q_norm)  # (batch_size, hidden_dim)
    y = y / (z_dot_q + eps)  # (batch_size, hidden_dim)

    # (batch_size, output_dim)
    y = y @ W_out

    return W, z_norm, y


class FWP(nnx.Module):
    W: nnx.Data[jax.Array | None]
    z_norm: nnx.Data[jax.Array | None]

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        initializer = jax.nn.initializers.glorot_uniform()

        self.W_in = nnx.Param(
            initializer(
                key=jax.random.key(42),
                shape=(input_dim, hidden_dim),
                dtype=jnp.float32,
            )
        )

        self.W_out = nnx.Param(
            initializer(
                key=jax.random.key(43),
                shape=(hidden_dim, output_dim),
                dtype=jnp.float32,
            )
        )

        self.W_qkv = nnx.Param(
            initializer(
                key=jax.random.key(44),
                shape=(hidden_dim, 3 * hidden_dim),
                dtype=jnp.float32,
            )
        )

        self.W_beta = nnx.Param(
            initializer(
                key=jax.random.key(45),
                shape=(hidden_dim, 1),
                dtype=jnp.float32,
            )
        )

        # Sample-wise broadcastable parameters
        self._W = nnx.Param(jnp.zeros((hidden_dim, hidden_dim), dtype=jnp.float32))
        self._z_norm = nnx.Param(jnp.ones((hidden_dim, 1), dtype=jnp.float32))

        self.W: jax.Array
        self.z_norm: jax.Array

    def lazy_init(self, batch_size: int):
        W = jnp.broadcast_to(
            self._W.value[None, ...],
            (batch_size, self.hidden_dim, self.hidden_dim),
        )
        z_norm = jnp.broadcast_to(
            self._z_norm.value[None, ...],
            (batch_size, self.hidden_dim, 1),
        )
        self.W = nnx.data(W)
        self.z_norm = nnx.data(z_norm)

    def __call__(self, x: jax.Array):
        W, z_norm, y = _step(
            self.W_qkv, self.W_beta, self.W, self.W_in, self.W_out, self.z_norm, x
        )
        self.W = nnx.data(W)
        self.z_norm = nnx.data(z_norm)

        return y
