import numpy as np
from sklearn.preprocessing import QuantileTransformer as SklearnQuantileTransformer
import pandas as pd

from .base import BaseColumnTransformer


# Numerical Transformer
class NumericalTransformer(BaseColumnTransformer):
    def __init__(self, embedding_dim=16, soft_bins=10):
        self.embedding_dim = embedding_dim
        self.soft_bins = soft_bins

    def _transform(self, X):
        transformed_data = []
        for value in X:
            sign = np.sign(value)
            mantissa, exponent = np.frexp(value)
            exponent_embedding = self._embed_discrete_value(exponent, -127, 127)
            sign_embedding = self._embed_discrete_value(sign, -1, 1)
            fractional_embedding = self._soft_bin(mantissa)
            transformed_data.append(sign_embedding + exponent_embedding + fractional_embedding)
        return transformed_data

    def _inverse_transform(self, X):
        # Implement inverse transformation logic
        raise NotImplementedError("NumericalTransformer inverse_transform is not implemented")

    def _embed_discrete_value(self, value, min_value, max_value):
        embed = np.zeros(self.embedding_dim)
        idx = int((value - min_value) / (max_value - min_value) * (self.embedding_dim - 1))
        embed[idx] = 1
        return embed.tolist()

    def _soft_bin(self, mantissa):
        bin_idx = np.linspace(1, 2, self.soft_bins)
        mantissa_embedding = np.exp(-np.square(bin_idx - mantissa) / 0.1)
        return mantissa_embedding.tolist()

class NumericalQuantileTransformer(BaseColumnTransformer):
    def __init__(self, n_quantiles=1000, output_distribution='uniform', random_state=None):
        """
        Initialize the NumericalQuantileTransformer.

        Args:
            n_quantiles (int): Number of quantiles to be computed.
            output_distribution (str): Marginal distribution for the transformed data. Options are 'uniform' or 'normal'.
            random_state (int, RandomState instance, or None): Determines random number generation for subsampling.
        """
        self.n_quantiles = n_quantiles
        self.output_distribution = output_distribution
        self.random_state = random_state
        self.quantile_transformer = SklearnQuantileTransformer(
            n_quantiles=self.n_quantiles,
            output_distribution=self.output_distribution,
            random_state=self.random_state
        )

    def fit(self, X):
        """
        Fit the quantile transformer to the data.

        Args:
            X (array-like): The data to fit.
        """
        # Reshape to 2D array if needed
        X_2d = np.array(X).reshape(-1, 1)
        self.quantile_transformer.fit(X_2d)
        return self

    def _transform(self, X):
        """
        Transform the data using the fitted quantile transformer.

        Args:
            X (array-like): The data to transform.

        Returns:
            array-like: Transformed data.
        """
        X_2d = np.array(X).reshape(-1, 1)
        transformed = self.quantile_transformer.transform(X_2d)
        return transformed.ravel()

    def _inverse_transform(self, X):
        """
        Inverse transform the data back to the original distribution.

        Args:
            X (array-like): The data to inverse transform.

        Returns:
            array-like: Inverse transformed data.
        """
        X_2d = np.array(X).reshape(-1, 1)
        inverse_transformed = self.quantile_transformer.inverse_transform(X_2d)
        return inverse_transformed.ravel()

    def set_fit(self, config):
        """
        Set the fitted parameters from config.

        Args:
            config (dict): Configuration containing fitted parameters for quantile transformer
        """
        loaded_params = config['quantile_params']
        self.quantile_transformer = SklearnQuantileTransformer(
            output_distribution=loaded_params['output_distribution'],
            random_state=loaded_params['random_state']
        )
        self.quantile_transformer.n_quantiles_ = loaded_params['n_quantiles']
        self.quantile_transformer.quantiles_ = np.array(loaded_params['quantiles'])
        self.quantile_transformer.references_ = np.array(loaded_params['references'])
        self.quantile_transformer.n_features_in_ = 1
        return self
    
    def get_config(self):
        """
        Return a JSON-serializable config containing fitted parameters.
        
        Returns:
            dict: Configuration containing fitted parameters for quantile transformer
        """
        if not hasattr(self.quantile_transformer, 'n_quantiles_'):
            raise RuntimeError("The transformer is not fitted yet.")
            
        return {
            'quantile_params': {
                'output_distribution': self.quantile_transformer.output_distribution,
                'random_state': self.quantile_transformer.random_state,
                'n_quantiles': self.quantile_transformer.n_quantiles_,
                'quantiles': self.quantile_transformer.quantiles_.tolist(),
                'references': self.quantile_transformer.references_.tolist(),
            }
        }

