import torch
import torch.nn as nn


class ConvAE(nn.Module):
    def __init__(
        self,
        input_dim: int = 1024,
        hidden_spec: list = [512],
        channel_spec: list = [32, 64],
    ) -> None:
        """
        Class for a convolutional autoencoder.
        The model is based on convolutional layers with a 1x1 kernel.

        Parameters
        ----------
        input_dim : int
            Dimension of each input vector (one instance).

        latent_dim : int
            Dimension of the bottleneck layer.

        channel_definition: list
            Number of channels for each convolutional layer.
        """

        super(ConvAE, self).__init__()
        assert (
            input_dim > 0 and hidden_spec[-1] > 0
        ), "Impossible definition of input or latent dimensionality."
        assert (
            len(channel_spec) == 2
        ), "The definition of the hidden layers number of channels must be of length 2."

        self.input_dim = input_dim
        self.channel_definition = channel_spec
        self.encoder_spec = hidden_spec
        self.decoder_spec = list(reversed(hidden_spec))
        self.decoder_spec.extend([self.input_dim])
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()

    def build_encoder(self):
        encoder_layers = [
            nn.Conv1d(1, self.channel_definition[0], kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(
                self.channel_definition[0], self.channel_definition[1], kernel_size=1
            ),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(self.channel_definition[-1] * self.input_dim, self.input_dim),
            nn.ReLU(),
        ]

        prev_size = self.input_dim

        for i, hidden_size in enumerate(self.encoder_spec):
            encoder_layers.append(nn.Linear(prev_size, hidden_size))
            if i != len(self.encoder_spec) - 1:
                encoder_layers.append(nn.ReLU())
                prev_size = hidden_size
            else:
                encoder_layers.append(nn.Tanh())

        return nn.Sequential(*encoder_layers)

    def build_decoder(self):
        decoder_layers = []
        prev_size = self.decoder_spec[0]

        for hidden_size in self.decoder_spec[1:]:
            decoder_layers.extend((nn.Linear(prev_size, hidden_size), nn.ReLU()))

        decoder_layers.extend(
            [
                nn.Linear(self.input_dim, self.channel_definition[-1] * self.input_dim),
                nn.ReLU(),
                nn.Unflatten(
                    1, (self.channel_definition[-1], self.input_dim)
                ),  # Assuming one channel in the input.
                nn.ConvTranspose1d(
                    self.channel_definition[1],
                    self.channel_definition[0],
                    kernel_size=1,
                ),
                nn.ReLU(),
                nn.ConvTranspose1d(self.channel_definition[0], 1, kernel_size=1),
                nn.Tanh(),
            ]
        )
        return nn.Sequential(*decoder_layers)

    def forward(self, x: torch.Tensor):
        latent = self.encoder(x)
        return self.decoder(latent), latent
