from neural_tangents import stax


def dense_relu(depth):
    return stax.serial(
        *([stax.Flatten()] + [stax.Dense(1, 2., 0.05), stax.Relu()]*depth)
    )


def lenet5(width_factor):
    return stax.serial(
        stax.Conv(6*width_factor, (5, 5), padding=stax.Padding.SAME.name),
        stax.Relu(),
        stax.AvgPool((2, 2), strides=(2, 2)),
        stax.Conv(16 * width_factor, (5, 5), padding=stax.Padding.SAME.name),
        stax.Relu(),
        stax.AvgPool((2, 2), strides=(2, 2)),
        stax.Conv(120 * width_factor, (5, 5), padding=stax.Padding.SAME.name),
        stax.Relu(),
        stax.Flatten(),
        stax.Dense(84*width_factor),
        stax.Relu(),
        stax.Dense(10),
    )
