module bb_func_rl

using Printf
using JuMP
using Random, LinearAlgebra, SparseArrays
using Statistics, StatsBase, Distances
using MPI
using Distributed, SharedArrays

@everywhere using Trees, Nodes_rl, branch_rl, parallel, opt_func_rl, ub_func_rl, lb_func_rl, groups_rl, bound_rl

export branch_bound

maxiter = 100000000000
# maxiter = 10000
tol = 1e-6

time_finish(seconds) = round(Int, 10^9 * seconds + time_ns())

function branch_bound(states, Q_value, num_states, num_actions, feature_dim, D, warm_start::Tree, LB_init, alpha, L_hat, method = "CF", prob = false, obbt = false, val=0; time_lapse::Int64=14400, seed=42, mingap=0.0005)
    # parameter initialization
    n_all = num_states
    if parallel.is_root()
        alpha_s = alpha/n_all
        sortX = sort(states, dims=1) # sorted on each feature used in lb_func
        eps = vec(mapslices(opt_func_rl.mini_dist, sortX, dims=1)) # eps used in opt_func
    else
        sortX = nothing # not used in parallel and no need for broadcasting
        alpha_s = alpha/n_all
        eps = nothing
    end
    # p = parallel.bcast(p)
    feature_dim = parallel.bcast(feature_dim)
    num_states = parallel.bcast(num_states)
    num_actions = parallel.bcast(num_actions)
    alpha_s = parallel.bcast(alpha_s)
    eps = parallel.bcast(eps)
    
    # All_proc Initialization
    LB = LB_init;
    # min_UB = 1e15; # used to save the best lower bound at the end (smallest but within the mingap)
    min_UB = -1e15

    # distribute data to each process and generate corresponding root_node for each process
    ~, ~, states_proc, Q_value_proc, node, tree = groups_rl.proc_data_preparation(states, Q_value, feature_dim, num_states, num_actions, D, warm_start, method, false, 0, val)
    # generate random subsets for UB computation
    states_rand, Q_value_rand, states_rproc, Q_value_rproc, node_rand, ~ = groups_rl.proc_data_preparation(states, Q_value, feature_dim, num_states, num_actions, D, warm_start, method, true)
    # println("states_rproc: ", size(states_rproc))
    # start bound calculation for root node
    iter = 0
    node, LB, tree, fathom = getBound(states_proc, Q_value_proc, states_rproc, Q_value_rproc, node, node_rand, num_actions, D, eps, LB, tree, alpha_s, L_hat, mingap, method, iter, false; seed=seed)
    # each nodeList stores the node on dataset of each process
    nodeList =[]
    UB_list = []
    push!(nodeList, node)
    push!(UB_list, node.UB)
    # parameters are store in root process
    if parallel.is_root()
        println("Iter\tleft\tlev\tLB\tUB\tgap\t")
    end
    # get program end time point
    end_time = time_finish(time_lapse) # the branch and bound process ends after 6 hours
    #####inside main loop##################################
    calcInfo = [] # initial space to save calcuation information
        
    while nodeList != []
        # node = nodeList[1] # node = nodeList[nodeid] # sorted node list, the last has the highest UB
        # UB = UB_list[1] # node_UB is current highest UB
        # deleteat!(nodeList, 1)#deleteat!(nodeList, nodeid) # delete the to-be-processed node
        # deleteat!(UB_list, 1) # delete the highest ub
        # Pop the node with highest UB from the end of the list

        node = pop!(nodeList)
        UB = pop!(UB_list)
        # so currently, the global lower bound corresponding to node, LB = node.LB, groups = node.groups
        if parallel.is_root()
            @printf "%-6d %-6d %-10d %-10.4f %-10.4f %-10.4f %s \n" iter length(nodeList) node.level LB UB (UB-LB)/min(abs(LB), abs(UB))*100 "%"
        end
        # time stamp should be checked after the retrival of the results
        if (iter >= maxiter) || (time_ns() >= end_time)
            if parallel.is_root()
                push!(calcInfo, [iter, length(nodeList), node.level, LB, UB, (UB-LB)/min(abs(LB), abs(UB))])
            end
            break
        end
        iter += 1
        
        ############# iteratively bound tightening #######################
        # the following code delete branch with ub close to the global lower bound
        delete_nodes = []
        for (idx, n) in enumerate(nodeList)
            if (((n.UB-LB)<= mingap) || ((n.UB-LB) <=mingap*min(abs(LB), abs(n.UB))))
                push!(delete_nodes, idx)
            end
        end
        deleteat!(nodeList, sort(delete_nodes))
        deleteat!(UB_list, sort(delete_nodes))
        ##################### lower and upper(inside ub function) bound update #####################
        states_rand, Q_value_rand, states_rproc, Q_value_rproc, node_rand, ~ = groups_rl.proc_data_preparation(states, Q_value, feature_dim, n_all, num_actions, D, warm_start, method, true, iter)
        node, LB, tree, fathom = getBound(states_proc, Q_value_proc, states_rproc, Q_value_rproc, node, node_rand, num_actions, D, eps, LB, tree, alpha_s, L_hat, mingap, method, iter, false) 
        
        ##################### branching #####################
        if fathom
            # save the best UB if it close to LB enough (within the mingap)
            parallel.root_println("UB close or lower than LB, fathomed")
            if node.UB > min_UB && node.UB >= LB-1e-10
                parallel.root_println("UB close or lower than LB, fathomed, UB: $(node.UB), LB: $LB")
                min_UB = node.UB
            end
            # if node.UB < min_UB && node.UB >= LB-1e-10
            #     println("UB close or lower than LB, fathomed, UB: $UB, LB: $LB")
            #     min_UB = node.UB
            # end
            continue   
        else
            if parallel.is_root() # generate branck info on root node
                bVarIdx, bVar = branch_rl.SelectVarSequential(node, D)
                println("branching on $bVarIdx with $bVar.")
            else
                bVarIdx = nothing
                bVar = nothing
            end
            # broadcasting
            bVarIdx = parallel.bcast(bVarIdx)
            bVar = parallel.bcast(bVar)
            if bVar != "stop" # means we can continue branching
                if bVar == "b"
                    # the split value is chosen by the midpoint
                    bValue = (node.upper.b[bVarIdx] + node.lower.b[bVarIdx])/2;
                else
                    bValue = nothing
                end
                branch!(nodeList, UB_list, bVar, bVarIdx, bValue, node, sortX)
            end
        end
    end
    
    if parallel.is_root()
        if nodeList==[]
            println("all node solved")
            push!(calcInfo, [iter, length(nodeList), LB, min_UB, (min_UB-LB)/min(abs(LB), abs(min_UB))])
        else
            min_UB = calcInfo[end][4]
        end
        println("solved nodes:  ", iter)
        @printf "%-52d  %-14.4e %-14.4e %-7.4f %s \n" iter  LB min_UB (min_UB-LB)/min(abs(LB),abs(min_UB))*100 "%"
    end
    parallel.barrier()
    return tree, LB, calcInfo, min_UB
end


end