def set_keys(η):
    """ 
        This function must be called on the Root node of the tree before printing it. 
    """
    if η.is_root() and η.is_leaf():
        η.key = 1 
        η.keys_set = True 
        return 1 

    elif η.is_root() and not η.is_leaf():
        η.key = 1 
        η.keys_set = True # Indicate that the keys have been set. 
        if η.is_leaf():
            return 1
        else:
            set_keys(η.left)
            set_keys(η.right)
    
    elif not η.is_leaf():
        if η.is_right_child():
            η.key = 2*η.parent.key + 1 
            set_keys(η.left)
            set_keys(η.right)
        elif η.is_left_child():
            η.key = 2*η.parent.key
            set_keys(η.left)
            set_keys(η.right)
            
    else:
        if η.is_right_child():
            η.key = 2*η.parent.key + 1
        elif η.is_left_child():
            η.key = 2*η.parent.key
    return 1   



def logistic(x):
    return 1/(1 + torch.exp(-x))


def create_tree(α, max_depth, η, tree_cursor = 1):
    """
        Creates tree structure given an array α of Bernoulli splits. 
        Recurses until 
            (a) The split was unsuccessful (α[tree_cursor] = 0)
            (b) The maximal depth of the tree in the variational family is reached. 
        
        Parameters:
            α : Array used to construct the splits.
            max_depth : maximum depth that is allowed in the tree
            η: Node currently being split. The root of a tree is passed in when calling this function. 
    """

    # Base cases (don't split)
    if np.floor(np.log2(2*tree_cursor)) >= max_depth + 1:
        η.split_reference = tree_cursor - 1 
        return η

    elif int(α[tree_cursor - 1]) == 0:
        η.split_reference = tree_cursor - 1 
        return η
    
    
    # Split this node and assign parameters ϕ[tree_cursor - 1] to the parent. 
    η.split(tree_cursor - 1)

    # Recurse on the left 
    η_left = create_tree(α, max_depth, η.left, 2*tree_cursor)

    # Recurse on the right 
    η_right = create_tree(α, max_depth, η.right, 2*tree_cursor + 1)

    return η




if __name__ == 'utils_reg':
    import torch
    import numpy as np 
