import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from .perceiveResampler import PerceiverResampler
import torch.nn.functional as F
from itertools import combinations

class FuseLayer(nn.Module):
    # Apply gating mechanism to fuse information from metadata/dtype/col name/col value.
    # Reference: https://arxiv.org/pdf/2307.09249v2
    def __init__(self, embedding_dim, fusion_strategy="all"):
        """
        Initialize FuseLayer with different fusion strategies.
        
        Args:
            embedding_dim: Dimension of the input embeddings
            fusion_strategy: One of "name_value", "dtype_name_value", or "all"
        """
        super(FuseLayer, self).__init__()
        
        self.fusion_strategy = fusion_strategy
        
        # Linear layers for each embedding type
        if "all" in fusion_strategy:
            self.w_dc = nn.Linear(embedding_dim, embedding_dim)  # data context
        if "dtype" in fusion_strategy or "all" in fusion_strategy:
            self.w_dt = nn.Linear(embedding_dim, embedding_dim)  # data type
        self.w_cn = nn.Linear(embedding_dim, embedding_dim)      # column name
        self.w_cv = nn.Linear(embedding_dim, embedding_dim)      # column value
        
        # Bias terms for each embedding
        if "all" in fusion_strategy:
            self.b_dc = nn.Parameter(torch.zeros(embedding_dim))
        if "dtype" in fusion_strategy or "all" in fusion_strategy:
            self.b_dt = nn.Parameter(torch.zeros(embedding_dim))
        self.b_cn = nn.Parameter(torch.zeros(embedding_dim))
        self.b_cv = nn.Parameter(torch.zeros(embedding_dim))
        
        # Gating vectors to compute the gates
        if "all" in fusion_strategy:
            self.v_dc = nn.Linear(embedding_dim, 1)
        if "dtype" in fusion_strategy or "all" in fusion_strategy:
            self.v_dt = nn.Linear(embedding_dim, 1)
        self.v_cn = nn.Linear(embedding_dim, 1)
        self.v_cv = nn.Linear(embedding_dim, 1)
        
    def forward(self, x_dc=None, x_dt=None, x_cn=None, x_cv=None):
        """
        Forward pass for different fusion strategies.
        
        Args:
            x_dc: Data context embedding (metadata) - optional based on strategy
            x_dt: Data type embedding - optional based on strategy
            x_cn: Column name embedding - required
            x_cv: Column value embedding - required
            
        Returns:
            Fused embedding tensor
        """
        if x_cn is None or x_cv is None:
            raise ValueError("Column name and value embeddings are required")
            
        # Get batch size and sequence length from column name embeddings
        batch_size = x_cn.size(0)
        seq_len = x_cn.size(1)
        
        # Initialize components for fusion
        components = []
        gates = []
        
        # Process data context (metadata) if in fusion strategy
        if "all" in self.fusion_strategy and x_dc is not None:
            # Ensure proper dimensions
            if x_dc.size(1) != seq_len:
                x_dc = x_dc.expand(batch_size, seq_len, -1)
            # Apply transformations
            dc_transformed = torch.relu(self.w_dc(x_dc) + self.b_dc)
            g_dc = torch.sigmoid(self.v_dc(dc_transformed))
            components.append(x_dc)
            gates.append(g_dc)
            
        # Process data type if in fusion strategy
        if ("dtype" in self.fusion_strategy or "all" in self.fusion_strategy) and x_dt is not None:
            # Ensure proper dimensions
            if x_dt.size(1) != seq_len:
                x_dt = x_dt.expand(batch_size, seq_len, -1)
            # Apply transformations
            dt_transformed = torch.relu(self.w_dt(x_dt) + self.b_dt)
            g_dt = torch.sigmoid(self.v_dt(dt_transformed))
            components.append(x_dt)
            gates.append(g_dt)
        
        # Process column name (always included)
        cn_transformed = torch.relu(self.w_cn(x_cn) + self.b_cn)
        g_cn = torch.sigmoid(self.v_cn(cn_transformed))
        components.append(x_cn)
        gates.append(g_cn)
        
        # Process column value (always included)
        cv_transformed = torch.relu(self.w_cv(x_cv) + self.b_cv)
        g_cv = torch.sigmoid(self.v_cv(cv_transformed))
        components.append(x_cv)
        gates.append(g_cv)
        
        # Compute the final fused embedding as a weighted sum
        fused_embedding = sum(g * x for g, x in zip(gates, components))
        
        return fused_embedding

class MoEGate(nn.Module):
    def __init__(self, in_dim, n_modal, hidden=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, n_modal)      # one logit per modality
        )

    def forward(self, h):          # h: (B, in_dim)
        return self.mlp(h)          # -> (B, n_modal) logits

