import torch.nn as nn
import bnn


def block(conv2d, C, kwargs):
    return bnn.Sum(
            nn.Sequential(
            ),
            nn.Sequential(
                conv2d(C, C, 3, padding=1, **kwargs),
                nn.ReLU(),

                conv2d(C, C, 3, padding=1, **kwargs),
                nn.ReLU(),
            )
        )


def net(ap_spec, in_shape, out_classes, kwargs={}, out_kwargs={}, channels=32):
    C = channels
    in_channels = in_shape[-3]
    conv2d = ap_spec.lower_conv2d
    linear = ap_spec.top_linear

    net= nn.Sequential(
        conv2d(in_channels, C, 3, padding=1, **kwargs),
        nn.ReLU(),
        block(conv2d, C, kwargs),
         
        bnn.AvgPool2d((2, 2)),
        block(conv2d, C, kwargs),
        
        bnn.AvgPool2d((2, 2)),
        block(conv2d, C, kwargs),

        bnn.AdaptiveAvgPool2d((1, 1)),
        bnn.Conv2d_2_FC(),
        linear(C, out_classes, bias=True, **out_kwargs)
    )
    return net

