from abc import ABC, abstractmethod
from typing import Union

import numpy as np
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler, StandardScaler, QuantileTransformer
from sklearn.utils.validation import check_is_fitted


def is_fitted(scaler: Union[MinMaxScaler, StandardScaler]) -> bool:
    try:
        check_is_fitted(scaler)
        return True
    except NotFittedError:
        return False


class Pipeline(ABC):
    def __init__(self, *args, **kwargs) -> None:
        self.scaler = None

    @abstractmethod
    def preprocess(self, x: np.ndarray) -> np.ndarray:
        pass

    @abstractmethod
    def inverse_transform(self, x: np.ndarray) -> np.ndarray:
        pass

    def batched_inverse_transform(self, batch: np.ndarray) -> np.ndarray:
        l = list()
        for x in batch:
            l.append(self.inverse_transform(x))
        return np.stack(l)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(scaler={self.scaler})"


class ScalerPipeline(Pipeline):
    def __init__(self, scaler: Union[MinMaxScaler, StandardScaler, QuantileTransformer]) -> None:
        super(ScalerPipeline, self).__init__()
        self.scaler = scaler

    def preprocess(self, x: np.ndarray) -> np.ndarray:
        if not is_fitted(self.scaler):
            self.scaler.fit(x)

        return self.scaler.transform(x)

    def inverse_transform(self, x: np.ndarray) -> np.ndarray:
        inv = self.scaler.inverse_transform(x)
        return inv