class PerceiveAggregator(nn.Module):
    _DTYPE_SUPPORT = [0,1]
    def __init__(self, 
        *,
        dim, # dimension of input embedding
        dim_latent,
        depth,
        dim_head=64,
        num_latents=16,
        max_seq_len=64,
        ff_mult=4,
        legacy=False,
        l2_normalize_latents=False,
        fuse_option="flatten",
        fusion_strategy="name_value",  # New parameter
        **kwargs,
        ) -> None:
        super().__init__()

        self.fuse_option = fuse_option
        self.fusion_strategy = fusion_strategy
        self.dim = dim if fuse_option != "flatten" else dim * 2 # Dimension quad due to flattening.
        self.dim_latent = dim_latent
        self.max_seq_len = max_seq_len

        self.perceiver_encoder = PerceiverResampler(dim=self.dim, dim_latent=self.dim_latent, depth=depth, dim_head=dim_head,
                                                        num_latents=num_latents, max_seq_len=self.max_seq_len, ff_mult=ff_mult, l2_normalize_latents=l2_normalize_latents)
        
        if fuse_option == 'fuse':
            self.fuse_layer = FuseLayer(dim, fusion_strategy=fusion_strategy) 
            self.dtype_embedding = nn.Embedding(len(self._DTYPE_SUPPORT), dim)

    def fuse_name_val_embedding(self, input_tensor, attention_mask, dtype, meta, fuse_option):
        # input_tensor: (B, num_col*2, D)
        # attention_mask: (B, num_col*2)
        # dtype: (-1, num_col, D)
        # meta: (-1, D)
        
        batch_size, column_number_times_2, embedding_dim = input_tensor.shape
        column_number = int(column_number_times_2 / 2)
        
        if fuse_option == "flatten":
            reshaped_tensor = input_tensor.view(batch_size, column_number, 2, embedding_dim).reshape(batch_size, column_number, embedding_dim * 2)
            reshape_attention_mask = attention_mask[:, ::2]
            
        elif fuse_option == "fuse":
            # Extract column embeddings
            col_names_emb, col_val_emb = input_tensor[:, ::2, :], input_tensor[:, 1::2, :]
            
            # Prepare inputs based on fusion strategy
            fuse_inputs = {"x_cn": col_names_emb, "x_cv": col_val_emb}
            
            if "dtype" in self.fusion_strategy or "all" in self.fusion_strategy:
                # Convert dtype to long/int tensor if needed
                if dtype.dtype != torch.long and dtype.dtype != torch.int:
                    if dtype.dim() > 2:
                        dtype = dtype[..., 0].long()
                    else:
                        dtype = dtype.long()
                
                # Get dtype embeddings and ensure correct shape
                dtype_emb = self.dtype_embedding(dtype)
                if dtype_emb.shape[1] != column_number:
                    dtype_emb = dtype_emb.view(batch_size, column_number, -1)
                
                fuse_inputs["x_dt"] = dtype_emb
            
            if "all" in self.fusion_strategy:
                # Prepare meta embedding with correct dimensions
                meta_emb = meta.unsqueeze(1).expand(-1, column_number, -1)
                fuse_inputs["x_dc"] = meta_emb
            
            # Use the FuseLayer with the appropriate inputs
            reshaped_tensor = self.fuse_layer(**fuse_inputs)
            reshape_attention_mask = attention_mask[:, ::2]
            
        elif fuse_option == "No":
            return input_tensor, attention_mask
        else:
            raise NotImplementedError(f"Option {fuse_option} for fusing column name/value embeddings is not supported!!")
        
        return reshaped_tensor, reshape_attention_mask
        
    def forward(self, input_tensor, attention_mask, dtype, meta):
        """
        input_tensor: (B, num_cols*2, D)
        attention_mask: (B, num_cols*2)
        dtype: (B, num_cols)
        meta: (B, D)
        """
        input_tensor,attention_mask = self.fuse_name_val_embedding(input_tensor,attention_mask,dtype, meta,self.fuse_option)
        #print(type(input_tensor),input_tensor)
        #print(input_tensor.shape, attention_mask.shape)
        return self.perceiver_encoder(input_tensor, mask=attention_mask.bool())

class DisentangledPerceiveAggregator(PerceiveAggregator):
    def __init__(self, 
                 *,
                 dim, 
                 dim_latent,
                 depth,
                 dim_head=64,
                 num_latents=16,
                 max_seq_len=64,
                 ff_mult=4,
                 legacy=False,
                 l2_normalize_latents=False,
                 fuse_option="flatten",
                 fusion_strategy="name_value",
                 **kwargs) -> None:
        
        super().__init__(
            dim=dim,
            dim_latent=dim_latent,
            depth=depth,
            dim_head=dim_head,
            num_latents=num_latents,
            max_seq_len=max_seq_len,
            ff_mult=ff_mult,
            legacy=legacy,
            l2_normalize_latents=l2_normalize_latents,
            fuse_option=fuse_option,
            fusion_strategy=fusion_strategy,
            **kwargs
        )
        
        # Create projection heads for shared and private representations
        # The output dimension matches the perceiver encoder output
        self.shared_head = nn.Linear(dim_latent, dim_latent)
        self.private_head = nn.Linear(dim_latent, dim_latent)
        
    def forward(self, input_tensor, attention_mask, dtype, meta):
        """
        Forward pass with disentangled representations.
        
        Args:
            input_tensor: (B, num_cols*2, D)
            attention_mask: (B, num_cols*2)
            dtype: (B, num_cols)
            meta: (B, D)
            
        Returns:
            tuple: (shared_latent, private_latent) both with shape (B, num_latents, dim_latent)
        """
        # Process input through the parent class's fuse_name_val_embedding method
        input_tensor, attention_mask = self.fuse_name_val_embedding(
            input_tensor, attention_mask, dtype, meta, self.fuse_option
        )
        
        # Get the base output from the perceiver encoder
        base_output = self.perceiver_encoder(input_tensor, mask=attention_mask.bool())
        
        # Project the base output into shared and private latent spaces
        shared_latent = self.shared_head(base_output)
        private_latent = self.private_head(base_output)
        
        # Return both latent representations as a tuple
        return (shared_latent, private_latent)
    
