import torch
import torch.nn as nn
from abc import ABC, abstractmethod

class AbstractLatentAutoencoder(nn.Module, ABC):
    def __init__(self, d_lm, d_latent_len, d_latent_width):
        """
        Abstract class for encoder-decoder latent models for table data.

        Args:
            d_lm (int): Dimension of input embeddings (value, column name, dtype, metadata).
            d_latent_len (int): Length of the latent compressed representation.
            d_latent_width (int): Width of the latent compressed representation.
        """
        super(AbstractLatentAutoencoder, self).__init__()
        self.d_lm = d_lm
        self.d_latent_len = d_latent_len
        self.d_latent_width = d_latent_width

    def encode(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None):
        """
        Encodes input embeddings into a compressed latent representation.

        Args:
            value_embedding (torch.Tensor): Input value embeddings of shape (B, n_cols, d_lm).
            column_name_embedding (torch.Tensor): Column name embeddings of shape (n_cols, d_lm).
            dtype_embedding (torch.Tensor): Dtype embeddings of shape (n_cols, d_lm).
            metadata_embedding (torch.Tensor): Metadata embedding of shape (d_lm).
            attention_mask (torch.Tensor, optional): Mask for input tensor of shape (B, n_cols). 
                                                    Defaults to None.

        Returns:
            torch.Tensor: Compressed latent embedding of shape (B, d_latent_len, d_latent_width).
        """
        input_tensor = self._prepare_encoder_input(
            value_embedding, 
            column_name_embedding, 
            dtype_embedding, 
            metadata_embedding, 
            attention_mask
        )
        return self.encoder(input_tensor)

    
    def decode(self, compressed_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None):
        """
        Decodes compressed latent representation back into value embeddings.

        Args:
            compressed_embedding (torch.Tensor): Compressed latent embeddings of shape (B, d_latent_len, d_latent_width).
            column_name_embedding (torch.Tensor): Column name embeddings of shape (n_cols, d_lm).
            dtype_embedding (torch.Tensor): Dtype embeddings of shape (n_cols, d_lm).
            metadata_embedding (torch.Tensor): Metadata embedding of shape (d_lm).
            attention_mask (torch.Tensor, optional): Mask for input tensor of shape (B, n_cols). 
                                                    Defaults to None.

        Returns:
            torch.Tensor: Decoded value embeddings of shape (B, d_latent_len, d_latent_width).
        """
        input_tensor = self._prepare_decoder_input(
            compressed_embedding, 
            column_name_embedding, 
            dtype_embedding, 
            metadata_embedding, 
            attention_mask
        )
        return self.decoder(input_tensor)

    @abstractmethod
    def _prepare_encoder_input(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None):
        """
        Prepare input for the encoder. Must be implemented in a derived class.

        Args:
            value_embedding (torch.Tensor): Input value embeddings.
            column_name_embedding (torch.Tensor): Column name embeddings.
            dtype_embedding (torch.Tensor): Dtype embeddings.
            metadata_embedding (torch.Tensor): Metadata embedding.
            attention_mask (torch.Tensor, optional): Mask for input tensor. Defaults to None.
        """
        pass

    @abstractmethod
    def _prepare_decoder_input(self, compressed_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None):
        """
        Prepare input for the decoder. Must be implemented in a derived class.

        Args:
            compressed_embedding (torch.Tensor): Compressed latent embeddings.
            column_name_embedding (torch.Tensor): Column name embeddings.
            dtype_embedding (torch.Tensor): Dtype embeddings.
            metadata_embedding (torch.Tensor): Metadata embedding.
            attention_mask (torch.Tensor, optional): Mask for input tensor. Defaults to None.
        """
        pass

    def forward(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None):
        """
        Forward pass through the autoencoder.

        Args:
            value_embedding (torch.Tensor): Input value embeddings of shape (B, n_cols, d_lm).
            column_name_embedding (torch.Tensor): Column name embeddings of shape (n_cols, d_lm).
            dtype_embedding (torch.Tensor): Dtype embeddings of shape (n_cols, d_lm).
            metadata_embedding (torch.Tensor): Metadata embedding of shape (d_lm).
            attention_mask (torch.Tensor, optional): Mask for input tensor of shape (B, n_cols). 
                                                    Defaults to None.

        Returns:
            torch.Tensor: Decoded value embeddings of shape (B, d_latent_len, d_latent_width).
        """
        latent = self.encode(
            value_embedding, 
            column_name_embedding, 
            dtype_embedding, 
            metadata_embedding, 
            attention_mask
        )
        return self.decode(
            latent, 
            column_name_embedding, 
            dtype_embedding, 
            metadata_embedding, 
            attention_mask
        )