import time
# from tqdm import tqdm
import pandas as pd
from itertools import product, combinations_with_replacement
from tree import *
import copy
from typing    import Dict, Iterable, List, Tuple

# # thresholds = [(1, 1.5), (22, 23), (3, 4), (1, 1.5), (58, 59)]
# thresholds = [(df[col].min(), df[col].max() + 1) for col in df.columns[:-1]]
# print(df)
# print(thresholds)
# print(count_instances_within_thresholds(df, thresholds))


def count_instances_within_thresholds(df, thresholds):
    # Identify the class column (assumed to be the last column)
    class_column = df.columns[-1]
    attribute_columns = df.columns[:-1]
    
    # Create a boolean mask for each attribute based on the thresholds
    # masks = [(df[col] >= left) & (df[col] < right)
    #          for col, (left, right) in zip(attribute_columns, thresholds)]
    masks = [(df[col] <= thresholds[col][1]) & (df[col] > thresholds[col][0])
                 for col in thresholds.keys()]
    
    # Combine all masks
    combined_mask = pd.Series(True, index=df.index)
    for mask in masks:
        combined_mask &= mask
    
    # Count instances of each class within the thresholds
    class_counts = df.loc[combined_mask, class_column].value_counts().to_dict()
    
    return class_counts


Pair   = Tuple[float, float]                 # (left, right)
Pairs  = Dict[str, Pair]                     # dim -> (left, right)
Counts = Dict[str, int]                      # dim -> 1 or 2