class TransformerRowDecoder(nn.Module):
    def __init__(self, aggregated_dim=128, output_dim=128, depth=4, lm_emb=768, num_heads=8, **kwargs):
        """
        Initialize the TransformerRowDecoder.

        Args:
        - aggregated_dim (int): Dimension of the aggregated embeddings.
        - output_dim (int): Output dimension of the decoder.
        - depth (int): Depth of the Transformer model.
        - lm_emb (int): Embedding size from the LM layer.
        - num_heads (int): Number of attention heads.
        """
        super(TransformerRowDecoder, self).__init__()
        self.output_dim = output_dim
        self.aggregated_dim = aggregated_dim
        self.lm_emb = lm_emb

        # Linear layers to project column name embeddings and metadata to aggregated_dim
        self.column_proj = nn.Linear(lm_emb, aggregated_dim)
        self.metadata_proj = nn.Linear(lm_emb, aggregated_dim)

        # Transformer encoder for metadata and row latent embeddings
        encoder_layer = nn.TransformerEncoderLayer(d_model=aggregated_dim, nhead=num_heads, batch_first=True)
        self.metadata_row_transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # Cross-attention module
        self.cross_attention = nn.MultiheadAttention(embed_dim=aggregated_dim, num_heads=num_heads, batch_first=True)

        # Final projection layer to map back to lm_emb
        self.final_proj = nn.Linear(aggregated_dim, output_dim)

    def forward(self, column_names_emb, metadata_emb, row_latent_emb):
        """
        Forward pass of TransformerRowDecoder.

        Args:
        - column_names_emb: Embedding of column names (B, num_cols, lm_emb).
        - metadata_emb: Embedding of table context metadata (B, 1, lm_emb).
        - row_latent_emb: Latent embedding of row information (B, num_latent, aggregated_dim).

        Returns:
        - Decoded embeddings for each column value (B, num_cols, lm_emb).
        """

        #B, num_cols, _ = column_names_emb.shape

        # Project column names and metadata embeddings to aggregated_dim
        column_names_emb_proj = self.column_proj(column_names_emb)  # (B, num_cols, aggregated_dim)
        metadata_emb_proj = self.metadata_proj(metadata_emb)        # (B, 1, aggregated_dim)

        # Concatenate metadata and row_latent embeddings
        metadata_row_emb = torch.cat([metadata_emb_proj, row_latent_emb], dim=1)  # (B, 1 + num_latent, aggregated_dim)

        # Process metadata_row_emb via transformer encoder
        metadata_row_emb = self.metadata_row_transformer(metadata_row_emb)  # (B, 1 + num_latent, aggregated_dim)

        # Apply cross-attention from column_names_emb_proj (queries) to metadata_row_emb (keys/values)
        attn_output, attn_weights = self.cross_attention(
            query=column_names_emb_proj,  # (B, num_cols, aggregated_dim)
            key=metadata_row_emb,         # (B, 1 + num_latent, aggregated_dim)
            value=metadata_row_emb        # (B, 1 + num_latent, aggregated_dim)
        )
        # attn_output shape: (B, num_cols, aggregated_dim)

        # Project the outputs back to lm_emb
        decoded_column_values = self.final_proj(attn_output)  # (B, num_cols, lm_emb)

        # Return only the decoded column values
        return decoded_column_values
    
    
class TransformerCondDecoder(nn.Module):
    def __init__(self, lm_emb=768, aggregated_dim=128, depth=4, output_dim=128,**kwargs):
        """
        Initialize the TransformerCondDecoder.

        Args:
        - lm_emb (int): Embedding size from the TransformerRowDecoder output.
        - aggregated_dim (int): Dimension of the aggregated embeddings.
        - depth (int): Number of layers in TransformerRowDecoder.
        - num_decoder_dim (int): Dimension for the numerical column decoder (default 1).
        """
        super(TransformerCondDecoder, self).__init__()

        # Instantiate TransformerRowDecoder
        self.transformer_row_decoder = TransformerRowDecoder(
            aggregated_dim=aggregated_dim,
            output_dim=output_dim,
            depth=depth,
            lm_emb=lm_emb
        )

    def forward(self, column_names_emb, metadata_emb, row_latent_emb, dtype_tensor=None, attention_mask=None):
        """
        Forward pass through TransformerRowDecoder.

        Args:
        - column_names_emb (torch.Tensor): Embedding of column names (B, num_cols, lm_emb).
        - metadata_emb (torch.Tensor): Embedding of table context metadata (B, 1, lm_emb).
        - row_latent_emb (torch.Tensor): Latent embedding of row information (B, num_latent, aggregated_dim).
        - dtype_tensor (torch.Tensor): Tensor indicating the data type of each column (0 for categorical, 1 for numerical).

        Returns:
        - decoder_output (torch.tensor):  Decoded embeddings for each column value (B, num_cols, lm_emb).
        """
        # Pass input through TransformerRowDecoder
        decoder_output = self.transformer_row_decoder(column_names_emb, metadata_emb, row_latent_emb)

        return decoder_output
    
    def encode_embs(self, column_names_emb, metadata_emb, concat=True):
        # Project column names and metadata embeddings to aggregated_dim
        column_names_emb_proj = self.transformer_row_decoder.column_proj(column_names_emb)  # (B, num_cols, aggregated_dim)
        metadata_emb_proj = self.transformer_row_decoder.metadata_proj(metadata_emb)        # (B, 1, aggregated_dim)
        
        if concat:
            return torch.concat([metadata_emb_proj,column_names_emb_proj],dim=1)
        else:
            return metadata_emb_proj,column_names_emb_proj
    
