import torch
import torch.nn as nn
import pandas as pd
from typing import List, Dict, Union, Tuple, Optional

class ColumnVectorizer(nn.Module):
    def __init__(self, output_dim, accepted_dtype):
        """
        Initialize the vectorizer.

        Args:
            output_dim (int): The dimension of the output vectors (D).
            accepted_dtype (list or str): The acceptable pandas data types for the column.
        """
        super(ColumnVectorizer, self).__init__()
        self.output_dim = output_dim
        if isinstance(accepted_dtype, str):
            self.accepted_dtype = [accepted_dtype]
        elif isinstance(accepted_dtype, list):
            self.accepted_dtype = accepted_dtype
        else:
            raise ValueError("accepted_dtype must be a string or a list of strings.")
        self.device = 'cpu'

    def to(self, device):
        """Override to() to ensure proper device management"""
        super().to(device)
        self.device = device
        #raise NotImplementedError("to() of column vectorizer must be implemented in subclasses.")

    def validate(self, column, config):
        """
        Validate if the column and configuration are acceptable.

        Args:
            column (pandas.Series): The column to validate.
            config (dict): The configuration dictionary for the transformation.

        Raises:
            ValueError: If the column type or configuration is invalid.
        """
        if not isinstance(column, pd.Series):
            raise ValueError("Input column must be a pandas Series.")

        # Check if the column has one of the accepted dtypes.
        if not any(pd.api.types.is_dtype_equal(column.dtype, dtype) for dtype in self.accepted_dtype):
            print(column.dtype)
            print(column)
            raise ValueError(
                f"Column dtype {column.dtype} is not accepted. Accepted dtypes are: {self.accepted_dtype}."
            )

        required_keys = self.required_config_keys()
        missing_keys = [key for key in required_keys if key not in config]
        if missing_keys:
            raise ValueError(f"Missing required keys in config: {missing_keys}")

        # Additional validation checks can be implemented here.
        return True

    def validate_batch(self, columns: List[pd.Series], configs: List[Dict]):
        """
        Validate a batch of columns and their configurations.
        
        Args:
            columns (List[pandas.Series]): List of columns to validate.
            configs (List[dict]): List of configuration dictionaries for transformations.
            
        Raises:
            ValueError: If input lists don't match or if any column/config is invalid.
            
        Returns:
            bool: True if all validations pass.
        """
        if len(columns) != len(configs):
            raise ValueError(f"Number of columns ({len(columns)}) must match number of configs ({len(configs)})")
            
        for i, (column, config) in enumerate(zip(columns, configs)):
            try:
                self.validate(column, config)
            except ValueError as e:
                raise ValueError(f"Validation failed for column at index {i}: {str(e)}")
                
        return True

    def is_trainable(self):
        raise NotImplementedError("is_trainable() must be implemented in subclasses.")

    def required_config_keys(self):
        """
        Define the required keys for the configuration.

        Returns:
            list: List of required keys.
        """
        # Override this method in subclasses to specify required config keys.
        return []
        #raise NotImplementedError("required_config_keys() must be implemented in subclasses.")

    def vectorize(self, column, config):
        """
        Transform the column into vectors.

        Args:
            column (pandas.Series): The column to transform.
            config (dict): The configuration dictionary for the transformation.

        Returns:
            torch.Tensor: Transformed tensor of shape (N, D).
        """
        self.validate(column, config)
        # Placeholder for transformation logic, should be implemented in subclass.
        return self._vectorize(column, config)
    
    def vectorize_batch(self, columns: List[pd.Series], configs: List[Dict]) -> List[torch.Tensor]:
        """
        Transform multiple columns into vectors in batch.
        
        Args:
            columns (List[pandas.Series]): The columns to transform.
            configs (List[dict]): The configuration dictionaries for each column.
            
        Returns:
            List[torch.Tensor]: List of transformed tensors.
        """
        self.validate_batch(columns, configs)
        # Use the batch implementation if available, otherwise fall back to sequential processing
        if hasattr(self, '_vectorize_batch'):
            return self._vectorize_batch(columns, configs)
        return [self._vectorize(column, config) for column, config in zip(columns, configs)]
    
    def _vectorize_batch(self, columns: List[pd.Series], configs: List[Dict]) -> List[torch.Tensor]:
        """
        Batch implementation for transforming multiple columns into vectors.
        Subclasses should override this method for efficient batch processing.
        
        Args:
            columns (List[pandas.Series]): The columns to transform.
            configs (List[dict]): The configuration dictionaries for each column.
            
        Returns:
            List[torch.Tensor]: List of transformed tensors.
        """
        # Default implementation processes columns sequentially
        # Subclasses should override with more efficient batch processing
        return [self._vectorize(column, config) for column, config in zip(columns, configs)]

    def _vectorize(self, column, config):
        raise NotImplementedError("_vectorize() must be implemented in subclasses.")

    def _compute_loss(self, reconstructed_values, target_column, config):
        """
        Compute the loss between reconstructed values and target column.
        Should be implemented by subclasses based on their specific needs.

        Args:
            reconstructed_values (torch.Tensor): The reconstructed values/logits
            target_column (pd.Series): The target column to compare against

        Returns:
            torch.Tensor: The computed loss
        """
        raise NotImplementedError("_compute_loss() must be implemented in subclasses.")
        
    def _compute_batch_loss(self, reconstructed_values_list: List, target_columns: List[pd.Series], 
                           configs: List[Dict]) -> torch.Tensor:
        """
        Compute the combined loss for multiple columns.
        
        Args:
            reconstructed_values_list (List): List of reconstructed values for each column
            target_columns (List[pd.Series]): List of target columns to compare against
            configs (List[Dict]): List of configuration dictionaries
            
        Returns:
            torch.Tensor: The combined loss from all columns
        """
        losses = []

        for reconstructed, target, config in zip(reconstructed_values_list, target_columns, configs):
            losses.append(self._compute_loss(reconstructed, target, config))
            
        # Combine individual losses (default: sum them)
        return sum(losses)

    def inverse_vectorize(self, tensor, config, mode='inference', target_column=None):
        """
        Reverse the vector transformation to the original values.

        Args:
            tensor (torch.Tensor): The tensor to inverse transform.
            config (dict): The configuration dictionary for the transformation.
            mode (str): 'inference' or 'train'. If 'inference', reconstruct original values in pandas.series.
                       If 'train', return reconstructed values and computed loss.
            target_column (pd.Series, optional): Required if mode='train'. Target values for loss computation.

        Returns:
            Union[pd.Series, Tuple[torch.Tensor, torch.Tensor]]: 
                - If mode='inference': reconstructed pandas Series
                - If mode='train': tuple of (reconstructed_values, loss)

        Raises:
            ValueError: If mode='train' and target_column is not provided or invalid
        """
        if not isinstance(tensor, torch.Tensor):
            raise ValueError("Input tensor must be a torch.Tensor.")
        
        if mode == 'train':
            if target_column is None:
                raise ValueError("target_column must be provided when mode='train'")
            if not isinstance(target_column, pd.Series):
                raise ValueError("target_column must be a pandas Series")

        reconstructed = self._inverse_vectorize(tensor, config, mode)
        
        if mode == 'train':
            try:
                loss = self._compute_loss(reconstructed, target_column, config)
            except Exception as e:
                # Create debug directory if it doesn't exist
                import os
                import json
                from datetime import datetime
                
                debug_dir = "debug_vectorizer"
                os.makedirs(debug_dir, exist_ok=True)
                
                # Generate timestamp for unique filenames
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                
                # Save target column
                target_path = os.path.join(debug_dir, f"target_column_{timestamp}.csv")
                target_column.to_csv(target_path)
                
                # Save config
                config_path = os.path.join(debug_dir, f"config_{timestamp}.json")
                with open(config_path, 'w') as f:
                    # Handle non-serializable objects in config
                    serializable_config = {}
                    for k, v in config.items():
                        try:
                            json.dumps({k: v})
                            serializable_config[k] = v
                        except (TypeError, OverflowError):
                            serializable_config[k] = str(v)
                    json.dump(serializable_config, f, indent=2)
                
                # Also save reconstructed values if they're in a format we can save
                try:
                    if isinstance(reconstructed, (pd.Series, pd.DataFrame)):
                        reconstructed_path = os.path.join(debug_dir, f"reconstructed_{timestamp}.csv")
                        reconstructed.to_csv(reconstructed_path)
                    elif isinstance(reconstructed, dict):
                        reconstructed_path = os.path.join(debug_dir, f"reconstructed_{timestamp}.json")
                        with open(reconstructed_path, 'w') as f:
                            json.dump({k: str(v) for k, v in reconstructed.items()}, f, indent=2)
                except Exception as save_error:
                    print(f"Could not save reconstructed values: {save_error}")
                
                # Add debug info to the error message
                raise type(e)(
                    f"{str(e)}\nDebug files saved to {debug_dir}:\n"
                    f"- Target column: {target_path}\n"
                    f"- Config: {config_path}"
                ) from e
                
            return reconstructed, loss
        
        return reconstructed
        
    def inverse_vectorize_batch(self, tensors: List[torch.Tensor], configs: List[Dict], 
                               mode='inference', target_columns: Optional[List[pd.Series]]=None):
        """
        Batch reverse transformation of vectors to original values.
        
        Args:
            tensors (List[torch.Tensor]): The tensors to inverse transform.
            configs (List[dict]): The configuration dictionaries for transformations.
            mode (str): 'inference' or 'train'. If 'inference', reconstruct original values.
                       If 'train', also return computed loss.
            target_columns (List[pd.Series], optional): Required if mode='train'. Target columns for loss computation.
            
        Returns:
            Union[List[pd.Series], Tuple[List, torch.Tensor]]: 
                - If mode='inference': list of reconstructed pandas Series
                - If mode='train': tuple of (list of reconstructed values, combined loss)
                
        Raises:
            ValueError: If inputs are invalid or don't match.
        """
        if len(tensors) != len(configs):
            raise ValueError(f"Number of tensors ({len(tensors)}) must match number of configs ({len(configs)})")
            
        for tensor in tensors:
            if not isinstance(tensor, torch.Tensor):
                raise ValueError("All input tensors must be torch.Tensor objects")
                
        if mode == 'train':
            if target_columns is None:
                raise ValueError("target_columns must be provided when mode='train'")
            if len(target_columns) != len(tensors):
                raise ValueError(f"Number of target columns ({len(target_columns)}) must match number of tensors ({len(tensors)})")
            if not all(isinstance(col, pd.Series) for col in target_columns):
                raise ValueError("All target columns must be pandas Series")
        
        # Use the batch implementation if available, otherwise process sequentially
        try:
            reconstructed_values = self._inverse_vectorize_batch(tensors, configs, mode)
        except NotImplementedError:
            # Fall back to sequential processing
            reconstructed_values = []
            for i, (tensor, config) in enumerate(zip(tensors, configs)):
                try:
                    reconstructed = self._inverse_vectorize(tensor, config, mode)
                    reconstructed_values.append(reconstructed)
                except Exception as e:
                    raise ValueError(f"Error processing tensor at index {i}: {str(e)}")
                    
        if mode == 'train':
            try:
                combined_loss = self._compute_batch_loss(reconstructed_values, target_columns, configs)
                return reconstructed_values, combined_loss
            except Exception as e:
                raise ValueError(f"Error computing batch loss: {str(e)}")
        
        return reconstructed_values
    
    def _inverse_vectorize_batch(self, tensors: List[torch.Tensor], configs: List[Dict], mode: str) -> List:
        """
        Batch implementation for inverse transformation of vectors to original values.
        Subclasses should override this method for efficient batch processing.
        
        Args:
            tensors (List[torch.Tensor]): The tensors to inverse transform.
            configs (List[dict]): The configuration dictionaries for transformations.
            mode (str): 'inference' or 'train'
            
        Returns:
            List: List of reconstructed values for each tensor
        """
        # Default implementation processes tensors sequentially
        # Subclasses should override with more efficient batch processing
        return [self._inverse_vectorize(tensor, config, mode) for tensor, config in zip(tensors, configs)]

    def _inverse_vectorize(self, tensor, config, mode):
        raise NotImplementedError("_inverse_vectorize() must be implemented in subclasses.")
