module opt_func_rl

using Printf
using JuMP
# using CPLEX
# import Pkg
# Pkg.add("Gurobi")
using Gurobi

using Random, SparseArrays
using InteractiveUtils
using Trees, bound_rl, parallel

# mingap = 1e-2
# mingap = 0.0005
time_lapse = 60*5 # 1 mins
time_lapse_glb = time_lapse*60*5 # 4 hours


# export global_OPT_DT_SG, mingap
export global_OPT_DT_SG

function global_OPT_DT_SG(states, Q_value, num_actions, D, alpha, L_hat; lower=nothing, upper=nothing, eps=nothing, dtm_idx=nothing, w_sos=nothing, lambda = nothing, warm_start = nothing, mute=false, rlx = false, time = time_lapse, seed=42, mingap=0.005)
    # parameter setup
    n, p= size(states)
    T = 2^(D+1)-1
    Tb = Int(floor(T/2))
    eps_min = minimum(eps)
    eps_max = maximum(eps)

    a_udt = dtm_idx[1]
    a_dt = dtm_idx[2]
    d_udt = dtm_idx[3]
    d_dt = dtm_idx[4]
    c_udt = dtm_idx[5]
    c_dt = dtm_idx[6]
    z_udt = dtm_idx[7]
    # check warm_start status
    if warm_start === nothing
        warm_start = Tree(zeros(p, Tb), rand(Tb), zeros(K, T), zeros(Tb), spzeros(n, Tb+1), D)
    end

    # optimizer setup
    m = optimizer_init(mute, time, seed, mingap)

    # variables setup
    # ajt: bool, choose which feature to split; real, mixed feature(hyperplane) to split
    m, a = var_a(m, lower.a, a_udt, a_dt, warm_start.a, p, Tb, rlx)
    # bt: real, choose which value to split
    @variable(m, lower.b[t] <= b[t in 1:Tb] <= upper.b[t], start = warm_start.b[t]); 
    # dt: bool, 1 if t has a split, 0 otherwise
    m, d = var_d(m, lower.d, d_udt, d_dt, warm_start.d, Tb, rlx)
    # variables for leaf
    # ckt: from class 1 to k
    m, c = var_c(m, lower.c, c_udt, c_dt, warm_start.c, num_actions, T, Tb, false)
    # zit: 1 if sample i in node t, t include both branch and leaf nodes.
    if z_udt isa Vector{Int64}# for only SG mode
        m, z = var_z_sg(m, z_udt, warm_start.z, n, T, Tb, rlx) # used in only SG mode
    else
        m, z = var_z(m, lower.z, z_udt, warm_start.z, n, T, Tb, rlx) # used in CF+SG mode
    end
    # variable for the policy
    @variable(m, Pi[i in 1:n, a in 1:num_actions], Bin)
    
    # constraints setup
    m = cons_setup_node(m, states, p, z_udt, Tb, eps, eps_max, a, b, d, z)
    m = cons_setup_leaf_SG(m, n, num_actions, T, Tb, c, z, z_udt, Pi)
    m = cons_setup_rl(m, Pi, n, num_actions)
    # objective \sum_{i} \sum_{k} \mu_{ik} \sum_{i'} P_{ii'k} (R_{ik'i'}+\gamma V^{old}_{i'})
    # println("Q_value: $(Q_value)")
    # println("L_hat: $(L_hat)")
    # println("alpha: $(alpha)")
    @objective(m, Max, 1/L_hat*sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions) - alpha*sum(d[t] for t in 1:Tb));

    # @objective(m, Min, 1/L_hat*sum(costs[i] for i in 1:n) + alpha*sum(d[t] for t in 1:Tb));

    optimize!(m);
    a = value.(a)
    b = value.(b)
    c = round.(value.(c))
    d = value.(d) 
    objv = objective_bound(m)
    # cpm = backend(m)
    # prob_type = CPXgetprobtype(cpm.env, cpm.lp)
    # if prob_type == CPXPROB_LP
    gm = backend(m)
    # prob_type = Gurobi.GRBgetintattr(gm, "ModelSense")  # Returns 1 for minimize, -1 for maximize
    # if objective_sense(m) == MOI.MIN_SENSE
    if MOI.get(m, MOI.NumberOfConstraints{MOI.VariableIndex,MOI.Integer}()) == 0 && MOI.get(m, MOI.NumberOfConstraints{MOI.VariableIndex,MOI.ZeroOne}()) == 0
        gap = 0.0
    else
        gap = relative_gap(m)
    end
    tree = Tree(a,b,c,d,nothing,D) # or tree = Tree(a,b,c,z_udt,D) so that z_udt is used in decesendent

    # get the optimal policy
    return tree, objv, gap
