"""
A small, practical scaler toolkit with:
- StreamingStandardScaler (Welford) and StreamingMinMaxScaler
- A CompositeScaler to chain multiple transformers (streaming & batch)
- A set of convenient aliases/pipelines (log1p, yj, std, robust, minmax, etc.)

Design goals:
- Safe streaming updates (partial_fit) with lazy finalize in transform()
- Invertible round-trips when each step supports inverse_transform
- Fewer surprises in mixed pipelines (stateless + streaming + batch-only)
"""

from typing import List
import numpy as np

from sklearn import preprocessing
from sklearn.preprocessing import (
    FunctionTransformer, PowerTransformer,
    StandardScaler, RobustScaler, MinMaxScaler
)

from sklearn.pipeline import Pipeline
from sklearn.base import clone


def _as_2d(X: np.ndarray) -> np.ndarray:
    """
    Ensure 2D shape (n_samples, n_features).
    """

    X = np.asarray(X)
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    return X


def identity(x):
    return x

def signed_log(x):
    return np.sign(x) * np.log1p(np.abs(x))

def signed_explog(x):
    return np.sign(x) * np.expm1(np.abs(x))


class StreamingStandardScaler:
    """
    Standardization in streaming/online fashion (Welford's algorithm).

    - partial_fit(X): update running mean and M2 (sum of squared diffs)
    - finalize(): compute scale from the accumulated moments
    - transform()/inverse_transform(): lazily call finalize() if needed

    Notes
    -----
    * Internal state uses float64 (O(d) memory), for numerical stability.
    * Output dtype is configurable via dtype_out (default float32).
    * eps keeps scale strictly positive; this also makes the mapping
      trivially invertible even for constant features.
    """

    def __init__(self, eps: float = 1e-6, dtype_out=np.float64):
        self.n = 0
        self.mean = None
        self.M2 = None
        self.eps = float(eps)
        self.dtype_out = dtype_out

    def partial_fit(self, X):
        X = _as_2d(np.asarray(X, dtype=np.float64))
        if self.mean is None:
            self.mean = np.zeros(X.shape[1], dtype=np.float64)
            self.M2   = np.zeros(X.shape[1], dtype=np.float64)
        for x in X:
            self.n += 1
            delta = x - self.mean
            self.mean += delta / self.n
            delta2 = x - self.mean
            self.M2 += delta * delta2
        return self

    def finalize(self):
        if self.mean is None:
            raise RuntimeError("StreamingStandardScaler has not seen any data (partial_fit).")
        if self.n > 1:
            var = self.M2 / (self.n - 1)
        else:
            var = np.zeros_like(self.mean)
        self.scale = np.sqrt(var) + self.eps

        self.mean_ = self.mean.copy()
        self.var_  = var.copy()
        self.scale_ = self.scale.copy()
        self.n_samples_seen_ = int(self.n)
        return self

    def summary(self, head=6):
        def _head(a):
            a = np.asarray(a)
            return a.tolist() if a.size <= head else {"shape": list(a.shape), "head": a.flat[:head].tolist()}
        self._ensure_ready()
        return {
            "type": "StreamingStandardScaler",
            "n_samples_seen_": self.n_samples_seen_,
            "mean_": _head(self.mean_),
            "var_":  _head(self.var_),
            "scale_": _head(self.scale_),
            "eps": self.eps,
            "dtype_out": str(self.dtype_out),
        }

    def __repr__(self):
        s = self.summary()
        return (f"StreamingStandardScaler(n={s['n_samples_seen_']}, "
                f"mean={s['mean_']}, scale={s['scale_']}, eps={s['eps']})")

    def _ensure_ready(self):
        if getattr(self, "scale", None) is None:
            self.finalize()

    def transform(self, X):
        self._ensure_ready()
        X = _as_2d(np.asarray(X, dtype=np.float64))
        Y = (X - self.mean) / self.scale
        return Y.astype(self.dtype_out, copy=False)

    def inverse_transform(self, X):
        self._ensure_ready()
        X = _as_2d(np.asarray(X, dtype=np.float64))
        Y = X * self.scale + self.mean
        return Y.astype(self.dtype_out, copy=False)