# ple_column_transformer.py
from typing import Sequence, Union, List
import numpy as np
import pandas as pd


class PiecewiseLinearEncoderColumn(BaseColumnTransformer):
    """
    Column transformer that implements the Piecewise Linear Encoding (PLE)
    described in:
        Gorishniy et al., "On Embeddings for Numerical Features in Tabular Deep Learning",
        NeurIPS 2022.

    Parameters
    ----------
    n_bins : int, default=10
        Number of bins (T in the paper).  Must be ≥ 1.
    strategy : {"quantile", "uniform"}, default="quantile"
        - "quantile": bin edges are empirical quantiles (equal-frequency bins).  
        - "uniform": bin edges are spaced uniformly between min and max.
    clip_outside : bool, default=True
        If True, values falling outside the fitted [min, max] range are clipped to
        that range before encoding.  This preserves the "loss-less" property for
        in-range data while preventing extreme extrapolation.
    """

    def __init__(
        self,
        n_bins: int = 32,
        strategy: str = "quantile",
        clip_outside: bool = True,
    ):
        if n_bins < 1:
            raise ValueError("n_bins must be ≥ 1")
        if strategy not in {"quantile", "uniform"}:
            raise ValueError("strategy must be 'quantile' or 'uniform'")
        self.n_bins = n_bins
        self.strategy = strategy
        self.clip_outside = clip_outside
        # fitted attributes
        self.bin_edges_: np.ndarray | None = None  # shape (n_bins + 1,)

    # --------------------------------------------------------------------- #
    # BaseColumnTransformer public API                                      #
    # --------------------------------------------------------------------- #
    def fit(self, X: Union[pd.Series, Sequence[float], np.ndarray]):
        """Compute bin edges from a 1-D numerical column `X`."""
        x = self._to_numpy_1d(X)
        if self.strategy == "quantile":
            # include 0 and 1 so we get n_bins + 1 edges
            q = np.linspace(0.0, 1.0, self.n_bins + 1)
            self.bin_edges_ = np.quantile(x, q, method="linear")
        else:  # uniform
            self.bin_edges_ = np.linspace(x.min(), x.max(), self.n_bins + 1)
        # Ensure strictly monotonically increasing edges (merge duplicates)
        eps = 1e-12
        for i in range(1, len(self.bin_edges_)):
            if self.bin_edges_[i] - self.bin_edges_[i - 1] < eps:
                self.bin_edges_[i] = self.bin_edges_[i - 1] + eps
        return self

    def set_fit(self, config: dict):
        """Restore a fitted encoder from a config dict produced by `get_config()`."""
        self.n_bins = int(config["n_bins"])
        self.strategy = str(config["strategy"])
        self.clip_outside = bool(config["clip_outside"])
        self.bin_edges_ = np.asarray(config["bin_edges_"], dtype=float)
        return self

    def get_config(self) -> dict:
        """Return a JSON-serialisable dict of fitted parameters."""
        return {
            "n_bins": self.n_bins,
            "strategy": self.strategy,
            "clip_outside": self.clip_outside,
            "bin_edges_": self.bin_edges_.tolist() if self.bin_edges_ is not None else None,
        }

    # --------------------------------------------------------------------- #
    # Internal helpers                                                      #
    # --------------------------------------------------------------------- #
    def _to_numpy_1d(self, X) -> np.ndarray:
        """Convert Series / list / ndarray to 1-D float64 ndarray."""
        if isinstance(X, pd.Series):
            X = X.values
        X = np.asarray(X, dtype=float).ravel()
        return X

    # --------------------------------------------------------------------- #
    # Required abstract method impls                                        #
    # --------------------------------------------------------------------- #
    def _transform(self, X: Union[pd.Series, Sequence[float], np.ndarray]) -> pd.Series:
        """
        Encode each scalar value into a single value where:
        - The integer part represents the bin index
        - The fractional part represents the position within the bin
        """
        if self.bin_edges_ is None:
            raise RuntimeError("The transformer is not fitted yet.")
        x = self._to_numpy_1d(X)

        if self.clip_outside:
            x = np.clip(x, self.bin_edges_[0], self.bin_edges_[-1])

        # Locate bins: rightmost=True gives index of first edge > x ⇒ subtract 1
        idx = np.searchsorted(self.bin_edges_, x, side="right") - 1
        idx = np.clip(idx, 0, self.n_bins - 1)  # safety

        # Compute fractional offsets inside active bins
        left = self.bin_edges_[idx]
        right = self.bin_edges_[idx + 1]
        # Avoid division by zero when bin edges are identical
        bin_width = right - left
        # Replace zero-width bins with small value to avoid division by zero
        mask_zero_width = np.isclose(bin_width, 0)
        if np.any(mask_zero_width):
            # Set a small non-zero width for these bins to avoid division by zero
            bin_width[mask_zero_width] = np.finfo(float).eps
            # For zero-width bins, use 0.5 as the fraction (middle of bin)
            frac = np.where(mask_zero_width, 0.5, (x - left) / bin_width)
        else:
            frac = (x - left) / bin_width  # ∈ [0,1]

        # Combine bin index and fraction into a single value
        transformed = idx + frac
        
        # Return as a pandas Series
        return pd.Series(transformed, index=getattr(X, "index", None))

    def _inverse_transform(
        self, X: Union[pd.Series, Sequence[float], np.ndarray]
    ) -> pd.Series:
        """
        Recover original scalars from the transformed values.
        The integer part is the bin index, the fractional part is the position within the bin.
        """
        if self.bin_edges_ is None:
            raise RuntimeError("The transformer is not fitted yet.")

        # Convert to numpy array
        if isinstance(X, pd.Series):
            transformed = X.values
            idx_out = X.index
        else:
            transformed = np.asarray(X)
            idx_out = None

        # Extract bin index (integer part) and fraction (fractional part)
        idx = np.floor(transformed).astype(int)
        frac = transformed - idx
        
        # Clip indices to valid range as a safety measure
        idx = np.clip(idx, 0, self.n_bins - 1)
        
        # Reconstruct x
        left = self.bin_edges_[idx]
        right = self.bin_edges_[idx + 1]
        x_rec = left + frac * (right - left)
        
        return pd.Series(x_rec, index=idx_out)



