import flax.nnx as nnx
import jax
import jax.numpy as jnp


def act(x: jax.Array, axis: int):
    y = jax.nn.elu(x) + 1
    y = y / y.sum(axis=axis, keepdims=True)
    return y


STEP_SIZE = 1


def generate_ergodic_rotation(key: jax.random.PRNGKey, N: int, epsilon: float = 0.1) -> jnp.ndarray:
    M = jax.random.normal(key, (N, N), dtype=jnp.float32)
    A = epsilon * (M - M.T)
    Id = jnp.eye(N)
    I_plus_A = Id + A
    inv_I_plus_A = jnp.linalg.inv(I_plus_A)
    R = (Id - A) @ inv_I_plus_A
    return R


def _step(
    x: jax.Array,
    W_slow: jax.Array,
    b_slow: jax.Array,
    W: jax.Array,
    step_size: int,
    R: jax.Array,
    sign: int = 1,
    idx: int = 1,
):
    # x: (batch_size, input_dim)
    # W_slow: (input_dim, 3 * hidden_dim + 1)
    # b_slow: (input_dim, 1)
    # W: (hidden_dim, hidden_dim)
    # b: (hidden_dim, 1)

    # We first convert x into q, k, v, beta
    qkvb = x @ W_slow + b_slow.T  # (batch_size, 3 * hidden_dim + 1)
    q, k, v = jnp.split(
        qkvb[:, :-1],
        3,
        axis=1,
    )  # (batch_size, hidden_dim), (batch_size, hidden_dim), (batch_size, hidden_dim)
    beta = qkvb[:, -1]  # (batch_size,)

    # This computes the reference value
    norm_k = act(k, axis=1)
    v_bar = jnp.matmul(W, norm_k[..., None])
    v_bar = v_bar.squeeze(axis=-1)
    # v_bar = jnp.einsum("bij,bj->bj", W, norm_k)

    # This updates the W
    beta = nnx.sigmoid(beta)  # (batch_size,)

    # step = jnp.matmul(norm_k[..., None], (v_bar - v)[:, None, ...])
    # W = W + beta * step
    def rotate():
        return W @ R

    def identity():
        return W

    W = jax.lax.cond(idx % 2 == 1, rotate, identity)
    W = W - sign * step_size * jnp.einsum("b,bi,bj->bij", beta, norm_k, (v_bar - v), optimize=True)

    norm_q = act(q, axis=1)
    # y_hat = jnp.matmul(W, norm_q[..., None])
    # y_hat = y_hat.squeeze(axis=-1)

    y_hat = jnp.einsum("bij,bj->bj", W, norm_q, optimize=True)
    return W, y_hat


class DeltaLayer(nnx.Module):
    W: nnx.Data[jax.Array | None]
    idx: nnx.Data[int | None]

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        K: int,
        step_size: int,
        layer_norm: bool = False,
        initializer: jax.nn.initializers.Initializer = jax.nn.initializers.glorot_uniform(),
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self._key_idx = 42
        self.key = jax.random.key(self._key_idx)

        self.rotation_matrix = generate_ergodic_rotation(self.key, N=self.hidden_dim)

        self.W_slow = nnx.Param(
            initializer(
                key=self.key,
                shape=(input_dim, 3 * hidden_dim + 1),
                dtype=jnp.float32,
            ),
        )
        self.b_slow = nnx.Param(
            initializer(key=self.key, shape=(3 * hidden_dim + 1, 1), dtype=jnp.float32),
        )

        self._a = nnx.Param(
            initializer(key=self.key, shape=(hidden_dim, K), dtype=jnp.float32),
        )
        self._b = nnx.Param(
            initializer(key=self.key, shape=(hidden_dim, K), dtype=jnp.float32),
        )

        self.step_size = step_size

        self._layer_norm = layer_norm
        if self._layer_norm:
            self.layer_norm = nnx.LayerNorm(hidden_dim, rngs=nnx.Rngs(1))

        self.W: jax.Array
        self.idx: int

    def __call__(self, x: jax.Array):
        W, y_hat = _step(x, self.W_slow, self.b_slow, self.W, self.step_size, self.rotation_matrix, idx=self.idx)
        # W, y_hat = _single_step(x, self.W, self.W_slow, self.b_slow)

        self.W = nnx.data(W)
        self.idx = nnx.data(self.idx + 1)

        if self._layer_norm:
            y_hat = self.layer_norm(y_hat)

        return y_hat

    def lazy_init(self, batch_size: int):
        W = jnp.broadcast_to(
            (self._a @ self._b.T)[None, ...],
            (batch_size, self.hidden_dim, self.hidden_dim),
        )  # (batch_size, hidden_dim, hidden_dim)
        self.W = nnx.data(W)
        self.key = jax.random.key(self._key_idx)
        self.idx = 1


