import torch.nn as nn
import torch.nn.functional as F
from .resnet import CBA, ResBlock


class BottleNeck(nn.Module):
    def __init__(self, model_cfg, in_width, width, ksize=3, expansion=4, stride=1):
        super().__init__()

        out_width = width*expansion
        self.block = nn.Sequential(
            CBA(model_cfg, in_width, width,       1,     stride=stride),
            CBA(model_cfg, width,    width,       ksize               ),
            CBA(model_cfg, width,    out_width,   1,     act=False    ),
        )
        self.skip = (lambda xx: xx) if in_width == out_width and stride == 1 else \
            CBA(model_cfg, in_width, out_width, 1, stride=stride, act=False)
        self.last_act = model_cfg.act()

    def forward(self, nx):
        main = self.block(nx)
        shortcut = self.skip(nx)
        nx = main + shortcut
        return self.last_act(nx)

class BigResNet(nn.Module):
    def __init__(
        self, model_cfg=None, n_size=50, n_classes=1000, input_size=3,
        expansion=4
    ):
        super().__init__()

        structures = {
            18: [2,2,2,2],
            34: [3,4,6,3],
            50: [3,4,6,3],
            101: [3,4,23,3],
            152: [3,8,36,3]
        }

        repetitions = structures[n_size]
        StructBlock = BottleNeck if (n_size >= 50) else ResBlock

        self.start_block = CBA(model_cfg, input_size, 64, 7, stride=2)
        width = prev_width = 64
        blocks = []
        for iblock in range(4):
            stride = 1 if iblock==0 else 2
            blocks.append(nn.Sequential(
                StructBlock(model_cfg, prev_width, width, stride=stride, expansion=expansion),
                *[StructBlock(model_cfg, width*expansion, width, expansion=expansion) for _ in range(1, repetitions[0])]
            ))
            prev_width = width * expansion
            width *= 2
        self.block1 = blocks[0]
        self.block2 = blocks[1]
        self.block3 = blocks[2]
        self.block4 = blocks[3]


        self.fc = model_cfg.fc(prev_width, n_classes, bias=False)

    def forward(self, nx):
        nx = self.start_block(nx)
        nx = F.max_pool2d(nx, kernel_size=3, stride=2)
        nx = self.block1(nx)
        nx = self.block2(nx)
        nx = self.block3(nx)
        nx = self.block4(nx)
        nx = F.adaptive_avg_pool2d(nx, (1, 1)).squeeze_(-1).squeeze_(-1)
        nx = self.fc(nx)
        return nx