import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tableLatent.perceive.perceiveResampler import PerceiverResampler,Transformer
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F

class FuseLayer(nn.Module):
    # Apply gating mechanism to fuse information from metadata/dtype/col name/col value.
    # Reference: https://arxiv.org/pdf/2307.09249v2
    def __init__(self, embedding_dim):
        super(FuseLayer, self).__init__()
        
        # Linear layers for each embedding type (data context, column type, column name, column value)
        self.w_dc = nn.Linear(embedding_dim, embedding_dim)
        self.w_dt = nn.Linear(embedding_dim, embedding_dim)
        self.w_cn = nn.Linear(embedding_dim, embedding_dim)
        self.w_cv = nn.Linear(embedding_dim, embedding_dim)

        # Bias terms for each embedding
        self.b_dc = nn.Parameter(torch.zeros(embedding_dim))
        self.b_dt = nn.Parameter(torch.zeros(embedding_dim))
        self.b_cn = nn.Parameter(torch.zeros(embedding_dim))
        self.b_cv = nn.Parameter(torch.zeros(embedding_dim))

        # Gating vectors to compute the gates
        self.v_dc = nn.Linear(embedding_dim, 1)
        self.v_dt = nn.Linear(embedding_dim, 1)
        self.v_cn = nn.Linear(embedding_dim, 1)
        self.v_cv = nn.Linear(embedding_dim, 1)
        
    def forward(self, x_dc, x_dt, x_cn, x_cv):
        # Apply ReLU and linear transformations to each embedding (B, L, D)
        dc_transformed = torch.relu(self.w_dc(x_dc) + self.b_dc)
        dt_transformed = torch.relu(self.w_dt(x_dt) + self.b_dt)
        cn_transformed = torch.relu(self.w_cn(x_cn) + self.b_cn)
        cv_transformed = torch.relu(self.w_cv(x_cv) + self.b_cv)

        # Compute gating values using sigmoid activation (B, L, 1)
        g_dc = torch.sigmoid(self.v_dc(dc_transformed))  # Shape (B, L, 1)
        g_dt = torch.sigmoid(self.v_dt(dt_transformed))  # Shape (B, L, 1)
        g_cn = torch.sigmoid(self.v_cn(cn_transformed))  # Shape (B, L, 1)
        g_cv = torch.sigmoid(self.v_cv(cv_transformed))  # Shape (B, L, 1)
        
        # Compute the final fused embedding as a weighted sum (B, L, D)
        fused_embedding = (g_dc * x_dc) + (g_dt * x_dt) + (g_cn * x_cn) + (g_cv * x_cv)
        
        return fused_embedding


def pick_most_likely_category(embedding_tensor_np,val_embeddings_this_column):
    # Convert category embeddings to a matrix for similarity computation
    category_embeddings = np.stack([val for val in val_embeddings_this_column.values()])

    # Calculate cosine similarity
    # This results in a matrix of shape (N, number_of_categories)
    similarities = cosine_similarity(embedding_tensor_np, category_embeddings)

    # Find the indices of the max similarity for each row
    max_similarity_indices = np.argmax(similarities, axis=1)

    # Map indices back to categories
    categories = list(val_embeddings_this_column.keys())
    selected_categories = [categories[idx] for idx in max_similarity_indices]

    # Convert the selected categories into a pandas Series
    selected_categories_series = pd.Series(selected_categories)

    return selected_categories_series 

