module lb_func_rl

using DecisionTree, StatsBase
using Printf
using JuMP
using Random, SparseArrays
using Trees, bound_rl, parallel
using opt_func_rl, bound_rl


export warm_start, getLowerBound, CART_base, LB_update, predict_oct

function getLowerBound(states_proc, Q_value_proc, num_actions, D, alpha_s, L_hat, LB_old, LB_tree_old, method = "heur", trees=nothing, lower=nothing, upper=nothing, new_tree_gp = nothing)
    if method == "heur"
        num_states = size(states_proc)[1]
        feature_dim = size(states_proc)[2]
        tree, objv = CART_base(states_proc, Q_value_proc, num_states, num_actions, feature_dim, D, alpha_s, L_hat, lower=lower, upper=upper)

    elseif method == "merg"
        
        tree, objv_s = LB_update(trees, states_proc, Q_value_proc, num_actions, D)
        objv = 1/L_hat*objv_s - sum(tree.d)*alpha_s
    else # "slt"
        tree, objv = LB_select(trees, LB_old, LB_tree_old, states_proc, Q_value_proc, num_actions, D, alpha_s, L_hat, new_tree_gp, lower, upper)
    end
    return tree, objv
end

function LB_update(trees, states_proc, Q_value_proc, num_actions, D)
    ngroups = length(trees)
    # y = opt_func.label_int(Q_value_proc)
    y = zeros(num_states, num_actions)
    for i in 1:num_states
        argmax_Q = argmax(Q_value_proc[i,:])
        y[i,argmax_Q] = 1
    end

    n, p = size(states_proc)
    Tb = 2^D-1
    Tl = 2^D
    T = Tb+Tl
    # first, determine node split from d
    d = round.(sum(map(p->p.d, trees))/ngroups.+0.000001)
    # second, determine split feature from a on split node
    a = zeros(p, Tb)
    a_sum = sum(map(p->p.a, trees)) # has the same dim with a
    a_idx = argmax(a_sum, dims = 1)[findall(x->x==1.0, d)] # 1*Tb
    a[a_idx] .= 1
    # third, determine split value from b on groups with corresponding feature
    b = zeros(Tb)
    for i in 1:Tb
        # a_sum here is to prevent special case when d =1 and a[:,t] are all zeros
        if maximum(a[:,i])*maximum(a_sum[:,i]) != 0 
            fea = findall(x->round(x)==1.0, a[:,i])[1]
            slt_trees = trees[map(p->round(p.a[fea,i])==1, trees)]
            b[i] = mean(map(p->p.b[i], slt_trees))
        end
    end
    # forth, determine the active leaf node and label
    c = zeros(num_actions, T)
    z = spzeros(n,Tl)
    for i in 1:n
        t = 1
        while t in 1:Tb
            if a[:,t]'*states_proc[i,:] + 1e-12 >= b[t]
                t = 2*t+1
            else
                t = 2*t
            end
        end
        z[i,t-Tb] = 1
        c[:,t] = c[:,t]+y[i,:] # get node accumulate counts
    end
    # set c to be binary indicator of label
    objv = 0
    for i in Tl:T
        if maximum(c[:,i]) != 0
            c_idx = argmax(c[:,i])
            c[:,i] = zeros(num_actions)
            c[c_idx,i] = 1
            objv += Q_value_proc[i,c_idx]
        end
    end
    return Tree(a,b,c,d,z,D), objv # return final tree and number of misclassified points
end