class LatentAutoEncoder(nn.Module):
    def __init__(self, encoder_params, decoder_params):
        super().__init__()
        d_lm = encoder_params.get('dim', 1024)
        d_latent_len = encoder_params.get('num_latents', 16)
        d_latent_width = encoder_params.get('dim_latent', 64)
        self.d_lm = d_lm
        self.d_latent_len = d_latent_len
        self.d_latent_width = d_latent_width

        # Create two encoders with identical structure for mu and logvar
        self.encoder_mu = PerceiveAggregator(**encoder_params)
        self.encoder_logvar = PerceiveAggregator(**encoder_params)
        self.decoder = TransformerCondDecoder(**decoder_params)

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from N(0,1).
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def _prepare_encoder_input(self, 
                                value_embedding, 
                                column_name_embedding, 
                                dtype_embedding, 
                                metadata_embedding, 
                                attention_mask=None):
        """
            value_embedding: (B, num_cols, d_lm)
            column_name_embedding: (B, num_cols, d_lm)
            dtype_embedding: (B, num_cols, d_lm)
            metadata_embedding: (B, d_lm)
            attention_mask: (B, num_cols)

            Returns:
            interleaved_tensor: (B, num_cols * 2, d_lm)
            interleaved_mask: (B, num_cols * 2)
            dtype_embedding: (B, num_cols, d_lm)
            metadata_embedding: (B, d_lm)
        """

        if column_name_embedding.dim() == 2:
            column_name_embedding = column_name_embedding.unsqueeze(0).repeat(value_embedding.size(0), 1, 1)

        interleaved_tensor = torch.stack([column_name_embedding, value_embedding], dim=2)
        interleaved_tensor = interleaved_tensor.view(
            interleaved_tensor.size(0), 
            interleaved_tensor.size(1) * 2, 
            interleaved_tensor.size(3)
        )

        if attention_mask is not None:
            interleaved_mask = attention_mask.repeat_interleave(2, dim=1)
        else:
            interleaved_mask = torch.ones_like(interleaved_tensor[:, :, 0], dtype=torch.bool)

        return interleaved_tensor, interleaved_mask, dtype_embedding, metadata_embedding
    
    def encode(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None, deterministic=False, dist=None):
        """
        Encode inputs to latent distribution parameters and sample latent code.
            value_embedding: (B, num_cols, d_lm)
            column_name_embedding: (B, num_cols, d_lm)
            dtype_embedding: (B, num_cols, d_lm)
            metadata_embedding: (B, d_lm)
            attention_mask: (B, num_cols)
            deterministic: bool, if True uses only mu for inference, if False uses reparameterization trick (default: False)
            dist: (B, num_cols, d_lm) Distribution-based embeddings for columns
        """
        prepared_input, prepared_mask, prepared_dtype, prepared_meta = self._prepare_encoder_input(
            value_embedding, 
            column_name_embedding=column_name_embedding, 
            dtype_embedding=dtype_embedding, 
            metadata_embedding=metadata_embedding, 
            attention_mask=attention_mask
        )

        # Get distribution parameters
        mu = self.encoder_mu(prepared_input, prepared_mask, prepared_dtype, prepared_meta)
        
        if deterministic:
            return mu, mu, None  # Return mu as z, with None for logvar since it's not used
        else:
            logvar = self.encoder_logvar(prepared_input, prepared_mask, prepared_dtype, prepared_meta)
            z = self.reparameterize(mu, logvar)
            return z, mu, logvar
    
    def _prepare_decoder_input(self, compressed_embedding, column_name_embedding, dtype_embedding, metadata_embedding):
        # Prepare inputs for TransformerCondDecoder
        batch_size = compressed_embedding.size(0)

        # Reshape metadata_embedding to match expected input shape (B, 1, d_lm)
        if len(metadata_embedding.shape) < 3:
            metadata_embedding = metadata_embedding.unsqueeze(1)

        return compressed_embedding, column_name_embedding, metadata_embedding

    def decode(self, latent_code, column_name_embedding, dtype_embedding, metadata_embedding, dist=None):
        """
        Decode latent code to column values.
        latent_code: (B, num_latents, d_latent_width)
        column_name_embedding: (B, num_cols, d_lm)
        dtype_embedding: (B, num_cols, d_lm)
        metadata_embedding: (B, d_lm)
        dist: (B, num_cols, d_lm) Distribution-based embeddings for columns (optional)
        """
        prepared_latent, prepared_column_names, prepared_metadata = self._prepare_decoder_input(
            latent_code, column_name_embedding, dtype_embedding, metadata_embedding
        )
        return self.decoder(prepared_column_names, prepared_metadata, prepared_latent, dtype_embedding)

    def forward(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None, dist=None):
        latent_code, _, _ = self.encode(value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask, dist=dist)
        return self.decode(latent_code, column_name_embedding, dtype_embedding, metadata_embedding, dist=dist)
    
