import torch.nn as nn
import torch.nn.functional as F
from .resnet import CBA


class ResBlock(nn.Module):
    def __init__(self, model_cfg, in_width, width, ksize, stride=1):
        super().__init__()

        self.block = nn.Sequential(
            CBA(model_cfg, in_width, width, ksize, stride=stride),
            CBA(model_cfg, width,    width, ksize, act=False),
        )
        self.skip = (lambda xx: xx) if in_width == width and stride == 1 else \
            CBA(model_cfg, in_width, width, 1, stride=stride,
                bn=False, act=False)
        self.last_act = model_cfg.act()

    def forward(self, nx):
        main = self.block(nx)
        shortcut = self.skip(nx)
        nx = main + shortcut
        return self.last_act(nx.mul_(0.5))


class ResNetSimple(nn.Module):
    def __init__(
        self, model_cfg=None, n_size=3, n_classes=10, input_size=3, scale=1
    ):
        super().__init__()

        res_blocks = []
        bw = 16
        prev_bw = new_bw = int(bw*scale)
        self.start_block = CBA(model_cfg, input_size, new_bw, 3)
        for i in range(3):
            for j in range(n_size):
                stride = 1 if i == 0 or j > 0 else 2
                res_blocks.append(
                    ResBlock(model_cfg, prev_bw, new_bw, 3, stride=stride)
                )
                prev_bw = new_bw
            bw *= 2
            new_bw = int(bw*scale)
        self.res_blocks = nn.Sequential(*res_blocks)
        self.fc = model_cfg.fc(prev_bw, n_classes, bias=False)

    def forward(self, nx):
        nx = self.start_block(nx)
        nx = self.res_blocks(nx)
        nx = F.adaptive_avg_pool2d(nx, (1, 1)).squeeze_(-1).squeeze_(-1)
        nx = self.fc(nx)
        return nx