end

function direct_iteration(states, Q_value, D, num_actions, last_a, last_b, last_c, last_d, a_udt, a_dt, b_udt, b_dt, c_udt, c_dt, d_udt, d_dt, alpha=0.0; time_limit=60*5, seed=42, mingap=0.005)
    n, p = size(states)
    T = 2^(D+1)-1
    Tb = 2^D-1
    # Tb = Int(floor(T/2))
    sortX = sort(states, dims=1)
    eps = vec(mapslices(opt_func_rl.mini_dist, sortX, dims=1)) # eps used in opt_func
    eps_min = minimum(eps)
    eps_max = maximum(eps)
    println("eps_min: $(eps_min)")
    println("eps_max: $(eps_max)")
    m = optimizer_init(true, time_limit, seed, mingap)
    
    
    # Var_a
    # Initialize matrices and vectors
    a = Matrix(undef, p, Tb)
    b = Vector(undef, Tb)
    c = Matrix(undef, num_actions, T)
    # d = Vector(undef, Tb)

    # Define variables for undetermined nodes
    @variable(m, a_var[j in 1:p, i in 1:length(a_udt)], Bin)
    @variable(m, 0 <= b_var[i in 1:length(b_udt)])
    @variable(m, c_var[j in 1:num_actions, i in 1:length(c_udt)], Bin)
    # @variable(m, d_var[i in 1:length(d_udt)], Bin)

    # Set values for undetermined nodes
    for i in 1:length(a_udt)
        for j in 1:p
            set_start_value(a_var[j,i], last_a[j,a_udt[i]])
            a[j,a_udt[i]] = a_var[j,i]
        end
    end
    # println("a: $(a)")
    for idx in a_dt
        for j in 1:p
            a[j,idx] = last_a[j,idx]
        end
    end
    # println("a: $(a)")
    # Var_b
    for i in 1:length(b_udt)
        set_start_value(b_var[i], last_b[b_udt[i]]);
        b[b_udt[i]] = b_var[i]
    end
    for idx in b_dt
        b[idx] = last_b[idx]
    end

    # Var_c
    for i in 1:length(c_udt)
        for j in 1:num_actions
            set_start_value(c_var[j,i], last_c[j,c_udt[i]]);
            c[j,c_udt[i]] = c_var[j,i]
        end
    end
    for idx in c_dt
        for j in 1:num_actions
            c[j,idx] = last_c[j,idx]
        end
    end

    # Var_d
    # for i in 1:length(d_udt)
    #     set_start_value(d_var[i], last_d[d_udt[i]]);
    #     d[d_udt[i]] = d_var[i]
    # end
    # for idx in d_dt
    #     d[idx] = last_d[idx]
    # end
    
    # @variable(m, a[j in 1:p, t in 1:Tb], Bin);
    # @variable(m, 0 <= b[t in 1:Tb]);
    # @variable(m, c[k in 1:num_actions, t in 1:T], Bin);
    # @variable(m, d[t in 1:Tb], Bin);
    @variable(m, z[i in 1:n, t in 1:(Tb+1)], Bin)
    @variable(m, Pi[i in 1:n, a in 1:num_actions], Bin)

    @constraint(m, [t in 1:Tb], sum(a[j,t] for j in 1:p) == 1);
    @constraint(m, [t in 1:Tb], b[t] <= 1);
    # @constraint(m, [t in 2:Tb], d[t] <= d[Int(floor(t/2))]);

    # @constraint(m, [t in 1:Tb], d[t] <= 1);
    for i in 1:n
        for t in 1:(Tb+1)
            Al, Ar = node_direct(t+Tb)
            # @constraint(m, [i in 1:n, mt in Al], sum(a[j,mt]*(X[j,i]+eps[j]-eps_min) for j in 1:p)+eps_min <= b[mt]+(1+eps_max)*(1-z[i,t-Tb])) # eq 13
            # @constraint(m, [mt in Al], sum(a[j,mt]*(states[i,j]+eps[j]) for j in 1:p) + (1+eps_max)*(1-d[mt]) <= b[mt]+(1+eps_max)*(1-z[i,t])) # eq 13
            @constraint(m, [mt in Al], sum(a[j,mt]*(states[i,j]+eps[j]-eps_min) for j in 1:p) + eps_min <= b[mt]+(1+eps_max)*(1-z[i,t])) # eq 13
            @constraint(m, [mt in Ar], sum(a[j,mt]*states[i,j] for j in 1:p) >= b[mt]-(1-z[i,t]))
        end
    end
    @constraint(m, [i in 1:n], sum(z[i,t] for t in 1:(Tb+1)) == 1);
    @constraint(m, [t in (Tb+1):T], sum(c[k,t] for k in 1:num_actions) == 1);
    @constraint(m, [i in 1:n, a in 1:num_actions, t in 1:(Tb+1)], Pi[i,a] >= z[i,t] + c[a,t+Tb] - 1) # t+Tb is the node index for the corresponding leaf node
    @constraint(m, [i in 1:n], sum(Pi[i,a] for a in 1:num_actions) == 1);

    # @objective(m, Max, sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions) - alpha*sum(d[t] for t in 1:Tb));
    @objective(m, Max, sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions));

    # check if the model is feasible
    # println("is_feasible: $(is_feasible(m))")
    optimize!(m);

    println("Termination status: ", termination_status(m))
    println("Has values: ", all(!isnothing, value.(c)))


    a = value.(a)
    b = value.(b)

    c = round.(value.(c))
    # d = value.(d)
    Pi = value.(Pi)
    # get the objective value
    # objv_value = sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions) - alpha*sum(d[t] for t in 1:Tb)
    objv_value = objective_bound(m)
    gap = relative_gap(m)
    # objv_value = objective_bound(m)
    # get the optimal policy
    Pi_opt = zeros(Int, n, 1)
    for i in 1:n
        Pi_opt[i] = Int(argmax(Pi[i,:]))
    end
    # return a, b, c, d, objv_value, gap, Pi_opt
    return a, b, c, objv_value, gap, Pi_opt

