import numpy as np
from linear_tree_shap.utils import copy_tree, get_N


def psi(E, D_power, D, q, Ns, d):
    n = Ns[d, :d]
    return ((E*D_power/(D+q))[:d]).dot(n)/d

def inference(tree, x, activation, result, Base, Offset, Ns, C, E, node=0, edge_feature=-1, depth=0):
    left, right, parent, child_edge_feature = (
                            tree.children_left[node], 
                            tree.children_right[node], 
                            tree.parents[node], 
                            tree.features[node]
                            )
    left_height, right_height, parent_height, current_height = ( 
                            tree.edge_heights[left], 
                            tree.edge_heights[right], 
                            tree.edge_heights[parent], 
                            tree.edge_heights[node]
                            )
    if left >= 0:
        if x[child_edge_feature] <= tree.thresholds[node]:
            activation[left], activation[right] = True, False
        else:
            activation[left], activation[right] = False, True
    if edge_feature >= 0:
        if parent >= 0:
            activation[node] &= activation[parent]

        if activation[node]:
            q_eff = 1./tree.weights[node]
        else:
            q_eff = 0. 
        C[depth] = C[depth-1]*(Base+q_eff)

        if parent >= 0:
            if activation[parent]:
                s_eff = 1./tree.weights[parent]
            else:
                s_eff = 0.
            C[depth] = C[depth]/(Base+s_eff)
    if left < 0:
        E[depth] = C[depth]*tree.leaf_predictions[node]
    else:
        inference(tree, x, activation, result, Base, Offset, Ns, C, E, left, child_edge_feature, depth+1)
        E[depth] = E[depth+1]*Offset[current_height-left_height]
        inference(tree, x, activation, result, Base, Offset, Ns, C, E, right, child_edge_feature, depth+1)
        E[depth] += E[depth+1]*Offset[current_height-right_height]
    if edge_feature >= 0:
        value = (q_eff-1)*psi(E[depth], Offset[0], Base, q_eff, Ns, current_height)
        result[edge_feature] += value
        if parent >= 0:
            value = (s_eff-1)*psi(E[depth], Offset[parent_height-current_height], Base, s_eff, Ns, parent_height)
            result[edge_feature] -= value

class TreeExplainer:
    def __init__(self, clf, base_func=np.polynomial.chebyshev.chebpts2):
        self.clf = clf
        self.tree = copy_tree(clf.tree_)
        self.Base = base_func(self.tree.max_depth)
        self.Offset = np.vander(self.Base+1).T[::-1]
        self.N = get_N(self.Base)

    def py_shap_values(self, x):
        activation = np.zeros_like(self.tree.children_left, dtype=bool)
        C = np.zeros((self.tree.max_depth+1, self.tree.max_depth))
        E = np.zeros((self.tree.max_depth+1, self.tree.max_depth))
        C[0, :] = 1
        result = np.zeros_like(x)
        inference(self.tree, x.astype(np.float32), activation, result, self.Base, self.Offset, self.N, C, E)
        return result

    def shap_values(self, X):
        from linear_tree_shap import _cext
        V = np.zeros_like(X, dtype=np.float64)
        _cext.linear_tree_shap(
                               self.tree.weights, 
                               self.tree.leaf_predictions, 
                               self.tree.thresholds, 
                               self.tree.parents.astype(np.int32), 
                               self.tree.edge_heights.astype(np.int32), 
                               self.tree.features.astype(np.int32), 
                               self.tree.children_left.astype(np.int32), 
                               self.tree.children_right.astype(np.int32), 
                               self.tree.max_depth,
                               self.tree.num_nodes,
                               self.Base, self.Offset, self.N, X.astype(np.float32), V)
        return V


if __name__ == "__main__":
    from sklearn.datasets import make_regression
    from sklearn.tree import DecisionTreeRegressor, export_text
    from shap import TreeExplainer as Truth
    import numpy as np
    np.random.seed(10)
    x, y = make_regression(1000, n_features=10)
    clf = DecisionTreeRegressor(max_depth=6).fit(x, y)
    sim = Truth(clf)
    mine = TreeExplainer(clf)
    a = mine.py_shap_values(x[4])
    b = sim.shap_values(x[4][None,:])[0]
    np.testing.assert_array_almost_equal(a, b)