def test_quantile():
    # Create toy data
    np.random.seed(42)
    original_data = np.random.normal(100, 15, size=1000)

    # Test case 1: Normal fitting and transformation
    print("Test Case 1: Normal fitting and transformation")
    transformer = NumericalQuantileTransformer(n_quantiles=100, output_distribution='normal')
    transformer.fit(original_data)
    transformed = transformer.transform(original_data)
    inverse_transformed = transformer.inverse_transform(transformed)

    print(f"Original data mean: {original_data.mean():.3f}, std: {original_data.std():.3f}")
    print(f"Transformed data mean: {transformed.mean():.3f}, std: {transformed.std():.3f}")
    print(f"Inverse transformed data mean: {inverse_transformed.mean():.3f}, std: {inverse_transformed.std():.3f}")
    print(f"Max absolute difference after round trip: {np.abs(original_data - inverse_transformed).max():.6f}\n")

    # Test case 2: Setting from config
    print("Test Case 2: Setting from config")
    
    # Instead of creating a new transformer, use the existing one's config
    config = transformer.get_config()
    
    # Create new transformer and set from config
    new_transformer = NumericalQuantileTransformer()
    new_transformer.set_fit(config)
    
    # Test transformation
    transformed_from_config = new_transformer.transform(original_data)
    inverse_transformed_from_config = new_transformer.inverse_transform(transformed_from_config)
    
    original_data_flattend = original_data.reshape(-1)

    print(f"Original data mean: {original_data_flattend.mean():.3f}, std: {original_data_flattend.std():.3f}")
    print(f"Transformed data mean: {transformed_from_config.mean():.3f}, std: {transformed_from_config.std():.3f}")
    print(f"Inverse transformed data mean: {inverse_transformed_from_config.mean():.3f}, std: {inverse_transformed_from_config.std():.3f}")
    print(f"Max absolute difference after round trip: {np.abs(original_data_flattend - inverse_transformed_from_config).max():.6f}")
    print("Original data:", original_data_flattend[:5])
    print("Transformed data:", transformed_from_config[:5])
    print("Inverse transformed data:", inverse_transformed_from_config[:5])


    