def threshold_seqs(
        node,
        total_dimens: List[str],
        total_threshs: Dict[str, List[float]],
        node_dimens: List[str],
        node_threshs: Dict[str, List[float]],
        overall_min, overall_max,
        k_exch: int, k_adj: int, k_rais: int, limit: int
) -> Iterable[Pairs]:
    """
    Process the nodes on the path from node to root. For each node
    v, branch whether to change its cut or not. If not, add v's cut to
    the threshold sequence and recurse. If yes, then go over the
    possible modifications and reduce the budgets accordingly.

    Assumes no dimension is "".
    """
    
    non_node_dimens = list(set(total_dimens) - set(node_dimens))
    non_node_threshs = {dim: [thresh for thresh in total_threshs[dim] if thresh not in node_threshs[dim]] for dim in node_dimens}


    def _threshold_seqs_generator(
            node,
            previous_node_is_left_child: bool,
            current_seq: Iterable[Pairs],
            largest_selected_dim: str,
            was_guessed,
            k_exch: int, k_adj: int, k_rais: int
    ) -> Iterable[Pairs]:

        def dim_choosable(dim: str) -> int:
            if not (dim in current_seq.keys()):
                return 1        # new dimension
            if not (current_seq[dim][0] == min(total_threshs[dim])):
                return 2        # already chosen and introduced lower bound
            if not (current_seq[dim][1] == max(total_threshs[dim])):
                return 3        # already chosen and introduced upper bound
            return 0

        def _ts_update(current_seq, dim, thresh, is_upper_bound):
            # Create left threshold sequence
            new_seq = current_seq.copy()
            if dim in new_seq:
                left_thresh, right_thresh = new_seq[dim]
                if is_upper_bound:
                    right_thresh = min(right_thresh, thresh)
                else:
                    left_thresh = max(left_thresh, thresh)
            else:
                if is_upper_bound:
                    left_thresh, right_thresh = overall_min[dim], thresh
                else:
                    left_thresh, right_thresh = thresh, overall_max[dim]
            new_seq[dim] = (left_thresh, right_thresh)
            nonempty_box = (left_thresh < right_thresh)
            if left_thresh == overall_min[dim] and right_thresh == overall_max[dim]:
                new_seq.pop(dim)

            return nonempty_box, new_seq

        if node is None:
            # The current node is the root node. We're done.
            yield current_seq
            return

        current_node_is_left_child = (not node.parent) or (node.parent.left is node)

        # Option 1: Current node is not changed. Add its cut to
        # sequence.
        nonempty_box, new_seq = _ts_update(current_seq, node.dimension, node.threshold, previous_node_is_left_child)
        if nonempty_box:
            # print(f"Adding node {node.to_string()}")
            yield from _threshold_seqs_generator(
                node.parent, current_node_is_left_child, new_seq, largest_selected_dim, was_guessed,
                k_exch, k_adj, k_rais)

        # Option 2: Current node got raised. Then simply ignore it.
        if k_rais > 0:
            yield from _threshold_seqs_generator(
                node.parent, current_node_is_left_child, current_seq, largest_selected_dim, was_guessed,
                k_exch, k_adj, k_rais - 1)

        # Option 3: Current node is adjusted. Go over all thresholds
        # that would move the cut. (I.e. below we may have fixed some
        # cut in the same dimension. If we change the current
        # threshold to something weaker it would not change the box,
        # so we can ignore this.)
        if k_adj > 0:
            # We should not do this twice in the same dimension and
            # direction.
            if not (node.dimension, previous_node_is_left_child) in was_guessed:
                useful_threshs = []
                if not node.dimension in current_seq:
                    useful_threshs = total_threshs[node.dimension]
                    # Cannot ignore min, max, because they serve as a
                    # proxy for not being tightest in this box.
                else:
                    useful_threshs = [thresh for thresh in total_threshs[node.dimension]
                                      if (current_seq[node.dimension][0] <= thresh and thresh <= current_seq[node.dimension][1])]
                for thresh in useful_threshs:
                    nonempty_box, new_seq = _ts_update(current_seq, node.dimension, thresh, previous_node_is_left_child)
                    if nonempty_box:
                        new_was_guessed = was_guessed.copy()
                        new_was_guessed[node.dimension, previous_node_is_left_child] = True
                        yield from _threshold_seqs_generator(
                            node.parent, current_node_is_left_child, new_seq, largest_selected_dim, new_was_guessed,
                            k_exch, k_adj - 1, k_rais)
            else:
                # It could be that this node was also adjusted but we
                # have already selected the tightest threshold. So,
                # simply ignore this node.
                yield from _threshold_seqs_generator(
                    node.parent, current_node_is_left_child, current_seq, largest_selected_dim, was_guessed,
                    k_exch, k_adj - 1, k_rais)
                

        # Option 4: Current node is exchanged.
        if k_exch > 0:
            # print(f"Trying to exchange {node.to_string()}")
            for dim in total_dimens:
                if not (dim, previous_node_is_left_child) in was_guessed:
                    # if dim > largest_selected_dim:
                    if True:
                        useful_threshs = []
                        if not dim in current_seq:
                            useful_threshs = total_threshs[dim]
                        else:
                            useful_threshs = [thresh for thresh in total_threshs[dim]
                                              if (current_seq[dim][0] <= thresh and thresh <= current_seq[dim][1])]
                        for thresh in useful_threshs:
                            nonempty_box, new_seq = _ts_update(current_seq, dim, thresh, previous_node_is_left_child)
                            if nonempty_box:
                                new_was_guessed = was_guessed.copy()
                                new_was_guessed[dim, previous_node_is_left_child] = True
                                yield from _threshold_seqs_generator(
                                    node.parent, current_node_is_left_child, new_seq, dim, new_was_guessed,
                                    k_exch - 1, k_adj, k_rais)
                else:
                    # It could be that this node was also exchanged
                    # but we have already selected the tightest
                    # threshold in that dimension.
                    yield from _threshold_seqs_generator(
                        node.parent, current_node_is_left_child, current_seq, largest_selected_dim, was_guessed,
                        k_exch - 1, k_adj, k_rais)

    if not node.parent:
        yield {}
    else:
        yield from _threshold_seqs_generator(node.parent, node.parent.left is node, {}, "", {}, k_exch, k_adj, k_rais)


def gen_threshold_key(threshold_seq):
    # print(threshold_seq)
    key = tuple(sorted(threshold_seq.items()))
    # print(key)
    return key