class StreamingMinMaxScaler:
    """
    MinMax scaling in streaming/online fashion (running min/max).

    - partial_fit(X): update feature-wise min_ and max_
    - finalize(): compute scale and offset into the target feature_range
    - transform()/inverse_transform(): lazily call finalize() if needed
    """

    def __init__(self, feature_range=(0., 1.), eps: float = 1e-6, dtype_out=np.float64):
        low, high = feature_range
        if not (isinstance(low, (int, float)) and isinstance(high, (int, float)) and high > low):
            raise ValueError("feature_range must be a pair (low < high).")
        self.min_ = None
        self.max_ = None
        self.fr = (float(low), float(high))
        self.eps = float(eps)
        self.dtype_out = dtype_out

    def partial_fit(self, X):
        X = _as_2d(np.asarray(X, dtype=np.float64))
        mn, mx = np.nanmin(X, axis=0), np.nanmax(X, axis=0)
        if self.min_ is None:
            self.min_, self.max_ = mn, mx
        else:
            self.min_ = np.minimum(self.min_, mn)
            self.max_ = np.maximum(self.max_, mx)
        return self

    def finalize(self):
        if self.min_ is None or self.max_ is None:
            raise RuntimeError("StreamingMinMaxScaler has not seen any data (partial_fit).")
        data_range = self.max_ - self.min_
        fr_low, fr_high = self.fr
        self.scale = (fr_high - fr_low) / (data_range + self.eps)
        self.min_r = fr_low - self.min_ * self.scale

        self.data_min_   = self.min_.copy()
        self.data_max_   = self.max_.copy()
        self.data_range_ = data_range.copy()
        self.scale_      = self.scale.copy()
        self.min__       = self.min_r.copy()
        self.feature_range = self.fr
        return self

    def summary(self, head=6):
        def _head(a):
            a = np.asarray(a)
            return a.tolist() if a.size <= head else {"shape": list(a.shape), "head": a.flat[:head].tolist()}
        self._ensure_ready()
        return {
            "type": "StreamingMinMaxScaler",
            "feature_range": tuple(self.feature_range),
            "data_min_": _head(self.data_min_),
            "data_max_": _head(self.data_max_),
            "data_range_": _head(self.data_range_),
            "scale_": _head(self.scale_),
            "offset(min_)": _head(self.min__),
            "eps": self.eps,
            "dtype_out": str(self.dtype_out),
        }

    def __repr__(self):
        s = self.summary()
        return (f"StreamingMinMaxScaler(range={s['feature_range']}, "
                f"data_min={s['data_min_']}, data_max={s['data_max_']})")

    def _ensure_ready(self):
        if getattr(self, "scale", None) is None or getattr(self, "min_r", None) is None:
            self.finalize()

    def transform(self, X):
        self._ensure_ready()
        X = _as_2d(np.asarray(X, dtype=np.float64))
        Y = X * self.scale + self.min_r
        return Y.astype(self.dtype_out, copy=False)

    def inverse_transform(self, X):
        self._ensure_ready()
        X = _as_2d(np.asarray(X, dtype=np.float64))
        Y = (X - self.min_r) / self.scale
        return Y.astype(self.dtype_out, copy=False)


SCALER_PIPELINES = {
    # Single-step aliases (stateless or single estimator)
    'identity': Pipeline([
        ('id', FunctionTransformer(identity, inverse_func=identity))
    ]),
    'log1p': Pipeline([
        ('log1p', FunctionTransformer(np.log1p, inverse_func=np.expm1))
    ]),
    'signedlog': Pipeline([
        ('slog', FunctionTransformer(signed_log, inverse_func=signed_explog))
    ]),
    'yj': Pipeline([
        ('yj', PowerTransformer(method='yeo-johnson', standardize=True))
    ]),
    'std': Pipeline([
        ('std', StandardScaler())
    ]),
    'robust': Pipeline([
        ('robust', RobustScaler())
    ]),
    'minmax_01': Pipeline([
        ('mm01', MinMaxScaler(feature_range=(0, 1)))
    ]),
    'minmax_pm1': Pipeline([
        ('mmpm1', MinMaxScaler(feature_range=(-1, 1)))
    ]),

    # Common two-step recipes
    'log1p_std': Pipeline([
        ('log1p', FunctionTransformer(np.log1p, inverse_func=np.expm1)),
        ('std',   StandardScaler())
    ]),
    'yj_std': Pipeline([
        ('yj',    PowerTransformer(method='yeo-johnson', standardize=False)),
        ('std',   StandardScaler())
    ]),
    'signedlog_robust': Pipeline([
        ('slog',   FunctionTransformer(signed_log, inverse_func=signed_explog)),
        ('robust', RobustScaler())
    ]),
    'asinh': Pipeline([
        ('asinh', FunctionTransformer(lambda x: np.arcsinh(x), inverse_func=lambda x: np.sinh(x)))
    ])
}
# asinh = FunctionTransformer(lambda x: np.arcsinh(x), inverse_func=lambda x: np.sinh(x))

def _expand_steps(obj) -> List[object]:
    """
    If obj is a Pipeline, return cloned step estimators; otherwise return [obj].
    """
    if isinstance(obj, Pipeline):
        return [clone(est) for _, est in obj.steps]
    return [clone(obj)]


