import json

from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.tree import DecisionTreeClassifier

from verification.decision_tree_attack import DecisionTreeAttackWrapper

import numpy as np

def convert_numpy(obj):
    """
    Convert numpy ints and floats to python types. Useful when converting objects to JSON.

    Parameters
    ----------
    obj : {np.int32, np.int64, np.float32, np.float64, np.longlong}
        Number to convert to python int or float.
    """
    if (
        isinstance(obj, np.int32)
        or isinstance(obj, np.int64)
        or isinstance(obj, np.longlong)
    ):
        return int(obj)
    elif isinstance(obj, np.float32) or isinstance(obj, np.float64):
        return float(obj)
    raise TypeError(f"Cannot convert type {type(obj)} to int or float")


class KantchModel:
    def __init__(self, json_model, n_classes):
        """
        General model class that exposes a common API for evaluating decision tree (ensemble) models. Usually you won't have to call this constructor manually, instead use `from_json_file`, `from_sklearn`, `from_treant`, `from_provably_robust_boosting` or `from_groot`.

        Parameters
        ----------
        json_model : list of dicts
            List of decision trees encoded as dicts. See the XGBoost JSON format.
        n_classes : int
            Number of classes that this model predicts.
        """
        self.json_model = json_model
        self.n_classes = n_classes

    @staticmethod
    def from_json_file(filename, n_classes):
        """
        Create a Model instance from a JSON file.

        Parameters
        ----------
        filename : str
            Path to JSON file that contains a list of decision trees encoded as dicts. See the XGBoost JSON format.
        n_classes : int
            Number of classes that this model predicts.

        Returns
        -------
        KantchModel
            Instantiated KantchModel object.
        """
        with open(filename, "r") as file:
            json_model = json.load(file)

        return KantchModel(json_model, n_classes)

    @staticmethod
    def from_sklearn(classifier):
        """
        Create a KantchModel instance from a Scikit-learn classifier.

        Parameters
        ----------
        classifier : DecisionTreeClassifier, RandomForestClassifier or GradientBoostingClassifier
            Scikit-learn model to load.

        Returns
        -------
        KantchModel
            Instantiated KantchModel object.
        """
        if isinstance(classifier, DecisionTreeClassifier):
            return _sklearn_tree_to_model(classifier)
        elif isinstance(classifier, RandomForestClassifier):
            return _sklearn_forest_to_model(classifier)
        elif isinstance(classifier, GradientBoostingClassifier):
            return _sklearn_booster_to_model(classifier)
        else:
            raise ValueError(
                "Only decision tree, random forest and gradient boosting classifiers are supported, not "
                + type(classifier)
            )

    @staticmethod
    def from_groot(classifier):
        """
        Create a Model instance from a GrootTree, GrootRandomForest or GROOT OneVsRestClassifier.

        Parameters
        ----------
        classifier : GrootTree, GrootRandomForest or OneVsRestClassifier (of GROOT models)
            GROOT model to load.

        Returns
        -------
        KantchModel
            Instantiated KantchModel object.
        """
        if isinstance(classifier, OneVsRestClassifier):
            one_vs_all_models = []
            for model in classifier.estimators_:
                json_model = model.to_xgboost_json(output_file=None)

                if not isinstance(json_model, list):
                    json_model = [json_model]

                one_vs_all_models.append(json_model)

            json_trees = []
            for grouped_models in zip(*one_vs_all_models):
                json_trees.extend(grouped_models)

            return KantchModel(json_trees, classifier.n_classes_)

        json_trees = classifier.to_xgboost_json(output_file=None)
        if not isinstance(json_trees, list):
            json_trees = [json_trees]

        return KantchModel(json_trees, 2)

    def predict(self, X):
        """
        Predict classes for some samples. The raw prediction values are turned into class labels.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to predict.

        Returns
        -------
        ndarray of shape (n_samples)
            Predicted class labels.
        """
        prediction_values = self.decision_function(X)
        if self.n_classes == 2:
            return (prediction_values >= 0).astype(int)
        else:
            return np.argmax(prediction_values, axis=1)

    def decision_function(self, X):
        """
        Compute prediction values for some samples. These values are the sum of leaf values in which the samples end up.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to predict.

        Returns
        -------
        ndarray of shape (n_samples) or ndarray of shape (n_samples, n_classes)
            Predicted values. Returns a 1-dimensional array if n_classes=2, else a 2-dimensional array.
        """
        values = []
        if self.n_classes == 2:
            for sample in X:
                value = 0
                for tree in self.json_model:
                    value += self.__predict_proba_tree_sample(tree, sample)
                values.append(value)
        else:
            for sample in X:
                class_values = np.zeros(self.n_classes)
                for i, tree in enumerate(self.json_model):
                    class_values[
                        i % self.n_classes
                    ] += self.__predict_proba_tree_sample(tree, sample)
                values.append(class_values)

        return np.array(values)

    def __predict_proba_tree_sample(self, json_tree, sample):
        """
        Recursively follow the path of a sample through the JSON tree and return the resulting leaf's value.
        """
        if "leaf" in json_tree:
            return json_tree["leaf"]

        if sample[json_tree["split"]] <= json_tree["split_condition"]:
            next_node_id = json_tree["yes"]
        else:
            next_node_id = json_tree["no"]

        for sub_tree in json_tree["children"]:
            if sub_tree["nodeid"] == next_node_id:
                return self.__predict_proba_tree_sample(sub_tree, sample)

    def __get_attack_wrapper(self, attack_name):
        """
        Return the instantiated attack wrapper for the appropriate attack.
        """
        # If the attack is set to automatic then use MILP for ensembles and
        # tree attack for individual trees.
        if attack_name == "auto":
            if len(self.json_model) == 1:
                attack_name = "tree"
            else:
                attack_name = "milp"

        if attack_name == "tree":
            return DecisionTreeAttackWrapper(self.json_model, self.n_classes)
        else:
            raise ValueError(f"Attack '{attack_name}' not supported.")

    def attack_feasibility(
        self, X, y, attack="auto", order=np.inf, epsilon=0.0, options={}
    ):
        """
        Determine whether an adversarial example is feasible for each sample given the maximum perturbation radius epsilon.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to attack.
        y : array-like of shape (n_samples,)
            True labels for the samples.
        attack : {"auto", "milp", "tree"}, optional
            The attack to use, if "auto" the attack is chosen automatically:
            - "milp" for optimal attacks on tree ensembles using a Mixed-Integer
              Linear Programming formulation.
            - "tree" for optimal attacks on single decision trees by enumerating
              all possible paths through the tree.
        order : {0, 1, 2, inf}, optional
            L-norm order to use. See numpy documentation of more explanation.
        epsilon : float, optional
            Maximum distance by which samples can move.
        options : dict, optional
            Extra attack-specific options.

        Returns
        -------
        ndarray of shape (n_samples,) of booleans
            Vector of True/False. Whether an adversarial example is feasible.
        """
        attack_wrapper = self.__get_attack_wrapper(attack)
        return attack_wrapper.attack_feasibility(
            X, y, order=order, epsilon=epsilon, options=options
        )

    def attack_distance(self, X, y, attack="auto", order=np.inf, options={}):
        """
        Determine the perturbation distance for each sample to make an adversarial example.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to attack.
        y : array-like of shape (n_samples,)
            True labels for the samples.
        attack : {"auto", "milp", "tree"}, optional
            The attack to use, if "auto" the attack is chosen automatically:
            - "milp" for optimal attacks on tree ensembles using a Mixed-Integer
              Linear Programming formulation.
            - "tree" for optimal attacks on single decision trees by enumerating
              all possible paths through the tree.
        order : {0, 1, 2, inf}, optional
            L-norm order to use. See numpy documentation of more explanation.
        options : dict, optional
            Extra attack-specific options.

        Returns
        -------
        ndarray of shape (n_samples,) of floats
            Distances to create adversarial examples.
        """
        attack_wrapper = self.__get_attack_wrapper(attack)
        return attack_wrapper.attack_distance(X, y, order=order, options=options)

    def adversarial_examples(self, X, y, attack="auto", order=np.inf, options={}):
        """
        Create adversarial examples for each input sample.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to attack.
        y : array-like of shape (n_samples,)
            True labels for the samples.
        attack : {"auto", "milp", "tree"}, optional
            The attack to use, if "auto" the attack is chosen automatically:
            - "milp" for optimal attacks on tree ensembles using a Mixed-Integer
              Linear Programming formulation.
            - "tree" for optimal attacks on single decision trees by enumerating
              all possible paths through the tree.
        order : {0, 1, 2, inf}, optional
            L-norm order to use. See numpy documentation of more explanation.
        options : dict, optional
            Extra attack-specific options.

        Returns
        -------
        ndarray of shape (n_samples, n_features)
            Adversarial examples.
        """
        attack_wrapper = self.__get_attack_wrapper(attack)
        return attack_wrapper.adversarial_examples(X, y, order=order, options=options)
    
    def get_bounding_boxes(self, X):
        attack_wrapper = self.__get_attack_wrapper("tree")
        bounding_boxes, predictions = attack_wrapper.get_bounding_box(X)
        return bounding_boxes, predictions
    
    def assign_bounding_boxes(self, X):
        attack_wrapper = self.__get_attack_wrapper("tree")
        return attack_wrapper.assign_bounding_boxes(X)

    def accuracy(self, X, y):
        """
        Determine the accuracy of the model on unperturbed samples.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Input samples.
        y : array-like of shape (n_samples,)
            True labels.

        Returns
        -------
        float
            Accuracy on unperturbed samples.
        """
        y_pred = self.predict(X)
        return np.sum(y_pred == y) / len(y)

    def adversarial_accuracy(
        self, X, y, attack="auto", order=np.inf, epsilon=0.0, options={}
    ):
        """
        Determine the accuracy against adversarial examples within maximum perturbation radius epsilon.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to attack.
        y : array-like of shape (n_samples,)
            True labels for the samples.
        attack : {"auto", "milp", "tree"}, optional
            The attack to use, if "auto" the attack is chosen automatically:
            - "milp" for optimal attacks on tree ensembles using a Mixed-Integer
              Linear Programming formulation.
            - "tree" for optimal attacks on single decision trees by enumerating
              all possible paths through the tree.
        order : {0, 1, 2, inf}, optional
            L-norm order to use. See numpy documentation of more explanation.
        epsilon : float, optional
            Maximum distance by which samples can move.

        Returns
        -------
        float
            Adversarial accuracy given the maximum perturbation radius epsilon.
        """
        attacks_feasible = self.attack_feasibility(
            X, y, attack, order, epsilon, options
        )
        return np.sum(1 - attacks_feasible) / len(attacks_feasible)

    def to_json(self, filename, indent=2):
        """
        Export the model object to a JSON file.

        Parameters
        ----------
        filename : str
            Name of the JSON file to export to.
        indent : int, optional
            Number of spaces to use for indentation in the JSON file. Can be reduced to save storage.
        """
        with open(filename, "w") as file:
            json.dump(self.json_model, file, indent=indent, default=convert_numpy)