def initialize_table(
        Q, tree, df,
        total_dimens, total_threshs, tree_dimens, tree_threshs, overall_min, overall_max,
        k_exch, k_adj, k_rais
):
    # print(tree.leaves)

    # for leaf in tqdm(tree.leaves, position = 1, desc = "Leaves"):
    for index, leaf in enumerate(tree.leaves):
        print(f"Working on leaf {leaf.to_string()} (index {index + 1} of {len(tree.leaves)})")
        # for threshold_seq in tqdm(threshold_seqs(tree_dimens[leaf], tree_threshs[leaf]),
        #                           position = 0, desc = "Entries",
        #                           total = compute_num_threshold_seqs(tree_dimens[leaf], tree_threshs[leaf])):
        # num_threshold_seqs = compute_num_threshold_seqs(tree_dimens[leaf], tree_threshs[leaf])
        duplicate_entries = 0

        for index, threshold_seq in enumerate(threshold_seqs(
                leaf, total_dimens, total_threshs, tree_dimens[leaf], tree_threshs[leaf], overall_min, overall_max,
                k_exch, k_adj, k_rais, leaf.depth)):
            # print(threshold_seq)
            if index % 1000 == 0:
                # print(f"Working on threshold-sequence {index + 1} of {num_threshold_seqs}")
                print(f"Working on threshold-sequence {index + 1}")
            threshold_key = gen_threshold_key(threshold_seq)
            
            if (leaf, threshold_key, 0, 0, 0, 0) in Q:
                duplicate_entries +=1
                continue
            
            class_counts = count_instances_within_thresholds(df, threshold_seq)
            
            if len(class_counts.values()) < 2:
                Q[leaf, threshold_key, 0, 0, 0, 0] = 0
            else:
                Q[leaf, threshold_key, 0, 0, 0, 0] = min(class_counts.values())
        print(f"Duplicate entries: {duplicate_entries}")


def compute_errors(df, tree, tree_thresh):

    def traverse_with_path(node, current_path):
        # Add the current node to the path
        current_path.append(node)

        if node.is_inner():
            # Continue traversing left and right children
            left_errors = traverse_with_path(node.left, current_path)
            right_errors = traverse_with_path(node.right, current_path)
            current_path.pop()
            return left_errors + right_errors
        else:
            masks = []
            for index, path_node in enumerate(current_path[:-1]):
                if current_path[index + 1] is path_node.left:
                    masks.append((df[path_node.dimension] <= path_node.threshold))
                else:
                    assert(current_path[index + 1] is path_node.right)
                    masks.append((df[path_node.dimension] > path_node.threshold))
                
            combined_mask = pd.Series(True, index=df.index)
            for mask in masks:
                combined_mask &= mask

            class_column = df.columns[-1]
            class_counts = df.loc[combined_mask, class_column].value_counts().to_dict()

            new_errors = class_counts[1 - node.class_label] if (1 - node.class_label) in class_counts.keys() else 0
            current_path.pop()
            return new_errors

    errors = traverse_with_path(tree.root, [])

    return errors


def compute_tree_dimens_threshs(tree):
    threshs = {}                # Dict of nodes -> dimens -> list of
    # thresholds strictly above node in
    # input tree.
    dimens = {}                 # Dict of nodes -> list of dimensions
    # strictly above node input tree.

    def collect_dimens(path):
        return list(set([node.dimension for node in path]))

    def collect_threshs(path, dimens):
        return {dimen: sorted(list(set([node.threshold for node in path if node.dimension == dimen]))) for dimen in dimens}

    def traverse_with_path(node, current_path):
        if node is None:
            return

        dimens[node] = collect_dimens(current_path)
        threshs[node] = collect_threshs(current_path, dimens[node])
        

        # Add the current node to the path
        current_path.append(node)

        if isinstance(node, TreeNode):
            # Continue traversing left and right children
            traverse_with_path(node.left, current_path)
            traverse_with_path(node.right, current_path)

        # Remove the current node from the path as we backtrack
        current_path.pop()

    traverse_with_path(tree.root, [])

    return (dimens, threshs)


