from flax import linen as nn

# torch.nn.Linear initialization: Unif(-1/sqrt(k), 1/sqrt(k)).
# Variance of uniform [-a, a] is (2 a)^2 / 12 = a^2 / 3.
kernel_init = nn.initializers.variance_scaling(scale=1/3, mode='fan_in', distribution='uniform')


class MLP(nn.Module):
    depth: int
    width: int
    num_classes: int
    no_bn: bool = False

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        for k in range(self.depth):
            x = nn.Dense(self.width, kernel_init=kernel_init)(x)
            if not self.no_bn:
                x = nn.BatchNorm(axis=-1, use_running_average=False)(x)
            x = nn.relu(x)
        x = nn.Dense(self.num_classes, kernel_init=kernel_init)(x)
        return x
