from torch import nn
import torch.nn.functional as F

from .conv2d import Conv2DMod, ConvTranspose2dMod


class ResBlock(nn.Module):
    
    expansion = 1

    def __init__(
            self, 
            in_channels: int, 
            out_channels: int,
            kernel_size: int=3, 
            stride: int=2, 
            relu: bool=True,
            dropout_p: float=0.
        ):
        super().__init__()

        self.conv1 = Conv2DMod(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = Conv2DMod(
            in_channels=out_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=1, 
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion*out_channels:
            self.shortcut = nn.Sequential(
                Conv2DMod(
                    in_channels=in_channels, 
                    out_channels=self.expansion*out_channels, 
                    kernel_size=1, 
                    stride=stride, 
                    bias=False
                ),
                nn.BatchNorm2d(self.expansion*out_channels)
            )

        self.relu = nn.ReLU() if relu else nn.Identity()


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out
    


class ResTransposeBlock(nn.Module):
    
    expansion = 1

    def __init__(
            self, 
            in_channels: int, 
            out_channels: int,
            kernel_size: int=3, 
            stride: int=2, 
            relu: bool=True,
            dropout_p: float=0.
        ):
        super().__init__()

        self.conv1 = ConvTranspose2dMod(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = ConvTranspose2dMod(
            in_channels=out_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=1, 
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion*out_channels:
            self.shortcut = nn.Sequential(
                ConvTranspose2dMod(
                    in_channels=in_channels, 
                    out_channels=self.expansion*out_channels, 
                    kernel_size=1, 
                    stride=stride, 
                    bias=False
                ),
                nn.BatchNorm2d(self.expansion*out_channels)
            )

        self.relu = nn.ReLU() if relu else nn.Identity()
            

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out