import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Union, Tuple, Optional

from .base import ColumnVectorizer

class QuantileEmbeddingVectorizer(ColumnVectorizer):
    def __init__(self, output_dim, hidden_dims=[1024, 1024], use_simple_embedding=True, *args, **kwargs):
        """
        Initialize the QuantileEmbeddingVectorizer.

        Args:
            output_dim (int): The dimension of the output vectors.
            hidden_dims (list): A list of integers representing the sizes of hidden layers.
            use_simple_embedding (bool): If True, uses simple multiplication instead of neural network
        """
        super(QuantileEmbeddingVectorizer, self).__init__(output_dim=output_dim, accepted_dtype=[
            "int8", "int16", "int32", "int64",
            "uint8", "uint16", "uint32", "uint64",
            "float16", "float32", "float64"
        ])
        
        self.use_simple_embedding = use_simple_embedding
        
        if not use_simple_embedding:
            # Create the embedding layer dynamically based on hidden_dims
            layers = []
            input_dim = 1  # Starting input dimension
            
            for hidden_dim in hidden_dims:
                layers.append(nn.Linear(input_dim, hidden_dim))
                layers.append(nn.ReLU())
                input_dim = hidden_dim
            
            layers.append(nn.Linear(input_dim, self.output_dim))  # Final output layer
            self.embedding_layer = nn.Sequential(*layers)
        
        # Create the inverse layer dynamically based on hidden_dims
        inverse_layers = []
        input_dim = self.output_dim

        for hidden_dim in reversed(hidden_dims):
            inverse_layers.append(nn.Linear(input_dim, hidden_dim))
            inverse_layers.append(nn.ReLU())
            input_dim = hidden_dim
        
        inverse_layers.append(nn.Linear(input_dim, 1))
        inverse_layers.append(nn.Sigmoid())  # Add Sigmoid to ensure output is in [0,1]
        self.inverse_layer = nn.Sequential(*inverse_layers)

    def is_trainable(self):
        return True
    
    def to(self, device):
        self.device = device
        if not self.use_simple_embedding:
            self.embedding_layer = self.embedding_layer.to(device)
        self.inverse_layer = self.inverse_layer.to(device)
        return self

    def _vectorize(self, column_data, config):
        """
        Embed already-transformed numerical values into higher dimensional space.
        
        Args:
            column_data (pd.Series): Already quantile-transformed numerical values
            config (dict): Configuration dictionary
        """
        # Convert to tensor and reshape
        values = torch.tensor(column_data.values, dtype=torch.float32, device=self.device).reshape(-1, 1)
        
        if self.use_simple_embedding:
            # Create a vector of ones with shape [batch_size, output_dim]
            ones = torch.ones(values.shape[0], self.output_dim, device=self.device)
            # Multiply the values by the ones vector
            embeddings = values * ones
        else:
            embeddings = self.embedding_layer(values)
            
        return embeddings

    def _inverse_vectorize(self, embeddings, config, mode='inference'):
        """
        Convert embeddings back to normalized values.
        
        Args:
            embeddings (torch.Tensor): The embedded values
            config (dict): Configuration dictionary
            mode (str): 'inference' or 'train'

        Returns:
            Union[pd.Series, torch.Tensor]: Normalized values
        """
        # Decode embeddings to normalized values
        normalized_values = self.inverse_layer(embeddings)
        
        if mode == 'train':
            return normalized_values
        
        # Convert to pandas Series for inference mode
        return pd.Series(normalized_values.detach().cpu().numpy().flatten())

    def _compute_loss(self, predictions, targets, config):
        """
        Compute MSE loss between predicted and target normalized values.
        
        Args:
            predictions (torch.Tensor): Predicted normalized values
            targets (pd.Series): Target normalized values
            config (dict): Configuration dictionary
        """
        target_tensor = torch.tensor(targets.values, dtype=torch.float32, device=predictions.device).reshape(-1, 1)
        predictions = predictions.reshape(-1, 1)
        
        return nn.MSELoss()(predictions, target_tensor)

