from spaghettini import quick_register

from torch import nn
import torch.nn.functional as F


@quick_register
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=(3, 3), stride=(stride, stride),
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3),
                               stride=(1, 1), padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=(1, 1), stride=(stride, stride), bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Resnet1DBasicBlock(nn.Module):
    """Basic residual block class"""
    expansion = 1
    BLOCK_GN_AFFINE = True

    def __init__(self, in_planes, planes, stride=1, wnorm=True, num_groups=8):
        super(Resnet1DBasicBlock, self).__init__()
        self.num_groups = num_groups

        # If asked, use weight normalization.
        use_weight_norm = lambda x: nn.utils.weight_norm(x) if wnorm else x

        self.conv1 = use_weight_norm(
            nn.Conv1d(in_planes, planes, kernel_size=(3,), stride=(stride,), padding=1, bias=False)
        )

        self.conv2 = use_weight_norm(
            nn.Conv1d(planes, planes, kernel_size=(3,), stride=(1,), padding=1, bias=False)
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                use_weight_norm(
                    nn.Conv1d(in_planes, self.expansion * planes, kernel_size=(1,), stride=(stride,), bias=False)
                )
            )
        self.gn1 = nn.GroupNorm(self.num_groups, planes, affine=self.BLOCK_GN_AFFINE)
        self.gn2 = nn.GroupNorm(self.num_groups, planes, affine=self.BLOCK_GN_AFFINE)

    def forward(self, x, injection=None):
        injection = 0 if injection is None else injection
        out = self.conv1(x)
        out = F.relu(self.gn1(out))
        out = self.conv2(out) + injection
        out += self.shortcut(x)
        out = F.relu(self.gn2(out))

        return out


class Resnet1DLayer(nn.Module):
    def __init__(self, block, planes, in_planes, num_blocks, stride=1, wnorm=True):
        super().__init__()
        self.wnorm = wnorm
        self.num_blocks = num_blocks
        strides = [stride] + [1] * (num_blocks-1)

        layers = list()
        for strd in strides:
            layers.append(block(in_planes, planes, strd, wnorm=self.wnorm))
            in_planes = planes * block.expansion

        self.layers = nn.ModuleList(layers)

    def forward(self, xs, injection=None):
        out = xs
        for i in range(self.num_blocks):
            out = self.layers[i](out, injection)
        return out