class PerceiveAggregator(nn.Module):
    _DTYPE_SUPPORT = [0,1]
    def __init__(self, 
        *,
        dim, # dimension of input embedding
        dim_latent,
        depth,
        dim_head=64,
        num_latents=16,
        max_seq_len=64,
        ff_mult=4,
        legacy=False,
        l2_normalize_latents=False,
        fuse_option="flatten",
        **kwargs,
        ) -> None:
        super().__init__()

        self.fuse_option = fuse_option
        self.dim = dim if fuse_option != "flatten" else dim * 2 # Dimension quad due to flattening.
        self.dim_latent = dim_latent
        self.max_seq_len = max_seq_len

        self.perceiver_encoder = PerceiverResampler(dim=self.dim, dim_latent=self.dim_latent, depth=depth, dim_head=dim_head,
                                                        num_latents=num_latents, max_seq_len=self.max_seq_len, ff_mult=ff_mult, l2_normalize_latents=l2_normalize_latents)
        
        if fuse_option == 'fuse':
            self.fuse_layer = FuseLayer(dim) 
            self.dtype_embedding = nn.Embedding(len(self._DTYPE_SUPPORT), dim)

    def fuse_name_val_embedding(self, input_tensor, attention_mask, dtype, meta,fuse_option):
        # input_tensor: (B, num_col*2, D)
        # attention_mask: (B, num_col*2)
        # dtype: (-1, num_col)
        # meta: (-1, D)
        
        batch_size, column_number_times_2, embedding_dim = input_tensor.shape
        column_number = int(column_number_times_2 / 2)
        
        if fuse_option == "flatten":
            #if dtype is not None or meta is not None:
            #    raise ValueError("Dtype and metadata embedding are provided, but no supported for fuse option flatten!")
            reshaped_tensor = input_tensor.view(batch_size, column_number, 2, embedding_dim).reshape(batch_size, column_number, embedding_dim * 2)
            reshape_attention_mask = attention_mask[:, ::2] # Slice attention_mask based on values

        elif fuse_option == "fuse":
            # turn dtypes label into embeddings, then        
            # expand the shape of dtype/meta embeddings to have shape (B, num_col, D)
            dtype_emb = self.dtype_embedding(dtype).repeat(batch_size, 1, 1)
            #print(dtype.shape,dtype_emb.shape)
            meta_emb = meta.unsqueeze(0).repeat(batch_size, column_number, 1)
            
            col_names_emb, col_val_emb = input_tensor[:, ::2, :], input_tensor[:, 1::2, :]

            # use fuse layer to reduce shape back to (B, num_cols, D)
            reshaped_tensor = self.fuse_layer(meta_emb, dtype_emb, col_names_emb, col_val_emb)
            reshape_attention_mask = attention_mask[:, ::2]
            pass
        elif fuse_option == "No":
            return input_tensor,attention_mask
        else:
            raise NotImplementedError(f"Option {fuse_option} for fusing column name/value embeddings is not supported!!")
        
        #print("Returning reshaped_tensor:", type(reshaped_tensor))
        return reshaped_tensor,reshape_attention_mask


    def forward(self, input_tensor, attention_mask, dtype, meta):
        input_tensor,attention_mask = self.fuse_name_val_embedding(input_tensor,attention_mask,dtype, meta,self.fuse_option)
        #print(type(input_tensor),input_tensor)
        #print(input_tensor.shape, attention_mask.shape)
        return self.perceiver_encoder(input_tensor, mask=attention_mask.bool())

    
class TransformerRowDecoder(nn.Module):
    def __init__(self, aggregated_dim=128, output_dim=128, depth=4, lm_emb=768, num_heads=8, **kwargs):
        """
        Initialize the TransformerRowDecoder.

        Args:
        - aggregated_dim (int): Dimension of the aggregated embeddings.
        - output_dim (int): Output dimension of the decoder.
        - depth (int): Depth of the Transformer model.
        - lm_emb (int): Embedding size from the LM layer.
        - num_heads (int): Number of attention heads.
        """
        super(TransformerRowDecoder, self).__init__()
        self.output_dim = output_dim
        self.aggregated_dim = aggregated_dim
        self.lm_emb = lm_emb

        # Linear layers to project column name embeddings and metadata to aggregated_dim
        self.column_proj = nn.Linear(lm_emb, aggregated_dim)
        self.metadata_proj = nn.Linear(lm_emb, aggregated_dim)

        # Transformer encoder for metadata and row latent embeddings
        encoder_layer = nn.TransformerEncoderLayer(d_model=aggregated_dim, nhead=num_heads, batch_first=True)
        self.metadata_row_transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # Cross-attention module
        self.cross_attention = nn.MultiheadAttention(embed_dim=aggregated_dim, num_heads=num_heads, batch_first=True)

        # Final projection layer to map back to lm_emb
        self.final_proj = nn.Linear(aggregated_dim, output_dim)

    def forward(self, column_names_emb, metadata_emb, row_latent_emb):
        """
        Forward pass of TransformerRowDecoder.

        Args:
        - column_names_emb: Embedding of column names (B, num_cols, lm_emb).
        - metadata_emb: Embedding of table context metadata (B, 1, lm_emb).
        - row_latent_emb: Latent embedding of row information (B, num_latent, aggregated_dim).

        Returns:
        - Decoded embeddings for each column value (B, num_cols, lm_emb).
        """

        B, num_cols, _ = column_names_emb.shape

        # Project column names and metadata embeddings to aggregated_dim
        column_names_emb_proj = self.column_proj(column_names_emb)  # (B, num_cols, aggregated_dim)
        metadata_emb_proj = self.metadata_proj(metadata_emb)        # (B, 1, aggregated_dim)

        # Concatenate metadata and row_latent embeddings
        metadata_row_emb = torch.cat([metadata_emb_proj, row_latent_emb], dim=1)  # (B, 1 + num_latent, aggregated_dim)

        # Process metadata_row_emb via transformer encoder
        metadata_row_emb = self.metadata_row_transformer(metadata_row_emb)  # (B, 1 + num_latent, aggregated_dim)

        # Apply cross-attention from column_names_emb_proj (queries) to metadata_row_emb (keys/values)
        attn_output, attn_weights = self.cross_attention(
            query=column_names_emb_proj,  # (B, num_cols, aggregated_dim)
            key=metadata_row_emb,         # (B, 1 + num_latent, aggregated_dim)
            value=metadata_row_emb        # (B, 1 + num_latent, aggregated_dim)
        )
        # attn_output shape: (B, num_cols, aggregated_dim)

        # Project the outputs back to lm_emb
        decoded_column_values = self.final_proj(attn_output)  # (B, num_cols, lm_emb)

        # Return only the decoded column values
        return decoded_column_values
    
