import torch.nn as nn


class MLPEmbeddingLayer(nn.Module):
    def __init__(self, config):
        super(MLPEmbeddingLayer, self).__init__()
        self.embedding = nn.Linear(config.input_size, config.hidden_size[0])
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.embedding(input)
        output = self.relu(output)
        return output


class MLPFullyConnectedLayer(nn.Module):
    def __init__(self, hidden_input, hidden_output):
        super(MLPFullyConnectedLayer, self).__init__()
        self.dense = nn.Linear(hidden_input, hidden_output)
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.dense(input)
        output = self.relu(output)
        return output


class CNNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, padding=0, stride=2):
        super(CNNLayer, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=4, stride=stride, padding=padding
        )
        # self.batch = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.conv(input)
        # output = self.batch(output)
        output = self.relu(output)
        return output


class CNNTrasnposedLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2):
        super(CNNTrasnposedLayer, self).__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1
        )
        # self.batch = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.conv(input)
        # output = self.batch(output)
        output = self.relu(output)
        return output