class NumericalVectorizer(ColumnVectorizer):
    def __init__(self, output_dim=1, accepted_dtype=None):
        """
        Initialize the numerical vectorizer.

        Args:
            output_dim (int): The dimension of the output vectors (default: 1).
            accepted_dtype (list or str, optional): The acceptable pandas data types for the column.
                If None, defaults to numerical types.
        """
        if accepted_dtype is None:
            accepted_dtype = ["int64", "float64", "int32", "float32", "int16", "float16"]
        
        super().__init__(output_dim=output_dim, accepted_dtype=accepted_dtype)
        
        # Simple linear layer for embedding if output_dim > 1
        if output_dim > 1:
            self.embedding = nn.Linear(1, output_dim)
        else:
            self.embedding = None

    def to(self, device):
        """Override to() to ensure model components are moved to the correct device"""
        super().to(device)
        if self.embedding is not None:
            self.embedding = self.embedding.to(device)
        return self

    def is_trainable(self):
        return self.embedding is not None

    def required_config_keys(self):
        """
        Define the required keys for the configuration.

        Returns:
            list: List of required keys.
        """
        # No special config needed for numerical vectorization
        return []

    def _vectorize(self, column, config):
        """
        Transform the numerical column into vectors.

        Args:
            column (pandas.Series): The numerical column to transform.
            config (dict): Configuration dictionary (not used for numerical).

        Returns:
            torch.Tensor: Transformed tensor of shape (N, D).
        """
        # Convert pandas Series to tensor
        tensor = torch.tensor(column.values, dtype=torch.float32).unsqueeze(1).to(self.device)
        
        # Apply embedding if needed
        if self.embedding is not None:
            return self.embedding(tensor)
        
        return tensor

    def _vectorize_batch(self, columns: List[pd.Series], configs: List[Dict]) -> List[torch.Tensor]:
        """
        Batch implementation for transforming multiple numerical columns into vectors.
        
        Args:
            columns (List[pandas.Series]): The columns to transform.
            configs (List[dict]): The configuration dictionaries (not used for numerical).
            
        Returns:
            List[torch.Tensor]: List of transformed tensors.
        """
        # Convert all columns to tensors and concatenate
        tensors = [torch.tensor(col.values, dtype=torch.float32).unsqueeze(1) for col in columns]
        
        # Get total length
        total_length = sum(tensor.shape[0] for tensor in tensors)
        
        # Create a concatenated tensor
        batch_tensor = torch.zeros((total_length, 1), dtype=torch.float32)
        
        # Fill the batch tensor
        start_idx = 0
        for tensor in tensors:
            end_idx = start_idx + tensor.shape[0]
            batch_tensor[start_idx:end_idx] = tensor
            start_idx = end_idx
        
        # Move to device
        batch_tensor = batch_tensor.to(self.device)
        
        # Apply embedding if needed
        if self.embedding is not None:
            embedded_batch = self.embedding(batch_tensor)
        else:
            embedded_batch = batch_tensor
        
        # Split back to individual tensors
        result_tensors = []
        start_idx = 0
        for tensor in tensors:
            end_idx = start_idx + tensor.shape[0]
            result_tensors.append(embedded_batch[start_idx:end_idx])
            start_idx = end_idx
        
        return result_tensors

    def _compute_loss(self, predictions, targets, config):
        """
        Compute MSE loss between predictions and targets.

        Args:
            predictions (torch.Tensor): Predicted values
            targets (pd.Series): Target values

        Returns:
            torch.Tensor: Computed MSE loss
        """
        # Convert targets to tensor
        target_tensor = torch.tensor(targets.values, dtype=torch.float32).to(self.device)
        
        # If predictions have multiple dimensions but targets don't, reshape
        if predictions.dim() > 1 and predictions.shape[1] > 1 and target_tensor.dim() == 1:
            # If using an embedding layer, we need to project back to scalar
            predictions = predictions.mean(dim=1)  # Simple approach - could use a dedicated projection
        
        # Compute MSE loss
        return nn.functional.mse_loss(predictions.squeeze(), target_tensor)

    def _compute_batch_loss(self, reconstructed_values_list: List[torch.Tensor], 
                           target_columns: List[pd.Series], 
                           configs: List[Dict]) -> torch.Tensor:
        """
        Compute the combined loss for multiple columns in batch.
        
        Args:
            reconstructed_values_list (List[torch.Tensor]): List of reconstructed values
            target_columns (List[pd.Series]): List of target columns
            configs (List[Dict]): List of configurations (not used for numerical)
            
        Returns:
            torch.Tensor: The combined loss
        """
        # Convert all targets to tensors and concatenate
        target_tensors = [torch.tensor(col.values, dtype=torch.float32) for col in target_columns]
        
        # Concatenate all predictions and targets
        all_predictions = torch.cat([pred.squeeze() for pred in reconstructed_values_list])
        all_targets = torch.cat(target_tensors).to(self.device)
        
        # Compute loss on entire batch
        return nn.functional.mse_loss(all_predictions, all_targets)

    def _inverse_vectorize(self, tensor, config, mode='inference'):
        """
        Inverse transform the vectors back to numerical values.

        Args:
            tensor (torch.Tensor): The tensor to inverse transform.
            config (dict): Configuration dictionary (not used for numerical).
            mode (str): 'inference' or 'train' mode.

        Returns:
            Union[pd.Series, torch.Tensor]: Original numerical values or tensor for training.
        """
        # If tensor has multiple dimensions (from embedding), project back to scalar
        if tensor.dim() > 1 and tensor.shape[1] > 1:
            # Simple approach - mean pooling
            values = tensor.mean(dim=1)
        else:
            values = tensor.squeeze()
        
        if mode == 'train':
            return values
        
        # Convert to pandas Series for inference mode
        return pd.Series(values.cpu().numpy())

    def _inverse_vectorize_batch(self, tensors: List[torch.Tensor], configs: List[Dict], mode: str = 'inference') -> List:
        """
        Batch implementation for inverse transformation of vectors back to scalar PLE values.
        
        Args:
            tensors (List[torch.Tensor]): The embedded tensors to inverse transform.
            configs (List[dict]): The configuration dictionaries (not used).
            mode (str): 'inference' or 'train'
            target_columns (List[pd.Series], optional): Target columns for training mode.
            
        Returns:
            List: For inference mode, list of reconstructed pd.Series.
                  For training mode, tuple of (decoder_outputs, loss).
        """
        if not tensors:
            return []
        
        result = []
        
        # Process tensors in batches to avoid memory issues with large datasets
        for batch_start in range(0, len(tensors), self.column_batch_size):
            batch_end = min(batch_start + self.column_batch_size, len(tensors))
            tensor_batch = tensors[batch_start:batch_end]
            config_batch = configs[batch_start:batch_end] if configs else [{}] * len(tensor_batch)
            
            # Save original shapes for later splitting
            batch_sizes = [tensor.shape[0] for tensor in tensor_batch]
            
            # Concatenate all tensors for batch processing
            all_embeddings = torch.cat(tensor_batch)
            
            # Run all embeddings through decoder at once
            all_decoder_output = self.decoder(all_embeddings)
            
            # Split outputs into bin probabilities and fractions
            all_bin_probs = all_decoder_output[:, :self.input_dim]
            all_fractions = all_decoder_output[:, self.input_dim:]
            
            if mode == 'train':
                # Split results back according to original tensor batches
                split_results = []
                start_idx = 0
                
                for size in batch_sizes:
                    end_idx = start_idx + size
                    # Get bin probs and fractions for this tensor
                    tensor_bin_probs = all_bin_probs[start_idx:end_idx]
                    tensor_fractions = all_fractions[start_idx:end_idx]
                    # Add tuple of (bin_probs, fractions) for this tensor to results
                    split_results.append((tensor_bin_probs, tensor_fractions))
                    start_idx = end_idx
                
                result.extend(split_results)
            else:
                # For inference mode, reconstruct scalar values
                all_results = torch.zeros(all_embeddings.shape[0], device=all_embeddings.device)
                
                for i in range(all_embeddings.shape[0]):
                    # Find the first bin that's not predicted as 1 (active bin)
                    mask = all_bin_probs[i] < 0.5
                    active_bins = torch.where(mask)[0]
                    
                    if len(active_bins) == 0:
                        # All bins predicted as 1, use the last bin with fraction=1
                        all_results[i] = self.input_dim - 1 + 0.999999  # Max value encoding
                    else:
                        active_bin = active_bins[0]
                        # Sum active bin index and predicted fraction
                        all_results[i] = active_bin + all_fractions[i, active_bin]
                
                # Split results back into original batch sizes
                start_idx = 0
                for size in batch_sizes:
                    end_idx = start_idx + size
                    # Convert tensor slice to pandas Series
                    result.append(pd.Series(all_results[start_idx:end_idx].detach().cpu().numpy()))
                    start_idx = end_idx
        
        return result