class MultiModalLatentAutoEncoder(LatentAutoEncoder):
    def __init__(self, encoder_params, decoder_params, num_modalities, combination_method='mopoe'):
        super().__init__(encoder_params, decoder_params)
        
        # Create separate encoders for each modality
        self.encoder_mus = nn.ModuleList([
            PerceiveAggregator(**encoder_params) 
            for _ in range(num_modalities)
        ])
        self.encoder_logvars = nn.ModuleList([
            PerceiveAggregator(**encoder_params)
            for _ in range(num_modalities)
        ])
        
        # Create separate decoders for each modality
        self.decoders = nn.ModuleList([
            TransformerCondDecoder(**decoder_params)
            for _ in range(num_modalities)
        ])
        
        self.num_modalities = num_modalities
        self.combination_method = combination_method
        
        # Add trainable gating network for modality weighting
        latent_dim = encoder_params['dim_latent'] * num_modalities
        self.moe_gate = SchemaMoEGate(
            latent_dim=latent_dim,
            schema_dim=encoder_params['dim_latent'] ,  # aggregate schema is same as latent dimension
            n_modal=num_modalities
        )

        # ------------------------------------------------------------------ #
        # Schema encoder components
        # ------------------------------------------------------------------ #
        d_lm = encoder_params['dim']
        schema_encoder_params = encoder_params
        schema_encoder_params['num_layers'] = 2
        schema_encoder_params['num_heads'] = 4
        # Fuse layer to combine name, dtype, and distribution embeddings
        self.schema_fuse_layer = FuseLayer(d_lm, fusion_strategy="dtype_name_value")

        # Create a schema encoder (PerceiveAggregator) to encode the fused schema
        schema_encoder_params = {
            **schema_encoder_params,
            'fuse_option': 'No'   # We already fuse outside, so no internal name/value fusion
        }
        self.schema_encoder = PerceiveAggregator(**schema_encoder_params)

    def _split_by_modality(self, tensor, dtype_embedding, attention_mask=None):
        """Split input tensors by modality based on dtype - optimized version"""
        # Pre-allocate lists
        split_tensors = []
        split_masks = []
        
        # Create a one-hot encoding of the dtype_embedding for all modalities at once
        # Shape: (B, num_cols, num_modalities)
        dtype_one_hot = F.one_hot(dtype_embedding[..., 0].long(), num_classes=self.num_modalities)
        
        # Process all modalities in a vectorized way
        for i in range(self.num_modalities):
            # Extract mask for this modality (B, num_cols)
            modality_mask = dtype_one_hot[..., i].bool()
            
            # Expand mask to match interleaved tensor structure (B, num_cols*2)
            expanded_mask = modality_mask.repeat_interleave(2, dim=1).unsqueeze(-1)
            
            # Apply mask to input tensor - more efficient than multiplication
            masked_tensor = tensor.clone()
            masked_tensor = masked_tensor * expanded_mask
            
            # Create attention mask for this modality
            valid_cols = expanded_mask[..., 0]
            if attention_mask is not None:
                valid_cols = valid_cols & attention_mask
                
            split_tensors.append(masked_tensor)
            split_masks.append(valid_cols)
        
        return split_tensors, split_masks

    def _combine_distributions(self, mus, logvars, method='mopoe', schema_vec=None):
        """
        Combine multiple gaussian distributions using either PoE, MoE, MoPoE or Schema-Aware MoPoE
        
        Args:
            mus: List of mean vectors for each modality, None for missing modalities
            logvars: List of log variance vectors for each modality, None for missing modalities
            method: 'poe' for Product of Experts, 'moe' for Mixture of Experts,
                    'mopoe' for Mixture of Products of Experts, or 'samopoe' for Schema-Aware MoPoE
            schema_vec: Schema vector for 'samopoe' method (required if method='samopoe')
        """
        # Filter out None values
        valid_mus = [mu for mu in mus if mu is not None]
        valid_logvars = [logvar for logvar in logvars if logvar is not None]
        
        if not valid_mus:  # No valid modalities
            raise ValueError("No valid modalities found for distribution combination")
            
        if method == 'poe':
            # Product of Experts approach
            # Initialize with zeros
            combined_precision = torch.zeros_like(valid_mus[0])
            combined_mu_precision = torch.zeros_like(valid_mus[0])
            
            # Combine distributions for each modality
            for mu, logvar in zip(valid_mus, valid_logvars):
                precision = torch.exp(-logvar)
                combined_precision += precision
                combined_mu_precision += precision * mu
                
            # Calculate combined parameters
            combined_var = 1.0 / (combined_precision + 1e-8)
            combined_mu = combined_var * combined_mu_precision
            combined_logvar = torch.log(combined_var)
            
        elif method == 'moe':
            # Mixture of Experts approach with trainable gate
            # Build context vector by concatenating means
            batch_shape = valid_mus[0].shape[:-1]  # Get batch dimensions
            latent_dim = valid_mus[0].shape[-1]
            device = valid_mus[0].device
            
            # Create a tensor to hold concatenated means, padded with zeros for missing modalities
            h_context = torch.zeros(*batch_shape, self.num_modalities * latent_dim, device=device)
            
            # Keep track of which modalities are present
            mask = torch.zeros(self.num_modalities, dtype=torch.bool, device=device)
            
            # Fill in the context vector with valid modality means
            valid_idx = 0
            for i, mu in enumerate(mus):
                if mu is not None:
                    # Fill in this modality's section in the context vector
                    h_context[..., i*latent_dim:(i+1)*latent_dim] = mu
                    mask[i] = True
                    valid_idx += 1
            
            # Compute gate logits
            gate_logits = self.moe_gate(h_context)  # (B, n_modal)
            
            # Zero-out logits of missing modalities
            gate_logits = gate_logits.masked_fill(~mask.unsqueeze(0), -1e4)
            
            # Add precision heuristic as additional information
            precision_logits = torch.stack([-lv.mean(-1) for lv in valid_logvars], dim=-1)  # (B, n_valid)
            
            # Map precision_logits to the correct positions based on the valid modalities
            precision_extended = torch.zeros_like(gate_logits)
            valid_idx = 0
            for i in range(self.num_modalities):
                if mask[i]:
                    precision_extended[..., i] = precision_logits[..., valid_idx]
                    valid_idx += 1
            
            # Combine gate logits with precision logits
            gate_logits = gate_logits + 0.5 * precision_extended
            
            # Apply softmax to get final weights
            weights = torch.softmax(gate_logits, dim=-1)
            
            # Weighted combination of means
            combined_mu = torch.zeros_like(valid_mus[0])
            valid_idx = 0
            for i in range(self.num_modalities):
                if mask[i]:
                    w = weights[..., i].unsqueeze(-1)
                    combined_mu += w * valid_mus[valid_idx]
                    valid_idx += 1
            
            # For variance, account for both individual variances and distances from combined mean
            combined_var = torch.zeros_like(valid_mus[0])
            valid_idx = 0
            for i in range(self.num_modalities):
                if mask[i]:
                    w = weights[..., i].unsqueeze(-1)
                    mu = valid_mus[valid_idx]
                    logvar = valid_logvars[valid_idx]
                    var = torch.exp(logvar)
                    
                    # Add weighted individual variance
                    combined_var += w * var
                    # Add weighted squared distance from combined mean
                    combined_var += w * (mu - combined_mu).pow(2)
                    valid_idx += 1
                    
            combined_logvar = torch.log(combined_var)
        
        elif method in ['mopoe', 'samopoe']:
            # Mixture of Products of Experts approach with trainable gate
            if schema_vec is None and method == 'samopoe':
                raise ValueError("schema_vec must be provided for 'samopoe' method")
            
            # Step 1: Generate all non-empty subsets of modalities
            n = len(valid_mus)
            subsets = []
            for k in range(1, n + 1):  # Start from 1 to exclude empty set
                # Generate all combinations of k elements
                for subset_indices in combinations(range(n), k):
                    subset_mus = [valid_mus[i] for i in subset_indices]
                    subset_logvars = [valid_logvars[i] for i in subset_indices]
                    subsets.append((subset_mus, subset_logvars, subset_indices))
            
            # Step 2: For each subset, compute the PoE
            subset_mus = []
            subset_logvars = []
            subset_contexts = []
            subset_masks = []
            
            batch_shape = valid_mus[0].shape[:-1]  # Get batch dimensions
            latent_dim = valid_mus[0].shape[-1]
            device = valid_mus[0].device
            
            for sub_mus, sub_logvars, indices in subsets:
                # Compute PoE for this subset
                sub_precision = torch.zeros_like(valid_mus[0])
                sub_mu_precision = torch.zeros_like(valid_mus[0])
                
                for mu, logvar in zip(sub_mus, sub_logvars):
                    precision = torch.exp(-logvar)
                    sub_precision += precision
                    sub_mu_precision += precision * mu
                    
                # Calculate parameters for this subset
                sub_var = 1.0 / (sub_precision + 1e-8)
                sub_mu = sub_var * sub_mu_precision
                sub_logvar = torch.log(sub_var)
                
                # Create context vector for this subset
                sub_context = torch.zeros(*batch_shape, self.num_modalities * latent_dim, device=device)
                sub_mask = torch.zeros(self.num_modalities, dtype=torch.bool, device=device)
                
                # Fill context with modality means from this subset
                for idx in indices:
                    # Map from subset index to original modality index
                    orig_idx = idx
                    # Put the mean into the appropriate slot
                    #print("Debugging combine dist")
                    #print(sub_context.shape, valid_mus[idx].shape, orig_idx, n, len(valid_mus))
                    sub_context[..., orig_idx*latent_dim:(orig_idx+1)*latent_dim] = valid_mus[idx]
                    sub_mask[orig_idx] = True
                
                subset_mus.append(sub_mu)
                subset_logvars.append(sub_logvar)
                subset_contexts.append(sub_context)
                subset_masks.append(sub_mask)
            
            # Step 3: Form a mixture of these PoEs using the gate
            subset_weights = []
            
            for i, (context, mask) in enumerate(zip(subset_contexts, subset_masks)):
                # Get gate logits for this subset
                
                gate_logits = self.moe_gate(context, schema_vec)
                
                # Zero-out logits of missing modalities
                gate_logits = gate_logits.masked_fill(~mask.unsqueeze(0), -1e4)
                
                # Get the weight for this subset by averaging over the valid modalities
                subset_weight = gate_logits.masked_fill(~mask.unsqueeze(0), 0).sum(dim=-1) / mask.sum()
                subset_weights.append(subset_weight.unsqueeze(-1))
            
            # Stack all weights
            subset_weights = torch.cat(subset_weights, dim=-1)
            
            # Apply softmax to normalize weights
            subset_weights = F.softmax(subset_weights, dim=-1)
            
            # Compute the weighted combination
            combined_mu = torch.zeros_like(subset_mus[0])
            for i, mu in enumerate(subset_mus):
                w = subset_weights[..., i].unsqueeze(-1)
                combined_mu += w * mu
            
            combined_var = torch.zeros_like(subset_mus[0])
            for i, (mu, logvar) in enumerate(zip(subset_mus, subset_logvars)):
                w = subset_weights[..., i].unsqueeze(-1)
                var = torch.exp(logvar)
                # Add weighted individual variance
                combined_var += w * var
                # Add weighted squared distance from combined mean
                combined_var += w * (mu - combined_mu).pow(2)
                
            combined_logvar = torch.log(combined_var)
        
        else:
            raise ValueError(f"Unknown combination method: {method}")
            
        return combined_mu, combined_logvar

    def encode(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None, deterministic=False, dist=None):
        """
        Encode inputs using modality-specific encoders - optimized version
        Args:
            value_embedding: Input value embeddings
            column_name_embedding: Column name embeddings
            dtype_embedding: Data type embeddings
            metadata_embedding: Metadata embeddings
            attention_mask: Attention mask
            deterministic: bool, if True uses only mu for inference, if False uses reparameterization trick (default: False)
            dist: Distribution-based embeddings for columns (optional)
        """
        # Create schema vector for gating
        if self.combination_method == 'samopoe':
            schema_vec = self._create_schema_vector(metadata_embedding, column_name_embedding, dtype_embedding, dist)
        else:
            schema_vec = None
        
        # Prepare input for each modality
        prepared_input, prepared_mask, prepared_dtype, prepared_meta = self._prepare_encoder_input(
            value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask
        )
        
        # Split inputs by modality
        split_inputs, split_masks = self._split_by_modality(prepared_input, dtype_embedding, prepared_mask)
        
        # Pre-allocate tensors
        mus = [None] * self.num_modalities
        logvars = [None] * self.num_modalities if not deterministic else None
        
        # Process modalities that are present
        present_modalities = []
        for i in range(self.num_modalities):
            if split_masks[i].any():
                present_modalities.append(i)
        
        # Process encoders in parallel if possible
        for i in present_modalities:
            # Always compute mu
            mus[i] = self.encoder_mus[i](split_inputs[i], split_masks[i], prepared_dtype, prepared_meta)
            # Compute logvar only if not in deterministic mode
            if not deterministic:
                logvars[i] = self.encoder_logvars[i](split_inputs[i], split_masks[i], prepared_dtype, prepared_meta)
        
        # Additional args for distribution combination
        combine_kwargs = {}
        if self.combination_method == 'samopoe':
            combine_kwargs['schema_vec'] = schema_vec
        
        # Now combine modality-specific private latents with the combined shared latent
        if deterministic:
            # In deterministic mode, just combine all mus using the specified method
            mu, _ = self._combine_distributions(
                mus, 
                [None] * (self.num_modalities + 1), 
                method=self.combination_method, 
                **combine_kwargs
            )
            return mu, mu, None
        else:
            # In stochastic mode, combine both mu and logvar and sample using the specified method
            mu, logvar = self._combine_distributions(
                mus, 
                logvars, 
                method=self.combination_method, 
                **combine_kwargs
            )
            z = self.reparameterize(mu, logvar)
            return z, mu, logvar

    def decode(self, latent_code, column_name_embedding, dtype_embedding, metadata_embedding, dist=None):
        """Decode latent code - optimized version"""
        # Prepare decoder inputs
        prepared_latent, prepared_column_names, prepared_metadata = self._prepare_decoder_input(
            latent_code, column_name_embedding, dtype_embedding, metadata_embedding
        )
        
        # Run each modality decoder
        modality_outputs = {}
        
        # Only run decoders for modalities that are present
        present_modalities = []
        for i in range(self.num_modalities):
            if (dtype_embedding[..., 0] == i).any():
                present_modalities.append(i)
            
        # Run decoders for present modalities
        for i in present_modalities:
            modality_outputs[i] = self.decoders[i](
                prepared_column_names, prepared_metadata, prepared_latent, dtype_embedding
            )
        
        # Vectorized output combination
        batch_size, num_cols, emb_dim = column_name_embedding.shape
        combined_output = torch.zeros_like(column_name_embedding, device=column_name_embedding.device)
        
        # Create a one-hot encoding of the dtype_embedding
        dtype_indices = dtype_embedding[..., 0].long()  # Shape: (B, num_cols)
        
        # For each present modality, use masked_scatter to place values
        for mod_idx in present_modalities:
            # Create mask for this modality
            mod_mask = (dtype_indices == mod_idx)  # Shape: (B, num_cols)
            
            # Expand mask for broadcasting
            expanded_mask = mod_mask.unsqueeze(-1).expand(-1, -1, emb_dim)
            
            # Use masked_scatter to update only relevant positions
            combined_output.masked_scatter_(expanded_mask, modality_outputs[mod_idx][expanded_mask])
        
        return combined_output

    def _create_schema_vector(self, metadata_embedding, column_name_embedding, dtype_embedding, dist=None):
        """
        Create a schema vector for conditioning the gating network using a dedicated
        Perceiver-based schema encoder.
        
        Steps:
            1. Fuse column-level information (name, dtype, distribution) using `schema_fuse_layer`.
            2. Concatenate the global metadata embedding as an additional token.
            3. Pass the concatenated sequence through `schema_encoder` to obtain a latent
               representation.
            4. Aggregate (mean-pool) the latent representation across sequence dimension to
               obtain a fixed-size schema vector.
        Args:
            metadata_embedding (torch.Tensor): (B, d_lm)
            column_name_embedding (torch.Tensor): (B, C, d_lm)
            dtype_embedding (torch.Tensor): (B, C, d_lm)
            dist (torch.Tensor): (B, C, d_lm) distribution-based embeddings
        Returns:
            torch.Tensor: (B, dim_latent) schema vector for gating
        """
        B, C, d_lm = column_name_embedding.shape

        # 1. Fuse column-level embeddings
        fused_columns = self.schema_fuse_layer(
            x_cn=column_name_embedding,
            x_dt=dtype_embedding,
            x_cv=dist
        )  # -> (B, C, d_lm)

        # 2. Concatenate metadata as an additional token
        meta_token = metadata_embedding.unsqueeze(1)  # (B, 1, d_lm)
        schema_sequence = torch.cat([meta_token, fused_columns], dim=1)  # (B, C+1, d_lm)

        # 3. Prepare inputs for schema encoder (duplicate each token to match name/value expectation)
        interleaved = torch.repeat_interleave(schema_sequence, repeats=2, dim=1)  # (B, 2*(C+1), d_lm)
        attn_mask = torch.ones(B, interleaved.shape[1], dtype=torch.bool, device=interleaved.device)

        # Dummy dtype tensor (unused because fuse_option='No')
        dummy_dtype = torch.zeros(B, schema_sequence.shape[1], dtype=torch.long, device=interleaved.device)

        # Pass through schema encoder to obtain latent representation
        schema_vec = self.schema_encoder(interleaved, attn_mask, dummy_dtype, metadata_embedding)  # (B, num_latents, dim_latent)

        # 4. Aggregate latent representation (mean pooling)
        #schema_vec = schema_latent.mean(dim=1)  # (B, dim_latent)

        return schema_vec

