import torch.nn as nn

class Autoencoder(nn.Module):

    def __init__(self, input_flatten=784, latent_space_input=10):
        super(Autoencoder, self).__init__(),

        self.input_flatten = input_flatten
        self.latent_space_input = latent_space_input

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.input_flatten, 500),
            nn.ReLU(True),
            nn.Linear(500, 500),
            nn.ReLU(True),
            nn.Linear(500, 2000),
            nn.ReLU(True),
            nn.Linear(2000, self.latent_space_input),
            nn.Sigmoid()
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_space_input, 2000),
            nn.ReLU(True),
            nn.Linear(2000, 500),
            nn.ReLU(True),
            nn.Linear(500, 500),
            nn.ReLU(True),
            nn.Linear(500, self.input_flatten),
            nn.ReLU(True),
        )