import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
from typing import List, Dict
import math

from .columnVectorizer import *

# Table vectorizer
class TableVectorizer(nn.Module):
    REQUIRED_VAR_TYPES = ["numerical", "categorical", "text", "datetime"]
    
    def __init__(self, transformer_mapping, output_dim, transformer_configs=None):
        """
        Initialize TableVectorizer with configurable transformer parameters.

        Args:
            transformer_mapping (dict): Mapping of data types to transformer classes
            output_dim (int): Common output dimension for all transformers
            transformer_configs (dict, optional): Specific configurations for each transformer type
                Format: {
                    "numerical": {"num_bins": 32, "num_exponents": 64, ...},
                    "categorical": {"model_name": "bert-base", "projection_dim": 128, ...},
                    "text": {...},
                    "datetime": {...}
                }
        """
        super(TableVectorizer, self).__init__()
        
        # Validate required transformers
        assert all([vtype in transformer_mapping.keys() for vtype in self.REQUIRED_VAR_TYPES]), \
            f"Required variable types: {self.REQUIRED_VAR_TYPES} not matched with transformer_mapping: {transformer_mapping.keys()}"
        
        # Initialize default configs if none provided
        transformer_configs = transformer_configs or {}
        default_configs = {
            "numerical": {"num_bins": 32, "num_exponents": 64},
            "categorical": {"model_name": "Alibaba-NLP/gte-large-en-v1.5", "projection_dim": 128},
            "text": {},  # Add default text configs
            "datetime": {}  # Add default datetime configs
        }
        
        # Create ModuleDict with configured transformers
        self.transformer_mapping = nn.ModuleDict()
        for var_type, transformer_class in transformer_mapping.items():
            # Merge default config with provided config
            config = {**default_configs.get(var_type, {}), 
                     **transformer_configs.get(var_type, {}),
                     "output_dim": output_dim}  # Always include output_dim
            
            # Initialize transformer with config
            self.transformer_mapping[var_type] = transformer_class(**config)
        
        self._freeze_non_trainable()
        self.output_dim = output_dim

    def _freeze_non_trainable(self):
        for name, module in self.transformer_mapping.items():
            if module.is_trainable():
                continue
            for param in module.parameters():
                param.requires_grad = False

    def forward(self, dataframe, config, batch_processing=True):
        return self.vectorize(dataframe, config, batch_processing=batch_processing)

    def vectorize(self, dataframe, config, batch_processing=True):
        """
        Transform a dataframe into a tensor representation.
        
        Args:
            dataframe (pd.DataFrame): The dataframe to transform.
            config (dict): Configuration dictionary with variables info.
            batch_processing (bool): Whether to use batch processing for columns of the same type.
            
        Returns:
            torch.Tensor: Transformed tensor of shape (N, num_columns, D).
        """
        if batch_processing:
            return self._vectorize_batch(dataframe, config)
        else:
            return self._vectorize(dataframe, config)

    def _vectorize(self, dataframe, config):
        """
        Original sequential implementation for transforming a dataframe into tensor representation.
        
        Args:
            dataframe (pd.DataFrame): The dataframe to transform.
            config (dict): Configuration dictionary with variables info.
            
        Returns:
            torch.Tensor: Transformed tensor of shape (N, num_columns, D).
        """
        variables = config["variables"]
        
        if len(variables) != dataframe.shape[1]:
            raise ValueError("Number of variables in config does not match number of columns in the dataframe.")

        column_vectors = []
        for i, column in enumerate(dataframe.columns):
            column_data = dataframe[column]
            column_config = variables[i]
            variable_type = column_config.get("variable_type", "unknown")

            if variable_type not in self.transformer_mapping:
                raise ValueError(f"Unsupported variable type: {variable_type}")

            vectorizer = self.transformer_mapping[variable_type]
            column_vector = vectorizer.vectorize(column_data, column_config)
            
            # Ensure the column vector requires gradients if the vectorizer is trainable
            if vectorizer.is_trainable() and not column_vector.requires_grad:
                column_vector = column_vector.detach().clone().requires_grad_(True)
            
            column_vectors.append(column_vector)

        table_tensor = torch.stack(column_vectors, dim=1)
        
        # Ensure the final tensor requires gradients
        if not table_tensor.requires_grad and any(vec.requires_grad for vec in column_vectors):
            table_tensor = table_tensor.detach().clone().requires_grad_(True)
        
        return table_tensor

    def _vectorize_batch(self, dataframe, config):
        """
        Batch implementation for transforming a dataframe into tensor representation.
        Groups columns by their variable type for more efficient processing.
        
        Args:
            dataframe (pd.DataFrame): The dataframe to transform.
            config (dict): Configuration dictionary with variables info.
            
        Returns:
            torch.Tensor: Transformed tensor of shape (N, num_columns, D).
        """
        variables = config["variables"]
        
        if len(variables) != dataframe.shape[1]:
            raise ValueError("Number of variables in config does not match number of columns in the dataframe.")

        # Pre-allocate to maintain order
        column_vectors = [None] * len(variables)
        
        # Group columns by their type
        types_to_columns = {}
        types_to_configs = {}
        types_to_indices = {}
        
        # Group columns and configs by variable type
        for i, column in enumerate(dataframe.columns):
            column_data = dataframe[column]
            column_config = variables[i]
            variable_type = column_config.get("variable_type", "unknown")
            
            if variable_type not in self.transformer_mapping:
                raise ValueError(f"Unsupported variable type: {variable_type}")
            
            if variable_type not in types_to_columns:
                types_to_columns[variable_type] = []
                types_to_configs[variable_type] = []
                types_to_indices[variable_type] = []
            
            types_to_columns[variable_type].append(column_data)
            types_to_configs[variable_type].append(column_config)
            types_to_indices[variable_type].append(i)
        
        # Process each group using batch vectorization
        for variable_type, columns in types_to_columns.items():
            configs = types_to_configs[variable_type]
            indices = types_to_indices[variable_type]
            vectorizer = self.transformer_mapping[variable_type]
            
            # Use batch vectorization
            batch_vectors = vectorizer.vectorize_batch(columns, configs)
            
            # Ensure vectors require gradients if needed
            if vectorizer.is_trainable():
                batch_vectors = [
                    vec.detach().clone().requires_grad_(True) 
                    if not vec.requires_grad else vec 
                    for vec in batch_vectors
                ]
            
            # Place vectors in the correct positions to maintain original order
            for i, vec in zip(indices, batch_vectors):
                column_vectors[i] = vec
        
        # Stack to create the final table tensor
        table_tensor = torch.stack(column_vectors, dim=1)
        
        # Ensure the final tensor requires gradients
        if not table_tensor.requires_grad and any(vec.requires_grad for vec in column_vectors):
            table_tensor = table_tensor.detach().clone().requires_grad_(True)
        
        return table_tensor

    def inverse_vectorize(self, tensor, config, mode='inference', target_df=None, batch_processing=True):
        """
        Reverse the tensor transformation back to a dataframe.
        
        Args:
            tensor (torch.Tensor): Tensor to transform back to dataframe.
            config (dict): Configuration dictionary with variables info.
            mode (str): 'inference' or 'train' mode.
            target_df (pd.DataFrame, optional): Target dataframe for training mode.
            batch_processing (bool): Whether to use batch processing for columns of the same type.
            
        Returns:
            pd.DataFrame or list: 
                - If mode='inference': Reconstructed dataframe
                - If mode='train': List of reconstructed values for loss computation
        """
        try:
            if batch_processing:
                return self._inverse_vectorize_batch(tensor, config, mode, target_df)
            else:
                return self._inverse_vectorize(tensor, config, mode, target_df)
        except Exception as e:
            # Create debug directory if it doesn't exist
            import os
            import json
            from datetime import datetime
            
            debug_dir = "debug_table_vectorizer"
            os.makedirs(debug_dir, exist_ok=True)
            
            # Generate timestamp for unique filenames
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            
            # Save target dataframe if it exists
            debug_files = []
            if target_df is not None:
                target_path = os.path.join(debug_dir, f"target_df_{timestamp}.csv")
                target_df.to_csv(target_path)
                debug_files.append(f"- Target dataframe: {target_path}")
            
            # Save config
            config_path = os.path.join(debug_dir, f"table_config_{timestamp}.json")
            with open(config_path, 'w') as f:
                # Handle non-serializable objects in config
                serializable_config = {}
                for k, v in config.items():
                    try:
                        json.dumps({k: v})
                        serializable_config[k] = v
                    except (TypeError, OverflowError):
                        serializable_config[k] = str(v)
                json.dump(serializable_config, f, indent=2)
            debug_files.append(f"- Table config: {config_path}")
            
            # Save tensor shape information
            tensor_info_path = os.path.join(debug_dir, f"tensor_info_{timestamp}.json")
            with open(tensor_info_path, 'w') as f:
                tensor_info = {
                    "shape": list(tensor.shape),
                    "device": str(tensor.device),
                    "dtype": str(tensor.dtype),
                    "batch_processing": batch_processing
                }
                json.dump(tensor_info, f, indent=2)
            debug_files.append(f"- Tensor info: {tensor_info_path}")
            
            # Add debug info to the error message
            raise type(e)(
                f"{str(e)}\nDebug files saved to {debug_dir}:\n" + 
                "\n".join(debug_files)
            ) from e

    def _inverse_vectorize(self, tensor, config, mode='inference', target_df=None):
        """
        Original sequential implementation for reversing tensor transformation back to dataframe.
        
        Args:
            tensor (torch.Tensor): Tensor to transform back to dataframe.
            config (dict): Configuration dictionary with variables info.
            mode (str): 'inference' or 'train' mode.
            target_df (pd.DataFrame, optional): Target dataframe for training mode.
            
        Returns:
            pd.DataFrame or list: 
                - If mode='inference': Reconstructed dataframe
                - If mode='train': List of reconstructed values for loss computation
        """
        variables = config["variables"]
        if len(variables) != tensor.shape[1]:
            raise ValueError("Number of variables in config does not match number of columns in the tensor.")
        if mode == "train":
            assert target_df is not None, "Target dataframe is required for training mode!"

        reconstructed_columns = []
        for i in range(tensor.shape[1]):
            column_tensor = tensor[:, i, :]
            column_config = variables[i]
            variable_type = column_config.get("variable_type", "unknown")

            if variable_type not in self.transformer_mapping:
                raise ValueError(f"Unsupported variable type: {variable_type}")

            vectorizer = self.transformer_mapping[variable_type]
            if target_df is not None and mode == "train":
                target_column = target_df[column_config["variable_name"]]
                reconstructed = vectorizer.inverse_vectorize(column_tensor, column_config, mode, target_column)
            else:
                reconstructed = vectorizer.inverse_vectorize(column_tensor, column_config, mode)
            reconstructed_columns.append(reconstructed)

        if mode == "inference":
            dataframe = pd.concat(reconstructed_columns, axis=1)
            dataframe.columns = [var["variable_name"] for var in variables]
            return dataframe
        elif mode == "train":
            return reconstructed_columns 
        else:
            raise ValueError("Unsupported mode. Use 'inference' or 'train'.")

    def _inverse_vectorize_batch(self, tensor, config, mode='inference', target_df=None):
        """
        Batch implementation for reversing tensor transformation back to dataframe.
        Groups tensors by their variable type for more efficient processing.
        
        Args:
            tensor (torch.Tensor): Tensor to transform back to dataframe.
            config (dict): Configuration dictionary with variables info.
            mode (str): 'inference' or 'train' mode.
            target_df (pd.DataFrame, optional): Target dataframe for training mode.
            
        Returns:
            pd.DataFrame or list: 
                - If mode='inference': Reconstructed dataframe
                - If mode='train': List of tuple(reconstructed_values, column_loss)
        """
        variables = config["variables"]
        if len(variables) != tensor.shape[1]:
            raise ValueError("Number of variables in config does not match number of columns in the tensor.")
        if mode == "train":
            assert target_df is not None, "Target dataframe is required for training mode!"

        # Pre-allocate to maintain order
        reconstructed_columns = [None] * len(variables)
        
        # Group tensors by their type
        types_to_tensors = {}
        types_to_configs = {}
        types_to_indices = {}
        types_to_targets = {}
        
        # Group tensors and configs by variable type
        for i in range(tensor.shape[1]):
            column_tensor = tensor[:, i, :]
            column_config = variables[i]
            variable_type = column_config.get("variable_type", "unknown")
            
            if variable_type not in self.transformer_mapping:
                raise ValueError(f"Unsupported variable type: {variable_type}")
            
            if variable_type not in types_to_tensors:
                types_to_tensors[variable_type] = []
                types_to_configs[variable_type] = []
                types_to_indices[variable_type] = []
                if target_df is not None and mode == "train":
                    types_to_targets[variable_type] = []
            
            types_to_tensors[variable_type].append(column_tensor)
            types_to_configs[variable_type].append(column_config)
            types_to_indices[variable_type].append(i)
            
            if target_df is not None and mode == "train":
                target_column = target_df[column_config["variable_name"]]
                types_to_targets[variable_type].append(target_column)
        
        # Process each group using batch inverse vectorization
        for variable_type, column_tensors in types_to_tensors.items():
            configs = types_to_configs[variable_type]
            indices = types_to_indices[variable_type]
            vectorizer = self.transformer_mapping[variable_type]
            
            # Use batch inverse vectorization
            if target_df is not None and mode == "train":
                targets = types_to_targets[variable_type]
                batch_results, batch_loss = vectorizer.inverse_vectorize_batch(
                    column_tensors, configs, mode, targets
                )
                
                # Place reconstructed values in the correct positions
                for idx, (i, result) in enumerate(zip(indices, batch_results)):
                    # For training mode, we need to preserve both the reconstructed values and the loss
                    # Create individual loss for each column from the batch loss
                    if isinstance(batch_loss, torch.Tensor):
                        # If the vectorizer returns a combined loss tensor, we need to create individual losses
                        # We divide the batch loss evenly among the columns of this type
                        column_loss = batch_loss / len(batch_results)
                        reconstructed_columns[i] = (result, column_loss)
                    else:
                        # If the vectorizer returns a list of losses, use the corresponding one
                        column_loss = batch_loss[idx] if isinstance(batch_loss, list) else batch_loss
                        reconstructed_columns[i] = (result, column_loss)
            else:
                batch_results = vectorizer.inverse_vectorize_batch(
                    column_tensors, configs, mode
                )
                # Place results in correct positions
                for i, result in zip(indices, batch_results):
                    reconstructed_columns[i] = result

        if mode == "inference":
            dataframe = pd.concat(reconstructed_columns, axis=1)
            dataframe.columns = [var["variable_name"] for var in variables]
            return dataframe
        elif mode == "train":
            return reconstructed_columns 
        else:
            raise ValueError("Unsupported mode. Use 'inference' or 'train'.")

    def encode_meta(self, config):
        """
        Encode the table description and variable names using a CategoricalVectorizer.

        Args:
            config (dict): Configuration dictionary containing "description" and "variables".

        Returns:
            tuple: A tuple containing:
                - meta (torch.Tensor): Encoded description vector of shape (d_lm,).
                - column_names (torch.Tensor): Encoded variable names tensor of shape (d_var, d_lm).
                - dtype_tensor (torch.Tensor): Tensor representing variable types, shape (num_columns, self.output_dim).
                - dist_tensor (torch.Tensor): Distribution-based embeddings for variables, shape (num_columns, self.output_dim).
        """
        description = config.get("description", "")
        variables = config.get("variables", [])

        # Get the categorical vectorizer
        categorical_vectorizer = self.transformer_mapping["categorical"]
        if categorical_vectorizer is None:
            raise ValueError("Categorical vectorizer is required for encoding.")

        # Encode the table description
        description_tensor = categorical_vectorizer.vectorize(
            pd.Series([description]), {"variable_type": "categorical", "categories": [description]}
        )

        # Encode variable names
        variable_names = [var["variable_name"] for var in variables]
        variable_name_tensor = categorical_vectorizer.vectorize(
            pd.Series(variable_names), {"variable_type": "categorical", "categories": variable_names}
        )

        # Map dtypes to numeric codes
        dtype_mapping = {"numerical": 0, "categorical": 1, "text": 2, "datetime": 3}
        dtype_values = [dtype_mapping.get(var["variable_type"], -1) for var in variables]  # Use -1 for unknown dtypes

        # Create dtype tensor
        dtype_tensor = torch.ones(len(dtype_values), self.output_dim) * torch.tensor(dtype_values).view(-1, 1)

        # Create distribution-based embeddings
        dist_tensors = []
        for var_config in variables:
            # Use the encode_distribution method to get rich distribution embedding
            dist_tensor = self.encode_distribution(var_config)
            
            # Check for NaN values in the distribution tensor
            if torch.isnan(dist_tensor).any():
                print(f"WARNING: NaN detected in distribution embedding for variable {var_config.get('variable_name', 'unknown')}. Replacing with zeros.")
                # Create zero tensor with same shape
                dist_tensor = torch.zeros_like(dist_tensor)
                
            dist_tensors.append(dist_tensor.squeeze(0))  # Remove batch dimension
        
        # Stack distribution embeddings
        dist_tensor = torch.stack(dist_tensors, dim=0)

        return description_tensor.squeeze(0), variable_name_tensor, dtype_tensor, dist_tensor
    
    def encode_distribution(self, var_config):
        """
        Create an embedding for the variable type as well as the distribution pattern.
        
        Args:
            var_config (dict): Configuration for a single variable
            
        Returns:
            torch.Tensor: Embedding vector representing the data type and distribution
        """
        categorical_vectorizer = self.transformer_mapping["categorical"]
        variable_type = var_config["variable_type"]
        
        # Base embedding for the variable type
        dtype_embedding = categorical_vectorizer.vectorize(
            pd.Series([variable_type]), 
            {"variable_type": "categorical", "categories": [variable_type]}
        )
        
        # Enhance embedding based on data type
        if variable_type in ["text", "datetime"]:
            # For text and datetime, just use the type embedding
            return dtype_embedding
            
        elif variable_type == "categorical":
            # For categorical, average embeddings of all categories
            if "categories" in var_config and var_config["categories"]:
                categories = var_config["categories"]
                # Vectorize all categories
                categories_embeddings = categorical_vectorizer.vectorize(
                    pd.Series(categories), 
                    {"variable_type": "categorical", "categories": categories}
                )
                # Average the category embeddings
                avg_categories_embedding = torch.mean(categories_embeddings, dim=0, keepdim=True)
                # Combine with type embedding
                return 0.5 * dtype_embedding + 0.5 * avg_categories_embedding
            return dtype_embedding
            
        elif variable_type == "numerical":
            # For numerical, encode bin edges or quantiles
            if "ple_params" in var_config:
                # Use PiecewiseLinearEncoder parameters
                bin_edges = var_config["ple_params"].get("bin_edges_", [])
                if bin_edges:
                    bins_embedding = self.encode_bins(bin_edges, dtype_embedding.device)
                    return 0.5 * dtype_embedding + 0.5 * bins_embedding
            elif "quantile_params" in var_config:
                # Use NumericalQuantileTransformer parameters
                quantiles = var_config["quantile_params"].get("quantiles", [])
                if quantiles:
                    bins_embedding = self.encode_bins(quantiles, dtype_embedding.device)
                    return 0.5 * dtype_embedding + 0.5 * bins_embedding
            return dtype_embedding
        
        # Default case - just return the type embedding
        return dtype_embedding
        
    def encode_bins(self, bins, device=None):
        """
        Encode numerical bin edges or quantiles into an embedding vector of length `self.output_dim`
        using a Discrete Cosine Transform (DCT) of the bin-widths (density pattern).

        Steps:
            1. Compute bin widths w_i = b_i - b_{i-1}
            2. Zero-pad (or truncate) the width vector to length D (=output_dim)
            3. Compute the 1-D DCT-II of that length-D vector
        Args:
            bins (list | np.ndarray | torch.Tensor): Bin edges or quantiles
            device (torch.device, optional): Device for returned tensor (defaults to cpu or dtype_embedding device)

        Returns:
            torch.Tensor: Tensor of shape (1, D) containing the DCT coefficients
        """
        

        D = self.output_dim
        # Ensure tensor on proper device
        if device is None:
            device = torch.device("cpu")

        bins_tensor = torch.tensor(bins, dtype=torch.float32, device=device) if not isinstance(bins, torch.Tensor) else bins.to(device)

        if bins_tensor.numel() < 2:
            # Not enough edges to compute widths → return uniform embedding
            return torch.ones(1, D, device=device) / math.sqrt(D)

        # Compute widths
        widths = bins_tensor[1:] - bins_tensor[:-1]

        # Handle negative or zero widths by clamping small positives to avoid degenerate cases
        widths = torch.clamp(widths, min=1e-12)

        # Normalize widths to sum-to-1 to remove scale effects (optional)
        widths = widths / widths.sum()

        # Pad or truncate to length D
        if widths.numel() < D:
            pad = torch.zeros(D - widths.numel(), device=device)
            widths_padded = torch.cat([widths, pad], dim=0)
        else:
            widths_padded = widths[:D]

        # Compute DCT-II manually (avoid dependency on torch.fft.dct to ensure compatibility)
        n = torch.arange(D, device=device).float()  # (D,)
        k = n.clone().view(-1, 1)                  # (D,1)
        cos_arg = math.pi / D * (n + 0.5) * k      # Broadcasting → (D, D)
        cos_matrix = torch.cos(cos_arg)            # (D, D)

        # c_k = \sum_n w_n cos(...)
        c = torch.matmul(cos_matrix, widths_padded)  # (D,)

        # Normalize coefficients (L2)
        c = c / (c.norm() + 1e-8)

        return c.unsqueeze(0)  # (1, D)

    def to(self, device):
        """Override to() to ensure all vectorizers are moved to the device"""
        super().to(device)
        for vectorizer in self.transformer_mapping.values():
            vectorizer.to(device)
        return self

