import torch
import torch.nn as nn
import torch.nn.functional as F
from ..blocks.bottleneck_post import BottleneckPost
from ..blocks.basicblock_post import BasicBlockPost

class ResNet(nn.Module):

    def __init__(self, state, event):
        assert len(state["replace_stride_with_dilation"])+1 == len(state["strides"]) and len(state["strides"]) == len(state["num_blocks"]), "len(replace_stride_with_dilation)+1, len(strides) should be same size as len(num_blocks)"
        super(ResNet, self).__init__()
        block = event.resblock
        self.state = state
        self.event = event
        self.in_planes = state["first_conv_filters"]
        self.dilation = 1

        # first conv ensures right number of channels
        self.pre_resblockchain = nn.Sequential(
            *event.pre_resblockchain(),
            *([nn.MaxPool2d(kernel_size=3, stride=state["strides"][0], padding=1)] if state["first_conv_max_pool"] else [])
        )

        # residual blocks
        # note: first residual blocks may differ if first_conv applies max pooling
        self.resblocks = []
        self.blocks = nn.ModuleList([])
        for i, (num_planes, num_blocks, stride, dilate) in enumerate(zip(state["num_planes"], state["num_blocks"], state["strides"], [False]+state["replace_stride_with_dilation"])):
            if i == 0 and state["first_conv_max_pool"]:
                stride = 1
            self.blocks.append(self._make_layer(state, block, num_planes, num_blocks, stride=stride, dilate=dilate))

        # shortcut-path may not have seen any batchnorm
        self.post_resblockchain = nn.Sequential(
            *event.post_resblockchain(self.in_planes)
        )

        # global average pooling & fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.last_bn = nn.BatchNorm1d(state["num_planes"][-1]*state["expansion"])
        self.fc = nn.Linear(state["num_planes"][-1] * state["expansion"], state.all["dataset.num_classes"])

        # initialize resnet to behave like identity
        if state["zero_init_residual"]:
            for m in self.modules():
                if isinstance(m, BottleneckPost) or isinstance(m, BasicBlockPost):
                    if hasattr(m.block_seq[-1],'weight'):
                        nn.init.constant_(m.block_seq[-1].weight, 0)

        # optional: apply weight initializer
        event.optional.init_net_finished(self)

    # residual blocks
    def _make_layer(self, state, block, planes, num_blocks, stride=1, dilate=False):
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1

        # note: the first block defined differs from the rest of the block-chain
        layers = []
        for i in range(num_blocks):
            block_ = block(self.in_planes, planes, stride=stride if i == 0 else 1, dilation=previous_dilation if i == 0 else self.dilation)
            self.in_planes = planes * state["expansion"]
            layers.append(block_)
            self.resblocks.append(block_)

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.pre_resblockchain(x)
        for block in self.blocks:
            x = block(x)
        x = self.post_resblockchain(x)
        x = self.avgpool(x).squeeze()
        x = self.last_bn(x)
        x = self.fc(x)
        return x

def weights_init(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            self.event.init_conv(m)
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

def register(mf):
    mf.redefine_scope("model.resnet")

    # defaults of resnet20 / cifar-variant
    mf.register_defaults({

        # base settings
        "num_blocks": [3, 3, 3],
        "num_planes": [16, 32, 64],
        "strides": [1,2,2],

        # each element in the tuple indicates if we should replace
        # the 2x2 stride with a dilated convolution instead
        "replace_stride_with_dilation": [False, False],

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        "zero_init_residual": True,
    })
    mf.register_event("init_net", ResNet, unique=True)
    mf.register_event('init_net_finished', weights_init)
