import torch.nn as nn
import torch.nn.functional as F


class CBA(nn.Module):
    def __init__(
        self, model_cfg, in_width, width, ksize, bn=True, act=True, stride=1,
        depthwise=False, bias=True
    ):
        super().__init__()
        assert not depthwise or in_width == width,\
            f"DWise conv needs equal width for in-and outputs,\n" +\
            f"but got {in_width} and {width}"

        padding = ksize // 2
        self.conv = model_cfg.conv(
            in_width, width, ksize, stride=stride, padding=padding,
            padding_mode='zeros', groups=(in_width if depthwise else 1),
            bias=bias
        )
        if bn:
            self.batchnorm = model_cfg.bn(width, momentum=0.01)
        if act:
            self.act = model_cfg.act()

    def forward(self, nx):
        nx = self.conv(nx)
        if hasattr(self, "batchnorm"):
            nx = self.batchnorm(nx)
        if hasattr(self, "act"):
            nx = self.act(nx)
        return nx


class ResBlock(nn.Module):
    def __init__(self, model_cfg, in_width, width, ksize=3, stride=1):
        super().__init__()

        self.block = nn.Sequential(
            CBA(model_cfg, in_width, width, ksize, stride=stride),
            CBA(model_cfg, width,    width, ksize, act=False),
        )
        self.skip = (lambda xx: xx) if in_width == width and stride == 1 else \
            CBA(model_cfg, in_width, width, 1, stride=stride, act=False)
        self.last_act = model_cfg.act()

    def forward(self, nx):
        main = self.block(nx)
        shortcut = self.skip(nx)
        nx = main + shortcut
        return self.last_act(nx)


class ResNet(nn.Module):
    def __init__(
        self, model_cfg=None, n_size=3, n_classes=10, input_size=3
    ):
        super().__init__()

        res_blocks = []
        bw = 16
        prev_bw = new_bw = bw
        self.start_block = CBA(model_cfg, input_size, new_bw, 3)
        for i in range(3):
            for j in range(n_size):
                stride = 1 if i == 0 or j > 0 else 2
                res_blocks.append(
                    ResBlock(model_cfg, prev_bw, new_bw, 3, stride=stride)
                )
                prev_bw = new_bw
            bw *= 2
            new_bw = bw
        self.res_blocks = nn.Sequential(*res_blocks)
        self.fc = model_cfg.fc(prev_bw, n_classes, bias=False)

    def forward(self, nx):
        nx = self.start_block(nx)
        nx = self.res_blocks(nx)
        nx = F.adaptive_avg_pool2d(nx, (1, 1)).squeeze_(-1).squeeze_(-1)
        nx = self.fc(nx)
        return nx