def test_batch_processing(table_vectorizer, dataframe, config):
    """
    Test batch processing vs sequential processing to verify they produce the same results.
    
    Args:
        table_vectorizer (TableVectorizer): The table vectorizer to test.
        dataframe (pd.DataFrame): The dataframe to transform.
        config (dict): Configuration dictionary with variables info.
        
    Returns:
        dict: Dictionary with test results.
    """
    import time
    import numpy as np
    
    # Time sequential processing
    start_time = time.time()
    sequential_tensor = table_vectorizer.vectorize(dataframe, config, batch_processing=False)
    sequential_time = time.time() - start_time
    
    # Time batch processing
    start_time = time.time()
    batch_tensor = table_vectorizer.vectorize(dataframe, config, batch_processing=True)
    batch_time = time.time() - start_time
    
    # Compare results
    tensor_diff = torch.abs(sequential_tensor - batch_tensor).max().item()
    
    # Test inverse vectorization
    start_time = time.time()
    sequential_df = table_vectorizer.inverse_vectorize(sequential_tensor, config, batch_processing=False)
    sequential_inverse_time = time.time() - start_time
    
    start_time = time.time()
    batch_df = table_vectorizer.inverse_vectorize(batch_tensor, config, batch_processing=True)
    batch_inverse_time = time.time() - start_time
    
    # Compare reconstructed dataframes
    df_match = sequential_df.equals(batch_df)
    column_diffs = {}
    
    if not df_match:
        # Check column by column for differences
        for col in sequential_df.columns:
            if not sequential_df[col].equals(batch_df[col]):
                # For numerical columns, calculate absolute difference
                if np.issubdtype(sequential_df[col].dtype, np.number):
                    max_diff = np.abs(sequential_df[col] - batch_df[col]).max()
                    column_diffs[col] = f"Max diff: {max_diff}"
                else:
                    # For non-numerical columns, count mismatches
                    mismatch_count = (sequential_df[col] != batch_df[col]).sum()
                    column_diffs[col] = f"Mismatches: {mismatch_count}/{len(sequential_df[col])}"
    
    return {
        "vectorize": {
            "sequential_time": sequential_time,
            "batch_time": batch_time,
            "speedup": sequential_time / batch_time,
            "tensor_diff": tensor_diff,
            "tensors_equal": tensor_diff < 1e-6,
        },
        "inverse_vectorize": {
            "sequential_time": sequential_inverse_time,
            "batch_time": batch_inverse_time,
            "speedup": sequential_inverse_time / batch_inverse_time,
            "dataframes_equal": df_match,
            "column_diffs": column_diffs if not df_match else {}
        }
    }