class PLEVectorizer(ColumnVectorizer):
    def __init__(self, output_dim, input_dim=None, seed=42, column_batch_size=32, **kwargs):
        """
        Initialize the PLE vectorizer that uses orthogonal projection.

        Args:
            output_dim (int): The dimension of the output vectors.
            input_dim (int, optional): The dimension of the input PLE encoded values.
                If None, it will default to 32 bins.
            seed (int): Random seed for initializing orthogonal projection matrix.
            column_batch_size (int): Maximum number of columns to process at once in _vectorize_batch.
            **kwargs: Additional keyword arguments passed to parent class.
        """
        # Pass kwargs to parent constructor for flexibility
        super().__init__(output_dim=output_dim, accepted_dtype=[
            "int8", "int16", "int32", "int64",
            "uint8", "uint16", "uint32", "uint64",
            "float16", "float32", "float64"
        ])  # PLE encoding is stored as list in object type
        
        # Set default input_dim if not provided
        self.input_dim = input_dim if input_dim is not None else 32
        self.seed = seed
        self.column_batch_size = column_batch_size
        
        # Cache for projection matrices
        self._projection_cache = {}
        
        # Initialize the projection matrix and decoder immediately
        self._initialize_projection_matrix(self.input_dim)
        
        # JIT compile the matrix multiplication function
        self._setup_jit_functions()
        
    def _setup_jit_functions(self):
        """Set up JIT-compiled functions for performance"""
        from torch import jit
        
        # Define JIT-compiled function for projection
        @jit.script
        def project_embeddings(tensor, projection_matrix):
            return torch.matmul(tensor, projection_matrix)
        
        self.project_embeddings = project_embeddings

    def to(self, device):
        """Override to() to ensure model components are moved to the correct device"""
        super().to(device)
        if self.projection_matrix is not None:
            self.projection_matrix = self.projection_matrix.to(device)
        
        if hasattr(self, 'decoder') and self.decoder is not None:
            self.decoder = self.decoder.to(device)
        return self

    def is_trainable(self):
        return True

    def _get_projection_matrix(self, input_dim):
        """
        Get cached projection matrix or create a new one.
        
        Args:
            input_dim (int): The dimension of the input PLE encoded values.
            
        Returns:
            torch.Tensor: The projection matrix.
        """
        if input_dim in self._projection_cache:
            return self._projection_cache[input_dim]
            
        # Create a new projection matrix and cache it
        self._initialize_projection_matrix(input_dim)
        self._projection_cache[input_dim] = self.projection_matrix
        return self.projection_matrix

    def _initialize_projection_matrix(self, input_dim):
        """
        Initialize an orthogonal projection matrix.
        
        Args:
            input_dim (int): The dimension of the input PLE encoded values.
        """
        torch.manual_seed(self.seed)
        
        # For an orthogonal projection from input_dim to output_dim dimensions:
        # 1. If input_dim >= output_dim, we can directly get an orthogonal matrix via QR decomposition
        # 2. If input_dim < output_dim, we need to pad with zeros to reach output_dim dimensions
        
        if input_dim >= self.output_dim:
            # Case 1: input_dim >= output_dim - standard orthogonal projection
            random_matrix = torch.randn(input_dim, input_dim)
            q, r = torch.linalg.qr(random_matrix)
            # Take only the first output_dim columns
            self.projection_matrix = q[:, :self.output_dim]
        else:
            # Case 2: input_dim < output_dim - need to pad
            # First create an orthogonal matrix for the input dimensions
            random_matrix = torch.randn(input_dim, input_dim)
            q, r = torch.linalg.qr(random_matrix)
            
            # Create a padded matrix to output_dim
            padded_matrix = torch.zeros(input_dim, self.output_dim)
            padded_matrix[:, :input_dim] = q
            
            # Add random orthogonal vectors for the remaining dimensions
            if self.output_dim > input_dim:
                # Create additional orthogonal vectors
                for i in range(input_dim, self.output_dim):
                    # Generate a random vector
                    v = torch.randn(input_dim)
                    
                    # Make it orthogonal to all previous vectors through Gram-Schmidt
                    for j in range(i):
                        v = v - torch.dot(v, padded_matrix[:, j]) * padded_matrix[:, j]
                    
                    # Normalize
                    norm = torch.norm(v)
                    if norm > 1e-6:  # Avoid division by near-zero
                        v = v / norm
                        padded_matrix[:, i] = v
            
            self.projection_matrix = padded_matrix
        
        # Simplified decoder architecture for better performance
        # The decoder outputs a tensor of size input_dim * 2, which we'll split into:
        # - First half: bin probabilities
        # - Second half: fractions
        self.decoder = nn.Sequential(
            nn.Linear(self.output_dim, input_dim * 2),
            nn.Sigmoid()  # Constrain output to [0,1] range - crucial for BCE loss
        )
        
        # Apply kaiming initialization for better convergence
        for m in self.decoder.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
        # Set but don't override input_dim in case it was already provided
        if not hasattr(self, 'input_dim') or self.input_dim is None:
            self.input_dim = input_dim
        
    def _scalar_to_ple_vector(self, scalar_values: torch.Tensor) -> torch.Tensor:
        """
        Convert scalar PLE encoding to a vector encoding (N, n_bins) without loops.
        """
        # Integer bin index and fractional part
        idx  = torch.floor(scalar_values).long()           # (N,)
        frac = scalar_values - idx                         # (N,)

        # Special-value masks
        is_min = torch.isclose(scalar_values,
                            torch.tensor(0.0,
                                            device=scalar_values.device))
        is_max = torch.isclose(scalar_values,
                            torch.tensor(self.input_dim - 1 + 0.999999,
                                            device=scalar_values.device))

        batch_size = scalar_values.size(0)
        ple_vectors = torch.zeros(batch_size, self.input_dim,
                                device=scalar_values.device)  # (N, n_bins)

        # 1) Max → all ones
        ple_vectors[is_max] = 1.0

        # 2) Regular values (neither min nor max)
        regular_mask = ~(is_min | is_max)
        if regular_mask.any():
            regular_indices = torch.where(regular_mask)[0]     # (R,)
            regular_idx     = idx[regular_mask]               # (R,)
            regular_frac    = frac[regular_mask]              # (R,)

            # Create a (R, n_bins) mask: 1.0 before the active bin, 0 elsewhere
            bin_range   = torch.arange(self.input_dim,
                                    device=scalar_values.device)  # (n_bins,)
            bins_before = (bin_range.unsqueeze(0) <
                        regular_idx.unsqueeze(1))                # (R, n_bins)

            # Set all bins before the active bin to 1
            ple_vectors[regular_indices] = bins_before.float()

            # Set the active bin itself to the fractional part
            ple_vectors[regular_indices, regular_idx] = regular_frac

        # 3) Min values are already zeros → nothing to do
        return ple_vectors


    def _vectorize(self, column, config):
        """
        Transform a column of scalar PLE encoded values into vectors using orthogonal projection.
        Imputes missing values with column mean.

        Args:
            column (pandas.Series): The column where each value is a scalar from PLE encoding.
            config (dict): Configuration dictionary (not used).

        Returns:
            torch.Tensor: Transformed tensor of shape (N, output_dim).
        """
        # Handle NaN values by imputing with mean
        has_nans = column.isna().any()
        if has_nans:
            # Calculate mean of non-NaN values
            mean_value = column.mean()
            # Create a copy to avoid modifying the original data
            column = column.fillna(mean_value)
        
        # Convert column to tensor
        scalar_values = torch.tensor(column.values, dtype=torch.float32).to(self.device)
        
        # Convert scalar PLE to vector PLE
        ple_tensor = self._scalar_to_ple_vector(scalar_values)
        
        # Project PLE encodings to output dimension using JIT-compiled function
        embedded = self.project_embeddings(ple_tensor, self.projection_matrix)
        
        return embedded

    def _vectorize_batch(self, columns: List[pd.Series], configs: List[Dict]) -> List[torch.Tensor]:
        """
        Batch implementation for transforming multiple columns of scalar PLE values into vectors.
        Imputes missing values with column means.
        
        Args:
            columns (List[pandas.Series]): The columns to transform, each containing scalar PLE values.
            configs (List[dict]): The configuration dictionaries (not used).
            
        Returns:
            List[torch.Tensor]: List of transformed tensors.
        """
        if not columns:
            return []
        
        # Handle NaN values in each column by imputing with mean
        imputed_columns = []
        for col in columns:
            if col.isna().any():
                mean_value = col.mean()
                imputed_columns.append(col.fillna(mean_value))
            else:
                imputed_columns.append(col)
        
        result_tensors = []
        
        # Process columns in batches to avoid memory issues with large datasets
        for batch_start in range(0, len(imputed_columns), self.column_batch_size):
            batch_end = min(batch_start + self.column_batch_size, len(imputed_columns))
            column_batch = imputed_columns[batch_start:batch_end]
            config_batch = configs[batch_start:batch_end] if configs else [{}] * len(column_batch)
            
            # Track batch sizes for later splitting
            batch_sizes = [len(col) for col in column_batch]
            
            # Concatenate all scalar values into a single tensor
            all_scalar_values = torch.cat([
                torch.tensor(col.values, dtype=torch.float32) 
                for col in column_batch
            ]).to(self.device)
            
            # Convert all scalar PLE values to vector PLE in a single operation
            all_ple_vectors = self._scalar_to_ple_vector(all_scalar_values)
            
            # Project all PLE vectors at once
            all_embedded = self.project_embeddings(all_ple_vectors, self.projection_matrix)
            
            # Use torch.split to efficiently split the tensor in one operation
            batch_result_tensors = torch.split(all_embedded, batch_sizes, dim=0)
            result_tensors.extend(list(batch_result_tensors))
        
        return result_tensors

    def _inverse_vectorize(self, embedded, config, mode='inference'):
        """
        Inverse transform embeddings back to scalar values.

        Args:
            embedded (torch.Tensor): The embedded tensor to inverse transform.
            config (dict): Configuration dictionary (not used).
            mode (str): 'inference' or 'train'.

        Returns:
            torch.Tensor or tuple: Reconstructed values or (bin_probs, fractions) in training mode.
        """
        # Use decoder to get combined output
        decoder_output = self.decoder(embedded)
        
        # Split the output into bin probabilities and fractions
        # First half of channels is bin probs, second half is fractions
        bin_probs = decoder_output[:, :self.input_dim]
        fractions = decoder_output[:, self.input_dim:]
        
        if mode == 'train':
            # During training, return the raw decoder outputs
            return (bin_probs, fractions)
        else:
            # In inference mode, reconstruct scalar PLE values
            batch_size = embedded.size(0)
            results = torch.zeros(batch_size, device=embedded.device)
            
            for i in range(batch_size):
                # Find the first bin that's not predicted as 1 (active bin)
                # Using a threshold to account for numerical imprecision
                mask = bin_probs[i] < 0.5
                active_bins = torch.where(mask)[0]
                
                if len(active_bins) == 0:
                    # All bins predicted as 1, use the last bin with fraction=1
                    results[i] = self.input_dim - 1 + 0.999999  # Max value encoding
                else:
                    active_bin = active_bins[0]
                    # Sum active bin index and predicted fraction
                    results[i] = active_bin + fractions[i, active_bin]
            
            return pd.Series(results.detach().cpu().numpy())

    def _inverse_vectorize_batch(self, tensors: List[torch.Tensor], configs: List[Dict], mode: str = 'inference') -> List:
        """
        Batch implementation for inverse transformation of vectors back to scalar PLE values.
        
        Args:
            tensors (List[torch.Tensor]): The embedded tensors to inverse transform.
            configs (List[dict]): The configuration dictionaries (not used).
            mode (str): 'inference' or 'train'
            target_columns (List[pd.Series], optional): Target columns for training mode.
            
        Returns:
            List: For inference mode, list of reconstructed pd.Series.
                  For training mode, tuple of (decoder_outputs, loss).
        """
        if not tensors:
            return []
        
        result = []
        
        # Process tensors in batches to avoid memory issues with large datasets
        for batch_start in range(0, len(tensors), self.column_batch_size):
            batch_end = min(batch_start + self.column_batch_size, len(tensors))
            tensor_batch = tensors[batch_start:batch_end]
            config_batch = configs[batch_start:batch_end] if configs else [{}] * len(tensor_batch)
            
            # Save original shapes for later splitting
            batch_sizes = [tensor.shape[0] for tensor in tensor_batch]
            
            # Concatenate all tensors for batch processing
            all_embeddings = torch.cat(tensor_batch)
            
            # Run all embeddings through decoder at once
            all_decoder_output = self.decoder(all_embeddings)
            
            # Split outputs into bin probabilities and fractions
            all_bin_probs = all_decoder_output[:, :self.input_dim]
            all_fractions = all_decoder_output[:, self.input_dim:]
            
            if mode == 'train':
                # Split results back according to original tensor batches
                split_results = []
                start_idx = 0
                
                for size in batch_sizes:
                    end_idx = start_idx + size
                    # Get bin probs and fractions for this tensor
                    tensor_bin_probs = all_bin_probs[start_idx:end_idx]
                    tensor_fractions = all_fractions[start_idx:end_idx]
                    # Add tuple of (bin_probs, fractions) for this tensor to results
                    split_results.append((tensor_bin_probs, tensor_fractions))
                    start_idx = end_idx
                
                result.extend(split_results)
            else:
                # For inference mode, reconstruct scalar values
                all_results = torch.zeros(all_embeddings.shape[0], device=all_embeddings.device)
                
                for i in range(all_embeddings.shape[0]):
                    # Find the first bin that's not predicted as 1 (active bin)
                    mask = all_bin_probs[i] < 0.5
                    active_bins = torch.where(mask)[0]
                    
                    if len(active_bins) == 0:
                        # All bins predicted as 1, use the last bin with fraction=1
                        all_results[i] = self.input_dim - 1 + 0.999999  # Max value encoding
                    else:
                        active_bin = active_bins[0]
                        # Sum active bin index and predicted fraction
                        all_results[i] = active_bin + all_fractions[i, active_bin]
                
                # Split results back into original batch sizes
                start_idx = 0
                for size in batch_sizes:
                    end_idx = start_idx + size
                    # Convert tensor slice to pandas Series
                    result.append(pd.Series(all_results[start_idx:end_idx].detach().cpu().numpy()))
                    start_idx = end_idx
        
        return result

    def _compute_loss(self, predictions, targets, config=None):
        """
        Compute loss between predicted and target values.

        Args:
            predictions (tuple): Tuple of (bin_probs, fractions) from the decoder.
            targets (torch.Tensor or pd.Series): Target scalar PLE values.
            config (dict, optional): Configuration dictionary.

        Returns:
            torch.Tensor: Computed loss.
        """
        bin_probs, fractions = predictions
        
        # Convert targets to tensor if it's a pandas Series
        if isinstance(targets, pd.Series):
            targets = torch.tensor(targets.values, dtype=torch.float32, device=bin_probs.device)
        
        # Filter out NaN values in targets
        valid_mask = ~torch.isnan(targets)
        if not valid_mask.all():
            # Filter targets and predictions by mask
            targets = targets[valid_mask]
            bin_probs = bin_probs[valid_mask]
            fractions = fractions[valid_mask]
            
            # If no valid samples remain, return zero loss
            if targets.numel() == 0:
                return torch.tensor(0.0, device=bin_probs.device)
        
        # Convert scalar targets to PLE vectors
        target_vectors = self._scalar_to_ple_vector(targets)
        
        # Compute binary classification loss for each bin
        bin_targets = (target_vectors > 0.5).float()
        bin_loss = F.binary_cross_entropy(bin_probs, bin_targets)
        
        # Compute regression loss for the fractional parts
        # Extract bin indices and fractions from targets
        active_bins = torch.floor(targets).long()
        active_fractions = targets - active_bins
        
        # Special cases for min/max
        is_min = torch.isclose(targets, torch.tensor(0.0, device=targets.device))
        is_max = torch.isclose(targets, torch.tensor(self.input_dim - 1 + 0.999999, device=targets.device))
        
        # Handle active bins and their fractions
        frac_loss = torch.tensor(0.0, device=bin_loss.device)
        mask = ~(is_min | is_max)  # Only regular values
        
        if mask.sum() > 0:
            # Get predicted fractions for active bins
            batch_indices = torch.where(mask)[0]
            bin_indices = active_bins[mask]
            
            # Extract predicted fractions at the active bin locations
            pred_fractions = fractions[batch_indices, bin_indices]
            
            # Calculate MSE loss for fractions
            frac_loss = F.mse_loss(pred_fractions, active_fractions[mask])
        
        # Combine losses
        total_loss = bin_loss + frac_loss
        
        return total_loss

    def _compute_batch_loss(self, reconstructed_values_list: List[Tuple[torch.Tensor, torch.Tensor]], 
                           target_columns: List[pd.Series], 
                           configs: List[Dict]) -> torch.Tensor:
        """
        Compute the combined loss for multiple columns in batch.
        
        Args:
            reconstructed_values_list (List[Tuple[torch.Tensor, torch.Tensor]]): 
                List of tuples of (bin_probabilities, fractions) for each column
            target_columns (List[pd.Series]): List of target columns with PLE encoded lists
            configs (List[Dict]): List of configurations (not used)
            
        Returns:
            torch.Tensor: The combined loss
        """
        try:
            # Validate and prepare inputs
            valid_items = []
            for reconstructed, target, config in zip(reconstructed_values_list, target_columns, configs):
                bin_probabilities, fractions = reconstructed
                # Handle NaN or Inf values
                if torch.isnan(bin_probabilities).any() or torch.isinf(bin_probabilities).any():
                    bin_probabilities = torch.where(
                        torch.isnan(bin_probabilities) | torch.isinf(bin_probabilities),
                        torch.tensor(0.5, device=self.device),
                        bin_probabilities
                    )
                    bin_probabilities = torch.clamp(bin_probabilities, 0.0, 1.0)
                    reconstructed = (bin_probabilities, fractions)
                valid_items.append((reconstructed, target))
            
            if not valid_items:
                return torch.tensor(1e-5, device=self.device)
            
            # Concatenate all inputs for batch processing
            all_bin_probs = torch.cat([item[0][0] for item in valid_items], dim=0)
            all_fractions = torch.cat([item[0][1] for item in valid_items], dim=0)
            all_targets = pd.concat([item[1] for item in valid_items], ignore_index=True)
            
            # Process in a single batch
            total_loss = self._compute_loss((all_bin_probs, all_fractions), all_targets)
            
            return total_loss
            
        except Exception as e:
            print(f"Error in _compute_batch_loss: {str(e)}. Returning fallback loss.")
            return torch.tensor(1e-5, device=self.device)

