module ub_func_rl

using Random, Distributions
using LinearAlgebra, SparseArrays, Statistics, StatsBase
using MLDataUtils, Clustering
using JuMP
using TimerOutputs: @timeit, get_timer

using Distributed, SharedArrays
using parallel

using lb_func_rl, opt_func_rl, Trees, Nodes_rl, bound_rl


export getBound


tol = 1e-6
max_iter = 1000


function ub_calc(states_proc, Q_value_proc, num_actions, D, lower, upper, dtm_idx, values, group_trees, groups, LB, tree, eps, UB_gp, lrg_gap, alpha_s, L_hat, mingap, UB_mtd, iter, updateUB = false, UB_small_test=true; seed=42)
    fathom = false
    ##### first level #####
    if "CF" in UB_mtd
        parallel.root_println("Calling CF in ub_calc")
        UB, lower, upper, dtm_idx, values = CF(states_proc, Q_value_proc, num_actions, D, lower, upper, dtm_idx, values, alpha_s, L_hat)
        #println("checkpoint 1 rank: $(parallel.myid())")
        UB = parallel.sum(UB) # sum the lower bound of all cores
        # if parallel.myid() == 0
        #     println("fathom: $fathom, UB: $UB, LB: $LB, mingap: $mingap")
        # end
        # parallel.barrier()
        # if parallel.myid() == 1
        #     println("fathom: $fathom, UB: $UB, LB: $LB, mingap: $mingap")
        # end
        # parallel.barrier()
        if (UB-LB)<= mingap || (UB-LB) <= mingap*min(abs(LB), abs(UB))
            fathom = true
            # println("UB close or above LB, fathomed, UB: $UB, LB: $LB")
        end
        # if parallel.myid() == 0
        #     println("fathom: $fathom, UB: $UB, LB: $LB, mingap: $mingap")
        # end
        # parallel.barrier()
        # if parallel.myid() == 1
        #     println("fathom: $fathom, UB: $UB, LB: $LB, mingap: $mingap")
        # end
        # parallel.barrier()

        ##### second level #####
        # LB can be obtained through optimizer with only c under the determination of leaf-reach(z) from CF
        if !fathom && "MILP" in UB_mtd && lower.c != upper.c # if lower.c == upper.c, no need to run MILP
            # first get cost udt and transmit udt base info to root, then reduce cost dt and add to LB
            # values_udt = findall(x->x==-10000, values)
            # values_dt = findall(x->x!=-10000, values)
            values_udt = findall(x->x==-typemax(Float64), values)
            values_dt = findall(x->x!=-typemax(Float64), values)
            # var to transmit are X_gp, Y_gp, lz, uz, tz, dtmidx z, and reduce costdt lb, the hard thing is lz, uz，tz all need to be vector and the idx are all start from one
            # reshape the states_milp and Q_value to fit for mpi collect operation
            states_milp = parallel.collect(Matrix(states_proc[values_udt,:]')) # vetrical bind data from each process.
            #println("checkpoint 2.1 rank: $(parallel.myid())")
            Q_value_milp = parallel.collect(Matrix(Q_value_proc[values_udt,:]')) # same here have to be vertical bind
            #println("states_milp: ", states_milp)
            #println("checkpoint 2.2 rank: $(parallel.myid())")
            dtmz_all = parallel.collect(dtm_idx[7][values_udt])
            #println("checkpoint 2.3 rank: $(parallel.myid())")
            UB_dt = parallel.sum(sum(values[values_dt]))
            parallel.root_println("UB_dt: $(UB_dt)")
            #println("checkpoint 3 rank: $(parallel.myid())")
            if parallel.is_root()
                # transpose the states_milp and Q_value_milp to make it back to the original shape
                states_milp = Matrix(states_milp')
                Q_value_milp = Matrix(Q_value_milp')
                # println("states_milp: ", states_milp, "Q_value_milp: ", Q_value_milp)
                n_all = size(states_milp)[1]
                lz_all = sparse(Int64[], Int64[], Float64[], n_all, 2^D)
                uz_I = Int64[]
                for i in 1:n_all
                    append!(uz_I, i*ones(length(dtmz_all[i])))
                end
                uz_J = Int64[]
                for i in 1:n_all
                    append!(uz_J, dtmz_all[i])
                end
                uz_V = ones(length(uz_I))
                uz_all = sparse(uz_I, uz_J, uz_V, n_all, 2^D)
                lwr = Tree(lower.a, lower.b, lower.c, lower.d, lz_all, D)
                upr = Tree(upper.a, upper.b, upper.c, upper.d, uz_all, D)
                # tree.z are got from calculation of current optimal tree
                tz_all = bound_rl.warm_start_z(states_milp, tree.a, tree.b, D)
                ws_tree = Tree(tree.a, tree.b, tree.c, tree.d, tz_all, D)
                # here X, Y are the global data from the root.
                parallel.root_println("Calling global_OPT_DT_MILP in ub_calc")
                c, UB_MILP = opt_func_rl.global_OPT_DT_MILP(states_milp, Q_value_milp, num_actions, D, alpha_s*n_all, L_hat; lower=lwr, upper=upr, dtm_idx=vcat(dtm_idx[1:6], [dtmz_all]), warm_start=ws_tree, mute=true, seed=seed, mingap=mingap)
                UB_MILP += 1/L_hat*UB_dt
                if UB_MILP < UB
                    UB = UB_MILP
                    if (UB-LB)<= mingap || (UB-LB) <= mingap*min(abs(LB), abs(UB))
                        fathom = true
                        # println("UB close or above LB, fathomed, UB: $UB, LB: $LB")
                    end
                end
            end
            # LB = parallel.bcast(LB)
            UB = parallel.bcast(UB)
            fathom = parallel.bcast(fathom)
            #println("checkpoint 4 rank: $(parallel.myid())")
        end
        ##### third level ##### # can also be launched alone
        if !fathom && "SG" in UB_mtd # LB obtained through grouping
            if UB_small_test && parallel.nprocs() == 1 ### add costs info
                ngroups = length(groups)
                # ~, n = size(X);
                n_sub = 5
                obj_trial = 0### change to mpi using SG_solver
                for t = 1:n_sub
                    #Random.seed!(1)
                    # println("groups: $groups")
                    i = rand(1:ngroups)
                    lwr_i = Tree(lower.a, lower.b, lower.c, lower.d, lower.z[groups[i],:], D)
                    upr_i = Tree(upper.a, upper.b, upper.c, upper.d, upper.z[groups[i],:], D)
                    dtm = bound_rl.boundIdx_all(lwr_i, upr_i, Vector{Int64}[])
                    println("dtm: $dtm")
                    ws_i = group_trees[i]
                    new_z = bound_rl.warm_start_z(states_proc[groups[i],:], ws_i.a, ws_i.b, D)
                    ws_i = Tree(ws_i.a, ws_i.b, ws_i.c, ws_i.d, new_z, D)
                    # get the idx for z that sample can reach
                    # ~, objv,~ = opt_func.global_OPT_DT_SG(X_proc[:,groups[i]], Y_proc[:,groups[i]], K, D, alpha_s, L_hat;lower=lwr_i, upper=upr_i, eps=eps, dtm_idx=dtm, w_sos=nothing, lambda = nothing, warm_start = ws_i, mute=true, rlx=false)
                    println("Calling global_OPT_DT_SG in ub_calc")
                    ~, objv, ~ = opt_func_rl.global_OPT_DT_SG(states_proc[groups[i],:], Q_value_proc[groups[i],:,:], num_actions, D, alpha_s, L_hat; lower=lwr_i, upper=upr_i, eps=eps, dtm_idx=dtm, w_sos=nothing, lambda = nothing, warm_start = ws_i, mute=true, rlx=false, seed=seed)
                    obj_trial += objv
                end
                if obj_trial <= LB
                    return obj_trial, lower, upper, values, UB_gp, lrg_gap, group_trees, fathom, LB, tree
                end
            end
            #println("checkpoint 5 rank: $(parallel.myid())")
            @timeit get_timer("Shared") "Bound Calculation (UB and LB1) " begin
            parallel.root_println("Calling SG in ub_calc")

            UB_SG, UB_gp, lrg_gap, group_trees, LB, tree = SG(states_proc, Q_value_proc, states_proc, Q_value_proc, num_actions, D, group_trees, groups, lower, upper, eps, dtm_idx, values, LB, tree, UB_gp, lrg_gap, alpha_s, L_hat, iter, false, 60*1; seed=seed, mingap=mingap)   # run 3 mins
            UB_SG = parallel.sum(UB_SG) # sum the lower bound of all cores
            ############## NOTE: the fathom is not used in the current version #############
            if UB_SG < UB
                UB = UB_SG
                if (UB-LB)<= mingap || (UB-LB) <= mingap*min(abs(LB), abs(UB))
                    fathom = true
                    # println("UB close or above LB, fathomed, UB: $UB, LB: $LB")
                end
            end
            end
        end
    end
    return UB, lower, upper, values, UB_gp, lrg_gap, group_trees, fathom, LB, tree
end

function getBound(states_proc, Q_value_proc, states_rproc, Q_value_rproc, node, node_rand, num_actions, D, eps, LB, LB_tree, alpha_s, L_hat, mingap, UB_mtd = "SG", iter = 0, UB_small_test=true; seed=42)
    lower = node.lower
    upper = node.upper
    values = node.values
    UB_gp = node.UB_gp
    lrg_gap = node.lrg_gap
    # check the bound of variables # Vector{Int64}[] is init for z_udt
    dtm_idx = bound_rl.boundIdx_all(lower, upper, Vector{Int64}[])
    UB, lower, upper, values, UB_gp, lrg_gap, group_trees, fathom, LB, LB_tree = ub_calc(states_proc, Q_value_proc, num_actions, D, lower, upper, dtm_idx, values, node.group_trees, node.groups, LB, LB_tree, eps, UB_gp, lrg_gap, alpha_s, L_hat, mingap, UB_mtd, iter, true, UB_small_test; seed=seed)
    # LB = max(node.LB, LB)
    UB = min(node.UB, UB)
    # update best solution and objective_value 

    #################NOTE: Need to check whether it needed to be changed to parallel.root_println()
    @timeit get_timer("Shared") "LB2" begin   
        if (UB-LB)> mingap && (UB-LB) > mingap*min(abs(LB), abs(UB))
            #println("rank: $(parallel.myid()), UB: $UB")
            if lower.a == upper.a && lower.d == upper.d && lower.c == upper.c
                # [lower, upper] is a vector and input as a vector the element can be changed even in the function
                trees_lb2 = parallel.nprocs() <= 3 ? [lower, upper] : [lower]
                LB_tree, LB = lb_func_rl.LB_select(trees_lb2, LB, LB_tree, states_proc, Q_value_proc, num_actions, D, alpha_s, L_hat, nothing, lower.b, upper.b)
            end
            if "SG" in UB_mtd 
                # update UB with bootstrapped data
                ############# the no solution bug is the problem of z, costs and all that related to X_proc #############
                dtm_idx_rand = bound_rl.boundIdx_all(node_rand.lower, node_rand.upper, Vector{Int64}[])
                ~, lower_rand, upper_rand, dtm_idx_rand, values_rand = CF(states_rproc, Q_value_rproc, num_actions, D, node_rand.lower, node_rand.upper, dtm_idx_rand, node_rand.values, alpha_s, L_hat)
                # println("states_rproc: ", size(states_rproc))
                ~, ~, ~, ~, LB, LB_tree = SG(states_rproc, Q_value_rproc, states_proc, Q_value_proc, num_actions, D, node_rand.group_trees, node_rand.groups, lower_rand, upper_rand, eps, dtm_idx_rand, values_rand, LB, LB_tree, node_rand.UB_gp, node_rand.lrg_gap, alpha_s, L_hat, iter, false, 60; seed=seed, mingap=mingap) # ub calc run only 45 secs
            else      
                node_tree, node_LB = lb_func_rl.getLowerBound(states_proc, Q_value_proc, num_actions, D, alpha_s, L_hat, UB, UB_tree, "heur", node_rand.group_trees, lower, upper)
                if (node_LB > LB)
                    LB = node_LB
                    LB_tree = node_tree
                end 
            end
            parallel.root_println("LB: $LB")     
        end
    end
    
    #GC.gc() 
    return Node(lower, upper, node.level, UB, values, node.groups, node.lambda, group_trees, UB_gp, lrg_gap, node.bch_var), LB, LB_tree, fathom
end


function check_CF(state_s, Q_value_s, lower, upper, Tl)
    # get the upper bound of the value
    t = 1
    # lidx = findall(x->x==1, Q_value_s)[1] # get the true value of the sample 
    # get index of max Q value
    # lidx = argmax(Q_value_s)
    z = zeros(Tl)
    # cost_leaf = (Tl+1)*ones(Tl)
    max_Q = maximum(Q_value_s)
    min_Q = minimum(Q_value_s)
    # println("max_Q: $max_Q, min_Q: $min_Q")

    # value_leaf = ones(Tl)*(max_Q + 1) # set as the max Q value + 1
    value_leaf = ones(Tl)*(min_Q - 1)
    # ub_list = ones(Tl)*max_Q # initial lb for each leaf as one, if sample can reach and label match, update to zero
    ub_list = ones(Tl)*min_Q

    nodelist = Int64[]
    # println("Q_value_s.shape", size(Q_value_s))
    num_actions = length(Q_value_s)
    # _, num_actions = size(Q_value_s)
    push!(nodelist, t) # push first node idx in to list
    while nodelist != []
        t = popfirst!(nodelist) # get current node idx t which is at the first of the Array
        if t < Tl
            if upper.d[t] == 0
                while t < Tl
                    t = 2*t+1
                end
                push!(nodelist, t) 
            else
                fset_t = findall(x->x==1, upper.a[:,t])
                chk_l = state_s[fset_t] .< lower.b[t] # check whether the sample is smaller than the lower bound of the node t
                if sum(chk_l) == length(chk_l) # if all the feature value is smaller than the lower bound of the node t
                    push!(nodelist, 2*t)
                elseif sum(chk_l) == 0 # all the feature value is larger than the lower bound of the node t
                    push!(nodelist, 2*t+1)
                    chk_u = state_s[fset_t] .>= upper.b[t] # check whether the sample is larger than the upper bound of the node t
                    if sum(chk_u) < length(chk_u) # if not all the feature value is larger than the upper bound of the node t
                        push!(nodelist, 2*t)
                    end 
                else # some feature value is smaller than the lower bound and some is larger than the upper bound
                    push!(nodelist, 2*t)
                    push!(nodelist, 2*t+1)
                end
            end
        else # t in leaf
            t_i = t-Tl+1 # since here c and z only have Tl elements, idx t should be transferred
            z[t_i] = 1 # this leaf can be reached by the sample s
            # check whether the true label idx is determined and true label (should has value 1) is not equal to the dtm value
            # if upper.c[lidx, t] == 0 # && CartesianIndex(lidx, t) in c_dt
            #     lb_list[t_i] = 1 # cost 1
            #     cost_leaf[t_i] = 1
            # elseif lower.c[lidx, t] == 1
            #     lb_list[t_i] = 0
            #     cost_leaf[t_i] = 0
            # else
            #     # other condition, c[lidx, t] is undetermined
            #     lb_list[t_i] = 0
            # end
            flag = false # denote whether the true label is determined
            # Find all action indices k for which lower.c[k, t] is 1 for the current leaf t
            action_indices = findall(@view(lower.c[:, t]) .== 1)

            if !isempty(action_indices)
                if length(action_indices) > 1
                    # 报错
                    error("Multiple actions have lower.c[k, t] == 1 for leaf $t")
                end
                # If multiple actions have lower.c[k, t] == 1, the original loop's behavior
                # was to use the values from the *last* such action k.
                # We replicate this by taking the last index from action_indices.
                chosen_action_idx = action_indices[end]
                
                # lb_list[t_i] = Q_value_s[chosen_action_idx] # Original commented out line
                value_leaf[t_i] = Q_value_s[chosen_action_idx]
                ub_list[t_i] = Q_value_s[chosen_action_idx]
                flag = true # Set flag to true indicating an action was determined by lower.c
            end
            if !flag
                ub_list[t_i] = max_Q 
            end
            
        end 
    end
    # get reached leaf index
    z_udt = findall(x->x==1, z)
    # if all reached value are the same, then UB is the value
    if length(unique(value_leaf[z_udt])) == 1 && value_leaf[z_udt][1] > min_Q - 1
        value = value_leaf[z_udt][1]
    else
        value = -typemax(Float64)
    end
    UB = maximum(ub_list[z_udt])
    # UB = maximum(ub_list[z_udt])
    # if sum(cost_leaf[z_udt]) == 0
    #     cost = 0
    # elseif sum(cost_leaf[z_udt]) == length(z_udt)
    #     cost = 1
    # else # some leaf has lb 0 and some leaf has lb 1
    #     cost = -1
    # end
    # LB = minimum(lb_list[z_udt]) # if all leaf has lb=1, then LB of xs is 1 else 0.
    return z_udt, UB, value
end



function CF(states_proc, Q_value_proc, num_actions, D, lower, upper, dtm_idx, values, alpha_s, L_hat)
    if length(states_proc) == 0
        return 0, lower, upper, dtm_idx, values
    end
    n, p= size(states_proc)
    Tb = 2^D-1
    Tl = 2^D
    T = Tb+Tl
    if lower === nothing || upper === nothing
        lower, upper = bound.init_bound(p, n, num_actions, D)
    end
    # here z_udt can be dtm_idx[7] and pass to CF to reduce the calculation load
    z_udt = Vector{Int64}[]
    UB = 0 # for each rank
    i = 1
    for s in 1:n::Int64
        # if values[i] == -10000  # value of s is not determined
        if values[i] == -typemax(Float64)  # value of s is not determined
            state_s = states_proc[s,:]
            Q_value_s = Q_value_proc[s,:]
            z_udt_s, UB_s, value_s = check_CF(state_s, Q_value_s, lower, upper, Tl)
            # get the idx of z that sample will reach, won't have 0 element since must have one leaf can be reached
            values[i] = value_s 
            bound_rl.update_zs!(lower, upper, i, z_udt_s)
        else # value of s is determined
            z_udt_s = Int64[] # only when value of s is determined then z[s] can be empty, other must has one leaf to be reached
            UB_s = values[i] # same as costs[s]
        end
        push!(z_udt, z_udt_s)
        UB += UB_s
        i += 1
    end    
    dropzeros!(upper.z) # remove zero value on sparse matrix
    upr_I, upr_J, upr_V = findnz(upper.z)
    upper = Tree(upper.a, upper.b, upper.c, upper.d, sparse(upr_I, upr_J, upr_V, n, Tl), D) # no need to update lwr
    dtm_idx[7] = z_udt
    # costs = -ones(n) # close sample reduction
    return 1/L_hat*UB - alpha_s*n*sum(lower.d), lower, upper, dtm_idx, values
end

################### grouping lower and upper bound calculation ###################
function SG_solver(states_proc, Q_value_proc, num_actions, D, alpha, L_hat, group, lower, upper, eps, dtm_idx, tree, UB_gp_old, lrg_gap, values, values_udt, values_dt, iter, rlx, time_limit; seed=42, mingap=0.0005)
    if sum(tree.a .< lower.a)==0 && sum(tree.a .> upper.a)==0 && 
        sum(tree.b .< lower.b)==0 && sum(tree.b .> upper.b)==0 && 
        sum(tree.c .< lower.c)==0 && sum(tree.c .> upper.c)==0 && 
        sum(tree.d .< lower.d)==0 && sum(tree.d .> upper.d)==0 && 
        !lrg_gap && iter > 0
        objv = UB_gp_old
        new_tree_gp = false
    else
        idx = filter(x->x in values_udt, group) # get index of sample(for whole dataset) that is in values_udt from groups[i]
        # println("~~~group: ", group)
        # println("~~~idx: ", idx)
        idx_dt = filter(x->x in values_dt, group) # get index of sample with determined cost
        lwr_i = Tree(lower.a, lower.b, lower.c, lower.d, lower.z[idx,:], D)
        upr_i = Tree(upper.a, upper.b, upper.c, upper.d, upper.z[idx,:], D)
        dtm_idx_i = vcat(dtm_idx[1:6], [dtm_idx[7][idx]])
        # since tree.z contains for each subproblem sample, thus we should get the index of sample on group[i] for each subproblem
        gp_idx = findall(x->x in idx, group) # get the index of the selected sample(labeled in idx) from groups[i]
        new_z = bound_rl.warm_start_z(states_proc[gp_idx, :], tree.a, tree.b, D)
        ws_i = Tree(tree.a, tree.b, tree.c, tree.d, new_z, D)
        mute = true
        tree, objv, gap = opt_func_rl.global_OPT_DT_SG(states_proc[idx, :], Q_value_proc[idx, :, :], num_actions, D, alpha*length(group), L_hat; lower=lwr_i, upper=upr_i, eps=eps, dtm_idx=dtm_idx_i, w_sos=nothing, lambda = nothing, warm_start = ws_i, mute=mute, rlx=rlx, time=time_limit, seed=seed, mingap=mingap)
        if gap > mingap
            lrg_gap = true
        end
        # add determined costs
        objv += 1/L_hat*sum(values[idx_dt])
        new_tree_gp = true
    end
    return tree, objv, lrg_gap, new_tree_gp
end


function SG(states_ub, Q_value_ub, states_lb, Q_value_lb, num_actions, D, ws_trees, groups, lower, upper, eps, dtm_idx, values, LB_old, tree_old, UB_gp_old, lrg_gap_old, alpha_s, L_hat, iter = 0, rlx = false, time_limit = 60; seed=42, mingap=0.005)
    if length(states_ub) == 0
        p = 0
        n_lb = 0
        n_ub = 0
    else
        n_ub, p = size(states_ub)
        n_lb = size(states_lb)[1] # when SG in lb_calc, n_ub=n for rand ub select, n_ub is the size of group data in each process
    end
    ngroups = length(groups)
    # if parallel.myid() == 0
    #     println("groups: ", groups)
    #     println("UB_gp_old: ", UB_gp_old)
    # end
    # parallel.barrier()
    # if parallel.myid() == 1
    #     println("groups: ", groups)
    #     println("UB_gp_old: ", UB_gp_old)
    # end
    # lower bound Initialization
    # LB = 0
    UB = 0
    UB_gp = Array{Float64}(undef, ngroups)
    lrg_gap = falses(ngroups)
    trees_gp = [Tree(p, D, n_ub, num_actions) for i in 1:ngroups] 
    # values_udt = findall(x->x==-10000, values) # index on all sample
    # values_dt = findall(x->x!=-10000, values)
    values_udt = findall(x->x==-typemax(Float64), values) # index on all sample
    values_dt = findall(x->x!=-typemax(Float64), values)
    # upper bound Initialization
    LB = LB_old  #change to LB global from input
    tree = Trees.copy_tree(tree_old)
    z_pos = findnz(tree_old.z)[2]
    # start bound calculation
    for i in 1:ngroups::Int64
        # group process
        parallel.root_println("starting group: $i.")        
        group = groups[i]
        ws_tree = ws_trees[i]
        UB_gp_old_i = UB_gp_old[i]
        lrg_gap_old_i = lrg_gap_old[i]
        # start lower bound calculation
        if length(group) > 0
            tree_i, UB_i, lrg_gap_i, new_tree_gp_i = SG_solver(states_ub, Q_value_ub, num_actions, D, alpha_s, L_hat, group, lower, upper, eps, dtm_idx, ws_tree, UB_gp_old_i, lrg_gap_old_i, values, values_udt, values_dt, iter, rlx, time_limit; seed=seed, mingap=mingap)
            trees_gp[i] = tree_i; # i for groups and centers are corresponding to i+1 of lambda and trees is the tree parameters in type of Tree
            UB_gp[i] = UB_i # update cuts info: opt lb value
            lrg_gap[i] = lrg_gap_i
            UB += UB_i; 
        else # used for dummy group
            tree_i = Tree()
            trees_gp[i] = tree_i; # if no update, tree set to empty. tree.D = 0
            UB_gp[i] = UB_gp_old_i
            lrg_gap[i] = lrg_gap_old_i

            ### NOTE: need to check whether it is correct
            UB += UB_gp_old_i
        end
        parallel.barrier()
        # start upper bound calculation
        LB, tree, z_pos = lb_func_rl.distributed_LB(tree_i, LB, tree, z_pos, states_lb, Q_value_lb, n_lb, alpha_s, L_hat)
    end
    z = sparse(1:n_lb, z_pos, ones(n_lb), n_lb, 2^D)
    tree = Tree(tree.a, tree.b, tree.c, tree.d, z, D) # update tree with z, z only store value for samples at each process
    return UB, UB_gp, lrg_gap, trees_gp, LB, tree
end

end