class StateConditionalDeltaLayer(nnx.Module):
    W: nnx.Data[jax.Array | None]

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        state_dim: int,
        step_size: int,
        initializer: jax.nn.initializers.Initializer = jax.nn.initializers.glorot_uniform(),
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.W_slow = nnx.Param(
            initializer(
                key=jax.random.key(42),
                shape=(input_dim, 3 * hidden_dim + 1),
                dtype=jnp.float32,
            ),
        )
        self.b_slow = nnx.Param(
            initializer(key=jax.random.key(42), shape=(3 * hidden_dim + 1, 1), dtype=jnp.float32),
        )

        self._a = nnx.Param(
            initializer(key=jax.random.key(42), shape=(state_dim, hidden_dim), dtype=jnp.float32),
        )
        self._b = nnx.Param(
            initializer(key=jax.random.key(42), shape=(state_dim, hidden_dim), dtype=jnp.float32),
        )

        self.rotation_matrix = generate_ergodic_rotation(jax.random.key(42), N=self.hidden_dim)
        self.rotation_matrix = jax.lax.stop_gradient(self.rotation_matrix)

        self.step_size = step_size

        self.W: jax.Array

    def __call__(self, x: jax.Array):
        W, y_hat = _step(x, self.W_slow, self.b_slow, self.W, self.step_size, sign=1, R=self.rotation_matrix)
        # W, y_hat = _single_step(x, self.W, self.W_slow, self.b_slow)
        self.W = nnx.data(W)

        return y_hat

    def lazy_init(self, state: jax.Array):  # state: (B, I)
        w_a = state @ self._a  # (batch_size, hidden_dim,)
        w_b = state @ self._b  # (batch_size, hidden_dim,)

        W = jnp.einsum("bi,bj->bij", w_a, w_b)  # (batch_size, hidden_dim, hidden_dim)

        self.W = nnx.data(W)


class SRWM(nnx.Module):
    _step: nnx.Data[jax.Array | None]

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        latent_dim: int = 16,
        K: int = None,
        layer_norm: bool = False,
        step_size: int = STEP_SIZE,
        *,
        rngs: nnx.Rngs,
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.step_size = step_size
        self.K = K or max(hidden_dim // 8, 1)

        # Flax seems to use lecun_normal as default, though.
        # I should check this further.
        initializer = jax.nn.initializers.lecun_uniform()

        self.linear_in = nnx.Linear(input_dim, self.latent_dim, rngs=rngs, kernel_init=initializer)
        self.delta_layer = DeltaLayer(
            self.latent_dim,
            self.latent_dim,
            layer_norm=layer_norm,
            initializer=initializer,
            K=self.K,
            step_size=step_size,
        )
        self.linear_after_delta = nnx.Linear(self.latent_dim, self.hidden_dim, rngs=rngs, kernel_init=initializer)
        self.linear_out = nnx.Linear(self.hidden_dim, self.output_dim, rngs=rngs, kernel_init=initializer)

    def __call__(self, x: jax.Array):
        y = self.linear_in(x)
        y = jax.nn.leaky_relu(y)
        y = self.delta_layer(y)
        y = jax.nn.leaky_relu(y)
        y = self.linear_after_delta(y)
        y = jax.nn.leaky_relu(y)
        y = self.linear_out(y)

        return y

    def lazy_init(self, batch_size: int):
        self.delta_layer.lazy_init(batch_size)


class StateConditionalSRWM(nnx.Module):
    _step: nnx.Data[jax.Array | None]

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        step_size: int = STEP_SIZE,
        *,
        rngs: nnx.Rngs,
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.step_size = step_size

        # Flax seems to use lecun_normal as default, though.
        # I should check this further.
        initializer = jax.nn.initializers.lecun_uniform()

        self.linear_in = nnx.Linear(input_dim, hidden_dim, rngs=rngs, kernel_init=initializer)
        self.delta_layer = StateConditionalDeltaLayer(
            hidden_dim,
            hidden_dim,
            initializer=initializer,
            state_dim=input_dim,
            step_size=step_size,
        )
        self.linear_out = nnx.Linear(hidden_dim, output_dim, rngs=rngs, kernel_init=initializer)

    def __call__(self, x: jax.Array):
        y = self.linear_in(x)
        y = jax.nn.leaky_relu(y)
        y = self.delta_layer(y)
        y = jax.nn.leaky_relu(y)
        y = self.linear_out(y)

        return y

    def lazy_init(self, state: jax.Array):
        self.delta_layer.lazy_init(state)
