import torch.nn as nn


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class Unflatten(nn.Module):
    def __init__(self, channel, height, width):
        super(Unflatten, self).__init__()
        self.channel = channel
        self.height = height
        self.width = width

    def forward(self, input):
        return input.view(input.size(0), self.channel, self.height, self.width)


class Conv1d(nn.Module):
    def __init__(self, inp, oup):
        super(Conv1d, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layer(x)


class ConvDW(nn.Module):
    def __init__(self, inp, oup, stride):
        super(ConvDW, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=3, stride=stride, padding=1,
                      groups=inp, bias=False),
            nn.BatchNorm2d(inp),
            nn.ReLU(),

            nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layer(x)


class ConvTransposeDW(nn.Module):
    def __init__(self, inp, oup, stride):
        super(ConvTransposeDW, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=oup, out_channels=oup, kernel_size=3, stride=stride, padding=1,
                               output_padding=stride-1, groups=1, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layer(x)