if __name__ == "__main__":
    # Test all vectorizers
    column = pd.Series([1.5, 2.0, 3.5])
    config = {}

    # Test QuantileEmbeddingVectorizer
    print("\nTesting QuantileEmbeddingVectorizer:")
    # First transform data to [0,1] range to simulate quantile transformation
    quantile_data = pd.Series([0.2, 0.5, 0.8])  # Simulated quantile-transformed data
    vectorizer = QuantileEmbeddingVectorizer(output_dim=4, hidden_dims=[64, 32])
    vectors = vectorizer.vectorize(quantile_data, config)
    print("Vectors:\n", vectors)
    reconstructed = vectorizer.inverse_vectorize(vectors, config)
    print("Reconstructed (should be in [0,1] range):\n", reconstructed)

    # Test the numerical vectorizer
    vectorizer = QuantileEmbeddingVectorizer(output_dim=4)  # Using higher dimension for testing
    
    # Sample data
    column1 = pd.Series([1.0, 2.5, 3.7, -1.2, 0.8])
    column2 = pd.Series([0.5, -2.0, 3.0, 1.5])
    column3 = pd.Series([-5.0, 0.0, 7.5])
    
    # Regular processing tests
    print("\nTest 1: Basic vectorization")
    vectors = vectorizer.vectorize(column1, {})
    print("Vectors shape:", vectors.shape)
    
    decoded_values = vectorizer.inverse_vectorize(vectors, {})
    print("Decoded values:", decoded_values)
    
    # Test training mode
    vectors_train, loss = vectorizer.inverse_vectorize(vectors, {}, mode='train', target_column=column1)
    print("Training loss:", loss.item())
    
    # Batch processing tests
    print("\nTest 2: Batch vectorization")
    columns = [column1, column2, column3]
    configs = [{}, {}, {}]
    
    batch_vectors = vectorizer.vectorize_batch(columns, configs)
    print(f"Number of tensors: {len(batch_vectors)}")
    for i, tensor in enumerate(batch_vectors):
        print(f"Tensor {i} shape: {tensor.shape}")
    
    # Batch inverse vectorization
    decoded_batch = vectorizer.inverse_vectorize_batch(batch_vectors, configs)
    print("\nDecoded batch:")
    for i, col in enumerate(decoded_batch):
        print(f"Column {i}:\n{col}")
    
    # Batch training mode
    batch_values, batch_loss = vectorizer.inverse_vectorize_batch(
        batch_vectors, configs, mode='train', target_columns=columns
    )
    print(f"Batch training loss: {batch_loss.item()}")
    
    # Performance comparison
    print("\nTest 3: Performance comparison")
    import time
    import numpy as np
    
    # Create larger test data
    num_columns = 100
    column_length = 1000
    
    large_columns = []
    large_configs = []
    
    for i in range(num_columns):
        data = np.random.randn(column_length)
        large_columns.append(pd.Series(data))
        large_configs.append({})
    
    # Sequential processing time
    start_time = time.time()
    sequential_vectors = [vectorizer._vectorize(col, cfg) for col, cfg in zip(large_columns, large_configs)]
    sequential_time = time.time() - start_time
    print(f"Sequential vectorization time: {sequential_time:.4f}s")
    
    # Batch processing time
    start_time = time.time()
    batch_vectors = vectorizer._vectorize_batch(large_columns, large_configs)
    batch_time = time.time() - start_time
    print(f"Batch vectorization time: {batch_time:.4f}s")
    print(f"Speedup: {sequential_time / batch_time:.2f}x")

    # Test 4: Verify batch decoding accuracy
    print("\nTest 4: Verify batch decoding accuracy")
    
    # Create test data
    test_columns = [
        pd.Series([1.1, 2.2, 3.3, 4.4, 5.5]),
        pd.Series([10.1, 20.2, 30.3, 40.4]),
        pd.Series([-5.5, -4.4, -3.3, -2.2, -1.1, 0.0])
    ]
    test_configs = [{}, {}, {}]
    
    # 1. Process each column individually
    individual_vectors = [vectorizer.vectorize(col, {}) for col in test_columns]
    individual_decoded = [vectorizer.inverse_vectorize(vec, {}) for vec in individual_vectors]
    
    # 2. Process columns in batch
    batch_vectors = vectorizer.vectorize_batch(test_columns, test_configs)
    batch_decoded = vectorizer.inverse_vectorize_batch(batch_vectors, test_configs)
    
    # 3. Compare results
    all_match = True
    max_diff = 0.0
    
    for i, (ind_col, batch_col) in enumerate(zip(individual_decoded, batch_decoded)):
        # Check if columns have same length
        if len(ind_col) != len(batch_col):
            print(f"❌ Column {i}: Length mismatch - individual: {len(ind_col)}, batch: {len(batch_col)}")
            all_match = False
            continue
            
        # Calculate maximum absolute difference
        diff = np.abs(ind_col.values - batch_col.values).max()
        max_diff = max(max_diff, diff)
        
        # Check if differences are within tolerance
        if diff > 1e-5:
            print(f"❌ Column {i}: Values differ by {diff:.8f}")
            print(f"  Individual: {ind_col.values[:3]}...")
            print(f"  Batch:      {batch_col.values[:3]}...")
            all_match = False
        else:
            print(f"✓ Column {i}: Match (max diff: {diff:.8f})")
    
    if all_match:
        print(f"✅ All columns match between individual and batch processing (max diff: {max_diff:.8f})")
    else:
        print(f"❌ Some columns have differences (max diff: {max_diff:.8f})")
    
    # Test 5: Verify round-trip accuracy (original → vectors → decoded)
    print("\nTest 5: Verify round-trip accuracy")
    
    # Create simple test data with predictable values
    simple_column = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
    
    # Process individually
    vectors = vectorizer.vectorize(simple_column, {})
    decoded = vectorizer.inverse_vectorize(vectors, {})
    
    # Calculate mean absolute error
    ind_mae = np.abs(simple_column.values - decoded.values).mean()
    print(f"Individual round-trip MAE: {ind_mae:.8f}")
    print(f"Original: {simple_column.values}")
    print(f"Decoded:  {decoded.values}")
    
    # Process in batch
    batch_vectors = vectorizer.vectorize_batch([simple_column], [{}])
    batch_decoded = vectorizer.inverse_vectorize_batch(batch_vectors, [{}])
    
    # Calculate mean absolute error for batch
    batch_mae = np.abs(simple_column.values - batch_decoded[0].values).mean()
    print(f"Batch round-trip MAE: {batch_mae:.8f}")
    print(f"Original: {simple_column.values}")
    print(f"Decoded:  {batch_decoded[0].values}")
    
    if ind_mae < 1e-5 and batch_mae < 1e-5:
        print("✅ Both individual and batch processing preserve values accurately")
    else:
        print("❌ Round-trip processing introduces errors")
    
    # Test the PLEVectorizer
    print("\nTesting PLEVectorizer:")
    
    # Create sample PLE encoded data using scalar encoding format
    # Example: 3 samples with 4 bins
    # First sample: 2.5 - active bin is 2 with fraction 0.5
    # Second sample: 0.7 - active bin is 0 with fraction 0.7
    # Third sample: 3.8 - active bin is 3 with fraction 0.8
    ple_data = pd.Series([2.5, 0.7, 3.8])
    
    # Create and test the vectorizer
    # Use input_dim=4 to match the number of bins
    ple_vectorizer = PLEVectorizer(output_dim=4, input_dim=4)
    
    # Test vectorization
    print("\nTest 1: Basic PLE vectorization")
    vectors = ple_vectorizer.vectorize(ple_data, {})
    print("Vectors shape:", vectors.shape)
    
    # Test reconstruction
    reconstructed = ple_vectorizer.inverse_vectorize(vectors, {})
    print("Original PLE data:")
    print(ple_data)
    print("Reconstructed PLE data:")
    print(reconstructed)
    
    # Test training mode
    print("\nTest 2: PLE training mode")
    predictions, loss = ple_vectorizer.inverse_vectorize(vectors, {}, mode='train', target_column=ple_data)
    print("Training loss:", loss.item())
    bin_probs, fractions = predictions
    print("Bin probabilities shape:", bin_probs.shape)
    print("Fractions shape:", fractions.shape)
    
    # Test batch processing
    print("\nTest 3: PLE batch processing")
    ple_data2 = pd.Series([0.3, 2.4])
    
    ple_columns = [ple_data, ple_data2]
    ple_configs = [{}, {}]
    
    # Batch vectorization
    batch_vectors = ple_vectorizer.vectorize_batch(ple_columns, ple_configs)
    print(f"Number of tensors: {len(batch_vectors)}")
    for i, tensor in enumerate(batch_vectors):
        print(f"Tensor {i} shape: {tensor.shape}")
    
    # Batch reconstruction
    reconstructed_batch = ple_vectorizer.inverse_vectorize_batch(batch_vectors, ple_configs)
    print("\nBatch reconstruction results:")
    for i, reconstructed in enumerate(reconstructed_batch):
        print(f"Column {i} original: {ple_columns[i].values}")
        print(f"Column {i} reconstructed: {reconstructed.values}")
    
    # Batch training mode
    print("\nBatch training mode:")
    batch_predictions, batch_loss = ple_vectorizer.inverse_vectorize_batch(
        batch_vectors, ple_configs, mode='train', target_columns=ple_columns
    )
    print(f"Batch training loss: {batch_loss.item()}")