class CondReconstructor(nn.Module):
    def __init__(self, lm_emb=768, output_dim=128, num_decoder_dim=1):
        """
        Initialize the Reconstructor.

        Args:
        - lm_emb (int): Embedding size from the TransformerRowDecoder output.
        - output_dim (int): Dimension of the decoder output embeddings.
        - num_decoder_dim (int): Dimension for the numerical column decoder (default 1).
        """
        super(CondReconstructor, self).__init__()

        # Shared decoder for numerical columns
        self.num_decoder = nn.Sequential(
            nn.Linear(output_dim, num_decoder_dim),
            nn.Sigmoid()
        )

        # Linear layer to refine unique embeddings for categorical columns
        self.refine_embeddings = nn.Linear(lm_emb, output_dim)

    def forward(self, decoder_output, dtype_tensor, unique_embedding_list, return_prob=True):
        """
        Forward pass of the Reconstructor.

        Args:
        - decoder_output (torch.Tensor): Output embeddings from TransformerRowDecoder
                                         (B, num_cols + 1 + num_latent, output_dim).
        - dtype_tensor (torch.Tensor): Tensor indicating the data type of each column 
                                       (0 for categorical, 1 for numerical) (num_cols,).
        - unique_embedding_list (list): List of unique embeddings tensors for each categorical column.

        Returns:
        - recon_num (torch.Tensor): Reconstructed numerical tensor of shape (B, num_numerical_cols).
        - recon_cat (list): List of reconstructed categorical tensors, where each tensor has shape 
                            (B, number_unique_levels_this_column).
        """
        dtype_tensor = dtype_tensor.squeeze()  # Remove first dimension (1)
        num_cols = len(dtype_tensor)
        recon_num = []  # To store reconstructed numerical column results
        recon_cat = []  # To store reconstructed categorical column results
        cat_col_idx = 0

        # Iterate over columns and process based on dtype
        for col in range(num_cols):
            # Get the dtype for the current column
            dtype = dtype_tensor[col].item()

            # Get the output embedding for the current column
            col_embedding = decoder_output[:, col, :]  # Shape (B, output_dim)

            if dtype == 1:  # Numerical column
                # Pass through the shared decoder for numerical values
                num_recon = self.num_decoder(col_embedding)  # Shape (B, 1)
                recon_num.append(num_recon)

            elif dtype == 0:  # Categorical column
                # Get the unique embeddings for the current categorical column
                unique_embeddings = unique_embedding_list[cat_col_idx]  # Shape (num_unique_levels, output_dim)

                # Refine unique embeddings using a linear layer
                refined_unique_embeddings = self.refine_embeddings(unique_embeddings)  # Shape (num_unique_levels, output_dim)

                # Calculate similarity between predicted embedding and refined unique category embeddings
                similarity = F.cosine_similarity(col_embedding.unsqueeze(1), refined_unique_embeddings.unsqueeze(0), dim=2)

                # Apply softmax to get normalized class probabilities
                class_probabilities = F.softmax(similarity, dim=1)  # Shape (B, num_unique_levels)

                recon_cat.append(class_probabilities)
                cat_col_idx += 1

        # Stack recon_num results into a tensor with shape (B, num_numerical_cols)
        if recon_num:
            recon_num = torch.cat(recon_num, dim=1)  # Shape (B, num_numerical_cols)
        else:
            recon_num = torch.tensor([])  # If no numerical columns

        return recon_num, recon_cat
    