def compute_relevant_threshs(df, total_dimens, tree_dimens, tree_threshs):
    """
    For each column in the dataframe, generate thresholds between unique values. (Assumes all columns are numerical.)
    
    Args:
        df (pd.DataFrame): Input dataframe with numerical columns
        
    Returns:
        dict: Dictionary where keys are column names and values are lists of thresholds
    """
    thresholds = {}

    # Process each column
    for dim in total_dimens:
        assert dim != ""
        if dim == "class":
            continue

        # Get unique values and sort them
        unique_values = sorted(df[dim].unique())
        
        # Initialize thresholds list
        column_thresholds = []
        
        # Add threshold below minimum
        if len(unique_values) > 0:
            min_val = unique_values[0]
            
            # Create a threshold below the minimum
            if min_val == 0:
                min_threshold = -0.1
            else:
                min_threshold = min_val - abs(min_val) * 0.1
            
            column_thresholds.append(min_threshold)
            
            # Assert that min threshold is not in unique values
            assert min_threshold not in unique_values, f"Min threshold {min_threshold} is in unique values {unique_values}"
        
        # Add thresholds between consecutive unique values
        for i in range(len(unique_values) - 1):
            relevant_threshold = None

            # Check if there is already a threshold in tree_threshs
            # that separates the two values. If so, use that.
            inbetween, have_any = None, False
            if dim in tree_dimens:
                for elem in tree_threshs[dim]:
                    if unique_values[i] <= elem and elem < unique_values[i]:
                        inbetween = elem
                        have_any = True
                
            if have_any:
                relevant_threshold = inbetween
            else:
                midpoint = (unique_values[i] + unique_values[i + 1]) / 2
            
                # Ensure the midpoint is not exactly equal to either of the adjacent values
                # This should be mathematically impossible unless the values are identical
                # But we'll add an assertion as a safeguard
                assert midpoint != unique_values[i] and midpoint != unique_values[i + 1], \
                    f"Midpoint {midpoint} equals one of the adjacent values {unique_values[i]} or {unique_values[i + 1]}"
                
                relevant_threshold = midpoint
            
            column_thresholds.append(relevant_threshold)
        
        # Add threshold above maximum
        if len(unique_values) > 0:
            max_val = unique_values[-1]
            
            # Create a threshold above the maximum
            if max_val == 0:
                max_threshold = 0.1
            else:
                max_threshold = max_val + abs(max_val) * 0.1
            
            column_thresholds.append(max_threshold)
            
            # Assert that max threshold is not in unique values
            assert max_threshold not in unique_values, f"Max threshold {max_threshold} is in unique values {unique_values}"
        
        # Store thresholds for this column
        thresholds[dim] = column_thresholds

    return thresholds


def ts_update(node, overall_min, overall_max, threshold_seq, dim=None, thresh=None):
    if dim is None and thresh is None: # update_type == "no change":
        dimension = node.dimension
        threshold = node.threshold
    elif dim is None and thresh is not None: #update_type == "threshold adjustment":
        dimension = node.dimension
        threshold = thresh
    elif dim is not None and thresh is not None: # update_type == "cut exchange":
        dimension = dim
        threshold = thresh
    else:
        raise ValueError(f"Invalid update_type: {update_type}. Expected 'no change', 'cut exchange', or 'threshold adjustment'.")
    
    # Create left threshold sequence
    left_threshold_seq = copy.deepcopy(threshold_seq)
    if dimension in left_threshold_seq:
        left_thresh, right_thresh = left_threshold_seq[dimension]
        right_thresh = min(right_thresh, threshold)
    else:
        left_thresh, right_thresh = overall_min[dimension], threshold
    left_threshold_seq[dimension] = (left_thresh, right_thresh)
    if left_thresh == overall_min[dimension] and right_thresh == overall_max[dimension]:
        left_threshold_seq.pop(dimension)
    left_threshold_key = gen_threshold_key(left_threshold_seq)
    check_left = (left_thresh < right_thresh)
    
    # Create right threshold sequence
    right_threshold_seq = copy.deepcopy(threshold_seq)
    if dimension in right_threshold_seq:
        left_thresh, right_thresh = right_threshold_seq[dimension]
        left_thresh = max(left_thresh, threshold)
    else:
        left_thresh, right_thresh = threshold, overall_max[dimension]
    right_threshold_seq[dimension] = (left_thresh, right_thresh)
    if left_thresh == overall_min[dimension] and right_thresh == overall_max[dimension]:
        right_threshold_seq.pop(dimension)
    right_threshold_key = gen_threshold_key(right_threshold_seq)
    check_right = (left_thresh < right_thresh)
    
    return (left_threshold_key, right_threshold_key, check_left, check_right)