end


function init_random_policy(states, D, p, num_actions)
    n, p = size(states)
    T = 2^(D+1)-1
    Tb = 2^D-1
    a, b, c, d = init_random_tree(D, p, num_actions)
    # Pi = zeros(n, num_actions)
    Pi = zeros(Int, n, 1)
    for i in 1:n
        t = 1
        while t in 1:Tb
            if a[:,t]'*states[i,:] + 1e-12 >= b[t]
                t = 2*t+1
            else
                t = 2*t
            end
        end
        Pi[i] = Int(argmax(c[:,t]))
    end
    return a, b, c, d, Pi
end

function init_random_tree(D, p, num_actions)
    T = 2^(D+1)-1
    Tb = 2^D-1
    a = zeros(Int, p, Tb)
    for t in 1:Tb
        j = rand(1:p)
        a[j,t] = 1
    end
    b = rand(Float64, Tb)
    c = zeros(Int, num_actions, T)
    for t in 1:T
        k = rand(1:num_actions)
        c[k,t] = 1
    end
    d = ones(Tb)
    return a, b, c, d
end

function get_optimal_cumul_return(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions,)
    model = Model(Gurobi.Optimizer)

    @variable(model, x[1:num_states, 1:num_actions] >= 0)

    # 目标函数: maximize sum_{s,a} x_{s,a} * sum_{s'} P_{s,s',a} * R_{s,s',a}
    @objective(model, Max,
        sum(x[s,a] * sum(transition_prob[s,j,a] * rewards[s,j,a] for j in 1:num_states)
            for s in 1:num_states, a in 1:num_actions)
    )

    # 约束: ∑_a x_{s,a} - ∑_{s′} ∑_a γ P_{s′,s,a} x_{s′,a} = p₀(s)
    for s in 1:num_states
        @constraint(model,
            sum(x[s,a] for a in 1:num_actions) -
            sum(gamma * transition_prob[j,s,a] * x[j,a]
                for j in 1:num_states, a in 1:num_actions)
            == initial_state_p[s]
        )
    end

    optimize!(model)

    if termination_status(model) != MOI.OPTIMAL
        error("Optimization failed: ", termination_status(model))
    end

    max_reward = objective_value(model)
    x_opt = value.(x)

    return max_reward, x_opt
end