class TransformerCondDecoder(nn.Module):
    def __init__(self, lm_emb=768, aggregated_dim=128, depth=4, output_dim=128,num_decoder_dim=1,**kwargs):
        """
        Initialize the TransformerCondDecoder.

        Args:
        - lm_emb (int): Embedding size from the TransformerRowDecoder output.
        - aggregated_dim (int): Dimension of the aggregated embeddings.
        - depth (int): Number of layers in TransformerRowDecoder.
        - num_decoder_dim (int): Dimension for the numerical column decoder (default 1).
        """
        super(TransformerCondDecoder, self).__init__()

        # Instantiate TransformerRowDecoder
        self.transformer_row_decoder = TransformerRowDecoder(
            aggregated_dim=aggregated_dim,
            output_dim=output_dim,
            depth=depth,
            lm_emb=lm_emb
        )

        # Instantiate Reconstructor
        self.reconstructor = CondReconstructor(lm_emb=lm_emb, output_dim=output_dim, num_decoder_dim=num_decoder_dim)

    def forward(self, column_names_emb, metadata_emb, row_latent_emb, dtype_tensor, unique_embedding_list):
        """
        Forward pass through TransformerRowDecoder and Reconstructor.

        Args:
        - column_names_emb (torch.Tensor): Embedding of column names (B, num_cols, lm_emb).
        - metadata_emb (torch.Tensor): Embedding of table context metadata (B, 1, lm_emb).
        - row_latent_emb (torch.Tensor): Latent embedding of row information (B, num_latent, aggregated_dim).
        - dtype_tensor (torch.Tensor): Tensor indicating the data type of each column (0 for categorical, 1 for numerical).
        - unique_embedding_list (list): List of unique embeddings tensors for each column.

        Returns:
        - recon_num (torch.Tensor): Reconstructed numerical tensor of shape (B, num_numerical_cols).
        - recon_cat (list): List of reconstructed categorical tensors, each tensor has shape (B, number_unique_levels_this_column).
        """
        # Pass input through TransformerRowDecoder
        decoder_output = self.transformer_row_decoder(column_names_emb, metadata_emb, row_latent_emb)
        
        # Pass the decoder output through Reconstructor
        recon_num, recon_cat = self.reconstructor(decoder_output, dtype_tensor, unique_embedding_list)
        
        return recon_num, recon_cat
    
    def encode_embs(self, column_names_emb, metadata_emb, concat=True):
        # Project column names and metadata embeddings to aggregated_dim
        column_names_emb_proj = self.transformer_row_decoder.column_proj(column_names_emb)  # (B, num_cols, aggregated_dim)
        metadata_emb_proj = self.transformer_row_decoder.metadata_proj(metadata_emb)        # (B, 1, aggregated_dim)
        
        if concat:
            return torch.concat([metadata_emb_proj,column_names_emb_proj],dim=1)
        else:
            return metadata_emb_proj,column_names_emb_proj

class LatentAutoEncoder(nn.Module):
    def __init__(self, encoder_params, decoder_params):
        """
        Initialize the AutoEncoder with encoder and decoder parameters.

        Args:
        - encoder_params (dict): Dictionary of parameters for the encoder (PerceiveAggregator).
        - decoder_params (dict): Dictionary of parameters for the decoder (TransformerCondDecoder).
        """
        super(LatentAutoEncoder, self).__init__()
        self.encoder = PerceiveAggregator(**encoder_params)
        self.decoder = TransformerCondDecoder(**decoder_params)

    def encode(self, input_tensor, attention_mask, dtype, meta):
        """
        Encode the input data into a latent representation.

        Args:
        - input_tensor (torch.Tensor): Input tensor of shape appropriate for the encoder.
        - attention_mask (torch.Tensor): Attention mask tensor.
        - dtype (torch.Tensor): Data type tensor.
        - meta (torch.Tensor): Metadata tensor.

        Returns:
        - latent_code (torch.Tensor): Latent representation produced by the encoder.
        """
        latent_code = self.encoder(input_tensor, attention_mask, dtype, meta)
        return latent_code

    def decode(self, column_names_emb, metadata_emb, latent_code, dtype_tensor, unique_embedding_list):
        """
        Decode the latent representation back to the original data space.

        Args:
        - column_names_emb (torch.Tensor): Embedding of column names.
        - metadata_emb (torch.Tensor): Embedding of table context metadata.
        - latent_code (torch.Tensor): Latent representation from the encoder.
        - dtype_tensor (torch.Tensor): Tensor indicating the data type of each column.
        - unique_embedding_list (list): List of unique embeddings tensors for each column.

        Returns:
        - recon_num (torch.Tensor): Reconstructed numerical tensor.
        - recon_cat (list): List of reconstructed categorical tensors.
        """
        recon_num, recon_cat = self.decoder(
            column_names_emb, metadata_emb, latent_code, dtype_tensor, unique_embedding_list
        )
        return recon_num, recon_cat

    def forward(self, input_tensor, attention_mask, dtype, meta, column_names_emb, unique_embedding_list):
        """
        Forward pass through the AutoEncoder: encode then decode.

        Args:
        - input_tensor (torch.Tensor): Input tensor for the encoder.
        - attention_mask (torch.Tensor): Attention mask tensor for the encoder.
        - dtype (torch.Tensor): Data type tensor for the encoder.
        - meta (torch.Tensor): Metadata tensor for the encoder.
        - column_names_emb (torch.Tensor): Embedding of column names for the decoder.
        - metadata_emb (torch.Tensor): Embedding of table context metadata for the decoder.
        - dtype_tensor (torch.Tensor): Data type tensor for the decoder.
        - unique_embedding_list (list): List of unique embeddings tensors for each column for the decoder.

        Returns:
        - recon_num (torch.Tensor): Reconstructed numerical tensor.
        - recon_cat (list): List of reconstructed categorical tensors.
        """
        latent_code = self.encode(input_tensor, attention_mask, dtype, meta)
        recon_num, recon_cat = self.decode(
            column_names_emb, meta, latent_code, dtype, unique_embedding_list
        )
        return recon_num, recon_cat
    