class SimpleAutoEncoder(LatentAutoEncoder):
    def __init__(self, encoder_params, decoder_params):
        super().__init__(encoder_params, decoder_params)
        # Only need one encoder for a simple autoencoder
        self.encoder = PerceiveAggregator(**encoder_params)
        # We can delete the logvar encoder since we won't use it
        delattr(self, 'encoder_logvar')

    def encode(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None, deterministic=False):
        """
        Encode inputs directly to latent representation without sampling.
        Returns the latent code and dummy values for mu and logvar to maintain API compatibility.
        
        Args:
            value_embedding: (B, num_cols, d_lm)
            column_name_embedding: (B, num_cols, d_lm)
            dtype_embedding: (B, num_cols, d_lm)
            metadata_embedding: (B, d_lm)
            attention_mask: (B, num_cols)
        
        Returns:
            z: The latent representation (B, num_latents, d_latent_width)
            mu: Same as z (for API compatibility)
            logvar: Zeros (for API compatibility, makes KL loss = 0)
        """
        prepared_input, prepared_mask, prepared_dtype, prepared_meta = self._prepare_encoder_input(
            value_embedding, 
            column_name_embedding=column_name_embedding, 
            dtype_embedding=dtype_embedding, 
            metadata_embedding=metadata_embedding, 
            attention_mask=attention_mask
        )

        # Get latent representation directly from encoder
        z = self.encoder(prepared_input, prepared_mask, prepared_dtype, prepared_meta)
        
        # For API compatibility, return z as mu and zeros as logvar
        # This ensures KL loss will be 0 since log(1) = 0 and std = 1
        mu = z
        logvar = torch.zeros_like(z)
        
        return z, mu, logvar

