import jax
import jax.numpy as jnp
import flax.linen as nn


class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Conv(16, (3, 3))(x)

key = jax.random.PRNGKey(7)
x = jnp.zeros((16, 16, 16, 16))
mlp = MLP()
params = mlp.init(key, x)
mlp.apply(params, x)
