from torch import nn
import torch.nn.functional as F
import torch

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BottleneckPost(nn.Module):

    def __init__(self, state, event, inplanes, planes, stride=1, downsample=None, dilation=1):
        self.state = state
        self.event = event
        self.nu = torch.tensor(1,requires_grad=False) #factor of residual branch
        super(BottleneckPost, self).__init__()

        if stride != 1 or inplanes != planes * state["expansion"]:
            self.shortcut_downsampled = nn.Sequential(
                conv1x1(inplanes, planes * state["expansion"], stride),
                self.event.normalization_layer(planes * state["expansion"]),
            )
        else:
            self.shortcut_downsampled = nn.Sequential()

        width = int(planes * (state["width_per_group"] / 64.)) * state["groups"]

        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.block_seq = nn.Sequential(
            conv1x1(inplanes, width),
            event.normalization_layer(width),
            event.activation_layer(),
            conv3x3(width, width, stride, state["groups"], dilation),
            event.normalization_layer(width),
            event.activation_layer(),
            conv1x1(width, planes * state["expansion"]),
            event.normalization_layer(planes*state["expansion"])
        )
        self.relu = event.activation_layer()
        self.stride = stride

    def forward(self, x):
        self.s1 = self.block_seq(x)
        self.s2 = self.shortcut_downsampled(x)
        if self.state["short"]:
            self.event.optional.residual_branch_modifier(self.s2)
            self.out = self.s1+self.s2
        else:
            self.out = self.s1
        return self.relu(self.out)


def pre_resblockchain(state, event):
    return nn.Conv2d(state.all["dataset.num_channels"], state["first_conv_filters"], kernel_size=state["first_conv_kernel_size"], stride=state["first_conv_stride"], padding=state["first_conv_padding"], bias=False), \
           event.normalization_layer(state["first_conv_filters"]), \
           event.activation_layer()

def post_resblockchain(in_planes):
    return []



def register(mf):
    mf.set_scope("model.resnet")
    mf.register_defaults({
        "short": True,
        "option": "A",
        "expansion": 4,
        "groups": 1,
        "width_per_group": 1,

        # first conv layer
        "first_conv_filters": 16,
        "first_conv_kernel_size": 3,
        "first_conv_padding": 1,
        "first_conv_stride": 1,
        "first_conv_max_pool": False,
    })
    mf.register_event("resblock", BottleneckPost, unique=True)
    mf.register_event("pre_resblockchain", pre_resblockchain, unique=True)
    mf.register_event("post_resblockchain", post_resblockchain, unique=True)
