import torch
import torch.nn as nn


class TransformerAE(nn.Module):
    def __init__(
        self,
        input_dim: int = 1024,
        pre_proj: bool = False,
        hidden_spec: list = [512, 256],
        nhead: int = 1,
        dropout: float = 0.1,
    ) -> None:
        """
        Class for a transformer based autoencoder which starts with a linear
        layer (optional), followed by a encoder block and a linear layer.

        Parameters
        ----------
        input_dim : tuple
            Dimension of each input vector (number of features).
            The dimension of each feature is assumed to be one.

        # pre_proj: bool
        #     If True, the model both encoder and decoder will start with a
        #     linear layer. Else they will start with the transformer block.

        hidden_spec : list
            Determines the hidden dimension of each component of the encoder
            and decoder.

        nhead : int
            The number of heads in the multiheadattention models.

        dropout : float
            The dropout value used with the weight matrices of the transformer
            block.
        """

        super(TransformerAE, self).__init__()

        assert len(hidden_spec) == 3 if pre_proj else 2

        self.input_dim = input_dim
        self.model_dim = 1
        self.coder_spec = hidden_spec
        self.nhead = nhead
        self.dropout = dropout
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()

    def build_encoder(self):
        """
        Build the encoder consisting of a single transformer block, followed
        by a single linear layer with tanh activation which projects down to
        the latend dimension.
        """
        encoder_layers = []
        encoder_layers.append(
            nn.TransformerEncoderLayer(
                d_model=self.model_dim,
                nhead=self.nhead,
                dim_feedforward=self.coder_spec[0],
                dropout=self.dropout,
            )
        )
        encoder_layers.extend(
            (
                # nn.Flatten(start_dim=0, end_dim=-1),  # For single datapoint
                nn.Flatten(start_dim=1, end_dim=-1),  # For batch
                nn.Linear(self.input_dim, self.coder_spec[1]),
                nn.Tanh(),
            )
        )
        return nn.Sequential(*encoder_layers)

    def build_decoder(self):
        decoder_layers = []
        decoder_layers.extend(
            (
                nn.Unflatten(
                    # dim=0, # For single datapoint
                    # unflattened_size=(self.coder_spec[1], 1)  # For single
                    dim=1,  # For batch
                    unflattened_size=(self.coder_spec[1], 1),  # For batch
                ),
                nn.TransformerEncoderLayer(
                    d_model=self.model_dim,
                    nhead=self.nhead,
                    dim_feedforward=self.coder_spec[0],
                    dropout=self.dropout,
                ),
            )
        )
        decoder_layers.extend(
            (
                # nn.Flatten(start_dim=0, end_dim=-1),  # For single datapoint
                nn.Flatten(start_dim=1, end_dim=-1),  # For batch
                nn.Linear(self.coder_spec[1], self.input_dim),
                nn.Tanh(),
            )
        )
        return nn.Sequential(*decoder_layers)

    def forward(self, x: torch.Tensor):
        """
        Forward pass throught the autoencoder.

        x : torch.Tensor
            An input tensor of size (batch, seq, feature). E.g. (64, 10, 1)
        """
        assert x.shape[-1] == 1, "Feature dimension has to equal 1."
        assert len(x.shape) == 3, "Input shape not like (batch, seq, feature)"

        latent = self.encoder(x)
        output = self.decoder(latent), latent
        return output
