from typing import Callable

import flax.nnx as nnx
import jax
import jax.numpy as jnp


class Perceptron(nnx.Module):
    act: Callable[[jax.Array], jax.Array]

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        act: callable = nnx.leaky_relu,
        *,
        rngs: nnx.Rngs,
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        initializer = nnx.initializers.lecun_uniform()

        self.linear_in = nnx.Linear(input_dim, hidden_dim, rngs=rngs, kernel_init=initializer)
        self.linear_out = nnx.Linear(hidden_dim, output_dim, rngs=rngs, kernel_init=initializer)

        self.act = act

    def __call__(self, x: jax.Array) -> jax.Array:
        y = self.linear_in(x)
        # y = self.act(y)
        return self.linear_out(y)


class LinearControl(nnx.Module):
    a: nnx.Data[jax.Array | None]
    b: nnx.Data[jax.Array | None]

    act: Callable[[jax.Array], jax.Array]

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        act: callable = nnx.leaky_relu,
        *,
        rngs: nnx.Rngs,
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.act = act

        initializer = nnx.initializers.lecun_uniform()

        self.linear_in = nnx.Linear(input_dim, hidden_dim, rngs=rngs, kernel_init=initializer)

        key = rngs.key.key.value

        self._a = nnx.Param(initializer(key=key, shape=(hidden_dim, 1), dtype=jnp.float32))
        self._b = nnx.data(initializer(key=key, shape=(output_dim, 1), dtype=jnp.float32))

        self.w_la = nnx.Param(initializer(key=key, shape=(hidden_dim, 1), dtype=jnp.float32))
        self.w_ra = nnx.Param(initializer(key=key, shape=(input_dim, 1), dtype=jnp.float32))

        # self.w_lb = nnx.Param(initializer(key=key, shape=(output_dim,1), dtype=jnp.float32))
        # self.w_rb = nnx.Param(initializer(key=key, shape=(input_dim,1), dtype=jnp.float32))

        self.a: jax.Array
        self.b: jax.Array

        self.w_a: jax.Array
        # self.w_b: jax.Array

    def __call__(self, x: jax.Array):
        # We implement a linear control rule

        w = self.b @ jnp.transpose(self.a, (0, 2, 1))  # (batch_size, output_dim, hidden_dim)

        # We first compute the output logits
        y = self.linear_in(x)  # (batch_size, hidden_dim)
        # y = self.act(y)

        y = w @ y[..., None]  # (batch_size, output_dim, 1)
        y = y.squeeze(axis=2)

        # We then compute the update rule for w_o
        # a <- a + w_a @ x
        # b <- b + w_b @ x
        a = self.a + self.w_a @ x[..., None]  # (batch_size, hidden_dim, 1)
        # b = self.b + self.w_b @ x[..., None] # (batch_size, output_dim, 1)

        self.a = nnx.data(a)
        # self.b = nnx.data(b)

        return y

    def lazy_init(self, batch_size: int):
        a = jnp.broadcast_to(self._a[None, ...], (batch_size, self.hidden_dim, 1))
        b = jnp.broadcast_to(self._b[None, ...], (batch_size, self.output_dim, 1))

        w_a = jnp.broadcast_to(
            jnp.outer(self.w_la, self.w_ra)[None, ...],
            (batch_size, self.hidden_dim, self.input_dim),
        )
        # w_b = jnp.broadcast_to(
        #     jnp.outer(self.w_lb, self.w_rb)[None, ...],
        #     (batch_size, self.output_dim, self.input_dim)
        # )

        self.a = nnx.data(a)
        self.b = nnx.data(b)
        self.w_a = nnx.data(w_a)
        # self.w_b = nnx.data(w_b)


class StateConditionalLC(nnx.Module):
    _a: nnx.Data[jax.Array | None]
    _b: nnx.Data[jax.Array | None]

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, *, rngs: nnx.Rngs):
        initializer = nnx.initializers.lecun_uniform()
        key = rngs.key.key.value

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.W_a = nnx.Param(
            initializer(key=key, shape=(input_dim, input_dim), dtype=jnp.float32),
        )
        self.W_b = nnx.Param(
            initializer(key=key, shape=(input_dim, hidden_dim), dtype=jnp.float32),
        )

        self.W_step = nnx.Param(
            initializer(key=key, shape=(input_dim, 1), dtype=jnp.float32),
        )

        self.w_la = initializer(key=key, shape=(input_dim, 1), dtype=jnp.float32)

        self.w_ra = nnx.Param(
            initializer(key=key, shape=(input_dim, 1), dtype=jnp.float32),
        )
        # self.w_b: jax.Array
        self._a: jax.Array
        self._b: jax.Array

        self.linear_out = nnx.Linear(hidden_dim, output_dim, kernel_init=initializer, rngs=rngs)

    def __call__(self, x: jax.Array):
        W = self._b[..., None] @ self._a[:, None, :]

        y = jnp.einsum("bji,bi->bj", W, x)  # (B, H)
        y = nnx.leaky_relu(y)

        beta = x @ self.W_step

        # Update a
        step = self.w_la.squeeze(axis=1) * (x @ self.w_ra)
        a = self._a - nnx.sigmoid(beta) * step
        self._a = nnx.data(a)

        return self.linear_out(y)

    def lazy_init(self, state: jax.Array):
        a = state @ self.W_a
        b = state @ self.W_b

        self._a = nnx.data(a)
        self._b = nnx.data(b)