function global_OPT_DT_MILP(states, Q_value, num_actions, D, alpha, L_hat; lower=nothing, upper=nothing, eps=nothing, dtm_idx=nothing, w_sos=nothing, lambda = nothing, warm_start = nothing, mute=false, rlx = false, time = time_lapse, seed=42, mingap=0.005)
    # parameter setup
    # parameter setup
    n, p = size(states)
    T = 2^(D+1)-1
    Tb = Int(floor(T/2))
    c_udt = dtm_idx[5]
    c_dt = dtm_idx[6]
    z_udt = dtm_idx[7]
    # println("z_udt: $(z_udt)")
    # check warm_start status
    if warm_start === nothing
        warm_start = Tree(zeros(p, Tb), rand(Tb), zeros(K, T), zeros(Tb), spzeros(n, T), D)
    end
    # optimizer setup
    m = optimizer_init(mute, time, seed, mingap)
    # variables for leaf
    # ckt: from class 1 to k
    # println("c_udt: $(c_udt)")
    # println("c_dt: $(c_dt)")
    m, c = var_c(m, lower.c, c_udt, c_dt, warm_start.c, num_actions, T, Tb, false)
    # variable for the loss
    @variable(m, Pi[i in 1:n, a in 1:num_actions], Bin)

    # constraints setup
    m = cons_setup_leaf_MILP(m, states, Q_value, n, num_actions, T, Tb, c, z_udt, Pi)
    # objective
    @objective(m, Max, 1/L_hat*sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions) - alpha*sum(lower.d));

    # println("is_solved_and_feasible: $(is_solved_and_feasible(m))")

    optimize!(m);
    # retrive value and cost
    c = round.(value.(c))
    objv = objective_bound(m)
    return c, objv
end


# initialize the optimizer
function optimizer_init(mute, time_lapse, seed, mingap)
    gurobi_env = Gurobi.Env()  # Regular variable assignment
    # Create model with Gurobi optimizer
    m = Model(() -> Gurobi.Optimizer(gurobi_env))
    # m = Model(Gurobi.Optimizer())
    if mute
        set_optimizer_attribute(m, "OutputFlag", 0)
    end
    set_optimizer_attribute(m, "Threads",1)
    set_optimizer_attribute(m, "TimeLimit", time_lapse) # maximum runtime limit is 1 hours
    # here the gap should always < mingap of BB, e.g. if mingap = 0.1%, then gap here should be < 0.1%, the default is 0.01%
    set_optimizer_attribute(m, "MIPGap", mingap/2) 
    set_optimizer_attribute(m, "Seed", seed)
    return m

end

