import torch.nn as nn
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d


class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(conv_block, self).__init__()

        padding = kernel_size // 2
        BatchNorm = SynchronizedBatchNorm2d
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            BatchNorm(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            BatchNorm(out_ch),
            nn.ReLU(inplace=True))
        # self.conv_res = nn.Sequential(
        #     nn.Conv2d(in_ch, out_ch, kernel_size=1,
        #               stride=1, padding=0, bias=True),
        #     nn.BatchNorm2d(out_ch),
        #     nn.ReLU(inplace=True),)

    def forward(self, x):
        x = self.conv(x) #+ self.conv_res(x)

        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        BatchNorm = SynchronizedBatchNorm2d
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3,
                      stride=1, padding=1, bias=True),
            BatchNorm(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x
