import torch
import torch.nn as nn


class StandardAE(nn.Module):
    def __init__(self, input_dim: int = 1024, hidden_spec: list = [512, 256]) -> None:
        """
        Class for a standard autoencoder.

        Parameters
        ----------
        input_dim : int
            Dimension of each input vector (one instance).

        hidden_spec : list
            Defines the number and dimension of hidden layers.
            The last number is the latent dimension.
        """

        super(StandardAE, self).__init__()
        assert (
            len(hidden_spec) > 0
        ), "No hidden layers defined! Define at least one hidden layer for each encoder and decoder!"

        self.input_dim = input_dim
        self.encoder_spec = hidden_spec
        self.decoder_spec = list(reversed(hidden_spec))
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()

    def build_encoder(self):
        encoder_layers = []
        prev_size = self.input_dim

        for i, hidden_size in enumerate(self.encoder_spec):
            encoder_layers.append(nn.Linear(prev_size, hidden_size))
            if i != len(self.encoder_spec) - 1:
                encoder_layers.append(nn.ReLU())
            else:
                encoder_layers.append(nn.Tanh())
            prev_size = hidden_size

        return nn.Sequential(*encoder_layers)

    def build_decoder(self):
        decoder_layers = []
        prev_size = self.decoder_spec[0]

        # Starting with the second element since the first one is the latent dimension.
        for hidden_size in self.decoder_spec[1:]:
            decoder_layers.extend((nn.Linear(prev_size, hidden_size), nn.Tanh()))
            prev_size = hidden_size

        decoder_layers.extend((nn.Linear(prev_size, self.input_dim), nn.Tanh()))
        return nn.Sequential(*decoder_layers)

    def forward(self, x: torch.Tensor):
        latent = self.encoder(x)
        return self.decoder(latent), latent