class CompositeScaler:
    """
    Compose a list of transformers/scalers into a single object.

    Supports:
      * Streaming steps: have partial_fit() (e.g. StreamingStandardScaler, MinMaxScaler, StandardScaler)
      * Stateless steps: can transform without fit (e.g. FunctionTransformer without learned params)
      * Batch-only steps: require fit() (e.g. PowerTransformer, RobustScaler)

    Behavior:
      * partial_fit(X): walk through the chain, updating any streaming steps.
        - Stateless steps are applied on-the-fly.
        - If a batch-only step is encountered, we stop (cannot safely transform further).
      * fit(X): fully fits the entire chain.
      * fit_transform(X): same as fit(), then returns transform(X) result.
      * transform(X): apply all steps in order.
      * inverse_transform(X): apply inverse steps in reverse order (skip if not available).
    """

    def __init__(self, scalers: List[object]):
        self.scalers: List[object] = scalers

    def _is_stateless(self, s) -> bool:
        return isinstance(s, FunctionTransformer)

    def partial_fit(self, X):
        Y = _as_2d(np.asarray(X))
        for s in self.scalers:
            if hasattr(s, "partial_fit"):
                s.partial_fit(Y)
                if hasattr(s, "finalize"):
                    s.finalize()
                if hasattr(s, "transform"):
                    Y = s.transform(Y)
            elif self._is_stateless(s) and hasattr(s, "transform"):
                Y = s.transform(Y)
            else:
                # Batch-only step ahead: cannot proceed further in streaming mode.
                break
        return self

    def fit(self, X, y=None):
        Y = _as_2d(np.asarray(X))
        for s in self.scalers:
            if hasattr(s, "partial_fit"):
                s.partial_fit(Y)
                if hasattr(s, "finalize"):
                    s.finalize()
                Y = s.transform(Y) if hasattr(s, "transform") else Y
            else:
                s.fit(Y, y) if hasattr(s, "fit") else None
                Y = s.transform(Y) if hasattr(s, "transform") else Y
        return self

    def fit_transform(self, X, y=None):
        Y = _as_2d(np.asarray(X))
        for s in self.scalers:
            if hasattr(s, "partial_fit"):
                s.partial_fit(Y)
                if hasattr(s, "finalize"):
                    s.finalize()
                Y = s.transform(Y) if hasattr(s, "transform") else Y
            else:
                if hasattr(s, "fit_transform"):
                    Y = s.fit_transform(Y, y)
                else:
                    s.fit(Y, y) if hasattr(s, "fit") else None
                    Y = s.transform(Y) if hasattr(s, "transform") else Y
        return Y

    def transform(self, X):
        Y = _as_2d(np.asarray(X))
        for s in self.scalers:
            if not hasattr(s, "transform"):
                raise RuntimeError(f"{s.__class__.__name__} has no transform().")
            Y = s.transform(Y)
        return Y

    def inverse_transform(self, X):
        Y = _as_2d(np.asarray(X))
        for s in reversed(self.scalers):
            if hasattr(s, "inverse_transform"):
                Y = s.inverse_transform(Y)
        return Y

    def roundtrip_error(self, X) -> float:
        X = _as_2d(np.asarray(X))
        Xs = self.transform(X)
        Xb = self.inverse_transform(Xs)
        return float(np.max(np.abs(X - Xb)))

    def summary(self, head=4):
        def _head(a):
            a = np.asarray(a)
            return a.tolist() if a.size <= head else {"shape": list(a.shape), "head": a.flat[:head].tolist()}

        comps = []
        for s in self.scalers:
            entry = {"type": s.__class__.__name__}
            if hasattr(s, "summary"):
                entry.update(s.summary(head=head))
            else:
                for k in ("mean_", "var_", "scale_", "data_min_", "data_max_", "n_samples_seen_", "feature_range"):
                    if hasattr(s, k):
                        v = getattr(s, k)
                        try:
                            entry[k] = _head(v)
                        except Exception:
                            entry[k] = str(v)
            comps.append(entry)
        return {"CompositeScaler": comps}

    def __repr__(self):
        names = " + ".join(s.__class__.__name__ for s in self.scalers)
        return f"CompositeScaler[{names}]"


def build_composite_scaler(name: str) -> CompositeScaler:
    """
    Build a CompositeScaler from a '+'-separated string.

    Examples
    --------
    'log1p+StreamingStandardScaler'
    'log1p_std'                      # alias: log1p + StandardScaler

    Rules
    -----
    - If token matches SCALER_PIPELINES: expand its steps.
    - If token is 'StreamingStandardScaler' / 'StreamingMinMaxScaler': create that instance.
    - Otherwise, treat token as a class name under sklearn.preprocessing (e.g., 'StandardScaler').
    """

    parts = [p.strip() for p in name.split('+') if p.strip()]
    scalers: List[object] = []
    for p in parts:
        if p in SCALER_PIPELINES:
            scalers.extend(_expand_steps(SCALER_PIPELINES[p]))
        elif p == "StreamingStandardScaler":
            scalers.append(StreamingStandardScaler())
        elif p == "StreamingMinMaxScaler":
            scalers.append(StreamingMinMaxScaler())
        else:
            try:
                cls = getattr(preprocessing, p)
            except AttributeError as e:
                raise ValueError(f"Unknown scaler token '{p}'. "
                                 f"Available aliases: {list(SCALER_PIPELINES.keys())} "
                                 f"or any sklearn.preprocessing class name.") from e
            scalers.append(cls())
    return CompositeScaler(scalers)
