import numpy as np
from trees.shap_model import ShapModel
import sklearn
import matplotlib.pyplot as plt
import numpy as np





def tree_to_fourier(tree):
    children_left = tree.tree_.children_left
    children_right = tree.tree_.children_right
    feature = tree.tree_.feature
    value = tree.tree_.value

    def fourier_tree_sum(left_fourier, right_fourier, feature):
        final_fourier = {}
        all_freqs = set().union(left_fourier, right_fourier)
        for freq in all_freqs:
            final_fourier[freq] = 0.5 * (left_fourier.get(freq, 0) + right_fourier.get(freq, 0))
            final_fourier[frozenset.union(freq, frozenset({feature}))] = 0.5 * left_fourier.get(freq, 0) \
                                                                         - 0.5 * right_fourier.get(freq, 0)
        return final_fourier


    def dfs(node_id):
        is_split_node = children_left[node_id] != children_right[node_id]
        if is_split_node:
            left_fourier = dfs(children_left[node_id])
            right_fourier = dfs(children_right[node_id])
            return fourier_tree_sum(left_fourier, right_fourier, feature[node_id])
        else:
            if isinstance(tree, sklearn.tree.DecisionTreeRegressor):
                return {frozenset(): value[node_id][0][0]}
            else:
                return {frozenset(): value[node_id][0][:]/sum(value[node_id][0])}
    return dfs(0)


def decision_path(X, tree):
    feature = tree.tree_.feature
    threshold = tree.tree_.threshold
    y = np.zeros(X.shape, dtype=np.int32)
    node_indicator = tree.decision_path(X)
    leaf_id = tree.apply(X)

    # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
    node_index = [node_indicator.indices[node_indicator.indptr[i]: node_indicator.indptr[i+1]]
                  for i in range(len(node_indicator.indptr)-1)]

    for i in range(X.shape[0]):
        for node_id in node_index[i]:
            # continue to the next node if it is a leaf node
            if leaf_id[i] == node_id:
                continue

            # check if value of the split feature for sample 0 is below threshold
            if X[i, feature[node_id]] > threshold[node_id]:
                y[i, feature[node_id]] = 1

    return y

# TODO: Make this multiple core on the different trees
def forest_to_fourier(forest):
    final_fourier = []
    for tree in forest.estimators_:
        final_fourier.append(tree_to_fourier(tree))
    return final_fourier


if __name__ == "__main__":
    sm = ShapModel("crimes", 10, 10)
    a = np.zeros((2, 101))
    a[0, 43] = 1000000
    a[1, 42] = 1000000
    path = decision_path(a, sm.rf.estimators_[0])
    fourier_all = forest_to_fourier(sm.rf)
    pass