function warm_start(states, labels, num_states, num_actions, feature_dim, D = 4, prune_val = 1.0)
    Tb = 2^D-1
    T = 2^(D+1)-1
    Random.seed!(1)
    # for original cart version (use package DecisionTree/v0.10.12)
    # cart_model = DecisionTree.build_tree(labels, states, 0, D)
    cart_model = DecisionTreeClassifier(pruning_purity_threshold=prune_val, max_depth=D, min_samples_leaf = ceil(0.05*num_states)) #
    DecisionTree.fit!(cart_model, states, labels)

    # a,b,c,d = warm_start_params(zeros(feature_dim,Tb), zeros(Tb), zeros(T), zeros(Tb), 1, cart_model.node, 2^D:T)
    a,b,c,d = warm_start_params(zeros(feature_dim,Tb), zeros(Tb), zeros(T), zeros(Tb), 1, cart_model.root, 2^D:T)

    parallel.root_println("cart_model: labels: $labels")
    parallel.root_println("cart_mode is fitted!")
    parallel.root_println("c: $c")
    z = bound_rl.warm_start_z(states, a, b, D)
    c_bin = zeros(num_actions,T)
    for i in 1:T
        k = Int(c[i])
        if k != 0
            c_bin[k,i] = 1
        end
    end
    # get the objective value of the warm start tree
    return Tree(a, b, c_bin, d, z, D)
end

function CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alpha, L_hat; lower=nothing, upper=nothing, prune_val=0.0)
    # y = opt_func.label_int(Y)
    # labels = zeros(num_states)
    labels = Vector{Int64}(undef, num_states)
    for i in 1:num_states
        argmax_Q = argmax(Q_value[i,:])
        labels[i] = Int(argmax_Q)
    end
    tree = warm_start(states, labels, num_states, num_actions, feature_dim, D, prune_val)
    pred, z_pos = predict_oct(states, labels, tree.a, tree.b, round.(tree.c))
    # objv = 1/L_hat*(1-accr)*length(y) + alpha*sum(tree.d)
    objv = 0
    for i in 1:num_states
        # println("pred[i]: $(pred[i])")
        objv += Q_value[i, Int(pred[i])]
    end
    objv = 1/L_hat*objv - alpha*sum(tree.d)
    return tree, objv
end


# prediction function
function predict_oct(X, y, a, b, c)
    n,~ = size(X)
    Tb_idx = length(b)

    if c isa JuMP.Containers.DenseAxisArray
        if c.data isa Matrix
            Tl_idx = size(c)[2]
        else
            Tl_idx = length(c)
        end
    else # c is vector or matrix
        if c isa Vector
            Tl_idx = length(c)-Tb_idx
        else # c isa matrix
            Tl_idx = size(c)[2]-Tb_idx
        end 
    end

    T = Tb_idx + Tl_idx
    
    pred = Array{Float64}(undef, n)
    z_pos = Array{Float64}(undef, n)
    for i in 1:n
        t = 1
        while !(t in (Tb_idx+1):T)
            if a[:,t]'*X[i,:] + 1e-12 >= b[t]
                t = 2*t+1
            else
                t = 2*t
            end
        end
        pred[i] = leaf_label(c,t)
        # println("pred[i]: $pred[i]")
        z_pos[i] = t-Tb_idx
    end

    return pred, z_pos
end

function distributed_LB(tree_i, LB, tree, z_pos, states_proc, Q_value_proc, n, alpha_s, L_hat)
    #println(tree)
    #println(z_pos)
    tree_list = parallel.allcollect([tree_i])
    LB_list = Float64[]
    z_pos_list = Vector{Int64}[]
    for t_i in tree_list::Vector{Tree}
        if t_i.D != 0 
            LB_gp_i, z_pos_gp_i = objv_cost(states_proc, Q_value_proc, n, t_i, alpha_s*n, L_hat) # here X, Y are global data and label.
            # lower bound info gathering
            push!(LB_list, LB_gp_i)
            push!(z_pos_list, z_pos_gp_i)
        else # only dummy tree has D = 0 and will not be used in UB calculation, set UB to Inf so that not to be selected
            push!(LB_list, -Inf)
            push!(z_pos_list, [-1])
        end
    end
    parallel.barrier()
    # allreduce the lb results and get the best(maximum) LB
    parallel.sumv!(LB_list)
    LB_i = maximum(LB_list)
    proc_idx = argmax(LB_list)
    # update the LB, tree and z_pos 
    if LB_i > LB
        LB = LB_i
        tree = tree_list[proc_idx]
        z_pos = z_pos_list[proc_idx]
    end
    parallel.barrier()
    return LB, tree, z_pos
end