function OMDT(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions, D; time_limit=300, seed=42, mingap=0.005)
    ### Create OMDT warm start for 
    n, p = size(states)
    T = 2^(D+1)-1
    Tb = 2^D-1
    # Tb = Int(floor(T/2))
    sortX = sort(states, dims=1)
    eps = vec(mapslices(opt_func_rl.mini_dist, sortX, dims=1)) # eps used in opt_func
    eps_min = minimum(eps)
    eps_max = maximum(eps)
    println("eps_min: $(eps_min)")
    println("eps_max: $(eps_max)")
    m = optimizer_init(true, time_limit, seed, mingap)
    
    @variable(m, a[j in 1:p, t in 1:Tb], Bin);
    @variable(m, 0 <= b[t in 1:Tb] <= 1);
    @variable(m, c[k in 1:num_actions, t in 1:T], Bin);
    @variable(m, z[i in 1:n, t in 1:(Tb+1)], Bin)
    @variable(m, Pi[i in 1:n, a in 1:num_actions], Bin)
    @variable(m, 0<= mu[i in 1:n, a in 1:num_actions] <= 1/(1-gamma))

    
    @constraint(m, [t in 1:Tb], sum(a[j,t] for j in 1:p) == 1);
    # @constraint(m, [t in 1:Tb], b[t] <= 1);

    @constraint(m, [i in 1:n, a in 1:num_actions], mu[i,a] <= 1/(1-gamma)*Pi[i,a])
    # @constraint(m, [i in 1:n, a in 1:num_actions, t in 1:(Tb+1)], Pi[i,a] >= z[i,t] + c[a,t+Tb] - 1) # t+Tb is the node index for the corresponding leaf node
    # @constraint(m, [t in 2:Tb], d[t] <= d[Int(floor(t/2))]);
    # @constraint(m, [t in 1:Tb], d[t] <= 1);
    for i in 1:n
        for t in 1:(Tb+1)
            Al, Ar = node_direct(t+Tb)
            # @constraint(m, [i in 1:n, mt in Al], sum(a[j,mt]*(X[j,i]+eps[j]-eps_min) for j in 1:p)+eps_min <= b[mt]+(1+eps_max)*(1-z[i,t-Tb])) # eq 13
            # @constraint(m, [mt in Al], sum(a[j,mt]*(states[i,j]+eps[j]) for j in 1:p) + (1+eps_max)*(1-d[mt]) <= b[mt]+(1+eps_max)*(1-z[i,t])) # eq 13
            @constraint(m, [mt in Al], sum(a[j,mt]*(states[i,j]+eps[j]-eps_min) for j in 1:p) + eps_min <= b[mt]+(1+eps_max)*(1-z[i,t])) # eq 13
            @constraint(m, [mt in Ar], sum(a[j,mt]*states[i,j] for j in 1:p) >= b[mt]-(1-z[i,t]))
        end
    end
    @constraint(m, [i in 1:n], sum(z[i,t] for t in 1:(Tb+1)) == 1);
    @constraint(m, [t in (Tb+1):T], sum(c[k,t] for k in 1:num_actions) == 1);
    @constraint(m, [i in 1:n, a in 1:num_actions, t in 1:(Tb+1)], Pi[i,a] >= z[i,t] + c[a,t+Tb] - 1) # t+Tb is the node index for the corresponding leaf node
    @constraint(m, [i in 1:n], sum(Pi[i,a] for a in 1:num_actions) == 1);

    
    @constraint(m, [state in 1:num_states], 
        sum(mu[state, action] for action in 1:num_actions) - 
        gamma * sum(transition_prob[other_state, state, action] * mu[other_state, action] 
            for other_state in 1:num_states, action in 1:num_actions) 
        == initial_state_p[state]
    )

    @objective(m, Max, 
    sum(mu[s, a] * sum(transition_prob[s, :, a] .* rewards[s, :, a]) 
        for s in 1:num_states, a in 1:num_actions)
    )
    
    
    # @objective(m, Max, sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions) - alpha*sum(d[t] for t in 1:Tb));
    # println("is_feasible: $(is_solved_and_feasible(m))")
    optimize!(m);

    a = value.(a)
    b = value.(b)
    println("Termination status: ", termination_status(m))
    println("Has values: ", all(!isnothing, value.(c)))

    c = round.(value.(c))
    # d = value.(d)
    Pi = value.(Pi)
    # get the objective value
    # objv_value = sum(Pi[i,a] * Q_value[i,a] for i in 1:n, a in 1:num_actions) - alpha*sum(d[t] for t in 1:Tb)
    # objv_value = objective_bound(m)
    objv_value = objective_value(m)
    gap = relative_gap(m)
    # objv_value = objective_bound(m)
    # get the optimal policy
    Pi_opt = zeros(Int, n, 1)
    for i in 1:n
        Pi_opt[i] = Int(argmax(Pi[i,:]))
    end
    return a, b, c, objv_value, gap, Pi_opt

end
# functions for variables setup
function var_a(m, l_a, a_udt, a_dt, a_w, p, Tb, rlx = false)
    #a_udt, a_dt= bound_idx(l_a, u_a)
    if length(a_dt) == 0
        if rlx # relax to continuous [0,1]
            @variable(m, 0<=a[j in 1:p, t in 1:Tb]<=1, start = a_w[j,t]);
        else # fix binary
            @variable(m, a[j in 1:p, t in 1:Tb], Bin, start = a_w[j,t]);
        end
    else
        # variable a formation
        if rlx # relax to continuous [0,1]
            @variable(m, 0<=a_var[i in 1:length(a_udt)]<=1) # create udtm decision var
        else # fix binary
            @variable(m, a_var[i in 1:length(a_udt)], Bin) # create udtm decision var
        end
        a = Matrix(undef, p, Tb)
        for i in eachindex(a_udt) #1:length(a_udt) # i is scalar
            set_start_value(a_var[i], a_w[a_udt[i]])
            a[a_udt[i]] = a_var[i]
        end
        for i in a_dt # i is CartesianIndex
            a[i] = l_a[i]
        end
    end
    return m, a
end

