import torch.nn as nn
import torch.nn.functional as F


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ResizeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, padding=1, mode='nearest'):
        super(ResizeConv2d, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x


class BasicEncoderBlock(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1):
        super(BasicEncoderBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, outplanes, stride)
        self.bn1 = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(outplanes, outplanes)
        self.bn2 = nn.BatchNorm2d(outplanes)

        self.require_shortcut = stride > 1 or inplanes != outplanes
        if self.require_shortcut:
            self.shortcut = nn.Sequential(
                conv1x1(inplanes, outplanes, stride),
                nn.BatchNorm2d(outplanes)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.require_shortcut:
            identity = self.shortcut(x)

        out += identity
        out = self.relu(out)
        return out


class DecoderBlockWithTransposedConvolution(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1, use_shortcut=True):
        super(DecoderBlockWithTransposedConvolution, self).__init__()
        self.conv2 = conv3x3(inplanes, inplanes, stride=1)
        self.bn2 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.ConvTranspose2d(in_channels=inplanes, out_channels=outplanes, kernel_size=3, stride=stride,
                                        padding=1, output_padding=stride-1, groups=1, bias=False)
        self.bn1 = nn.BatchNorm2d(outplanes)

        self.require_shortcut = use_shortcut and (stride > 1 or inplanes != outplanes)
        if self.require_shortcut:
            self.shortcut = nn.Sequential(
                nn.ConvTranspose2d(in_channels=inplanes, out_channels=outplanes, kernel_size=3, stride=stride,
                                   padding=1, output_padding=stride-1, groups=1, bias=False),
                nn.BatchNorm2d(outplanes)
            )
        self.use_shortcut = use_shortcut

    def forward(self, x):
        if self.use_shortcut:
            identity = x

        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn1(out)

        if self.require_shortcut:
            identity = self.shortcut(x)
        if self.use_shortcut:
            out += identity

        out = self.relu(out)
        return out


class DecoderBlockWithUpsampling(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1, use_shortcut=True):
        super(DecoderBlockWithUpsampling, self).__init__()
        self.conv2 = conv3x3(inplanes, inplanes, stride=1)
        self.bn2 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ResizeConv2d(inplanes, outplanes, kernel_size=3, scale_factor=stride)
        self.bn1 = nn.BatchNorm2d(outplanes)

        self.require_shortcut = use_shortcut and (stride > 1 or inplanes != outplanes)
        if self.require_shortcut:
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=1, padding=0,
                             scale_factor=stride),
                nn.BatchNorm2d(outplanes)
            )
        self.use_shortcut = use_shortcut

    def forward(self, x):
        if self.use_shortcut:
            identity = x
        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn1(out)

        if self.require_shortcut:
            identity = self.shortcut(x)
        if self.use_shortcut:
            out += identity

        out = self.relu(out)
        return out