using DataFrames, CSV
using Random, Distributions, StatsBase
using Plots

using MLDataUtils, Clustering
using Distributed, SharedArrays

# load functions for branch&bound and data preprocess from self-created module
@everywhere begin
    # Get the project root directory
    project_dir = dirname(dirname(@__FILE__))  # or use the absolute path
    
    # Add source directories to load path
    src_path = joinpath(project_dir, "src")
    test_path = joinpath(project_dir, "test")
    
    if !(src_path in LOAD_PATH)
        push!(LOAD_PATH, src_path)
    end
    
    if !(test_path in LOAD_PATH)
        push!(LOAD_PATH, test_path)
    end
end


using TimerOutputs: @timeit, get_timer

using Trees, bound_rl, parallel, Nodes_rl, Logger, best_return
using ub_func_rl, lb_func_rl, bb_func_rl, data_process_rl, utils_rl, opt_func_rl

# julia .\test\test_rl.jl 2 0.99 CF+MILP+SG 42 par frozenlake_4x4 4

D = parse(Int, ARGS[1]) # 2 # 
gamma = parse(Float64, ARGS[2]) # 0.99
LB_method = ARGS[3] # "CF+MILP+SG" #
seed = parse(Int, ARGS[4]) # 1 # 
scheme = ARGS[5] # sl # par #  
dataname = ARGS[6] # 
Iteration_num = parse(Int, ARGS[7]) # 4
mingap = length(ARGS) >= 8 ? parse(Float64, ARGS[8]) : 0.0005 # 0.0005
start_action_idx = length(ARGS) >= 9 ? parse(Int, ARGS[9]) : 1 # default is 1
# set dataset to string
println("gamma is set to $gamma")

if scheme == "par"
    using MPI
    parallel.init()
end

parallel.create_world()
parallel.root_println("Running $(parallel.nprocs()) processes.")
parallel.root_println("Start training $dataname with seed $seed.")

#############################################################
################# Main Process Program Body #################
#############################################################
# function initialize_data(dataname, seed, D)
if parallel.is_root()
    states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim = read_data(dataname)
    # Initialize logger
    if dataname in keys(best_return.return_upper)
        cumul_return_upper = best_return.return_upper[dataname]
    else
        cumul_return_upper, x= opt_func_rl.get_optimal_cumul_return(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions)
    end
    
    log_data = Logger.init_log(dataname, seed, gamma, D, LB_method, scheme, Iteration_num, cumul_return_upper, "Space Reduced BB", start_action_idx, mingap)

    # construct a new varaivble V_old and initialize it with all zero
    # V_old = zeros(num_states)
    init_policy = ones(Int, num_states) .* start_action_idx
    V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
    println("V:_init $(V_old)")
    println("cumul_reward_init: $cumul_reward_init")
    Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
    
    Random.seed!(seed)
    alp = 0.00
    L_hat = 1
    # get the initial tree
    # time_w = @elapsed tree_w, objv_w = lb_func_rl.warm_start(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
    time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
    ### NOTE: need to change L_hat

else
    L_hat = 0
    num_actions = 0
    alp = 0
    objv_w = 0
    tree_w = nothing
    cumul_return_upper = 0
end

println("initialize data")
println("size of states: $num_states")
println("size of actions: $num_actions")
println("size of features: $feature_dim")
println("size of states: $(size(states))")
# states, transition_prob, rewards, initial_state_p, Q_value, num_states, num_actions, feature_dim, alp, tree_w, L_hat, time_w, objv_w = initialize_data(dataname, seed, D)
# println("Q_value: $(Q_value)")
##################### Testing of different ODT methods #####################
# D and LB_method are already in all processes at the beginning.
alp = parallel.bcast(alp)
L_hat = parallel.bcast(L_hat)
objv_w = parallel.bcast(objv_w)
tree_w = parallel.bcast(tree_w)
num_states = parallel.bcast(num_states)

LB_method = split(LB_method, "+") # CF must in the method
gap_sg = nothing

# we should record the cumul_reward for each iteration
cumul_reward_list = []
tree_w_list = []
tree_sg_list = []


for i in 1:Iteration_num
    println("Iteration $i")
    ##################### Start training and optimization #####################
    if parallel.is_root()
        global time_w, tree_w, objv_w
        if i != 1
            # warm start
            time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        end
        push!(tree_w_list, tree_w)

        global time_sg
        time_sg = @elapsed begin
            global tree_sg, objv_sg, calc, UB_sg
            tree_sg, objv_sg, calc, UB_sg = bb_func_rl.branch_bound(states, Q_value, num_states, num_actions, feature_dim, D, tree_w, objv_w, alp, L_hat, LB_method, true, false; seed=seed, mingap=mingap)
        end
        push!(tree_sg_list, tree_sg)
        # Only perform policy evaluation in root process
        # extract policy from tree
        Pi_opt = utils_rl.extract_policy(tree_sg, states, num_states, num_actions)
        # println("Pi_opt: $(Pi_opt)")
        # update Q_value
        # change policy to be [num_states,1]
        pi_ = zeros(Int, num_states, 1)
        for j in 1:num_states
            pi_[j] = Int(argmax(Pi_opt[j,:]))
        end
        println("pi_: $(pi_)")
        V, cumul_reward = utils_rl.policy_evaluation(pi_, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        println("V: $(V)")
        println("Iteration $i: cumul_reward: $cumul_reward")
        push!(cumul_reward_list, cumul_reward)

        gap_ = !isempty(calc) ? calc[end][end] : nothing
        # Log iteration results
        Logger.log_iteration(i, time_w, objv_w, time_sg, objv_sg, UB_sg, gap_, cumul_reward, V)
        
        # Update Q_value for next iteration
        global Q_value = compute_Q_value(transition_prob, rewards, V, gamma, num_states, num_actions)

    else
        # the whole data are saved in root and scattered to other processes, in other processes, data input are nothing
        tree_sg, objv_sg, calc, UB_sg = bb_func_rl.branch_bound(nothing, nothing, num_states, num_actions, feature_dim, D, tree_w, objv_w, alp, L_hat, LB_method, true, false; seed=seed)
    end
    
    # Synchronize all processes before continuing to the next iteration
    parallel.barrier()
    
    # Broadcast Q_value from root to all processes for next iteration
    Q_value = parallel.bcast(Q_value)
end


# find the best cumul_reward
best_cumul_reward = maximum(cumul_reward_list)
best_iteration = findfirst(cumul_reward_list .== best_cumul_reward)
println("Best cumul_reward: $best_cumul_reward at iteration $best_iteration")
# find the best tree
best_tree_w = tree_w_list[best_iteration]
best_tree_sg = tree_sg_list[best_iteration]

if parallel.is_root()
    # Trees.printTree(best_tree_w)
    # Trees.printTree(best_tree_sg)
    if gap_sg === nothing && calc !== nothing && !isempty(calc)
        gap_sg = round(calc[end][end], digits=4)
    end
    
    # Log final results
    Logger.log_final_results(best_cumul_reward, best_iteration, gap_sg, cumul_return_upper)
    log_file = Logger.save_log_json("SP_BB")
    
    println("Final gap: $gap_sg")
    println("Log saved to: $log_file")

    ##################### Tree structure plot #####################
    # plt = true
    # if plt
    #     tree_plot(best_tree_w, "CART", dataname)
    #     tree_plot(best_tree_sg, "sglb", dataname)
    # end
end

parallel.finalize()