def fill_table(
        Q, QQ, tree, df,
        total_dimens, total_threshs, tree_dimens, tree_threshs, overall_min, overall_max,
        k_exch, k_adj, k_rais, k_repl):
    """Compute all necessary table entries of Q and QQ.""" 

    # for node in tqdm(tree.post_order_traversal(), position = 1, desc = "Nodes"):
    for index, node in enumerate(tree.post_order_traversal()):
        if node.is_leaf():
            continue
        print(f"Working on node {node.to_string()} (index {index + 1} of {len(tree.post_order_traversal())})")

        current_subtree_size = node.subtree_size
        left_subtree_size = 0 if node.left.is_leaf() else node.left.subtree_size
        right_subtree_size = 0 if node.right.is_leaf() else node.right.subtree_size

        # Now iterate over the budgets for operations at or below the
        # current node.
        for (kk_exch, kk_adj, kk_rais, kk_repl) in product(range(k_exch + 1), range(k_adj + 1), range(k_rais + 1), range(k_repl + 1)):
            print(f"Working on (kk_exch, kk_adj, kk_rais, kk_repl) = {(kk_exch, kk_adj, kk_rais, kk_repl)}")

            if kk_exch + kk_adj + kk_rais + kk_repl > current_subtree_size:
                continue
            
            for index, threshold_seq in enumerate(threshold_seqs(
                    node, total_dimens, total_threshs, tree_dimens[node], tree_threshs[node], overall_min, overall_max,
                    k_exch - kk_exch, k_adj - kk_adj, k_rais - kk_rais, node.depth)):
                if index % 1000 == 0:
                    print(f"Working on seq {index}")

                threshold_key = gen_threshold_key(threshold_seq)
                if (node, threshold_key, kk_exch, kk_adj, kk_rais, kk_repl) in Q:
                    continue

                accu = float("inf")
                accu_ref = None

                # We first look at all options that require iterating
                # over (almost) all budget assignments to left/right.
                # Here: choose the budgets for the left subtree.
                for (kkk_exch, kkk_adj, kkk_rais, kkk_repl) in \
                        product(range(kk_exch + 1), range(kk_adj + 1), range(kk_rais + 1), range(kk_repl + 1)):

                    # The first option is to not do an operation at the
                    # current node. Need to iterate over all budgets for
                    # left (giving those for right), look up.
                    
                    if kkk_exch + kkk_adj + kkk_rais + kkk_repl <= left_subtree_size and \
                       kk_exch + kk_adj + kk_rais + kk_repl - kkk_exch - kkk_adj - kkk_rais - kkk_repl <= right_subtree_size:
                        # Can skip any budget tuple where we cannot do
                        # that many operations in the corresponding
                        # subtree. 

                        left_threshold_key, right_threshold_key, check_left, check_right = \
                            ts_update(node, overall_min, overall_max, threshold_seq)
                        left_misclassifications = Q[node.left, left_threshold_key, kkk_exch, kkk_adj, kkk_rais, kkk_repl] if check_left else 0
                        right_misclassifications = Q[node.right, right_threshold_key, kk_exch - kkk_exch, kk_adj - kkk_adj, kk_rais - kkk_rais, kk_repl - kkk_repl] if check_right else 0
                        
                        compare = left_misclassifications + right_misclassifications

                        if compare < accu:
                            accu = compare
                            accu_ref = ("no op",
                                        (node.left, left_threshold_key, kk_exch, kk_adj, kk_rais, kk_repl),
                                        (node.right, right_threshold_key, kk_exch - kkk_exch, kk_adj - kkk_adj, kk_rais - kkk_rais, kk_repl - kkk_repl))

                    # Second option: cut exchange.

                    # We arbitrarily assume the budget for the current
                    # node is taken out of the budget for the left
                    # node. We go over all options, so this is fine.
                    if kkk_exch >= 1 and \
                       (kkk_exch - 1) + kkk_adj + kkk_rais + kkk_repl <= left_subtree_size and \
                       kk_exch + kk_adj + kk_rais + kk_repl - kkk_exch - kkk_adj - kkk_rais - kkk_repl <= right_subtree_size:
                    
                        for dim in total_dimens:
                            # Only introduce cuts that indeed restrict
                            # the set of examples arriving at the node
                            # further.
                            useful_threshs = []
                            if not dim in threshold_seq:
                                useful_threshs = total_threshs[dim][1:-1]
                                # Ignore min and max threshold.
                            else:
                                useful_threshs = [thresh for thresh in total_threshs[dim]
                                                  if (threshold_seq[dim][0] <= thresh and thresh <= threshold_seq[dim][1])]
                            for thresh in useful_threshs:
                                left_threshold_key, right_threshold_key, check_left, check_right = \
                                    ts_update(node, overall_min, overall_max, threshold_seq, dim=dim, thresh=thresh)

                                left_misclassifications = Q[node.left, left_threshold_key, kkk_exch - 1, kkk_adj, kkk_rais, kkk_repl] if check_left else 0
                                right_misclassifications = Q[node.right, right_threshold_key, kk_exch - kkk_exch, kk_adj - kkk_adj, kk_rais - kkk_rais, kk_repl - kkk_repl] if check_right else 0

                                compare = left_misclassifications + right_misclassifications

                                if compare < accu:
                                    accu = compare
                                    accu_ref = ("cut exch",
                                                (node.left, left_threshold_key, kk_exch - 1, kk_adj, kk_rais, kk_repl),
                                                (node.right, right_threshold_key, kk_exch - kkk_exch, kk_adj - kkk_adj, kk_rais - kkk_rais, kk_repl - kkk_repl))

                    # Third option: Threshold adjustment.
                    if kkk_adj >= 1 and \
                       kkk_exch + (kkk_adj - 1) + kkk_rais + kkk_repl <= left_subtree_size and \
                       kk_exch + kk_adj + kk_rais + kk_repl - kkk_exch - kkk_adj - kkk_rais - kkk_repl <= right_subtree_size:
                    
                        # Only change threshold to values that indeed
                        # restrict the set of examples arriving at the
                        # node further.
                        useful_threshs = []
                        if not node.dimension in threshold_seq:
                            useful_threshs = total_threshs[node.dimension][1:-1]
                            # Ignore min and max threshold.
                        else:
                            useful_threshs = [thresh for thresh in total_threshs[node.dimension]
                                              if (threshold_seq[node.dimension][0] <= thresh and thresh <= threshold_seq[node.dimension][1])]
                        for thresh in useful_threshs:
                            left_threshold_key, right_threshold_key, check_left, check_right = \
                                ts_update(node, overall_min, overall_max, threshold_seq, thresh=thresh)

                            left_misclassifications = Q[node.left, left_threshold_key, kkk_exch, kkk_adj - 1, kkk_rais, kkk_repl] if check_left else 0
                            right_misclassifications = Q[node.right, right_threshold_key, kk_exch - kkk_exch, kk_adj - kkk_adj, kk_rais - kkk_rais, kk_repl - kkk_repl] if check_right else 0

                            compare = left_misclassifications + right_misclassifications

                            if compare < accu:
                                accu = compare
                                accu_ref = ("cut exch",
                                            (node.left, left_threshold_key, kk_exch, kk_adj - 1, kk_rais, kk_repl),
                                            (node.right, right_threshold_key, kk_exch - kkk_exch, kk_adj - kkk_adj, kk_rais - kkk_rais, kk_repl - kkk_repl))                                

                # Fourth option: Replacement.
                if kk_exch == 0 and kk_adj == 0 and kk_rais == 0 and kk_repl == node.subtree_size:
                    class_counts = count_instances_within_thresholds(df, threshold_seq)
                    if len(class_counts.values()) < 2:
                        compare = 0
                    else:
                        compare = min(class_counts.values())

                    if compare < accu:
                        accu = compare
                        accu_ref = ("replacement")

                # Fifth option: Subtree raising.
                
                # Here we don't actually need to update the threshold
                # sequence since we are just pushing all examples that
                # arrive at the current node to either the left or
                # right subtree.
                other_threshold_seq = copy.deepcopy(threshold_seq)
                other_threshold_key = gen_threshold_key(other_threshold_seq)

                if left_subtree_size + 1 == kk_rais and right_subtree_size >= kk_exch + kk_adj + kk_repl + kk_rais - left_subtree_size - 1:
                    compare = Q[node.right, other_threshold_key, kk_exch, kk_adj, kk_rais - left_subtree_size - 1, kk_repl]
                    if compare < accu:
                        accu = compare
                        accu_ref = ("raising", node.right, other_threshold_key, kk_exch, kk_adj, kk_rais - left_subtree_size - 1, kk_repl)

                if right_subtree_size + 1 == kk_rais and left_subtree_size >= kk_exch + kk_adj + kk_repl + kk_rais - right_subtree_size - 1:
                    compare = Q[node.left, other_threshold_key, kk_exch, kk_adj, kk_rais - right_subtree_size - 1, kk_repl]
                    if compare < accu:
                        accu = compare
                        accu_ref = ("raising", node.left, other_threshold_key, kk_exch, kk_adj, kk_rais - right_subtree_size - 1, kk_repl)

                # Update table with best value.
                Q[node, threshold_key, kk_exch, kk_adj, kk_rais, kk_repl] = accu
                QQ[node, threshold_key, kk_exch, kk_adj, kk_rais, kk_repl] = accu_ref


