# Copyright © 2023 Gurobi Optimization, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Module for formulating a
XGBoost gradient boosting regressor
into a :external+gurobi:py:class:`Model`.
"""

import json

import numpy as np
import xgboost as xgb
from gurobipy import GRB
import gurobipy as gp

try:
    from gurobipy import nlfunc

    HAS_NLFUNC = True
except ImportError:
    HAS_NLFUNC = False

from gurobi_ml.exceptions import NoModel, NoSolution
from gurobi_ml.modeling import AbstractPredictorConstr
from gurobi_ml.modeling.decision_tree import AbstractTreeEstimator


def add_xgbclassifier_constr(
    gp_model, xgboost_classifier, input_vars, output_vars=None, epsilon=0.0, **kwargs
):
    return XGBoostClassifierConstr(
        gp_model,
        xgboost_classifier.get_booster(),
        input_vars,
        output_vars,
        epsilon=epsilon,
        **kwargs,
    )


class XGBoostClassifierConstr(AbstractPredictorConstr):

    def __init__(
        self, gp_model, xgb_classifier, input_vars, output_vars, epsilon=0.0, **kwargs
    ):
        self._output_shape = 1
        self.estimators_ = []
        self.xgb_classifier = xgb_classifier
        self._default_name = "xgb_clf"
        self.epsilon = epsilon
        AbstractPredictorConstr.__init__(
            self, gp_model, input_vars, output_vars, **kwargs
        )

    def _mip_model(self, **kwargs):
        """Predict output variables y from input variables X using the
        decision tree.

        Both X and y should be array or list of variables of conforming dimensions.
        """
        model = self.gp_model
        xgb_classifier: xgb.XGBClassifier = self.xgb_classifier

        _input = self._input
        output = self._output
        nex = _input.shape[0]
        timer = AbstractPredictorConstr._ModelingTimer()
        outdim = output.shape[1]
        assert (
            outdim == 1
        ), "Output dimension of gradient boosting classifier should be 1"

        xgb_raw = json.loads(xgb_classifier.save_raw("json").decode())
        booster_type = xgb_raw["learner"]["gradient_booster"]["name"]
        if booster_type != "gbtree":
            raise NoModel(xgb_classifier, f"model not implemented for {booster_type}")
        trees = xgb_raw["learner"]["gradient_booster"]["model"]["trees"]
        n_estimators = len(trees)

        estimators = []
        if self._no_debug:
            kwargs["no_record"] = True

        tree_vars = model.addMVar(
            (nex, n_estimators, 1),
            lb=-GRB.INFINITY,
            name=self._name_var("estimator"),
        )

        for i, tree in enumerate(trees):
            if self.verbose:
                self._timer.timing(f"Estimator {i}")
            tree["threshold"] = (
                np.array(tree["split_conditions"], dtype=np.float32) - self.epsilon
            )
            tree["children_left"] = np.array(tree["left_children"])
            tree["children_right"] = np.array(tree["right_children"])
            tree["feature"] = np.array(tree["split_indices"])
            tree["value"] = tree["threshold"].reshape(-1, 1)
            tree["capacity"] = len(tree["split_conditions"])
            tree["n_features"] = int(tree["tree_param"]["num_feature"])

            def _name_tree_var(name):
                rval = self._name_var(name)
                if rval is None:
                    return None
                return rval + f"_{i}"

            estimators.append(
                AbstractTreeEstimator(
                    self.gp_model,
                    tree,
                    self.input,
                    tree_vars[:, i, :],
                    self.epsilon,
                    timer,
                    **kwargs,
                )
            )

        self.estimators_ = estimators

        constant = float(xgb_raw["learner"]["learner_model_param"]["base_score"])
        learning_rate = 1.0
        objective = xgb_raw["learner"]["objective"]["name"]

        if objective in ("reg:logistic", "binary:logistic"):
            if gp.gurobi.version()[0] < 11:
                raise NoModel(
                    xgb_classifier,
                    f"Option objective:{objective} only supported with Gurobi >= 11",
                )

            # if HAS_NLFUNC:
            #     model.addConstr(
            #         output == nlfunc.logistic(learning_rate * tree_vars.sum(axis=1))
            #     )
            # else:
            #     affinevar = model.addMVar(output.shape, lb=-float("infinity"))
            #     model.addConstr(affinevar == learning_rate * tree_vars.sum(axis=1))
            #     for index in np.ndindex(self.output.shape):
            #         self.gp_model.addGenConstrLogistic(
            #             affinevar[index],
            #             output[index],
            #             name=self._indexed_name(index, "logistic"),
            #         )
            #     num_gc = self.gp_model.NumGenConstrs
            #     self.gp_model.update()
            #     for gen_constr in self.gp_model.getGenConstrs()[num_gc:]:
            #         gen_constr.setAttr("FuncNonLinear", 1)
        elif objective == "reg:squarederror":
            model.addConstr(output == learning_rate * tree_vars.sum(axis=1) + constant)
        else:
            raise NoModel(
                xgb_classifier, f"objective type '{objective}' not implemented"
            )


    def get_error(self, eps=None):
        if self._has_solution:
            xgb_in = xgb.DMatrix(self.input_values)
            xgb_out = self.xgb_regressor.predict(xgb_in)
            r_val = np.abs(xgb_out.reshape(-1, 1) - self.output.X)
            if eps is not None and np.max(r_val) > eps:
                print(f"{self.output.X} != {xgb_out.reshape(-1, 1)}")
            return r_val
        raise NoSolution()