class DisentangledMultiModalLatentAutoEncoder(MultiModalLatentAutoEncoder):
    def __init__(self, encoder_params, decoder_params, num_modalities, combination_method='mopoe'):
        super().__init__(
            encoder_params=encoder_params, 
            decoder_params=decoder_params, 
            num_modalities=num_modalities,
            combination_method=combination_method
        )
        
        # Replace regular encoders with disentangled versions
        # Create separate disentangled encoders for each modality
        self.encoder_mus = nn.ModuleList([
            DisentangledPerceiveAggregator(**encoder_params) 
            for _ in range(num_modalities)
        ])
        self.encoder_logvars = nn.ModuleList([
            DisentangledPerceiveAggregator(**encoder_params) 
            for _ in range(num_modalities)
        ])

        # Add trainable gating network for modality weighting
        latent_dim = encoder_params['dim_latent'] * (num_modalities+1) # add 1 due to the "shared" modality
        self.moe_gate = SchemaMoEGate(
            latent_dim=latent_dim,
            schema_dim=encoder_params['dim_latent'] ,  # aggregate schema is same as latent dimension
            n_modal=(num_modalities+1)
        )

    def encode(self, value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask=None, deterministic=False, dist=None):
        """
        Encode inputs using disentangled modality-specific encoders
        
        Args:
            value_embedding: Input value embeddings
            column_name_embedding: Column name embeddings
            dtype_embedding: Data type embeddings
            metadata_embedding: Metadata embeddings
            attention_mask: Attention mask
            deterministic: bool, if True uses only mu for inference
            dist: Distribution-based embeddings for columns (optional)
        """
        # Prepare input for each modality - reuse parent method
        prepared_input, prepared_mask, prepared_dtype, prepared_meta = self._prepare_encoder_input(
            value_embedding, column_name_embedding, dtype_embedding, metadata_embedding, attention_mask
        )
        
        # Split inputs by modality - reuse parent method
        split_inputs, split_masks = self._split_by_modality(prepared_input, dtype_embedding, prepared_mask)
        
        # Pre-allocate tensors for shared and private latents
        shared_mus = [None] * self.num_modalities
        private_mus = [None] * self.num_modalities
        
        shared_logvars = [None] * self.num_modalities if not deterministic else None
        private_logvars = [None] * self.num_modalities if not deterministic else None
        
        # Process modalities that are present
        present_modalities = []
        for i in range(self.num_modalities):
            if split_masks[i].any():
                present_modalities.append(i)
        
        # Process encoders in parallel if possible
        for i in present_modalities:
            # Each encoder now returns (shared_latent, private_latent)
            
            # Always compute mu
            shared_mu, private_mu = self.encoder_mus[i](
                split_inputs[i], split_masks[i], prepared_dtype, prepared_meta
            )
            shared_mus[i] = shared_mu
            private_mus[i] = private_mu
            
            # Compute logvar only if not in deterministic mode
            if not deterministic:
                shared_logvar, private_logvar = self.encoder_logvars[i](
                    split_inputs[i], split_masks[i], prepared_dtype, prepared_meta
                )
                shared_logvars[i] = shared_logvar
                private_logvars[i] = private_logvar
        
        # Create schema vector for schema-aware gating
        if self.combination_method == 'samopoe':
            schema_vec = self._create_schema_vector(metadata_embedding, column_name_embedding, dtype_embedding, dist)
        else:
            schema_vec = None
        
        # Additional args for distribution combination
        combine_kwargs = {}
        if self.combination_method == 'samopoe':
            combine_kwargs['schema_vec'] = schema_vec
        
        # First combine the shared latents across modalities
        combined_shared_mu = None
        combined_shared_logvar = None
        
        if deterministic:
            # Just combine the shared mus
            combined_shared_mu, _ = self._combine_distributions(
                shared_mus, [None] * self.num_modalities, method='poe'
            )
        else:
            # Combine shared mus and logvars
            combined_shared_mu, combined_shared_logvar = self._combine_distributions(
                shared_mus, shared_logvars, method='poe'
            )
        
        # Now combine modality-specific private latents with the combined shared latent
        # First, add the combined shared latent as an additional "modality"
        all_mus = private_mus + [combined_shared_mu]
        
        if deterministic:
            # In deterministic mode, just combine all mus using the specified method
            mu, _ = self._combine_distributions(
                all_mus, 
                [None] * (self.num_modalities + 1), 
                method=self.combination_method,
                **combine_kwargs
            )
            return mu, mu, None
        else:
            # In stochastic mode, add the combined shared logvar and combine all
            all_logvars = private_logvars + [combined_shared_logvar]
            
            # Combine both mu and logvar and sample using the specified method
            mu, logvar = self._combine_distributions(
                all_mus, 
                all_logvars, 
                method=self.combination_method,
                **combine_kwargs
            )
            z = self.reparameterize(mu, logvar)
            return z, mu, logvar

