module branch_rl

using Random
using Nodes_rl, Trees, bound_rl, parallel

export branch!

function sltVar_b(lb, ub, D)
    weight = ones(2^D-1)
    for i=1:D
        weight[2^(i-1):(2^i-1)] .= D-i+1
    end
    dif = weight.*(ub-lb)
    ind = findmax(dif)[2] # return idx of Tb node for b
    return ind, "b"
end

function SelectVarSequential(node, D)
    # idx determine the precedence of branching 
    lwr = node.lower
    upr = node.upper
    bch_var = node.bch_var
    # first branch d
    d_udt, d_dt = bound_rl.bound_idx(lwr.d, upr.d)
    if length(d_udt) >= 1
        return d_udt[1], "d"
    end
    #seed = Int(round(maximum(abs.(upr.b-lwr.b))))
    #Random.seed!(seed)
    # for parallel, rand() will be different on each processor, to main the same make sure the same rand() broadcast to all processors
    token = rand() 
    threshold =  1-0.5*maximum(abs.(upr.b-lwr.b))
    #parallel.root_println("threshold:  $threshold")
    if token <= threshold
        # then branch a
        a_udt, a_dt = bound_rl.bound_idx(lwr.a, upr.a)
        if length(a_udt) >= 1
            return a_udt[1], "a"
        end
        c_udt, c_dt = bound_rl.bound_idx(lwr.c, upr.c)
        if length(c_udt) >= 1
            return c_udt[1], "c"
        end
    end
    return sltVar_b(lwr.b, upr.b, D)
end


function branch!(nodeList, UB_list, bVar, bVarIdx, bValue, node, sortX::Union{Nothing, Matrix{Float64}})
    #### NOTE: the branch is different from branch_rl.jl, here we use UB_list to store the UB of each node
    # insert_id = searchsortedlast(UB_list, node.UB) # get the position to insert the current split nodes
    # insert_id = searchsortedfirst(UB_list, node.UB)
    # insert_id = isempty(UB_list) ? 1 : searchsortedlast(UB_list, node.UB)
    # insert_id = searchsortedlast(UB_list, node.UB) + 1
    insert_id = searchsortedfirst(UB_list, node.UB)
    # insert_id = searchsortedfirst(UB_list, node.UB, rev=true)

    println("insert_id: $insert_id")
    lower = Trees.copy_tree(node.lower)
    upper = Trees.copy_tree(node.upper)
    bound_rl.update_bound!(lower, upper, bVar, bVarIdx, bValue, "right") # split from this variable at bValue
    if parallel.is_root()
        fathom_r = bound_rl.check_bound_b(lower, upper, sortX) # if fathom, this branch will not saved
    else
        fathom_r = nothing
    end
    fathom_r = parallel.bcast(fathom_r)

    if sum(lower.a.>upper.a)==0 && sum(lower.b.>upper.b)==0 &&
       sum(lower.c.>upper.c)==0 && sum(lower.d.>upper.d)==0 && !fathom_r
        # node.level*2+1 is the index of the new right node instead of node.level+1 change back after debug
        right_node = Node(lower, upper, node.level+1, node.UB, copy(node.values), node.groups, nothing, copy(node.group_trees), copy(node.UB_gp), copy(node.lrg_gap), bVar)
        # push!(nodeList, left_node)
        insert!(nodeList, insert_id, right_node)
        insert!(UB_list, insert_id, node.UB)
        # println("left_node:   ", lower, "   ",upper)
    end

    lower = Trees.copy_tree(node.lower)
    upper = Trees.copy_tree(node.upper)
    bound_rl.update_bound!(lower, upper, bVar, bVarIdx, bValue, "left") # split from this variable at bValue
    if parallel.is_root()
        fathom_l = bound_rl.check_bound_b(lower, upper, sortX) # if fathom, this branch will not saved
    else
        fathom_l = nothing
    end
    fathom_l = parallel.bcast(fathom_l)
    
    if sum(lower.a.>upper.a)==0 && sum(lower.b.>upper.b)==0 && 
       sum(lower.c.>upper.c)==0 && sum(lower.d.>upper.d)==0 && !fathom_l
        # node.level*2 is the index of the new left node instead of node.level+1 change back after debug
    	left_node = Node(lower, upper, node.level+1, node.UB, copy(node.values), node.groups, nothing, copy(node.group_trees), copy(node.UB_gp), copy(node.lrg_gap), bVar)
	    # push!(nodeList, left_node)
	    insert!(nodeList, insert_id, left_node)
        insert!(UB_list, insert_id, node.UB)
        # println("left_node:   ", lower, "   ",upper)
    end

end

end