def _sklearn_tree_to_dict(tree, classifier=True, one_vs_all_class=1, learning_rate=1.0):
    if classifier:
        assert len(tree.classes_.shape) == 1, "Multi-output is not supported"

    n_nodes = tree.tree_.node_count
    children_left = tree.tree_.children_left
    children_right = tree.tree_.children_right
    feature = tree.tree_.feature
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    def dfs(node_id, depth):
        left_id = children_left[node_id]
        right_id = children_right[node_id]

        if left_id == right_id:
            # If leaf node
            if classifier:
                # A decision tree classifier contains the counts of samples
                # that reach the leaf
                class_counts = value[node_id][0]

                # Map the prediction probability to a value in the range [-1, 1]
                leaf_value = (
                    class_counts[one_vs_all_class] / np.sum(class_counts)
                ) * 2 - 1
                return {
                    "nodeid": node_id,
                    "leaf": leaf_value,
                }
            else:
                # A decision tree regressor contains the raw prediction value
                prediction = value[node_id][0][0]
                return {
                    "nodeid": node_id,
                    "leaf": learning_rate * prediction,
                }
        else:
            # If decision node
            left_dict = dfs(left_id, depth + 1)
            right_dict = dfs(right_id, depth + 1)

            return {
                "nodeid": node_id,
                "depth": depth,
                "split": feature[node_id],
                "split_condition": threshold[node_id],
                "yes": left_id,
                "no": right_id,
                "missing": left_id,
                "children": [left_dict, right_dict],
            }

    return dfs(0, 0)


def _sklearn_tree_to_model(tree: DecisionTreeClassifier):
    """
    Load a scikit-learn decision tree as a Model instance. A multiclass tree gets turned into a one-vs-all representation inside the JSON.

    Parameters
    ----------
    tree : sklearn.tree.DecisionTreeClassifier
        Decision tree to export
    """
    if tree.n_classes_ == 2:
        json_trees = [_sklearn_tree_to_dict(tree, classifier=True)]
    else:
        json_trees = []
        for class_label in range(tree.n_classes_):
            json_tree = _sklearn_tree_to_dict(
                tree, classifier=True, one_vs_all_class=class_label
            )
            json_trees.append(json_tree)

    return KantchModel(json_trees, tree.n_classes_)

def _sigmoid_inverse(proba: float):
    """
    Invert the sigmoid function that is used in the Scikit-learn binary gradient boosting classifier.
    """
    return np.log(proba / (1 - proba))