class ResidualBlock1(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ResidualBlock1, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(inplace=True),
            nn.Linear(output_dim, output_dim),
        )
        # Use nn.Sequential to add more layers if needed
        self.shortcut = nn.Sequential()
        if input_dim != output_dim:
            self.shortcut = nn.Sequential(nn.Linear(input_dim, output_dim))

    def forward(self, x):
        out = self.fc(x) + self.shortcut(x)
        out = nn.ReLU()(out)  # Activation after adding the shortcut
        return out

class FullyConnectedNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=1024, num_layers=2):
        super(FullyConnectedNetwork, self).__init__()
        self.resblock1 = ResidualBlock1(input_dim, hidden_dim)
        self.resblock2 = ResidualBlock1(hidden_dim, hidden_dim)
        self.fc_final = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = self.resblock1(x)
        out = self.resblock2(out)
        out = self.fc_final(out)
        return out

class PerceiveCondDecoder(nn.Module):
    def __init__(self, aggregated_dim=128, output_dim=128, depth=4, max_seq_len=3, lm_emb=768, decoder_type="categorical",**kwargs):
        """
        Initialize the PerceiveCondDecoder.

        Args:
        - aggregated_dim (int): Dimension of the aggregated embeddings.
        - output_dim (int): Output dimension of the decoder.
        - depth (int): Depth of the Perceiver model.
        - max_seq_len (int): Maximum sequence length.
        - lm_emb (int): Embedding size from the LM layer.
        - decoder_type (str): Type of decoder ("numerical" or "categorical"). If "numerical", a sigmoid activation is applied.
        """
        super().__init__()
        self.output_dim = output_dim
        self.aggregated_dim = aggregated_dim
        self.decoder_type = decoder_type

        self.model = PerceiverResampler(dim=aggregated_dim, dim_latent=aggregated_dim, depth=depth, max_seq_len=max_seq_len, num_latents=1)
        self.name_layer = FullyConnectedNetwork(lm_emb, aggregated_dim)  # Downsize LM layer to aggregated_dim

        if self.aggregated_dim != self.output_dim:
            self.linear = nn.Linear(aggregated_dim, output_dim)

        if self.decoder_type == "numerical":
            self.sigmoid = nn.Sigmoid()

    def forward(self, aggregated, col_name_emb):
        bsz = len(aggregated)
        emb = self.name_layer(col_name_emb).unsqueeze(0)
        emb = emb.repeat((bsz, 1, 1))  # Shape: (bsz, 1, agg_dim)
        
        aggregated = torch.cat([emb, aggregated], dim=1)  # Shape: (bsz, num_latent+1, agg_dim)
        decoded = self.model(aggregated)

        if self.aggregated_dim != self.output_dim:
            decoded = self.linear(decoded)

        # Apply sigmoid activation if the decoder type is numerical, to ensure output to uniform [0,1]
        if self.decoder_type == "numerical":
            decoded = self.sigmoid(decoded)

        return decoded