class SchemaMoEGate(nn.Module):
    """Gating network that can optionally condition on a schema vector.

    If *schema_vec* is ``None`` the gate reverts to a vanilla MoE gate that
    only sees the latent context.  This makes the gate backward compatible
    with earlier usages where no schema embedding was provided (e.g. plain
    *moe* or *mopoe* modes).
    """

    def __init__(self, latent_dim, schema_dim, n_modal, hidden=128):
        super().__init__()
        self.mlp_latent = nn.Linear(latent_dim, hidden)
        self.mlp_schema = nn.Linear(schema_dim, hidden)
        self.out = nn.Linear(hidden, n_modal)

    def forward(self, latent_ctx, schema_vec=None):  # schema_vec can be None for vanilla MoE use-cases
        # Process latent context
        h_latent = self.mlp_latent(latent_ctx)  # Shape: (B, latent_len, hidden)
        
        if schema_vec is None:
            # Only use latent context (no schema conditioning)
            h = F.relu(h_latent)
        else:
            #print("Debugging scheme encode:")
            #print(latent_ctx.shape, schema_vec.shape)
            # Process schema and properly broadcast to match latent dimensions
            h_schema = self.mlp_schema(schema_vec)  # Shape: (B, hidden)
            
            # Add unsqueezed dimension to schema output for broadcasting
            # h_schema = h_schema.unsqueeze(1)  # Shape: (B, 1, hidden)
            
            # Now we can add them - broadcasting will handle the dimension expansion
            h = F.relu(h_latent + h_schema)  # Shape: (B, latent_len, hidden)
        
        return self.out(h)  # -> logits