import torch
import torch.nn as nn


class CAE(nn.Module):
    def __init__(self, input_dim: int = 1024, latent_dim: int = 512) -> None:
        """
        Class which defines a single-layer CAE model and forward pass.
        This has standard loss calculation as proposed by the original CAE paper:
        https://icml.cc/2011/papers/455_icmlpaper.pdf

        NOTE: One could theoretically use a single-layer instance of the DeepCAE class here,
        but the loss calculation is different here.

        Parameters
        ----------
        input_dim : int
            Dimension of each input vector (one instance).

        latent_dim : int
            Dimension of the bottleneck layer.
        """

        super(CAE, self).__init__()
        assert latent_dim > 0, "The latent dimension must be positive"

        # Set parameters
        self.use_cuda = torch.cuda.is_available()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()

    def build_encoder(self):
        encoder_layers = [
            nn.Linear(self.input_dim, self.latent_dim, bias=True),
            nn.Tanh(),
        ]

        return nn.Sequential(*encoder_layers)

    def build_decoder(self):

        decoder_layers = [
            nn.Linear(self.latent_dim, self.input_dim, bias=True),
            nn.Tanh(),
        ]

        return nn.Sequential(*decoder_layers)

    def forward(self, x: torch.Tensor):
        latent = self.encoder(x)
        return self.decoder(latent), latent


class StackedCAE(nn.Module):
    def __init__(self, input_dim: int = 1024, hidden_spec: list = [512, 256]) -> None:
        """
        Class for a stacked CAE with standard loss calculation as proposed by the original CAE paper (as it builds on the CAE class):
        https://icml.cc/2011/papers/455_icmlpaper.pdf

        Parameters
        ----------
        input_dim : int
            Dimension of each input vector (one instance).

        hidden_spec : list
            Defines the number and dimension of hidden layers.
            The last number is the latent dimension.
        """

        super(StackedCAE, self).__init__()
        assert (
            len(hidden_spec) > 0
        ), "No hidden layers defined! Define at least one hidden layer for each encoder and decoder!"

        self.input_dim = input_dim
        self.encoder_spec = hidden_spec
        self.decoder_spec = list(reversed(hidden_spec))
        self.decoder_spec.append(input_dim)

        # Initialize stack
        self.caes = nn.ModuleList()
        for i, size in enumerate(self.encoder_spec):
            if i == 0:
                prev_size = self.input_dim
            else:
                prev_size = self.encoder_spec[i-1]
            cae = CAE(input_dim=prev_size, latent_dim=size)
            self.caes.append(cae)

    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor]:
        "Run input through all encoders."

        hidden = x
        latents = []
        for cae in self.caes:
            latent = cae.encoder(hidden)
            latents.append(latent)
            hidden = latent

        return hidden, latents
    
    def decode(self, x):
        "Run a given embedding through all decoders."

        output = x
        for cae in reversed(self.caes):
            output = cae.decoder(output)

        return output

    def forward(self, x: torch.Tensor):
        "Encode and decode the input"

        encoded, latents = self.encode(x)
        output = self.decode(encoded)
        return output, latents