# def construct_solution(QQ, k, tree):
#     root_threshold_seq = {}      # dimens,
#     root_threshold_key = gen_threshold_key(root_threshold_seq)

#     stack = [(None, True, tree.root, QQ[tree.root, root_threshold_key, k])]
#     # Solution operations: List of (node, left (True) / right (False))
#     # of nodes plus which subtree to remove (raise the other).
#     # solution_ops = []
#     solution_tree = Tree()
    
#     while(len(stack) > 0):
#         # print(stack)
#         parent, attachment_side, current_node, current_key = stack.pop()

#         if len(current_key) == 2:
#             # The minimum at the current came from /not/ cutting the
#             # current node.

#             # Insert current node into solution tree. If parent is
#             # None, this will insert the root.
#             current_sol_node = solution_tree.insert_inner(current_node.dimension, current_node.threshold, parent = parent, left = attachment_side)

#             # Push keys etc. for children onto stack
#             left_key, right_key = current_key
#             stack.append((current_sol_node, False, current_node.right, right_key))
#             stack.append((current_sol_node, True, current_node.left, left_key))
#         else:
#             # The minimum at the current node came from pruning the
#             # current node and either the left or right child. Find
#             # out which child it is, and push it onto the stack. Keep
#             # the parent the same.
#             assert(len(current_key) == 3)