function objv_cost(states_proc, Q_value_proc, n, tree, alpha, L_hat)
    if states_proc === nothing
        return -Inf, [-1]
    end
    Tb = length(tree.d)
    objv = 0
    z_pos = Array{Int64}(undef, n)
    for i in 1:n::Int64
        t = 1
        while t in 1:Tb
            if tree.a[:,t]'*states_proc[i,:] + 1e-12 >= tree.b[t]
                t = 2*t+1
            else
                t = 2*t
            end
        end
        # get the choosing action
        ############################ NOTE: Need to change the action selection here
        action = tree.c[:,t]
        if sum(action) == 0
            #parallel.root_println("No action selected")
            #println("seting objv: $(maximum(Q_value_proc[i,:]))")
            objv += minimum(Q_value_proc[i,:])
        else
            objv += sum(action .* Q_value_proc[i,:])
        end
        z_pos[i] = t-Tb
    end
    return 1/L_hat*objv - alpha*sum(tree.d), z_pos
end


# function distributed_UB(tree_i, UB, tree, z_pos, states_proc, Q_value_proc, n, alpha_s, L_hat)
#     #println(tree)
#     #println(z_pos)
#     tree_list = parallel.allcollect([tree_i])
#     UB_list = Float64[]
#     z_pos_list = Vector{Int64}[]
#     for t_i in tree_list::Vector{Tree}
#         if t_i.D != 0 
#             UB_gp_i, z_pos_gp_i = objv_cost(states_proc, Q_value_proc, n, t_i, alpha_s*n, L_hat) # here X, Y are global data and label.
#             # upper bound info gathering
#             push!(UB_list, UB_gp_i)
#             push!(z_pos_list, z_pos_gp_i)
#         else # only dummy tree has D = 0 and will not be used in UB calculation, set UB to Inf so that not to be selected
#             push!(UB_list, Inf)
#             push!(z_pos_list, [-1])
#         end
#     end
#     parallel.barrier()
#     # allreduce the ub results and get the best(minimum) UB
#     parallel.sumv!(UB_list)
#     UB_i = minimum(UB_list)
#     proc_idx = argmin(UB_list)
#     # update the UB, tree and z_pos 
#     if UB_i < UB
#         UB = UB_i
#         tree = tree_list[proc_idx]
#         z_pos = z_pos_list[proc_idx]
#     end
#     parallel.barrier()
#     return UB, tree, z_pos
# end

function LB_select(trees, LB_old, LB_tree_old, states_proc, Q_value_proc, num_actions, D, alpha_s, L_hat, new_tree_gp = nothing, lwr_b=nothing, upr_b=nothing)
    n, p = size(states_proc)
    Tb = 2^D-1
    n_trees = length(trees)
    if new_tree_gp === nothing
        new_tree_gp = trues(n_trees)
    end
    LB = LB_old
    tree = Trees.copy_tree(LB_tree_old)
    z_pos = findnz(LB_tree_old.z)[2]
    for i in 1:n_trees::Int
        if new_tree_gp[i]
            tree_i = Trees.copy_tree(trees[i])
            if lwr_b !== nothing
                Random.seed!(i*(parallel.myid()+1))
                tree_i.b[:] = rand(Tb).*(upr_b .- lwr_b) .+ lwr_b 
            end
            LB, tree, z_pos = distributed_LB(tree_i, LB, tree, z_pos, states_proc, Q_value_proc, n, alpha_s, L_hat)
        end
    end
    
    z = sparse(1:n, z_pos, ones(n))
    tree = Tree(tree.a, tree.b, tree.c, tree.d, z, tree.D)
    return tree, LB
end

function warm_start_params(a,b,c,d,t,node,Tl)
    
    if node isa Leaf
        t_leaf = t
        while !(t_leaf in Tl)
            t_leaf = 2*t_leaf+1
        end
        # println("node.majority: $node.majority")
        c[t_leaf] = convert(Int, round(node.majority))
    else
        a[node.featid, t] = 1
        b[t] = node.featval
        d[t] = 1
        a,b,c,d = warm_start_params(a,b,c,d,2*t, node.left, Tl)
        a,b,c,d = warm_start_params(a,b,c,d,2*t+1, node.right, Tl)
    end
    return a,b,c,d
end

end