"""
    Implementation of GOSDT-Guesses
    ThresholdGuess class will take a dataset and find the best thresholds
    to use for the features using a Gradient Boosting Decision Tree Classifier.

    Dataset then can be binarized using the found thresholds.
"""
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier

class ThresholdGuess:
    """
    Class to find the best thresholds for features using
            a Gradient Boosting Decision Tree Classifier.
    """
    def __init__(self, guess_model_param, back_select=True, random_state=42):
        """ 

        Args:
            guess_model_param (dict): GBDT parameters.
            back_select (bool, optional): Back selection of features. Defaults to True.
            random_state (int, optional): Random seed. Defaults to 42.
        """
        self.model_param = guess_model_param
        self.back_select = back_select
        self.thresholds = None
        self.rs = random_state
        self.num_features = None
        self.feature_names_out = None
        self.feature_importances_ = None

    def fit_gbdt(self, X, y) -> tuple:
        """ Fit the GBDT model and return the model and its accuracy.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.

        Returns:
            tuple: (model, accuracy)
        """
        clf = GradientBoostingClassifier(**self.model_param, random_state=self.rs)
        clf.fit(X, y)
        out = clf.score(X,y)
        return clf, out

    def fit(self, X, y, feat=None) -> None:
        """ Fit the model to find the best thresholds for features.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            feat (list): List of feature names.

        Returns:
            _type_: _description_
        """
        clf, _ = self.fit_gbdt(X, y)

        thresholds = set()
        for est in clf.estimators_:
            tree = est[0].tree_
            f = tree.feature
            t = tree.threshold
            thresholds.update([(f[i], t[i]) for i in range(len(f)) if f[i] >= 0])

        self.thresholds = list(thresholds)

        if self.back_select:
            X_new = self.transform(X, feat)
            clf, _ = self.fit_gbdt(X_new, y)

            X_init = X_new.copy()
            vi = clf.feature_importances_

            if vi.size > 0:
                num_features_to_remove = max(1, int(0.2 * X_init.shape[1]))
                vi_indices = np.argsort(vi)[:num_features_to_remove]

                X_init = np.delete(X_init, vi_indices, axis=1)
                for idx in sorted(vi_indices, reverse=True):
                    del self.thresholds[idx]

        self.thresholds.sort(key=lambda x: (x[0], x[1]))
        self.num_features = len(self.thresholds)

    def transform(self, X, feat=None) -> np.ndarray:
        """ Transform the feature matrix using the found thresholds.
                Need to call fit() before calling this function.
        
        Args:
            X (np.ndarray): Feature matrix.
            feat (list, optional): List of feature names. Defaults to None.

        Returns:
            np.ndarray: Transformed feature matrix.
        """
        X = X.copy()
        feature_names_in = feat
        if feature_names_in is None:
            feature_names_in = [f"feature_{i}" for i in range(X.shape[1])]
        # check or transform X, y into ndarrays
        if not isinstance(X, np.ndarray):
            X = X.values

        feature_names_out = []
        X_new = np.zeros((X.shape[0], len(self.thresholds)))
        for i, _ in enumerate(self.thresholds):
            f, t = self.thresholds[i]
            # check if the original column is binary
            if np.array_equal(X[:,f], X[:,f].astype(bool)):
                X_new[:, i] = X[:,f]
                feature_names_out.append(feature_names_in[f])
            else:
                X_new[X[:,f] <= t, i] = 1
                feature_names_out.append(f"{feature_names_in[f]} <= {t}")

        self.feature_names_out = feature_names_out

        return X_new
