from copy import deepcopy

import numpy as np
from sklearn.tree._tree import TREE_LEAF, TREE_UNDEFINED
from tqdm import tqdm

def prune_index(inner_tree, index):
    # turn node into a leaf by "unlinking" its children
    inner_tree.children_left[index] = TREE_LEAF
    inner_tree.children_right[index] = TREE_LEAF
    inner_tree.feature[index] = TREE_UNDEFINED
    # if there are shildren, visit them as well
    if inner_tree.children_left[index] != TREE_LEAF:
        prune_index(inner_tree, inner_tree.children_left[index])
        prune_index(inner_tree, inner_tree.children_right[index])


traversal = []
traversal_depth = []


def is_leaf(tree, node):
    if tree.tree_.children_left[node] == -1:
        return True
    else:
        return False


def postOrderTraversal(tree, root):
    if root != -1:
        postOrderTraversal(tree, tree.tree_.children_left[root])
        postOrderTraversal(tree, tree.tree_.children_right[root])
        if not is_leaf(tree, root):
            traversal.append(root)


def postOrderTraversalDepth(tree, root, current_depth=0):
    if root != -1:
        postOrderTraversalDepth(tree, tree.tree_.children_left[root], current_depth + 1)
        postOrderTraversalDepth(tree, tree.tree_.children_right[root], current_depth + 1)
        if not is_leaf(tree, root):
            traversal_depth.append(current_depth)


def scale_pooling_num_node_samples(estimator, tree_data, data):
    node_indicator = estimator.decision_path(tree_data)
    num_nodes = estimator.tree_.node_count
    new_sample_count = np.zeros(num_nodes)
    for sample_id in range(len(tree_data)):
        node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                            node_indicator.indptr[sample_id + 1]]

        for node_id in node_index:
            new_sample_count[node_id] += len(data[sample_id].x)

    for node_id in range(num_nodes):
        estimator.tree_.n_node_samples[node_id] = new_sample_count[node_id]

    # print(f'Root Samples After: {estimator.tree_.n_node_samples[0]}')
    return estimator


def get_traversal_scheme(trees, layer_names, sort_by="num_samples", decending=False):
    traversal_list = []
    for layer_name in layer_names:
        tree = trees[layer_name]
        traversal.clear()
        postOrderTraversal(tree, 0)
        traversal_depth.clear()
        postOrderTraversalDepth(tree, 0)
        for idx, node_id in enumerate(traversal):
            traversal_list.append({
                "tree": layer_name,
                "node_id": node_id,
                "num_samples": tree.tree_.n_node_samples[node_id],
                "impurity": tree.tree_.impurity[node_id],
                "depth": traversal_depth[idx]
            })

    return sorted(traversal_list, key=lambda d: d[sort_by], reverse=decending)


def score_model(m, d, mask_index=0, use_pooling=True, mask=None):
    m.eval()
    total_correct = 0
    for data in d:
        out = m(data.x, data.edge_index, data.batch)
        correct = int((out.argmax(-1) == data.y).sum())
        if not use_pooling:
            correct /= len(data.y)
        total_correct += correct
    acc = total_correct / len(d.dataset)
    return acc


def getMaxDepth(tree, root, current_depth=0):
    if root != -1:
        return max(getMaxDepth(tree, tree.tree_.children_left[root], current_depth + 1),
                   getMaxDepth(tree, tree.tree_.children_right[root], current_depth + 1))
    return current_depth


def prune_trees_all_val(trees, dt_model, layer_names, data_train, data_val, score_model=score_model,
                        sort_by="num_samples", decending=False, debug=False, REP_val=False, REP_train=False):
    score_before_prune_train = score_model(dt_model, data_train)
    score_before_prune_val = score_model(dt_model, data_val)
    traversal_scheme = get_traversal_scheme(trees, layer_names, sort_by=sort_by, decending=decending)
    to_remove = []
    tq = tqdm(traversal_scheme, disable=(not debug))
    index_counter = 0
    for traversal_element in tq:
        node_index = traversal_element["node_id"]
        # if node_index == 0:
        #     continue
        tree_key = traversal_element["tree"]
        original_tree = deepcopy(trees[tree_key])
        tree_copy = deepcopy(trees[tree_key])
        prune_index(tree_copy.tree_, node_index)
        trees[tree_key] = tree_copy
        dt_model.update_trees(trees)
        new_score_train = score_model(dt_model, data_train)
        new_score_val = score_model(dt_model, data_val)
        if (REP_train and (new_score_train >= score_before_prune_train)) or (
                REP_val and (new_score_val >= score_before_prune_val)) or (not REP_val and not REP_train and (
                (new_score_train >= score_before_prune_train and new_score_val >= score_before_prune_val) or \
                (new_score_train >= new_score_val >= score_before_prune_val))):
            tree = tree_copy
            trees[tree_key] = tree
            to_remove.append(node_index)
        else:
            trees[tree_key] = original_tree
            dt_model.update_trees(trees)
        index_counter += 1
        tq.set_postfix(removed=f'{len(to_remove)}/{index_counter}')
    if len(to_remove) != 0 and debug:
        print(f'Pruning: {len(to_remove)} Nodes Removed')
    for layer_name in layer_names:
        trees[layer_name].tree_.max_depth = getMaxDepth(trees[layer_name], 0, 0)
    return trees, len(to_remove)


def get_num_nodes(trees, layer_names):
    traversal_scheme = get_traversal_scheme(trees, layer_names)
    return len(traversal_scheme)


def prune_trees_least_influencial(trees, dt_model, layer_names, data, remove_nodes=1, score_model=score_model,
                                  debug=False, mask=None):
    score_before_prune = score_model(dt_model, data, mask=mask)
    traversal_scheme = get_traversal_scheme(trees, layer_names)
    score_changes = []
    tq = tqdm(traversal_scheme, disable=(not debug))
    for traversal_element in tq:
        node_index = traversal_element["node_id"]
        tree_key = traversal_element["tree"]
        original_tree = deepcopy(trees[tree_key])
        tree_copy = deepcopy(trees[tree_key])
        prune_index(tree_copy.tree_, node_index)
        trees[tree_key] = tree_copy
        dt_model.update_trees(trees)
        new_score = score_model(dt_model, data, mask=mask)
        score_changes.append(score_before_prune - new_score)
        trees[tree_key] = original_tree
        dt_model.update_trees(trees)

    best_remove_indexes = np.argsort(score_changes)
    for i in range(min(remove_nodes, len(best_remove_indexes))):
        traversal_element = traversal_scheme[best_remove_indexes[i]]
        node_index = traversal_element["node_id"]
        tree_key = traversal_element["tree"]
        prune_index(trees[tree_key].tree_, node_index)
    for layer_name in layer_names:
        trees[layer_name].tree_.max_depth = getMaxDepth(trees[layer_name], 0, 0)
    dt_model.update_trees(trees)
    return trees
