import pandas as pd
import numpy as np

from .transformers import *

# Data Transformer
class DataTransformer:
    def __init__(self, col_types, transformer_mapping=None, transformer_config=None):
        self.col_types = col_types
        # Store transformer classes, not instances
        self.transformer_mapping = transformer_mapping or {
            #"numerical": NumericalQuantileTransformer,
            "numerical": PiecewiseLinearEncoderColumn,
            "categorical": StringTransformer,
            "datetime": DatetimeTransformer,
            "text": StringTransformer,
            "unknown": StringTransformer,
        }
        # Store column-specific configurations
        self.transformer_config = transformer_config or {}
        # Will store instantiated transformers for each column
        self.column_transformers = {}

    def fit_transform(self, df):
        self.fit(df)
        return self.transform(df)

    def fit(self, df):
        for col_name, col_type in self.col_types.items():
            if col_name in df.columns:
                # Get the transformer class
                transformer_class = None
                if col_name in self.transformer_mapping:
                    transformer_class = self.transformer_mapping.get(col_name)
                else:
                    transformer_class = self.transformer_mapping.get(col_type)
                
                if transformer_class:
                    # Create a new instance with column-specific config if available
                    if col_name in self.transformer_config:
                        # Pass column-specific config to instantiate the transformer
                        config = self.transformer_config.get(col_name, {})
                        transformer = transformer_class(**config)
                    else:
                        # Create with default parameters
                        transformer = transformer_class()
                    
                    # Fit and store the transformer
                    transformer.fit(df[col_name])
                    self.column_transformers[col_name] = transformer

    def transform(self, df):
        transformed_data = {}
        for col_name, col_type in self.col_types.items():
            if col_name in df.columns and col_name in self.column_transformers:
                transformer = self.column_transformers[col_name]
                transformed_data[col_name] = transformer.transform(df[col_name])
        return pd.DataFrame(transformed_data, index=df.index)

    def inverse_transform(self, df):
        inverse_data = {}
        for col_name, col_type in self.col_types.items():
            if col_name in df.columns and col_name in self.column_transformers:
                transformer = self.column_transformers[col_name]
                inverse_data[col_name] = transformer.inverse_transform(df[col_name])
        return pd.DataFrame(inverse_data, index=df.index)

    @classmethod
    def from_config(cls, config, transformer_mapping=None):
        """
        Creates a DataTransformer instance from a config dictionary.
        
        Parameters:
        - config (dict): Configuration dictionary containing variable information
        - transformer_mapping (dict, optional): Custom transformer mapping
        
        Returns:
        - DataTransformer: Configured transformer instance
        """
        col_types = {}
        transformer_config = {}
        column_transformers = {}
        
        # Create default type-based transformer mapping if not provided
        if transformer_mapping is None:
            type_transformers = {
                "numerical": PiecewiseLinearEncoderColumn,
                "categorical": StringTransformer,
                "datetime": DatetimeTransformer,
                "text": StringTransformer,
                "unknown": StringTransformer,
            }
        else:
            type_transformers = transformer_mapping
        
        # Extract column types from config and create column-specific configs
        for var_config in config.get('variables', []):
            var_name = var_config.get('variable_name')
            var_type = var_config.get('variable_type', '').lower()
            
            if var_name and var_type:
                col_types[var_name] = var_type
                
                # Store pre-fitted transformers separately
                if var_type == 'numerical':
                    # Store configuration parameters for this column
                    if 'quantile_params' in var_config:
                        # Create a pre-configured transformer
                        numerical_transformer = NumericalQuantileTransformer()
                        numerical_transformer.set_fit({'quantile_params': var_config.get('quantile_params')})
                        column_transformers[var_name] = numerical_transformer
                    elif 'ple_params' in var_config:
                        # Create a pre-configured transformer
                        ple_transformer = PiecewiseLinearEncoderColumn()
                        ple_transformer.set_fit(var_config.get('ple_params'))
                        column_transformers[var_name] = ple_transformer
                else:
                    column_transformers[var_name] = type_transformers[var_type]()
        
        # Create transformer instance
        dt = cls(col_types, type_transformers, transformer_config)
        # Set pre-configured transformers
        #print("column_transformers in from_config:", column_transformers.keys())
        dt.column_transformers = column_transformers
        
        return dt