function var_d(m, l_d, d_udt, d_dt, d_w, Tb, rlx = false)
    #d_udt, d_dt= bound_idx(l_d, u_d)
    if length(d_dt) == 0
        if rlx
            @variable(m, 0<=d[t in 1:Tb]<=1, start = d_w[t]);
        else
            @variable(m, d[t in 1:Tb], Bin, start = d_w[t]);
        end
    else
        if rlx
            @variable(m, 0<=d_var[t in 1:length(d_udt)]<=1);
        else
            @variable(m, d_var[t in 1:length(d_udt)], Bin);
        end
        d = Vector(undef, Tb)
        for i in eachindex(d_udt) #i in 1:length(d_udt)
            set_start_value(d_var[i], d_w[d_udt[i]])
            d[d_udt[i]] = d_var[i]
        end 
        for i in d_dt # i is CartesianIndex
            d[i] = l_d[i]
        end
    end
    return m, d
end

function var_c(m, l_c, c_udt, c_dt, c_w, K, T, Tb, rlx = false)
    if length(c_dt) == 0
        # println("Since the lenth of determined variable for c is 0, the c is set to be 0")
        if rlx
            @variable(m, 0<=c[k in 1:K, t in (Tb+1):T]<=1, start = c_w[k,t])
        else
            @variable(m, c[k in 1:K, t in (Tb+1):T], Bin, start = c_w[k,t])
        end
    else
        # println("Since the lenth of determined variable for c is not 0, the c is set to be 0")
        if rlx
            @variable(m, 0<=c_var[i in 1:length(c_udt)]<=1) # create udtm decision var
        else
            @variable(m, c_var[i in 1:length(c_udt)], Bin) # create udtm decision var
        end
        c = Matrix(undef, K, T) # the first Tb*k is useless, should reduce if memory issue raises
        for i in eachindex(c_udt) #1:length(c_udt) # i is scalar
            set_start_value(c_var[i], c_w[c_udt[i]])
            c[c_udt[i]] = c_var[i]
        end
        for i in c_dt # i is CartesianIndex
            c[i] = l_c[i]
        end
    end
    return m, c
end

# z are all in 1:Tl
################ z don't need to model if cost is determined
function var_z(m, l_z, z_udt, z_w, n, T, Tb, rlx = false)
    z_var_num = sum(filter(x->x>1, length.(z_udt)))
    if z_var_num == n*(Tb+1)
        if rlx
            @variable(m, 0<=z[i in 1:n, t in 1:(Tb+1)]<=1, start = z_w[i,t])
        else
            @variable(m, z[i in 1:n, t in 1:(Tb+1)], Bin, start = z_w[i,t])
        end
    else
        if rlx
            @variable(m, 0<=z_var[i in 1:z_var_num]<=1) # create udtm decision var
        else
            @variable(m, z_var[i in 1:z_var_num], Bin) # create udtm decision var
        end
        z = Matrix(undef, n, Tb+1) # set to 1:Tl
        j = 1
        for i in 1:n::Int
            z_i = z_udt[i] # here z_udt idxs are in 1:Tl
            if length(z_i) == 1 # if sample i only reaches one leaf then z[i] is all determined
                z[i,:] .= 0
                t = z_i[1]
                z[i,t] = 1 #l_z[i,t]
            else
                for t in 1:(Tb+1)
                    if t in z_i
                        set_start_value.(z_var[j], z_w[i,t])
                        z[i,t] = z_var[j]
                        j += 1
                    else
                        z[i,t] = l_z[i,t]
                    end
                end
            end
        end
    end
    return m, z
end

# functions for constraints setup
# constraints for decomposible solver split node
#cost_udt on n, z_udt on Tl
function cons_setup_node(m, X, p, z_udt, Tb, eps, eps_max, a, b, d, z)
    @constraint(m, [t in 1:Tb], sum(a[j,t] for j in 1:p) == d[t])
    @constraint(m, [t in 1:Tb], b[t] <= d[t])
    @constraint(m, [t in 2:Tb], d[t] <= d[Int(floor(t/2))]) # p(t) is floor(t/2) for t's parent node index
    # the following constraints are set for DT spliting track
    n_idx = findall(x->length(x)>1, z_udt)
    for i in n_idx
        for t in z_udt[i]
            Al, Ar = node_direct(t+Tb)
            #@constraint(m, [i in 1:n, mt in Al], sum(a[j,mt]*(X[j,i]+eps[j]-eps_min) for j in 1:p)+eps_min <= b[mt]+(1+eps_max)*(1-z[i,t-Tb])) # eq 13
            @constraint(m, [mt in Al], sum(a[j,mt]*(X[i,j]+eps[j]) for j in 1:p) + (1+eps_max)*(1-d[mt]) <= b[mt]+(1+eps_max)*(1-z[i,t])) # eq 13
            @constraint(m, [mt in Ar], sum(a[j,mt]*X[i,j] for j in 1:p) >= b[mt]-(1-z[i,t]))
        end
    end
    return m
