"""
@author: Jun Wang
@date: 20201020
@contact: jun21wangustc@gmail.com
"""
# based on:
# https://github.com/liguohao96/pytorch-prnet
from torch import nn


class Conv2d(nn.Module):
    def __init__(self, input_size, in_channel, out_channel, kernel_size=4, stride=1):
        super(Conv2d, self).__init__()
        output_size = input_size // stride
        self.padding_num = stride * (output_size - 1) - input_size + kernel_size
        self.even_conv = nn.Conv2d(
            in_channel,
            out_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=self.padding_num // 2,
            bias=False,
        )
        self.odd_conv = nn.Sequential(
            nn.ConstantPad2d(
                (
                    self.padding_num // 2,
                    self.padding_num // 2 + 1,
                    self.padding_num // 2,
                    self.padding_num // 2 + 1,
                ),
                0,
            ),
            nn.Conv2d(
                in_channel,
                out_channel,
                kernel_size=kernel_size,
                stride=stride,
                padding=0,
                bias=False,
            ),
        )

    def forward(self, x):
        if self.padding_num % 2 == 0:
            x = self.even_conv(x)
        else:
            x = self.odd_conv(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, activation_fn="relu"):
        super(UpBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.activation_fn = activation_fn
        self.upConvTranspose = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False
        )
        self.convTranspose = nn.Sequential(
            nn.ConstantPad2d((2, 1, 2, 1), 0),
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=4,
                stride=1,
                padding=3,
                bias=False,
            ),
        )
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.001)
        self.activation1 = nn.ReLU(inplace=True)
        self.activation2 = nn.Sigmoid()

    def forward(self, x):
        if self.stride == 1:
            x = self.convTranspose(x)
        else:
            x = self.upConvTranspose(x)
        x = self.bn(x)
        if self.activation_fn == "relu":
            x = self.activation1(x)
        else:
            x = self.activation2(x)
        return x


class ResBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=4, stride=1, input_size=None
    ):
        super(ResBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.conv1 = nn.Conv2d(
            in_channels, out_channels // 2, 1, 1, padding=0, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels // 2, eps=0.001, momentum=0.001)

        self.conv2 = Conv2d(
            input_size, out_channels // 2, out_channels // 2, stride=stride
        )
        self.bn2 = nn.BatchNorm2d(out_channels // 2, eps=0.001, momentum=0.001)

        input_size = input_size // stride
        self.conv3 = Conv2d(
            input_size, out_channels // 2, out_channels, kernel_size=1, stride=1
        )

        self.shortcut = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=stride, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.001)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        if (self.in_channels != self.out_channels) or (self.stride != 1):
            identity = self.shortcut(x)
        out += identity
        out = self.bn3(out)
        out = self.relu(out)
        return out


class PRNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(PRNet, self).__init__()
        size = 16
        self.input_layer = nn.Sequential(
            Conv2d(256, in_channels, size, kernel_size=4, stride=1),
            nn.BatchNorm2d(size, eps=0.001, momentum=0.001),
            nn.ReLU(inplace=True),
        )
        self.encoder_block_1 = ResBlock(
            size, size * 2, kernel_size=4, stride=2, input_size=256
        )  # 128x128x32
        self.encoder_block_2 = ResBlock(
            size * 2, size * 2, kernel_size=4, stride=1, input_size=128
        )  # 128x128x32
        self.encoder_block_3 = ResBlock(
            size * 2, size * 4, kernel_size=4, stride=2, input_size=128
        )  # 64x64x64
        self.encoder_block_4 = ResBlock(
            size * 4, size * 4, kernel_size=4, stride=1, input_size=64
        )  # 64x64x64
        self.encoder_block_5 = ResBlock(
            size * 4, size * 8, kernel_size=4, stride=2, input_size=64
        )  # 32x32x128
        self.encoder_block_6 = ResBlock(
            size * 8, size * 8, kernel_size=4, stride=1, input_size=32
        )  # 32x32x128
        self.encoder_block_7 = ResBlock(
            size * 8, size * 16, kernel_size=4, stride=2, input_size=32
        )  # 16x16x256
        self.encoder_block_8 = ResBlock(
            size * 16, size * 16, kernel_size=4, stride=1, input_size=16
        )  # 16x16x256
        self.encoder_block_9 = ResBlock(
            size * 16, size * 32, kernel_size=4, stride=2, input_size=16
        )  # 8x8x512
        self.encoder_block_10 = ResBlock(
            size * 32, size * 32, kernel_size=4, stride=1, input_size=8
        )  # 8x8x512

        self.decoder_block_1 = UpBlock(size * 32, size * 32, stride=1)  # 8x8x512
        self.decoder_block_2 = UpBlock(size * 32, size * 16, stride=2)  # 16x16x256
        self.decoder_block_3 = UpBlock(size * 16, size * 16, stride=1)  # 16x16x256
        self.decoder_block_4 = UpBlock(size * 16, size * 16, stride=1)  # 16x16x256
        self.decoder_block_5 = UpBlock(size * 16, size * 8, stride=2)  # 32 x 32 x 128
        self.decoder_block_6 = UpBlock(size * 8, size * 8, stride=1)  # 32 x 32 x 128
        self.decoder_block_7 = UpBlock(size * 8, size * 8, stride=1)  # 32 x 32 x 128
        self.decoder_block_8 = UpBlock(size * 8, size * 4, stride=2)  # 64 x 64 x 64
        self.decoder_block_9 = UpBlock(size * 4, size * 4, stride=1)  # 64 x 64 x 64
        self.decoder_block_10 = UpBlock(size * 4, size * 4, stride=1)  # 64 x 64 x 64

        self.decoder_block_11 = UpBlock(size * 4, size * 2, stride=2)  # 128 x 128 x 32
        self.decoder_block_12 = UpBlock(size * 2, size * 2, stride=1)  # 128 x 128 x 32
        self.decoder_block_13 = UpBlock(size * 2, size, stride=2)  # 256 x 256 x 16
        self.decoder_block_14 = UpBlock(size, size, stride=1)  # 256 x 256 x 16

        self.decoder_block_15 = UpBlock(size, 3, stride=1)  # 256 x 256 x 3
        self.decoder_block_16 = UpBlock(3, 3, stride=1)  # 256 x 256 x 3
        self.decoder_block_17 = UpBlock(3, 3, stride=1, activation_fn="sigmoid")  #

    def forward(self, x):
        x = self.input_layer(x)
        x = self.encoder_block_1(x)
        x = self.encoder_block_2(x)
        x = self.encoder_block_3(x)
        x = self.encoder_block_4(x)
        x = self.encoder_block_5(x)
        x = self.encoder_block_6(x)
        x = self.encoder_block_7(x)
        x = self.encoder_block_8(x)
        x = self.encoder_block_9(x)
        x = self.encoder_block_10(x)

        x = self.decoder_block_1(x)
        x = self.decoder_block_2(x)
        x = self.decoder_block_3(x)
        x = self.decoder_block_4(x)
        x = self.decoder_block_5(x)
        x = self.decoder_block_6(x)
        x = self.decoder_block_7(x)
        x = self.decoder_block_8(x)
        x = self.decoder_block_9(x)
        x = self.decoder_block_10(x)
        x = self.decoder_block_11(x)
        x = self.decoder_block_12(x)
        x = self.decoder_block_13(x)
        x = self.decoder_block_14(x)
        x = self.decoder_block_15(x)
        x = self.decoder_block_16(x)
        x = self.decoder_block_17(x)
        return x
