from torch import nn
import torch
import torch.nn.functional as F
from ..utils import LambdaLayer

class BasicBlockPost(nn.Module):

    def __init__(self, state, event, in_planes, planes, stride=1, *args, **kwargs):
        super(BasicBlockPost, self).__init__()
        self.event = event
        self.state = state

        self.block_seq = nn.Sequential(
            nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            event.normalization_layer(planes),
            event.activation_layer(),
            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False),
            event.normalization_layer(planes)
        )
        self.relu = event.activation_layer()

        if stride != 1 or in_planes != planes:

            # Cifar-10 ResNet paper uses option A.
            if self.state["option"] == 'A':
                self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif self.state["option"] == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, state["expansion"] * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(state["expansion"] * planes)
                )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        self.s1 = self.block_seq(x)
        self.s2 = self.shortcut(x)
        if self.state["short"]:
            self.event.optional.before_addition(self.s1,self.s2)
            self.event.optional.residual_branch_modifier(self.s2)
            self.out = self.s1+self.s2
        else:
            self.out = self.s1
        return self.relu(self.out)

def pre_resblockchain(state, event):
    return nn.Conv2d(state.all["dataset.num_channels"], state["first_conv_filters"], kernel_size=state["first_conv_kernel_size"], stride=state["first_conv_stride"], padding=state["first_conv_padding"], bias=False), \
           event.normalization_layer(state["first_conv_filters"]), \
           event.activation_layer()

def post_resblockchain(in_planes):
    return []

def register(mf):
    mf.set_scope("model.resnet")
    mf.register_defaults({
        "short": True,
        "option": "A",
        "expansion": 1,

        # first conv filter
        "first_conv_filters": 16,
        "first_conv_kernel_size": 3,
        "first_conv_padding": 1,
        "first_conv_stride": 1,
        "first_conv_max_pool": False,
    })
    mf.register_event("resblock", BasicBlockPost, unique=True)
    mf.register_event("pre_resblockchain", pre_resblockchain, unique=True)
    mf.register_event("post_resblockchain", post_resblockchain, unique=True)
