import bnn
import torch.nn as nn


def net(ap_spec, in_shape, out_classes, kwargs_lower={}, kwargs_top={}, channels=64):
    C = channels
    (in_channels, H, W) = in_shape
    assert H == W
    net = nn.Sequential(
        ap_spec.lower_conv2d(in_channels, C, 3, padding=1, stride=1, **kwargs_lower),
        nn.ReLU(),
        bnn.MaxPool2d(2),
        ap_spec.lower_conv2d(C, C, 3, padding=1, stride=1, **kwargs_lower),
        nn.ReLU(),
        bnn.MaxPool2d(2),
        ap_spec.lower_conv2d(C, C, 3, padding=1, stride=1, **kwargs_lower),
        nn.ReLU(),
        bnn.Conv2d_2_FC(),
        ap_spec.top_linear(C*(H//4)*(W//4), out_classes, bias=True, **kwargs_top)
    )
    return net
