using DataFrames, CSV
using Random, Distributions, StatsBase
using Plots

using MLDataUtils, Clustering
using Distributed, SharedArrays

# load functions for branch&bound and data preprocess from self-created module
@everywhere begin
    # Get the project root directory
    project_dir = dirname(dirname(@__FILE__))  # or use the absolute path
    
    # Add source directories to load path
    src_path = joinpath(project_dir, "src")
    test_path = joinpath(project_dir, "test")
    
    if !(src_path in LOAD_PATH)
        push!(LOAD_PATH, src_path)
    end
    
    if !(test_path in LOAD_PATH)
        push!(LOAD_PATH, test_path)
    end
end


using TimerOutputs: @timeit, get_timer

using Trees, bound_rl, parallel, Nodes_rl
using ub_func_rl, lb_func_rl, bb_func_rl, data_process_rl, utils_rl, opt_func_rl

# julia .\test\test_rl.jl 2 0.99 CF+MILP+SG 42 par frozenlake_4x4 4

D = parse(Int, ARGS[1]) # 2 # 
gamma = parse(Float64, ARGS[2]) # 0.99
LB_method = ARGS[3] # "CF+MILP+SG" #
seed = parse(Int, ARGS[4]) # 1 # 
scheme = ARGS[5] # sl # par #  
dataname = ARGS[6] # 
Iteration_num = parse(Int, ARGS[7]) # 4
start_action_idx = length(ARGS) >= 8 ? parse(Int, ARGS[8]) : 1 # default is 1
# set dataset to string
println("gamma is set to $gamma")

if scheme == "par"
    using MPI
    parallel.init()
end

parallel.create_world()
parallel.root_println("Running $(parallel.nprocs()) processes.")
parallel.root_println("Start training $dataname with seed $seed.")


#############################################################
################# Main Process Program Body #################
#############################################################
# function initialize_data(dataname, seed, D)
if parallel.is_root()
    states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim = read_data(dataname)
    # construct a new varaivble V_old and initialize it with all zero
    # V_old = zeros(num_states)
    init_policy = ones(Int, num_states) .* start_action_idx
    V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-6, 10000)
    println("V:_init $(V_old)")
    println("cumul_reward_init: $cumul_reward_init")
    Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
    
    Random.seed!(seed)
    alp = 0.00
    L_hat = 1
    # get the initial tree
    # time_w = @elapsed tree_w, objv_w = lb_func_rl.warm_start(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
    time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
    ### NOTE: need to change L_hat

else
    L_hat = 0
    num_actions = 0
    alp = 0
    objv_w = 0
    tree_w = nothing
end

println("initialize data")
println("size of states: $num_states")
println("size of actions: $num_actions")
println("size of features: $feature_dim")
println("size of states: $(size(states))")
# states, transition_prob, rewards, initial_state_p, Q_value, num_states, num_actions, feature_dim, alp, tree_w, L_hat, time_w, objv_w = initialize_data(dataname, seed, D)
# println("Q_value: $(Q_value)")
##################### Testing of different ODT methods #####################
# D and LB_method are already in all processes at the beginning.
alp = parallel.bcast(alp)
L_hat = parallel.bcast(L_hat)
objv_w = parallel.bcast(objv_w)
tree_w = parallel.bcast(tree_w)
num_states = parallel.bcast(num_states)

LB_method = split(LB_method, "+") # CF must in the method

# start Iteration_num times training
# Initialize variables outside the loop to maintain their scope
# tree_sg = nothing
# objv_sg = nothing
# calc = nothing
# LB_sg = nothing
# gap_sg = nothin tree_sg = nothing
# global tree_sg = nothing
# global objv_sg = nothing
# global calc = nothing
# global LB_sg = nothing
# global gap_sg = nothing
# global tree_w 
# global objv_w 
# global time_w 
# # 添加缺失的全局变量
# global tree_g = nothing
# global tree_d = nothing
gap_sg = nothing
for i in 1:Iteration_num
    println("Iteration $i")
    # if i == 1
    #     pass
    # else
    #     # warm start
    #     time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, start_action_idx, alp, L_hat)
    #     global tree_w = tree_w
    #     global objv_w = objv_w
    # end
    # update Q_value
    ##################### Start training and optimization #####################
    if parallel.is_root()
        global time_w, tree_w, objv_w
        if i != 1
            # warm start
            time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        end

        global time_sg
        time_sg = @elapsed begin
            global tree_sg, objv_sg, calc, UB_sg
            tree_sg, objv_sg, calc, UB_sg = bb_func_rl.branch_bound(states, Q_value, num_states, num_actions, feature_dim, D, tree_w, objv_w, alp, L_hat, LB_method, true, false, seed)
        end

        # get ground truth optimal objective value
        a_opt, b_opt, c_opt, d_opt, objv_optimal = opt_func_rl.optimal_iteration(states, Q_value, D, num_actions)
        println("objv_optimal: $objv_optimal")

        # Only perform policy evaluation in root process
        # extract policy from tree
        Pi_opt = utils_rl.extract_policy(tree_sg, states, num_states, num_actions)
        # println("Pi_opt: $(Pi_opt)")
        # update Q_value
        # change policy to be [num_states,1]
        pi_ = zeros(Int, num_states, 1)
        for j in 1:num_states
            pi_[j] = Int(argmax(Pi_opt[j,:]))
        end
        println("pi_: $(pi_)")
        V, cumul_reward = utils_rl.policy_evaluation(pi_, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-6, 1000000)
        println("V: $(V)")
        println("Iteration $i: cumul_reward: $cumul_reward")
        
        # Update Q_value for next iteration
        global Q_value = compute_Q_value(transition_prob, rewards, V, gamma, num_states, num_actions)

    else
        # the whole data are saved in root and scattered to other processes, in other processes, data input are nothing
        tree_sg, objv_sg, calc, UB_sg = bb_func_rl.branch_bound(nothing, nothing, num_states, num_actions, feature_dim, D, gamma, tree_w, objv_w, alp, L_hat, LB_method, true, false)

    end
    # Broadcast Q_value from root to all processes for next iteration
    Q_value = parallel.bcast(Q_value)


end

if parallel.is_root()
    Trees.printTree(tree_w)
    Trees.printTree(tree_sg)
    if gap_sg === nothing && calc !== nothing && !isempty(calc)
        gap_sg = round(calc[end][end], digits=4)
    end
    
    println("Final gap: $gap_sg")

    ##################### Tree structure plot #####################
    plt = true
    if plt
        tree_plot(tree_w, "CART", dataname)
        tree_plot(tree_sg, "sglb", dataname)
    end
end

parallel.finalize()
