import torch.nn as nn
import torch.nn.functional as F


class BAC(nn.Module):
    def __init__(
        self, model_cfg, in_width, width, ksize, bn=True, act=True, stride=1,
        depthwise=False, bias=True
    ):
        super().__init__()
        assert not depthwise or in_width == width,\
            f"DWise conv needs equal width for in-and outputs,\n" +\
            f"but got {in_width} and {width}"

        li = []
        if bn:
            li.append(model_cfg.bn(in_width))
        if act:
            li.append(model_cfg.act())

        padding = ksize // 2
        li.append(model_cfg.conv(
            in_width, width, ksize, stride=stride, padding=padding,
            padding_mode='zeros', groups=(in_width if depthwise else 1),
            bias=bias
        ))

        self.cba = nn.Sequential(*li)

    def forward(self, nx):
        return self.cba(nx)


class ResBlock(nn.Module):
    def __init__(self, model_cfg, in_width, width, ksize, stride=1):
        super().__init__()

        self.block = nn.Sequential(
            BAC(model_cfg, in_width, width, ksize, stride=stride),
            BAC(model_cfg, width,    width, ksize),
        )
        self.skip = (lambda xx: xx) if in_width == width and stride == 1 else \
            BAC(model_cfg, in_width, width, 1, stride=stride, bn=False, act=False)

    def forward(self, nx):
        main = self.block(nx)
        shortcut = self.skip(nx)
        nx = main + shortcut
        return nx


class ResNetv2(nn.Module):
    def __init__(
        self, model_cfg=None, n_size=3, n_classes=10, input_size=3, scale=1
    ):
        super().__init__()

        res_blocks = []
        bw = 16
        prev_bw = new_bw = int(bw*scale)
        self.start_block = BAC(model_cfg, input_size, new_bw, 3, bn=False, act=False)
        for i in range(3):
            for j in range(n_size):
                stride = 1 if i == 0 or j > 0 else 2
                res_blocks.append(
                    ResBlock(model_cfg, prev_bw, new_bw, 3, stride=stride)
                )
                prev_bw = new_bw
            bw *= 2
            new_bw = int(bw*scale)
        self.res_blocks = nn.Sequential(*res_blocks)
        self.fc = model_cfg.fc(prev_bw, n_classes, bias=False)

    def forward(self, nx):
        nx = self.start_block(nx)
        nx = self.res_blocks(nx)
        nx = F.adaptive_avg_pool2d(nx, (1, 1)).squeeze_(-1).squeeze_(-1)
        nx = self.fc(nx)
        return nx
