import jax
import jax.numpy as jnp

import flax.nnx as nnx


def _get_laplacian(adj: jax.Array):
    A = adj + jnp.eye(adj.shape[-1])[None, ...]

    # We first compute the diagonal matrix of degrees
    D = A.sum(axis=2)  # (b, n)
    D_mhalf = 1 / D**0.5

    # We then compute the message passing layer, D_mhalf @ A @ D_mhalf
    mp = D_mhalf[..., None] * A * D_mhalf[:, None, :]  # (b, n, n)
    return mp


class GCNConv(nnx.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        initializer: jax.nn.initializers.Initializer,
        key: jax.random.PRNGKey,
    ):
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.W = nnx.Param(
            initializer(
                key=key, shape=(self.input_dim, self.output_dim), dtype=jnp.float32
            )
        )

    # Ideally we would use segment_sum or something of the sort
    # instead of direct matrix multiplication
    # But for the kind of illustrative problems we are interested on,
    # this is possibly enough
    def __call__(self, x: jax.Array, mp: jax.Array):
        # x: (b, n, d) - batch size, nodes, dimension
        # adj: (b, n, n) - batch size, nodes, nodes

        # In conclusion, we compute the next layer representation
        y = mp @ (x @ self.W)  # (b, n, o)

        return y


class RecurrentGCNConv(nnx.Module):
    a: nnx.Data[jax.Array | None]
    b: nnx.Data[jax.Array | None]

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        initializer: jax.nn.initializers.Initializer,
        key: jax.random.PRNGKey,
    ):
        self.input_dim = input_dim
        self.output_dim = output_dim

        # We implement the same architecture used for the linear control, but simpler
        self._a = nnx.Param(
            initializer(key=key, shape=(self.input_dim, 1), dtype=jnp.float32)
        )
        self._b = nnx.Param(
            initializer(key=key, shape=(self.output_dim, 1), dtype=jnp.float32)
        )

        self.w_la = nnx.Param(
            initializer(key=key, shape=(self.output_dim, 1), dtype=jnp.float32)
        )
        self.w_ra = nnx.Param(
            initializer(key=key, shape=(self.input_dim, 1), dtype=jnp.float32)
        )

        self.w_lb = nnx.Param(
            initializer(key=key, shape=(self.output_dim, 1), dtype=jnp.float32)
        )
        self.w_rb = nnx.Param(
            initializer(key=key, shape=(self.output_dim, 1), dtype=jnp.float32)
        )

        self.a: jax.Array
        self.b: jax.Array

        self.w_a: jax.Array
        self.w_b: jax.Array

    def __call__(self, x: jax.Array, lp: jax.Array):
        W = self.a @ jnp.transpose(self.b, (0, 2, 1))  # (b, i, d)
        m = x @ W
        y = lp @ m  # (b, n, d)

        # Update the undelrying matrices
        h = m.sum(axis=1)  # (b, d)
        a = self.a + (h @ self.w_a)[..., None]
        b = self.b + (h @ self.w_b)[..., None]

        a = a / (1 + jnp.linalg.norm(a, axis=(1, 0), keepdims=True))
        b = b / (1 + jnp.linalg.norm(b, axis=(1, 0), keepdims=True))

        self.a = nnx.data(a)
        self.b = nnx.data(b)

        # Return the state representation
        return y

    def lazy_init(self, batch_size: int):
        a = jnp.broadcast_to(self._a[None, ...], (batch_size, self.input_dim, 1))
        b = jnp.broadcast_to(self._b[None, ...], (batch_size, self.output_dim, 1))

        w_a = jnp.outer(self.w_la, self.w_ra)  # (i, o)
        w_b = jnp.outer(self.w_lb, self.w_rb)  # (o, o)

        self.w_a = nnx.data(w_a)
        self.w_b = nnx.data(w_b)

        self.a = nnx.data(a)
        self.b = nnx.data(b)


class GCN(nnx.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        *,
        rngs: nnx.Rngs,
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim

        key = rngs.key.key.value

        initializer = jax.nn.initializers.lecun_uniform()

        self.gcn_in = GCNConv(self.input_dim, self.hidden_dim, initializer, key=key)

        self.layers = nnx.Sequential(
            *[
                GCNConv(self.hidden_dim, self.hidden_dim, initializer, key=key)
                for _ in range(num_layers - 1)
            ]
        )

        self.linear_out = nnx.Linear(self.hidden_dim, self.output_dim, rngs=rngs)

    def __call__(self, x: jax.Array, adj: jax.Array):
        lp = _get_laplacian(adj)
        y = self.gcn_in(x, lp)
        y = nnx.leaky_relu(y)
        y = self.layers(y, lp)
        y = nnx.leaky_relu(y)
        return self.linear_out(y)


class RecurrentGCN(nnx.Module):
    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int, *, rngs: nnx.Rngs
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        key = rngs.key.key.value

        initializer = jax.nn.initializers.lecun_normal()

        self.gcn_in = GCNConv(self.input_dim, self.hidden_dim, initializer, key=key)

        self.gcn_hidden = GCNConv(
            self.hidden_dim, self.hidden_dim, initializer, key=key
        )

        self.linear_out = nnx.Linear(self.hidden_dim, self.output_dim, rngs=rngs)

    def __call__(self, x: jax.Array, adj: jax.Array):
        lp = _get_laplacian(adj)
        y = self.gcn_in(x, lp)
        y = nnx.leaky_relu(y)
        y = self.gcn_hidden(y, lp)
        y = nnx.leaky_relu(y)
        return self.linear_out(y)

    # def lazy_init(self, batch_size: int):
    #     self.gcn_in.lazy_init(batch_size)
    #     self.gcn_hidden.lazy_init(batch_size)


if __name__ == "__main__":
    adj = jnp.array([[0, 1], [0, 0]])
    adj = jnp.expand_dims(adj, axis=0)

    x = jnp.eye(2)[None, :]

    gcn = GCN(input_dim=2, hidden_dim=32, output_dim=4, num_layers=2, rngs=nnx.Rngs(42))

    y = gcn(x, adj)

    jax.debug.print("{}", y)