def test_piecewise_linear():
    # Create toy data
    np.random.seed(42)
    original_data = np.random.exponential(size=1000)  # Skewed distribution
    #original_data = np.random.normal(100, 15, size=1000)
    df = pd.Series(original_data)
    
    print("\n--- Testing PiecewiseLinearEncoderColumn ---")
    # Test case 1: Quantile strategy
    print("Test Case 1: Quantile strategy")
    ple = PiecewiseLinearEncoderColumn(n_bins=10, strategy="quantile")
    ple.fit(df)
    
    # Transform data
    transformed = ple.transform(df)
    print(f"Original shape: {df.shape}, Transformed shape: {transformed.shape}")
    
    # Analyze the first few values
    first_val = transformed.iloc[0]
    bin_idx = int(np.floor(first_val))
    fraction = first_val - bin_idx
    print(f"First value transformed to: {first_val:.4f} (bin_idx={bin_idx}, fraction={fraction:.4f})")
    
    # Inverse transform
    inverse_transformed = ple.inverse_transform(transformed)
    print(f"Original data mean: {df.mean():.3f}, std: {df.std():.3f}")
    print(f"Inverse transformed data mean: {inverse_transformed.mean():.3f}, std: {inverse_transformed.std():.3f}")
    print(f"Max absolute difference after round trip: {np.abs(df - inverse_transformed).max():.6f}\n")
    
    # Test case 2: Uniform strategy
    print("Test Case 2: Uniform strategy")
    ple_uniform = PiecewiseLinearEncoderColumn(n_bins=10, strategy="uniform")
    ple_uniform.fit(df)
    
    # Transform data
    transformed_uniform = ple_uniform.transform(df)
    inverse_transformed_uniform = ple_uniform.inverse_transform(transformed_uniform)
    
    print(f"Original data mean: {df.mean():.3f}, std: {df.std():.3f}")
    print(f"Inverse transformed data mean: {inverse_transformed_uniform.mean():.3f}, std: {inverse_transformed_uniform.std():.3f}")
    print(f"Max absolute difference after round trip: {np.abs(df - inverse_transformed_uniform).max():.6f}\n")
    
    # Test case 3: Config saving and loading
    print("Test Case 3: Config saving and loading")
    config = ple.get_config()
    print(f"Config keys: {list(config.keys())}")
    
    # Create new transformer from config
    new_ple = PiecewiseLinearEncoderColumn()
    new_ple.set_fit(config)
    
    # Transform using loaded config
    new_transformed = new_ple.transform(df)
    new_inverse = new_ple.inverse_transform(new_transformed)
    
    print(f"Original data mean: {df.mean():.3f}, std: {df.std():.3f}")
    print(f"New inverse transformed mean: {new_inverse.mean():.3f}, std: {new_inverse.std():.3f}")
    print(f"Max absolute difference after config reload: {np.abs(df - new_inverse).max():.6f}")
    print("Original data (first 5):", df.iloc[:5].values)
    print("Transformed data (first 5):", new_transformed.iloc[:5].values)
    print("Inverse transformed data (first 5):", new_inverse.iloc[:5].values)
    
    # Test case 4: Test with numpy array input to inverse_transform
    print("\nTest Case 4: Test with numpy array input")
    # Use values from Test Case 1 transformed data
    array_input = transformed.iloc[:5].values
    array_inverse = ple.inverse_transform(array_input)  # Use the same transformer that created the data
    print(f"Original values: {df.iloc[:5].values}")
    print(f"Transformed values: {array_input}")
    print(f"Array input inverse result: {array_inverse.values}") 
    print(f"Difference: {np.abs(df.iloc[:5].values - array_inverse.values).max():.6f}")
    
    # Test case 5: Manual encoding/decoding verification
    print("\nTest Case 5: Manual encoding/decoding verification")
    # Take a single value to verify manually
    test_value = df.iloc[0]
    print(f"Original test value: {test_value}")
    
    # Find bin manually
    bin_idx = np.searchsorted(ple.bin_edges_, test_value, side="right") - 1
    bin_idx = np.clip(bin_idx, 0, ple.n_bins - 1)
    left = ple.bin_edges_[bin_idx]
    right = ple.bin_edges_[bin_idx + 1]
    frac = (test_value - left) / (right - left)
    manual_encoded = bin_idx + frac
    
    print(f"Manually encoded: {manual_encoded} (bin={bin_idx}, fraction={frac:.4f})")
    print(f"Transformed by PLE: {transformed.iloc[0]}")
    
    # Manually decode
    decoded_bin_idx = int(np.floor(manual_encoded))
    decoded_frac = manual_encoded - decoded_bin_idx
    left = ple.bin_edges_[decoded_bin_idx]
    right = ple.bin_edges_[decoded_bin_idx + 1]
    manual_decoded = left + decoded_frac * (right - left)
    
    print(f"Manually decoded: {manual_decoded}")
    print(f"Inverse transformed by PLE: {inverse_transformed.iloc[0]}")
    print(f"Original vs Manual decoded difference: {abs(test_value - manual_decoded):.10f}")
    print(f"Original vs PLE round-trip difference: {abs(test_value - inverse_transformed.iloc[0]):.10f}")

if __name__ == "__main__":
    #print("=== Testing NumericalQuantileTransformer ===")
    #test_quantile()
    
    print("\n=== Testing PiecewiseLinearEncoderColumn ===")
    test_piecewise_linear()