import math

import torch.nn as nn


class Conv2DMod(nn.Module):

    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            kernel_size: int=3, 
            stride: int=2, 
            padding: int=None, 
            **kwargs
        ):
        super().__init__()

        if padding is None:
            padding = math.ceil((kernel_size-stride)/2)

        self.conv = nn.Conv2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding,
            **kwargs
        )

    def forward(self, x):
       return self.conv(x)
    


class ConvTranspose2dMod(nn.Module):

    def __init__(
            
            self, 
            in_channels: int, 
            out_channels: int, 
            kernel_size: int=3, 
            stride: int=2, 
            padding: int=None, 
            output_padding: int=None, 
            **kwargs
        ):
        super().__init__()

        if padding is None:
            if output_padding is None:
                output_padding = (kernel_size - stride) % 2
            padding = int((kernel_size - stride + output_padding) / 2)
        else:
            if output_padding is None:
                output_padding = 2*padding - kernel_size + stride

        self.deconv = nn.ConvTranspose2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding, 
            output_padding=output_padding,
            **kwargs
        )

    def forward(self, x):
        x = self.deconv(x)
        return x


class Conv2DBlock(nn.Module):

    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            kernel_size: int=3, 
            stride: int=2, 
            padding: int=None, 
            relu: bool=True,
            dropout_p: float=0.
        ):
        super(Conv2DBlock, self).__init__()

        self.conv = Conv2DMod(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU() if relu else nn.Identity()
        self.dropout = nn.Dropout2d(dropout_p)


    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x



class ConvTranspose2DBlock(nn.Module):

    def __init__(
            
            self, 
            in_channels: int, 
            out_channels: int, 
            kernel_size: int=3, 
            stride: int=2, 
            padding: int=None, 
            output_padding: int=None, 
            relu: bool=True,
            dropout_p: float=0.
        ):
        super().__init__()

        self.deconv = ConvTranspose2dMod(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding, 
            output_padding=output_padding
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU() if relu else nn.Identity()
        self.dropout = nn.Dropout2d(dropout_p)


    def forward(self, x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x