import shap
import sys
import os
import pandas as pd

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import convert_boolean_and_encode


class KernelSHAP:
    """
    SHAP-based estimator using KernelExplainer.
    """

    def __init__(self, data, explicand, target, model, S_size):
        self.data = data
        self.explicand = explicand.to_frame()
        self.target = target
        self.features = data.drop(columns=[target]).columns.tolist()
        self.m = S_size
        # self.model = model.model._trainer.load_model( "XGBoost").model
        self.model = model
        self.explainer = shap.KernelExplainer(
            # self.model.predict_proba, shap.sample(self.data[self.features], 100)
            self.model.predict_proba,
            self.data[self.features],
        )

    def __call__(self):
        shap_values = self.explainer.shap_values(self.explicand, nsamples=self.m)

        shap_dict = {}
        for i, class_shap in enumerate(shap_values):
            for j, feature in enumerate(self.features):
                if feature not in shap_dict:
                    shap_dict[feature] = [[] for _ in range(len(shap_values))]
                shap_dict[feature][i].append(class_shap[j])

        return shap_dict


class TreeSHAP:
    """
    SHAP-based estimator using TreeExplainer with interventional feature perturbation.
    """

    def __init__(self, data, explicand, target, model, model_type="autoxgb"):
        self.target = target
        self.explicand = (
            self._get_data(explicand, model) if model_type == "autoxgb" else explicand
        )
        self.data = self._get_data(data, model) if model_type == "autoxgb" else data
        self.model = self._get_model(model) if model_type == "autoxgb" else model
        self.explainer = shap.TreeExplainer(
            self.model, data=self.data, feature_perturbation="interventional"
        )

    def __call__(self):
        for col in self.explicand:
            self.explicand[col] = self.explicand[col].astype(self.data[col].dtype)

        shap_values = self.explainer.shap_values(self.explicand)
        shap_dict = {}
        # Handling multi-class outputs
        if isinstance(shap_values, list):
            for i, class_shap in enumerate(shap_values):
                for j, feature in enumerate(self.features):
                    if feature not in shap_dict:
                        shap_dict[feature] = [[] for _ in range(len(shap_values))]
                    shap_dict[feature][i].append(class_shap[j])
        else:
            for j, feature in enumerate(self.features):
                shap_dict[feature] = [[val] for val in shap_values[j]]

        return shap_dict

    def _get_data(self, data, xgbmodel):
        data_transformed = xgbmodel.model.transform_features(
            data=data.drop(columns=[self.target]), model="XGBoost"
        )
        return convert_boolean_and_encode(data_transformed)

    def _get_model(self, model):
        xgbmodel = model.model._trainer.load_model("XGBoost")
        model_params = xgbmodel.model.get_params()
        model_params["enable_categorical"] = True
        xgbmodel.model.set_params(**model_params)

        return xgbmodel.model
