import numpy as np
from loguru import logger
from sklearn.feature_selection import SelectKBest
from sklearn.preprocessing import QuantileTransformer


class Preprocessor():
    """
    This class is used to preprocess the data before it is pushed through the model.
    The preprocessor assures that the data has the right shape and is normalized,
    This way the model always gets the same input distribution, 
    no matter whether the input data is synthetic or real.

    """

    def __init__(
            self, 
            max_features: int,
            n_classes: int,   # Actual number of classes in the dataset, assumed to be numbered 0, ..., n_classes - 1 
            max_classes: int,  # Maximum number of classes the model has been trained on
            use_quantile_transformer: bool,
            use_feature_count_scaling: bool,
            shuffle_classes: bool,
            shuffle_features: bool,
            random_mirror_x: bool,
        ):

        self.max_features = max_features
        self.n_classes = n_classes
        self.max_classes = max_classes
        self.use_quantile_transformer = use_quantile_transformer
        self.use_feature_count_scaling = use_feature_count_scaling
        self.shuffle_classes = shuffle_classes
        self.shuffle_features = shuffle_features
        self.random_mirror_x = random_mirror_x

    
    def fit(self, X: np.ndarray, y: np.ndarray):

        self.compute_pre_nan_mean(X)
        X = self.impute_nan_features_with_mean(X)

        self.determine_which_features_are_singular(X)
        X = self.cutoff_singular_features(X, self.singular_features)

        self.determine_which_features_to_select(X, y)
        X = self.select_features(X)

        if self.use_quantile_transformer:
            # If use quantile transform is off, it means that the preprocessing will happen on the GPU.

            X = self.fit_transform_quantile_transformer(X)
        
            self.mean, self.std = self.calc_mean_std(X)
            X = self.normalize_by_mean_std(X, self.mean, self.std)

        if self.shuffle_classes:
            self.determine_class_order()
        
        if self.shuffle_features:
            self.determine_feature_order(X)

        if self.random_mirror_x:
            self.determine_mirror(X)


        assert np.isnan(X).sum() == 0, "There are NaNs in the data after preprocessing"

        return self
    

    def transform_X(self, X: np.ndarray):

        X = self.impute_nan_features_with_mean(X)
        X = self.cutoff_singular_features(X, self.singular_features)
        X = self.select_features(X)

        if self.use_quantile_transformer:
            # If use quantile transform is off, it means that the preprocessing will happen on the GPU.

            X = self.quantile_transformer.transform(X)
        
            X = self.normalize_by_mean_std(X, self.mean, self.std)

            if self.use_feature_count_scaling:
                X = self.normalize_by_feature_count(X, self.max_features)


        if self.shuffle_features:
            X = self.randomize_feature_order(X)

        if self.random_mirror_x:
            X = self.apply_random_mirror_x(X)

        assert np.isnan(X).sum() == 0, "There are NaNs in the data after preprocessing"

        return X
    
    
    def transform_y(self, y: np.ndarray):

        if self.shuffle_classes:
            y = self.randomize_class_order(y)

        return y
    

    def inverse_transform_y(self, y: np.ndarray):

        if self.shuffle_classes:
            y = self.undo_randomize_class_order(y)

        return y
    

    def inverse_transform_y_logits(self, y_logits: np.ndarray):

        y_logits = self.extract_correct_classes(y_logits)

        return y_logits
        
    

    def fit_transform_quantile_transformer(self, X: np.ndarray) -> np.ndarray:

        n_obs, n_features = X.shape
        n_quantiles = min(n_obs, 1000)
        self.quantile_transformer = QuantileTransformer(n_quantiles=n_quantiles, output_distribution='normal')
        X = self.quantile_transformer.fit_transform(X)

        return X

        

    def determine_which_features_are_singular(self, x: np.ndarray) -> None:

        self.singular_features = np.array([ len(np.unique(x_col)) for x_col in x.T ]) == 1
        


    def determine_which_features_to_select(self, x: np.ndarray, y: np.ndarray) -> None:

        if x.shape[1] > self.max_features:
            logger.info(f"Number of features is capped at {self.max_features}, but the dataset has {x.shape[1]} features. A subset of {self.max_features} are selected using SelectKBest")

            self.select_k_best = SelectKBest(k=self.max_features)
            self.select_k_best.fit(x, y)


    def compute_pre_nan_mean(self, x: np.ndarray) -> None:
        """
        Computes the mean of the data before the NaNs are imputed
        """
        self.pre_nan_mean = np.nanmean(x, axis=0)


    def impute_nan_features_with_mean(self, x: np.ndarray) -> np.ndarray:

        inds = np.where(np.isnan(x))
        x[inds] = np.take(self.pre_nan_mean, inds[1])
        return x

    
    def select_features(self, x: np.ndarray) -> np.ndarray:

        if x.shape[1] > self.max_features:
            x = self.select_k_best.transform(x)

        return x
    

    def cutoff_singular_features(self, x: np.ndarray, singular_features: np.ndarray) -> np.ndarray:

        if singular_features.any():
            x = x[:, ~singular_features]

        return x


    def calc_mean_std(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """
        Calculates the mean and std of the training data
        """
        mean = x.mean(axis=0)
        std = x.std(axis=0)
        return mean, std
    

    def normalize_by_mean_std(self, x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
        """
        Normalizes the data by the mean and std
        """

        x = (x - mean) / std
        return x


    def normalize_by_feature_count(self, x: np.ndarray, max_features) -> np.ndarray:
        """
        An interesting way of normalization by the tabPFN paper
        """

        x = x * max_features / x.shape[1]
        return x



    def extend_feature_dim_to_max_features(self, x: np.ndarray, max_features) -> np.ndarray:
        """
        Increases the number of features to the number of features the model has been trained on
        """
        added_zeros = np.zeros((x.shape[0], max_features - x.shape[1]), dtype=np.float32)
        x = np.concatenate([x, added_zeros], axis=1)
        return x
    

    def determine_mirror(self, x: np.ndarray) -> None:

        n_features = x.shape[1]
        self.mirror = np.random.choice([1, -1], size=(1, n_features))


    def apply_random_mirror_x(self, x: np.ndarray) -> np.ndarray:

        x = x * self.mirror
        return x
    

    def determine_class_order(self) -> None:

        if self.shuffle_classes:
            self.new_classes = np.random.permutation(self.n_classes)
        else:
            self.new_classes = np.arange(self.n_classes)

    
    def randomize_class_order(self, y: np.ndarray) -> np.ndarray:

        mapping = { i: self.new_classes[i] for i in range(self.n_classes) }
        y = np.array([mapping[i.item()] for i in y], dtype=np.int64)

        return y    
    
    
    def undo_randomize_class_order(self, y: np.ndarray) -> np.ndarray:
    
        mapping = { self.new_classes[i]: i for i in range(self.n_classes) }
        y = np.array([mapping[i.item()] for i in y], dtype=np.int64)

        return y
    

    def extract_correct_classes(self, y_logits: np.ndarray) -> np.ndarray:

        y_logits = y_logits[:, self.new_classes]
        y_logits = y_logits[:, :self.n_classes]
        return y_logits



    def determine_feature_order(self, x: np.ndarray) -> None:

        n_features = x.shape[1]
        self.new_feature_order = np.random.permutation(n_features)
    

    def randomize_feature_order(self, x: np.ndarray) -> np.ndarray:

        x = x[:, self.new_feature_order]

        return x
    


    