end

# constraints for SG solver leaf
function cons_setup_leaf_SG(m, n, num_actions, T, Tb, c, z, z_udt, Pi)
    # leaf node constraints
    # constraints on z[i,t] at leaf
    n_idx = findall(x->length(x)>1, z_udt)
    @constraint(m, [i in n_idx], sum(z[i,t] for t in 1:(Tb+1)) == 1) # guarantee a sample only in one leaf
    # constraints for binary node label variable ckt
    @constraint(m, [t in (Tb+1):T], sum(c[k,t] for k in 1:num_actions) == 1) # for each node, label should at most choose one label
    # constraints for determine which node the point i will be allocated
    # constraints on Pi[i,a], z[i,t] + c[a,t] - 1 <= Pi[i,a]
    # @constraint(m, [i in 1:n, a in 1:num_actions, t in 1:(Tb+1)], Pi[i,a] >= z[i,t] + c[a,t+Tb] - 1) # t+Tb is the node index for the corresponding leaf node
    if z_udt isa Vector{Int64} # for only SG mode
        @constraint(m, [i in 1:n, a in 1:num_actions, t in 1:(Tb+1)], Pi[i,a] >= z[i,t] + c[a,t+Tb] - 1) # t+Tb is the node index for the corresponding leaf node
    else
        @constraint(m, [i in 1:n, a in 1:num_actions, t in z_udt[i]], Pi[i,a] >= z[i,t] + c[a,t+Tb] - 1) # t+Tb is the node index for the corresponding leaf node
    end
    
    ###################### NOTE: Need to be modified for SG mode!
    # if z_udt isa Vector{Int64}# for only SG mode
    #     @constraint(m, [i in 1:n, t in (Tb+1):T], costs[i] >= z[i,t-Tb]-sum(y[k,i]*c[k,t] for k in 1:K))
    # else
    #     @constraint(m, [i in 1:n, t in z_udt[i]], costs[i] >= z[i,t]-sum(y[k,i]*c[k,t+Tb] for k in 1:K))
    # end        
    return m
end

# constraints for MILP solver leaf
function cons_setup_leaf_MILP(m, states, Q_value, n, num_actions, T, Tb, c, z_udt, Pi)
    # leaf node constraints
    # constraints for binary node label variable ckt
    @constraint(m, [t in (Tb+1):T], sum(c[k,t] for k in 1:num_actions) == 1) # for each node, label should at most choose one label
    # constraints for determine which node the point i will be allocated
    # @constraint(m, [i in 1:n], costs[i] >= 1-sum(y[k,i]*c[k,t+Tb] for k in 1:K for t in z_udt[i]))
    # for i in 1:n
    #     println("z_udt[i]: $(z_udt[i])")
    #     for a in 1:num_actions
    #         # println("c[a, t+Tb]: $(sum(c[a, t+Tb] for t in z_udt[i]))")
    #     end
    # end
    @constraint(m, [i in 1:n, a in 1:num_actions], Pi[i,a] <= sum(c[a, t+Tb] for t in z_udt[i])) # if any leaf choose action a, then the action should be considered
    return m
end

function cons_setup_rl(m, Pi, n, num_actions)
    @constraint(m, [i in 1:n], sum(Pi[i,a] for a in 1:num_actions) == 1)
    return m
end

function node_direct(t)
    idx = t
    Ar = Int64[]
    Al = Int64[]
    while idx != 1
        if idx % 2 == 1 # can not aliqut
            idx = Int(floor(idx/2))
            push!(Ar, idx)
        else # idx % 2 == 0
            idx = Int(idx/2)
            push!(Al, idx)
        end
    end
    return Al, Ar
end

function mini_dist(v)
    #v = sort(v)
    n = length(v)
    dist = view(v,2:n)-view(v,1:(n-1))
    # handle the special case when all elements are 0
    if all(dist .== 0)
        return 1e-10
    end
    # find the min among non-zero elements
    filter!(x -> x != 0, dist)
    
    return minimum(dist)
end

end