import torch
import torch.nn as nn


class DeepCAE(nn.Module):
    def __init__(self, input_dim: int = 1024, hidden_spec: list = [512, 256]) -> None:
        """
        Class for a multi-layer CAE with enhanced loss calculation as proposed in:
        https://arxiv.org/abs/2402.18164

        Parameters
        ----------
        input_dim : int
            Dimension of each input vector (one instance).

        hidden_spec : list
            Defines the number and dimension of hidden layers.
            The last number is the latent dimension.
        """

        super(DeepCAE, self).__init__()
        assert (
            len(hidden_spec) > 0
        ), "No hidden layers defined! Define at least one hidden layer for each encoder and decoder!"

        self.input_dim = input_dim
        self.encoder_spec = hidden_spec
        self.decoder_spec = list(reversed(hidden_spec))
        self.decoder_spec.append(input_dim)
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()

    def build_encoder(self) -> nn.Sequential:
        encoder_layers = []
        prev_size = self.input_dim

        for hidden_size in self.encoder_spec:
            encoder_layers.extend([nn.Linear(prev_size, hidden_size), nn.Tanh()])
            prev_size = hidden_size

        return nn.Sequential(*encoder_layers)

    def build_decoder(self) -> nn.Sequential:
        decoder_layers = []
        prev_size = self.decoder_spec[0]

        for hidden_size in self.decoder_spec[1:]:
            decoder_layers.extend([nn.Linear(prev_size, hidden_size), nn.Tanh()])
            prev_size = hidden_size

        return nn.Sequential(*decoder_layers)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
        "Forward also returning all hidden states of the encoder including its final output and the final reconstruction."

        hidden_states = []
        hidden = x

        for layer in self.encoder:
            hidden = layer(hidden)
            if isinstance(layer, nn.Tanh):
                hidden_states.append(hidden)

        return self.decoder(hidden), hidden_states
