"""Tabular preprocessing utilities compatible with TabICL.

Provides a small wrapper that, when available, uses TabICL's
PreprocessingPipeline to normalize features (standard scale + normalization
and outlier handling). It also standardizes regression targets (y).

Designed to work with numpy arrays or torch tensors and with DataAttr batches
that include context (xc, yc), buffer (xb, yb), and targets (xt, yt).
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np

try:
    # Prefer TabICL's preprocessing pipeline if available
    from tabicl.sklearn.preprocessing import PreprocessingPipeline as _TabICLPP  # type: ignore
    _HAS_TABICL = True
except Exception:  # pragma: no cover - optional dependency
    _TabICL = None
    _HAS_TABICL = False

try:
    import torch
    _HAS_TORCH = True
except Exception:  # pragma: no cover
    torch = None  # type: ignore
    _HAS_TORCH = False


def _to_numpy(x):
    if _HAS_TORCH and isinstance(x, torch.Tensor):  # type: ignore[attr-defined]
        return x.detach().cpu().numpy()
    return np.asarray(x)


def _to_type(x_np: np.ndarray, like):
    if _HAS_TORCH and isinstance(like, torch.Tensor):  # type: ignore[attr-defined]
        return torch.from_numpy(x_np).to(device=like.device, dtype=like.dtype)
    return x_np


@dataclass
class FitStats:
    x_fitted: bool
    y_mean: Optional[np.ndarray] = None
    y_std: Optional[np.ndarray] = None


class TabICLScaler:
    """Feature/target scaler aligned with TabICL preprocessing semantics.

    - X: Fit TabICL PreprocessingPipeline (if available) on training/context features.
         Otherwise, fall back to simple standard scaling using mean/std.
    - y: Standardize using mean/std computed from training/context targets.
    """

    def __init__(
        self,
        normalization_method: str = "power",
        outlier_threshold: float = 4.0,
        random_state: Optional[int] = None,
    ) -> None:
        self.normalization_method = normalization_method
        self.outlier_threshold = outlier_threshold
        self.random_state = random_state

        self.x_pipeline = None
        self.x_mean: Optional[np.ndarray] = None
        self.x_std: Optional[np.ndarray] = None
        self.y_mean: Optional[np.ndarray] = None
        self.y_std: Optional[np.ndarray] = None

    def fit_x(self, X: np.ndarray) -> None:
        Xn = _to_numpy(X)
        if _HAS_TABICL:
            try:
                self.x_pipeline = _TabICLPP(  # type: ignore[name-defined]
                    normalization_method=self.normalization_method,
                    outlier_threshold=self.outlier_threshold,
                    random_state=self.random_state,
                )
                self.x_pipeline.fit(Xn)
            except Exception:
                # Fallback to simple standard scaling if TabICL pipeline/Deps not available
                self.x_pipeline = None
                self.x_mean = Xn.mean(axis=0, keepdims=True)
                self.x_std = Xn.std(axis=0, keepdims=True)
                self.x_std[self.x_std == 0] = 1.0
        else:
            # Simple standard scaling fallback
            self.x_mean = Xn.mean(axis=0, keepdims=True)
            self.x_std = Xn.std(axis=0, keepdims=True)
            self.x_std[self.x_std == 0] = 1.0

    def transform_x(self, X: np.ndarray):
        Xn = _to_numpy(X)
        if self.x_pipeline is not None:
            Xo = self.x_pipeline.transform(Xn)
        else:
            if self.x_mean is None or self.x_std is None:
                raise RuntimeError("TabICLScaler.transform_x called before fit_x")
            Xo = (Xn - self.x_mean) / self.x_std
        return _to_type(Xo, X)

    def fit_y(self, y: np.ndarray) -> None:
        yn = _to_numpy(y)
        # y can be (..., Dy). Compute per-dimension stats
        self.y_mean = yn.mean(axis=0, keepdims=True)
        self.y_std = yn.std(axis=0, keepdims=True)
        self.y_std[self.y_std == 0] = 1.0

    def transform_y(self, y: np.ndarray):
        yn = _to_numpy(y)
        if self.y_mean is None or self.y_std is None:
            raise RuntimeError("TabICLScaler.transform_y called before fit_y")
        yo = (yn - self.y_mean) / self.y_std
        return _to_type(yo, y)

    def fit(self, X_context: np.ndarray, y_context: Optional[np.ndarray] = None) -> FitStats:
        self.fit_x(X_context)
        if y_context is not None:
            self.fit_y(y_context)
        return FitStats(x_fitted=True, y_mean=self.y_mean, y_std=self.y_std)

    def transform_xy(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        Xo = self.transform_x(X)
        yo_out = None
        if y is not None:
            yo_out = self.transform_y(y)
        return Xo, yo_out

    # Integration with DataAttr batches -------------------------------------
    def fit_from_batch(self, batch) -> FitStats:
        """Fit using context statistics from a DataAttr-like batch."""
        if getattr(batch, "xc", None) is None:
            raise ValueError("Batch has no xc for fitting")
        Xc = _to_numpy(batch.xc)
        self.fit_x(Xc)
        if getattr(batch, "yc", None) is not None:
            self.fit_y(_to_numpy(batch.yc))
        return FitStats(x_fitted=True, y_mean=self.y_mean, y_std=self.y_std)

    def transform_batch(self, batch):
        """Return a new batch with scaled xc/xb/xt and yc/yb/yt (when present)."""
        from src.utils import DataAttr  # lazy import to avoid cycles

        def tX(x):
            return None if x is None else self.transform_x(x)

        def tY(y):
            return None if y is None else self.transform_y(y)

        return DataAttr(
            xc=tX(batch.xc),
            yc=tY(batch.yc),
            xb=tX(getattr(batch, "xb", None)),
            yb=tY(getattr(batch, "yb", None)),
            xt=tX(getattr(batch, "xt", None)),
            yt=tY(getattr(batch, "yt", None)),
        )
