import numpy as np


def get_decision_path_features(estimator, data):
    feature = estimator.tree_.feature
    node_indicator = estimator.decision_path(data)
    leave_id = estimator.apply(data)
    threshold = estimator.tree_.threshold
    features_used = np.zeros((np.shape(data)[0], np.shape(data)[1]))
    for sample_id in range(len(data)):
        node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                            node_indicator.indptr[sample_id + 1]]

        for node_id in node_index:
            # If Current Node is a Leaf -> Skip
            if leave_id[sample_id] == node_id:
                continue

            # Skip if feature already appeared higher up in the tree
            if features_used[sample_id][feature[node_id]] != 0.0:
                continue

            if data[sample_id, feature[node_id]] <= threshold[node_id]:
                features_used[sample_id][feature[node_id]] = -1.0
            else:
                features_used[sample_id][feature[node_id]] = 1.0

    return features_used


class ExplanationGenerator:
    def __init__(self, x, edge_index, stone_age_model):
        super().__init__()
        self.x = x
        self.edge_index = edge_index
        self.layers_input = stone_age_model.inputs
        self.layers_output = stone_age_model.outputs
        self.model = stone_age_model
        self.state_size = stone_age_model.state_size
        self.contributions = stone_age_model.contributions
        self.features_used = stone_age_model.features_used
        self.backprop_cache = {}

    def backpropagate_explanation(self, node_id, importance, layer_index, state=0, return_avg=True):
        cache_key = f'l{layer_index}-n{node_id}-s{state}-i{importance}'
        if cache_key in self.backprop_cache:
            if return_avg:
                aggregated_importance = [
                    self.backprop_cache[cache_key][0][i] / max(1, self.backprop_cache[cache_key][1][i]) for i in
                    range(self.model.num_nodes)]
                return aggregated_importance
            return self.backprop_cache[cache_key][0], self.backprop_cache[cache_key][1]

        aggregated_importance = np.zeros(self.model.num_nodes)
        aggregated_importance_counts = np.zeros(self.model.num_nodes)

        if self.model.subset is not None and node_id not in self.model.subset:
            if return_avg:
                return aggregated_importance
            return aggregated_importance, aggregated_importance_counts

        if layer_index == 0:
            aggregated_importance[node_id] = importance
            aggregated_importance_counts[node_id] = 1
            if return_avg:
                return aggregated_importance
            return aggregated_importance, aggregated_importance_counts

        features_used = self.model.features_used[layer_index][node_id][state]
        num_features = len(features_used)
        if np.sum(features_used) == 0:
            current_node_state = self.model.outputs[layer_index - 1][node_id]
            return self.backpropagate_explanation(
                    node_id, importance, layer_index - 1, state=current_node_state, return_avg=return_avg)
        features_used_state = features_used[num_features // 2:]
        features_used_messages = features_used[:num_features // 2]

        for i in range(len(features_used_state)):
            if abs(features_used_state[i]) > 0.0:
                new_aggregated_importance, new_aggregated_importance_count = self.backpropagate_explanation(node_id,
                                                                                                            features_used_state[i]*importance,
                                                                                                            layer_index - 1,
                                                                                                            state=i,
                                                                                                            return_avg=False)
                aggregated_importance += new_aggregated_importance
                aggregated_importance_counts += new_aggregated_importance_count

        for i in range(len(self.edge_index[0])):
            message_from = self.edge_index[0][i]
            message_to = self.edge_index[1][i]
            if message_to != node_id:
                continue

            from_node_state = self.model.outputs[layer_index - 1][message_from]
            if abs(features_used_messages[from_node_state]) > 0.0:
                num_neigboors_for_state = self.model.state_neighbours[layer_index][node_id][from_node_state]
                new_aggregated_importance, new_aggregated_importance_count = self.backpropagate_explanation(
                    message_from, (features_used_messages[from_node_state]*importance) / num_neigboors_for_state, layer_index - 1, state=from_node_state,
                    return_avg=False)
                aggregated_importance += new_aggregated_importance
                aggregated_importance_counts += new_aggregated_importance_count

        self.backprop_cache[cache_key] = (aggregated_importance, aggregated_importance_counts)

        if return_avg:
            aggregated_importance = [aggregated_importance[i] / max(1, aggregated_importance_counts[i]) for i in
                                     range(self.model.num_nodes)]
            return aggregated_importance

        return aggregated_importance, aggregated_importance_counts

    def forward_explain(self, layer_index):
        explanations = np.zeros((self.model.num_nodes, self.model.state_size, self.model.num_nodes))

        if layer_index == 0:
            for i in range(self.model.num_nodes):
                explanations[i, :, i] = 1.0
            return explanations

        explanations = self.model.explanation_forward_v2[layer_index-1]
        if self.model.tree_depths[layer_index] == 1:
            return self.model.explanation_forward_v2[layer_index - 1]

        for node_id in range(self.model.num_nodes):
            for state in range(self.model.state_size):
                features_used = self.model.features_used[layer_index][node_id][state]
                features_used_state = features_used[self.model.state_size:2*self.model.state_size]

                for i in range(len(features_used_state)):
                    if abs(features_used_state[i]) > 0.0:
                        new_aggregated_importance = self.model.explanation_forward_v2[layer_index - 1][node_id, i] * features_used_state[i]
                        explanations[node_id, state] += new_aggregated_importance

        for i in range(len(self.edge_index[0])):
            message_from = self.edge_index[0][i]
            node_id = self.edge_index[1][i]

            from_node_state = self.model.outputs[layer_index - 1][message_from]
            for state in range(self.model.state_size):
                features_used = self.model.features_used[layer_index][node_id][state]
                features_used_messages = features_used[:self.model.state_size]
                if abs(features_used_messages[from_node_state]) > 0.0:
                    num_neigboors_for_state = self.model.state_neighbours[layer_index][node_id][from_node_state]
                    new_aggregated_importance = self.model.explanation_forward_v2[layer_index - 1][message_from, from_node_state] * (1/num_neigboors_for_state) * features_used_messages[from_node_state]
                    explanations[node_id, state] += new_aggregated_importance

        for i in range(len(self.edge_index[0])):
            message_from = self.edge_index[0][i]
            node_id = self.edge_index[1][i]

            from_node_state = self.model.outputs[layer_index - 1][message_from]
            for state in range(self.model.state_size):
                features_used = self.model.features_used[layer_index][node_id][state]

                if len(features_used) > self.model.state_size * 2:
                    features_used_linear_combination = features_used[2 * self.model.state_size:]
                    for lin_comb_feat_index in range(self.model.state_size):
                        num_neigboors_for_state = self.model.state_neighbours[layer_index][node_id][from_node_state] + self.model.state_neighbours[layer_index][node_id][lin_comb_feat_index]
                        feature_index_larger = from_node_state*self.model.state_size + lin_comb_feat_index
                        feature_index_smaller = lin_comb_feat_index*self.model.state_size + from_node_state
                        if abs(features_used_linear_combination[feature_index_larger]) > 0:
                            negative_multiplier = 1.0
                            if self.model.inputs[layer_index][node_id][(2*self.model.state_size) + feature_index_larger] == 0:
                                negative_multiplier = -1.0
                            new_aggregated_importance = self.model.explanation_forward_v2[layer_index - 1][
                                                            message_from, from_node_state] * (
                                                                    1 / num_neigboors_for_state) * \
                                                        features_used_linear_combination[feature_index_larger] * negative_multiplier
                            explanations[node_id, state] += new_aggregated_importance
                        if abs(features_used_linear_combination[feature_index_smaller]) > 0:
                            negative_multiplier = -1.0
                            if self.model.inputs[layer_index][node_id][(2*self.model.state_size) + feature_index_smaller] == 0:
                                negative_multiplier = 1.0
                            new_aggregated_importance = self.model.explanation_forward_v2[layer_index - 1][
                                                            message_from, from_node_state] * (
                                                                    1 / num_neigboors_for_state) * \
                                                        features_used_linear_combination[feature_index_smaller] * negative_multiplier
                            explanations[node_id, state] += new_aggregated_importance

        return explanations

    def get_forward_explanation(self):
        explanations = np.zeros((self.model.num_nodes, self.model.num_nodes))

        for node_id in range(self.model.num_nodes):
            current_node_state = self.model.outputs[-1][node_id]
            for layer_index in range(len(self.model.features_used) - 1):
                layer_features_used = self.model.features_used[-1][node_id][current_node_state][
                                      layer_index * self.model.state_size:(layer_index + 1) * self.model.state_size]
                for i in range(len(layer_features_used)):
                    if abs(layer_features_used[i]) > 0:
                        new_aggregated_importance = self.model.explanation_forward_v2[layer_index][node_id, i] * layer_features_used[i]
                        explanations[node_id] += new_aggregated_importance

        # explanations_counts[explanations_counts == 0] = 1.0
        # explanations = np.divide(explanations, explanations_counts)
        return explanations

    def get_forward_explanation_gc(self, skip_connection=False):
        explanations = np.zeros((self.model.num_nodes, ))

        current_node_state = self.model.outputs[-1][0]
        if skip_connection:
            decision_path_features = get_decision_path_features(self.model.pooling.tree, self.model.inputs[-1])[0]
            # print("decision_path_features", np.shape(decision_path_features), decision_path_features)
            features_used = self.model.features_used[-1][0][current_node_state]
            # print("features_used", np.shape(features_used), features_used)
            features_used_path_combined = np.multiply(decision_path_features, features_used)
            # print("features_used_path_combined", np.shape(features_used_path_combined), features_used_path_combined)
            for layer_index in range(len(self.model.features_used) - 1):
                layer_features_used = features_used_path_combined[layer_index * self.model.state_size:(layer_index + 1) * self.model.state_size]
                for i in range(len(layer_features_used)):
                    if abs(layer_features_used[i]) > 0:
                        nodes_in_state = (self.model.outputs[layer_index] == i).sum()
                        for node_id in range(self.model.num_nodes):
                            if self.model.outputs[layer_index][node_id] != i:
                                continue
                            new_aggregated_importance = self.model.explanation_forward_v2[layer_index][node_id, i] * \
                                                        (layer_features_used[i] / nodes_in_state)
                            explanations += new_aggregated_importance

            num_pooling_features = self.model.state_size * (len(self.model.features_used) - 1)
            features_used_linear_combination = features_used_path_combined[num_pooling_features:]

            for i1 in range(num_pooling_features):
                for i2 in range(num_pooling_features):
                    feat_value = features_used_linear_combination[(i1*num_pooling_features) + i2]
                    if feat_value == 0:
                        continue
                    layer_index_i1 = i1 // self.model.state_size
                    state_i1 = i1 % self.model.state_size
                    layer_index_i2 = i2 // self.model.state_size
                    state_i2 = i2 % self.model.state_size
                    feature_index_larger = i1 * num_pooling_features + i2
                    if abs(feat_value) > 0:
                        num_nodes_for_states = (self.model.outputs[layer_index_i1] == state_i1).sum() + (self.model.outputs[layer_index_i2] == state_i2).sum()
                        negative_multiplier = 1.0
                        if self.model.inputs[-1][0][
                            num_pooling_features + feature_index_larger] == 0:
                            negative_multiplier = -1.0

                        for node_id in range(self.model.num_nodes):
                            if self.model.outputs[layer_index_i1][node_id] == state_i1:
                                new_aggregated_importance = self.model.explanation_forward_v2[layer_index_i1][node_id, state_i1] * \
                                                            feat_value * negative_multiplier * (1/num_nodes_for_states)
                                explanations += new_aggregated_importance
                            if self.model.outputs[layer_index_i2][node_id] == state_i2:
                                new_aggregated_importance = self.model.explanation_forward_v2[layer_index_i2][
                                                                node_id, state_i2] * \
                                                            feat_value * negative_multiplier * -1 * (1/num_nodes_for_states)
                                explanations += new_aggregated_importance

        else:
            layer_features_used = self.model.features_used[-1][0][current_node_state]
            for i in range(len(layer_features_used)):
                if abs(layer_features_used[i]) > 0:
                    for node_id in range(self.model.num_nodes):
                        # if self.model.outputs[-2][node_id] != i:
                        #     continue
                        new_aggregated_importance = self.model.explanation_forward_v2[-2][node_id, i] * \
                                                    layer_features_used[i]
                        explanations += new_aggregated_importance

        # explanations_counts[explanations_counts == 0] = 1.0
        # explanations = np.divide(explanations, explanations_counts)
        explanations = np.array([explanations for _ in range(self.model.num_nodes)])
        return explanations

    def get_explanation(self, node_id, importance):
        aggregated_importance = np.zeros(self.model.num_nodes)
        aggregated_importance_counts = np.zeros(self.model.num_nodes)

        if self.model.subset is not None and node_id not in self.model.subset:
            return aggregated_importance

        current_node_state = self.model.outputs[-1][node_id]
        for layer_index in range(len(self.model.features_used) - 1):
            layer_features_used = self.model.features_used[-1][node_id][current_node_state][
                                  layer_index * self.model.state_size:(layer_index + 1) * self.model.state_size]
            for i in range(len(layer_features_used)):
                if abs(layer_features_used[i]) > 0:
                    new_aggregated_importance, new_aggregated_importance_count = \
                        self.backpropagate_explanation(node_id, layer_features_used[i]*importance, layer_index, state=i, return_avg=False)
                    aggregated_importance += new_aggregated_importance
                    aggregated_importance_counts += new_aggregated_importance_count

        # aggregated_importance = [aggregated_importance[i] / max(1, aggregated_importance_counts[i]) for i in range(self.model.num_nodes)]
        negative_multiplier = np.array([0 if x == 0 else (1 if x > 0 else -1) for x in aggregated_importance])
        aggregated_importance = np.array([abs(x) for x in aggregated_importance])
        normalized = aggregated_importance
        if max(aggregated_importance) - min(aggregated_importance) != 0:
            normalized = np.array((aggregated_importance - min(aggregated_importance)) / (max(aggregated_importance) - min(aggregated_importance)))
        normalized = normalized * negative_multiplier
        return normalized

    def get_explanation_gc(self, importance, skip_connection=True):
        aggregated_importance = np.zeros(self.model.num_nodes)
        aggregated_importance_counts = np.zeros(self.model.num_nodes)

        current_node_state = self.model.outputs[-1][0]
        if skip_connection:
            for layer_index in range(len(self.model.features_used) - 1):
                layer_features_used = self.model.features_used[-1][0][current_node_state][
                                      layer_index * self.model.state_size:(layer_index + 1) * self.model.state_size]
                for i in range(len(layer_features_used)):
                    if abs(layer_features_used[i]) > 0:
                        for node_id in range(self.model.num_nodes):
                            if self.model.outputs[layer_index][node_id] != i:
                                continue
                            new_aggregated_importance, new_aggregated_importance_count = \
                                self.backpropagate_explanation(node_id, layer_features_used[i]*importance, layer_index, state=i, return_avg=False)
                            aggregated_importance += new_aggregated_importance
                            aggregated_importance_counts += new_aggregated_importance_count
        else:
            layer_features_used = self.model.features_used[-1][0][current_node_state]
            for i in range(len(layer_features_used)):
                if abs(layer_features_used[i]) > 0:
                    for node_id in range(self.model.num_nodes):
                        if self.model.outputs[-2][node_id] != i:
                            continue
                        new_aggregated_importance, new_aggregated_importance_count = \
                            self.backpropagate_explanation(node_id, layer_features_used[i] * importance,
                                                           (len(self.model.features_used) - 2), state=i, return_avg=False)
                        aggregated_importance += new_aggregated_importance
                        aggregated_importance_counts += new_aggregated_importance_count

        # aggregated_importance = [aggregated_importance[i] / max(1, aggregated_importance_counts[i]) for i in range(self.model.num_nodes)]
        negative_multiplier = np.array([0 if x == 0 else (1 if x > 0 else -1) for x in aggregated_importance])
        aggregated_importance = np.array([abs(x) for x in aggregated_importance])
        normalized = aggregated_importance
        if max(aggregated_importance) - min(aggregated_importance) != 0:
            normalized = np.array((aggregated_importance - min(aggregated_importance)) / (
                        max(aggregated_importance) - min(aggregated_importance)))
        normalized = normalized * negative_multiplier
        return normalized
