import numpy as np
import torch
from sklearn.preprocessing import (
    StandardScaler,
    MinMaxScaler,
    QuantileTransformer,
    RobustScaler,
)
from typing import Tuple, Optional, Literal, Union


class DataPreprocessor:
    """
    A class to handle fitting and transforming data using various scalers.
    It encapsulates the scalers for both X and Y, and handles the transformation
    from numpy arrays to PyTorch tensors.
    """

    ScalerObject = Union[
        StandardScaler, MinMaxScaler, QuantileTransformer, RobustScaler
    ]
    ScalerType = Literal[
        "standard", "minmax", "quantile_uniform", "quantile_normal", "robust", "none"
    ]

    def __init__(
        self,
        scaler_x_type: ScalerType = "standard",
        scaler_y_type: ScalerType = "standard",
    ):
        """
        Initializes the preprocessor with specified scaler types.
        Scaler instances are created during the `fit_transform` call.

        Args:
            scaler_x_type: The type of scaler to use for the features (X).
            scaler_y_type: The type of scaler to use for the target (Y).
        """
        self.scaler_x_type = scaler_x_type
        self.scaler_y_type = scaler_y_type
        self.scaler_x: Optional[DataPreprocessor.ScalerObject] = None
        self.scaler_y: Optional[DataPreprocessor.ScalerObject] = None
        self.cut_points_: Optional[torch.Tensor] = None

    def _create_scaler_instance(
        self, scaler_type: ScalerType, n_samples: int
    ) -> Optional[ScalerObject]:
        """Creates a scaler instance based on the type string and number of samples."""
        if scaler_type == "standard":
            return StandardScaler()
        elif scaler_type == "minmax":
            return MinMaxScaler()
        elif scaler_type == "robust":
            return RobustScaler()
        elif scaler_type == "quantile_uniform":
            # n_quantiles must be <= n_samples. Default to 1000 or n_samples if smaller.
            n_quantiles = min(1000, n_samples)
            return QuantileTransformer(
                output_distribution="uniform", n_quantiles=n_quantiles
            )
        elif scaler_type == "quantile_normal":
            n_quantiles = min(1000, n_samples)
            return QuantileTransformer(
                output_distribution="normal", n_quantiles=n_quantiles
            )
        elif scaler_type == "none":
            return None
        else:
            raise ValueError(f"Unknown scaler type: {scaler_type}")

    def fit_transform(
        self, X: np.ndarray, Y: np.ndarray
    ) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]:
        """
        Fits the scalers to the training data and transforms it.

        Args:
            X: The input features (numpy array).
            Y: The target values (numpy array).

        Returns:
            A tuple containing:
            - X_tensor: Scaled features as a PyTorch tensor.
            - Y_tensor: Scaled target as a PyTorch tensor.
            - X_scaled: Scaled features as a numpy array.
            - Y_scaled: Scaled target as a numpy array.
        """
        n_samples_x = X.shape[0]
        n_samples_y = Y.shape[0]
        self.scaler_x = self._create_scaler_instance(self.scaler_x_type, n_samples_x)
        self.scaler_y = self._create_scaler_instance(self.scaler_y_type, n_samples_y)

        X_scaled = self._fit_transform_single(self.scaler_x, X)

        # Store the data limits of the scaled features
        self.data_limits_x_ = torch.tensor(
            [[X_scaled[:, i].min(), X_scaled[:, i].max()] for i in range(X_scaled.shape[1])],
            dtype=torch.float32,
        )

        Y_original_ndim = Y.ndim
        if Y.ndim == 1:
            Y = Y.reshape(-1, 1)
        Y_scaled = self._fit_transform_single(self.scaler_y, Y)
        if Y_original_ndim == 1 and Y_scaled.ndim > 1:
            Y_scaled = Y_scaled.flatten()

        self.cut_points_ = torch.tensor(X_scaled, dtype=torch.float32).T

        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        Y_tensor = torch.tensor(Y_scaled, dtype=torch.float32)

        return X_tensor, Y_tensor, X_scaled, Y_scaled

    def transform(
        self, X: np.ndarray, Y: np.ndarray
    ) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]:
        """
        Transforms test data using the already fitted scalers.

        Args:
            X: The input features (numpy array).
            Y: The target values (numpy array).

        Returns:
            A tuple containing:
            - X_tensor: Scaled features as a PyTorch tensor.
            - Y_tensor: Scaled target as a PyTorch tensor.
            - X_scaled: Scaled features as a numpy array.
            - Y_scaled: Scaled target as a numpy array.
        """
        if self.scaler_x is None or self.scaler_y is None:
            # Check if scalers are 'none' type, which is valid
            if self.scaler_x_type != "none" or self.scaler_y_type != "none":
                raise RuntimeError(
                    "Scalers have not been fitted. Call fit_transform first."
                )

        X_scaled = self._transform_single(self.scaler_x, X)

        Y_original_ndim = Y.ndim
        if Y.ndim == 1:
            Y = Y.reshape(-1, 1)
        Y_scaled = self._transform_single(self.scaler_y, Y)
        if Y_original_ndim == 1 and Y_scaled.ndim > 1:
            Y_scaled = Y_scaled.flatten()

        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        Y_tensor = torch.tensor(Y_scaled, dtype=torch.float32)

        return X_tensor, Y_tensor, X_scaled, Y_scaled

    def _fit_transform_single(
        self, scaler: Optional[ScalerObject], data: np.ndarray
    ) -> np.ndarray:
        """Helper to fit and transform data with a single scaler."""
        if scaler is None:
            return data.copy()
        return scaler.fit_transform(data)

    def _transform_single(
        self, scaler: Optional[ScalerObject], data: np.ndarray
    ) -> np.ndarray:
        """Helper to transform data with a single scaler."""
        if scaler is None:
            return data.copy()
        return scaler.transform(data)

    def inverse_transform_y(self, Y_scaled: np.ndarray) -> np.ndarray:
        """
        Inverse transforms the scaled Y data to its original scale.

        Args:
            Y_scaled: The scaled target values (numpy array).

        Returns:
            The target values in their original scale.
        """
        if self.scaler_y is None:
            return Y_scaled

        if Y_scaled.ndim == 1:
            Y_scaled = Y_scaled.reshape(-1, 1)

        return self.scaler_y.inverse_transform(Y_scaled)