#             preserved_node, threshold_key, kk = current_key

#             if preserved_node.is_leaf():
#                 # print(f"Adding leaf {preserved_node} below {parent}")
#                 solution_tree.insert_leaf(parent, left = attachment_side, class_label = preserved_node.class_label)
#             else:
#                 new_key = QQ[preserved_node, threshold_key, kk]
#                 stack.append((parent, attachment_side, preserved_node, new_key))

#     return solution_tree


def local_search_tree(tree, df, k_exch, k_adj, k_rais, k_repl):
    start_time = time.time()

    if k_exch == None:
        k_exch = tree.root.subtree_size
    if k_adj == None:
        k_adj = tree.root.subtree_size
    if k_rais == None:
        k_rais = tree.root.subtree_size
    if k_repl == None:
        k_repl = tree.root.subtree_size

    compute_dimens_threshs_start = time.time()
    tree_dimens, tree_threshs = compute_tree_dimens_threshs(tree)
    compute_dimens_threshs_end = time.time()

    total_dimens = df.columns[:-1] # The last column is the class.
    total_threshs = compute_relevant_threshs(df, total_dimens, tree_dimens, tree_threshs)
    overall_min = {dim: min(total_threshs[dim]) for dim in total_dimens}
    overall_max = {dim: max(total_threshs[dim]) for dim in total_dimens}

    # print(overall_min)
    # print(overall_max)

    # print(tree_dimens)
    # print(total_dimens)
    assert all(set(dimens).issubset(set(total_dimens)) for dimens in tree_dimens.values())
    
    compute_errors_start = time.time()
    errors = compute_errors(df, tree, tree_threshs)
    compute_errors_end = time.time()

    tree_size = len(tree.inner_nodes)
    print(f"The initial tree has {tree_size} inner nodes and {errors} errors.")

    # print(tree_threshs)
    # print(total_threshs)

    # The table mapping node, threshold sequence, k_exch, k_adj,
    # k_rais, k_repl to the minimum number of errors when doing
    # *exactly* k_... of the corresponding operation in node's
    # subtree.
    Q = {}
    # Table saying for each entry of Q for inner nodes, which other
    # entry/entries of Q lead to the minimum number of errors. For
    # constructing the modified tree.
    QQ = {}

    # Compute table for all leaf nodes.
    initialize_table_start = time.time()
    print("Initializing table")
    initialize_table(Q, tree, df, total_dimens, total_threshs, tree_dimens, tree_threshs, overall_min, overall_max, k_exch, k_adj, k_rais)
    initialize_table_end = time.time()

    # Compute table for all inner nodes using recurrence.
    fill_table_start = time.time()
    print("Filling table with recurrence")
    fill_table(Q, QQ, tree, df, total_dimens, total_threshs, tree_dimens, tree_threshs, overall_min, overall_max, k_exch, k_adj, k_rais, k_repl)
    fill_table_end = time.time()

    # For reading out the solutions, we need the (empty) threshold
    # lists for the root node.
    sol_threshold_seq = {}
    sol_threshold_key = gen_threshold_key(sol_threshold_seq)

    # Compute dict of "k" -> "minimum number of errors achievable when
    # pruning exactly k nodes" and a dict of "k" -> "pruned tree
    # achieving minimum number of errors"
    # IMPLEMENT ME: Also allow pruning all inner nodes of the tree.
    construct_solution_start = time.time()
    min_errors = {}
    min_error_pruned_trees = {}
    for (kk_exch, kk_adj, kk_rais, kk_repl) in product(range(k_exch + 1), range(k_adj + 1), range(k_rais + 1), range(k_repl + 1)):
        if kk_exch + kk_adj + kk_rais + kk_repl > tree.root.subtree_size:
            continue
        min_errors[kk_exch, kk_adj, kk_rais, kk_repl] = Q[tree.root, sol_threshold_key, kk_exch, kk_adj, kk_rais, kk_repl]
        # if min_errors[k] < float("inf"):
        #     min_error_pruned_trees[k] = solution_tree = construct_solution(QQ, k, tree)
        # else:
        #     min_error_pruned_trees[k] = None
    construct_solution_end = time.time()

    end_time = time.time()

    # Report timings
    print(f"Total execution time: {end_time - start_time:.4f} seconds")
    print(f"Compute dimensions and thresholds time: {compute_dimens_threshs_end - compute_dimens_threshs_start:.4f} seconds")
    print(f"Compute errors time: {compute_errors_end - compute_errors_start:.4f} seconds")
    print(f"Initialize table time: {initialize_table_end - initialize_table_start:.4f} seconds")
    print(f"Fill table time: {fill_table_end - fill_table_start:.4f} seconds")
    print(f"Construct solution time: {construct_solution_end - construct_solution_start:.4f} seconds")

    return min_errors, min_error_pruned_trees
