import torch.nn as nn


class Autoencoder(nn.Module):
    def __init__(self, input_size, encoding_size=512, hidden_layers=4):
        super(Autoencoder, self).__init__()
        layers = []

        # Calculate layer sizes
        layer_sizes = [128 * (2**i) for i in range(hidden_layers)]

        # Encoder
        prev_size = input_size
        for size in layer_sizes:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.ReLU())
            prev_size = size

        # Bottleneck layer
        layers.append(nn.Linear(prev_size, encoding_size))
        # Reversed for Decoder
        prev_size = encoding_size
        reversed_layers = []
        for size in reversed(layer_sizes):
            reversed_layers.append(nn.Linear(prev_size, size))
            reversed_layers.append(nn.ReLU())
            prev_size = size

        # Final layer of Decoder
        reversed_layers.append(nn.Linear(prev_size, input_size))
        reversed_layers.append(nn.Sigmoid())

        self.encoder = nn.Sequential(*layers)
        self.decoder = nn.Sequential(*reversed_layers)
        self.encoding_size = encoding_size

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def encode(self, x):
        return self.encoder(x)
