using DataFrames, CSV
using Random, Distributions, StatsBase
using Plots

using MLDataUtils, Clustering
using Distributed, SharedArrays

# load functions for branch&bound and data preprocess from self-created module
@everywhere begin
    # Get the project root directory
    project_dir = dirname(dirname(@__FILE__))  # or use the absolute path
    
    # Add source directories to load path
    src_path = joinpath(project_dir, "src")
    test_path = joinpath(project_dir, "test")
    
    if !(src_path in LOAD_PATH)
        push!(LOAD_PATH, src_path)
    end
    
    if !(test_path in LOAD_PATH)
        push!(LOAD_PATH, test_path)
    end
end


using TimerOutputs: @timeit, get_timer

using Trees, bound_rl, parallel, Nodes_rl, Logger, best_return
using ub_func_rl, lb_func_rl, bb_func_rl, data_process_rl, utils_rl, opt_func_rl

# julia .\test\test_rl.jl 2 0.99 CF+MILP+SG 42 par frozenlake_4x4 4
# Parse command line arguments
D = parse(Int, ARGS[1])  # Depth of tree
gamma = parse(Float64, ARGS[2])  # Discount factor
LB_method = ARGS[3]  # Lower bound method
seed = parse(Int, ARGS[4])  # Random seed
scheme = ARGS[5]  # Scheme (sl/par)
dataname = ARGS[6]  # Dataset name
total_time_limit = parse(Int, ARGS[7])  # Total time seconds
# Iteration_num = parse(Int, ARGS[7])  # Number of iterations

# Optional parameters with defaults
mingap = length(ARGS) >= 8 ? parse(Float64, ARGS[8]) : 0.0005  # Minimum gap
start_action_idx = length(ARGS) >= 9 ? parse(Int, ARGS[9]) : 1  # Starting action index
warm_start_method = length(ARGS) >= 10 ? ARGS[10] : "CART"  # Warm start method
coffcient_method = length(ARGS) >= 11 ? ARGS[11] : "exp"  # Coefficient method (linear/exp)
limit_per_iteration = length(ARGS) >= 12 ? parse(Int, ARGS[12]) : 300  # Limit per iteration

threhold_upper = length(ARGS) >= 13 ? parse(Float64, ARGS[13]) : 1 # limit the threshold each iteration
decay_method = length(ARGS) >= 14 ? ARGS[14] : "sqrt" # decay method (sqrt/linear)
# set dataset to string
println("gamma is set to $gamma")

## parse the coeffcient
if occursin("exp", coffcient_method)
    # Remove all "exp" substrings
    coeff_str = replace(coffcient_method, "exp" => "")
    try
        k = parse(Int, coeff_str)
    catch
        # fallback default if parse fails
        k = 1
    end
end


if scheme == "par"
    using MPI
    parallel.init()
end

parallel.create_world()
parallel.root_println("Running $(parallel.nprocs()) processes.")
parallel.root_println("Start training $dataname with seed $seed.")


#############################################################
################# Main Process Program Body #################
#############################################################
# function initialize_data(dataname, seed, D)
if parallel.is_root()
    states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim = read_data(dataname)
    parallel.root_println("num_states: $num_states")
    parallel.root_println("num_actions: $num_actions")
    parallel.root_println("feature_dim: $feature_dim")
    # construct a new varaivble V_old and initialize it with all zero
    # V_old = zeros(num_states)
    if dataname in keys(best_return.return_upper)
        cumul_return_upper = best_return.return_upper[dataname]
    else
        cumul_return_upper, x= opt_func_rl.get_optimal_cumul_return(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions)
    end
    
    
    log_data = Logger.init_log_timeLimit(dataname, seed, gamma, D, LB_method, scheme, total_time_limit, cumul_return_upper, "10Direct_TL", start_action_idx, mingap, warm_start_method, coffcient_method, limit_per_iteration, threhold_upper, decay_method)


    if warm_start_method == "CART"
        init_policy = ones(Int, num_states) .* start_action_idx
        _, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-5, 100000)
        ## Now use the epsilon-greedy policy evaluation
        V_old, _ = utils_rl.policy_evaluation_with_epsilon(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1, 1e-5, 100000,)
        
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        # shape of discounted_occupancy_rewards: (num_states, num_actions)
        Random.seed!(seed)
        alp = 0.00
        L_hat = 1
        # get initial tree
        # time_w = @elapsed tree_w, objv_w = lb_func_rl.warm_start(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        # if warm_start_method == "CART"
        time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        a_init = tree_w.a
        b_init = tree_w.b
        c_init = tree_w.c
        gap_warm = 0.0
        # d_init = tree_w.d
    elseif warm_start_method == "Random"
        time_w = @elapsed a_init, b_init, c_init, d_init, init_policy = opt_func_rl.init_random_policy(states, D, feature_dim, num_actions)
        _, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-5, 100000)
        V_old, _ = utils_rl.policy_evaluation_with_epsilon(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1, 1e-5, 100000,)

        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        gap_warm = 0.0
        alp = 0.00
        L_hat = 1
        objv_w = cumul_reward_init

    elseif warm_start_method == "OMDT"
        time_w = @elapsed a_init, b_init, c_init, objv_w, gap_warm, Pi_init = opt_func_rl.OMDT(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions, D; time_limit=300, seed=seed, mingap=0.0005)
        a_init = a_init
        b_init = b_init  
        c_init = c_init
        alp = 0.00
        L_hat = 1

        # d_init = d_init
        # cumul_reward_init = objv_w
        init_policy = Pi_init
        println("objv_w: $objv_w")
        println("gap_warm: $gap_warm")
        println("Pi_init: $Pi_init")

        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-5, 1000000)
        ## Now use the epsilon-greedy policy evaluation
        # V_old, _ = utils_rl.policy_evaluation_with_epsilon(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1, 1e-5, 1000000)
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        println("discounted_occupancy: $discounted_occupancy")
    end
    ### NOTE: need to change L_hat
    println("V:_init_greedy $(V_old)")
    println("cumul_reward_init: $cumul_reward_init")
    Logger.log_warm_start_results(warm_start_method, time_w, cumul_reward_init, gap_warm)

    if warm_start_method == "OMDT"
        if abs(cumul_reward_init - objv_w) >= 1e-6
            @warn "cumul_reward_init and objv_w are not equal"
        end
    end    

    # shape of discounted_occupancy: (num_states,)

    # multiply discounted_occupancy with rewards
    for i in 1:num_states
        if coffcient_method == "exp"
            Q_value[i, :] = Q_value[i, :] .* exp(discounted_occupancy[i] * k)
            # println("Q_value[i, :] $(Q_value[i, :])")
        elseif coffcient_method == "linear"
            Q_value[i, :] = Q_value[i, :] .* discounted_occupancy[i]
        end
    end


    # println("a_init: $(a_init)")
    # println("b_init: $(b_init)")
    # println("c_init: $(c_init)")
    # println("d_init: $(d_init)")
else
    L_hat = 0
    num_actions = 0
    alp = 0
    objv_w = 0
    tree_w = nothing
    a_init = nothing
    b_init = nothing
    c_init = nothing
    d_init = nothing
end

println("initialize data")
println("size of states: $num_states")
println("size of actions: $num_actions")
println("size of features: $feature_dim")
println("size of states: $(size(states))")
# states, transition_prob, rewards, initial_state_p, Q_value, num_states, num_actions, feature_dim, alp, tree_w, L_hat, time_w, objv_w = initialize_data(dataname, seed, D)
# println("Q_value: $(Q_value)")
##################### Testing of different ODT methods #####################
# D and LB_method are already in all processes at the beginning.
alp = parallel.bcast(alp)
L_hat = parallel.bcast(L_hat)
objv_w = parallel.bcast(objv_w)
# tree_w = parallel.bcast(tree_w)
num_states = parallel.bcast(num_states)

LB_method = split(LB_method, "+") # CF must in the method
gap_sg = nothing
cumul_reward_list = [cumul_reward_init]
tree_sg_list = []
last_a = a_init
last_b = b_init
last_c = c_init
# last_d = d_init

# Initialize threshold as a global variable
if warm_start_method == "OMDT"
    if gap_warm < 0.3
        global threshold = 0.3
    else
        global threshold = gap_warm
    end
else
    global threshold = 1
end

global threshold = min(threhold_upper, threshold)


global cumul_time = 0

global iteration = 1
# for i in 1:Iteration_num
while cumul_time < total_time_limit
    # println("Iteration $iteration")
    global iteration
    if parallel.is_root()
        # global time_w, objv_w
        # for each node, there are 20% chance to be add to undtermined set
        a_udt = [t for t in 1:(2^D-1) if rand() <= threshold]
        a_dt = setdiff(1:(2^D-1), a_udt)
        b_udt = [t for t in 1:(2^D-1) if rand() <= threshold]
        b_dt = setdiff(1:(2^D-1), b_udt)
        c_udt = [t for t in 1:(2^(D+1)-1) if rand() <= threshold]
        c_dt = setdiff(1:(2^(D+1)-1), c_udt)
        d_udt = [t for t in 1:(2^D-1) if rand() <= threshold]
        d_dt = setdiff(1:(2^D-1), d_udt)
        
        # 判定 a_udt 和a_dt 之合是否等于1:(2^D-1)
        if length(a_udt) + length(a_dt) != 2^D-1
            @warn "a_udt and a_dt do not cover all states"
        end
        if length(b_udt) + length(b_dt) != 2^D-1
            @warn "b_udt and b_dt do not cover all states"
        end
        if length(c_udt) + length(c_dt) != 2^(D+1)-1
            @warn "c_udt and c_dt do not cover all states"
        end
        
        # println("a_udt: $a_udt")
        # println("last_a: $last_a")
        # println("a_dt: $a_dt")
        # println("b_udt: $b_udt")
        # println("last_b: $last_b")
        # println("b_dt: $b_dt")
        # println("c_udt: $c_udt")
        # println("last_c: $last_c")
        # println("c_dt: $c_dt")
        # println("d_udt: $d_udt")
        # get ground truth optimal objective value
        time_direct = @elapsed a_opt, b_opt, c_opt, objv_optimal, gap_optimal, Pi_opt = opt_func_rl.direct_iteration(states, Q_value, D, num_actions, last_a, last_b, last_c, nothing, a_udt, a_dt, b_udt, b_dt, c_udt, c_dt, d_udt, d_dt, alp; seed=seed, mingap=mingap, time_limit=limit_per_iteration)
        println("objv_optimal: $objv_optimal")
        println("time_direct: $time_direct")

        if dataname == "wlan0" || dataname == "wlan1"
            tol_for_eval = 5.0
        else
            tol_for_eval = 1e-5
        end
        # println("pi_: $(Pi_opt)")
        _, cumul_reward = utils_rl.policy_evaluation(Pi_opt, transition_prob, rewards, initial_state_p, gamma, num_states, tol_for_eval, 100000)
        
        if decay_method == "sqrt"
            epsilon = 1 / sqrt(iteration+1)
        elseif decay_method == "linear"
            epsilon = 1 / (iteration+1)
        end
        V, _ = utils_rl.policy_evaluation_with_epsilon(Pi_opt, transition_prob, rewards, initial_state_p, gamma, num_states, epsilon, tol_for_eval, 100000)
        # println("V_with_epsilon: $(V)")
        println("Iteration $iteration: cumul_reward: $cumul_reward")

        Logger.log_iteration(iteration, 0.0, nothing, time_direct, objv_optimal, objv_optimal, gap_optimal, cumul_reward, V, threshold)

        # Update Q_value for next iteration
        global Q_value = compute_Q_value(transition_prob, rewards, V, gamma, num_states, num_actions)

        
        global discounted_occupancy = utils_rl.get_discounted_occupancy(Pi_opt, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)

        # println("Original Q_value: $(Q_value)")
        # println("discounted_occupancy: $(discounted_occupancy)")
        for j in 1:num_states
            if coffcient_method == "exp"
                Q_value[j, :] = Q_value[j, :] .* exp(discounted_occupancy[j] * k)
            elseif coffcient_method == "linear"
                Q_value[j, :] = Q_value[j, :] .* discounted_occupancy[j]
            end
        end
        

        if cumul_reward > cumul_reward_list[end] # its a nice update
            # Store optimal values for next iteration
            global last_a = a_opt
            global last_b = b_opt
            global last_c = c_opt
            # global last_d = d_opt
        else
            if warm_start_method == "OMDT" 
                Init_temp = 0.02
            elseif warm_start_method == "CART" || warm_start_method == "Random"
                Init_temp = 0.05
            end
            temperature = Init_temp/log(iteration+1)
            # use heuristic from hill climbing
            # If the time is too short, and the result is not good, do not update
            normalized_gap = (cumul_reward - cumul_reward_list[end])/abs(cumul_return_upper)
            if rand() < exp(normalized_gap/temperature) && time_direct > 2
            # Store optimal values for next iteration
                global last_a = a_opt
                global last_b = b_opt
                global last_c = c_opt
                # global last_d = d_opt
            end
        end
        gap_norm = abs(cumul_reward - cumul_return_upper)/abs(cumul_return_upper)
        # if gap_norm > 
        #     global threshold = 0.
        # else
        #     global threshold = 1 / sqrt(Iteration_num)
        # end
        # if warm_start_method == "CART"

        # if solving the problem is very quick, set threshold larger
    
        if time_direct < 5
            global threshold = threshold + 0.1 # set threshold larger than last threshold
        else
            if gap_norm < 0.3
                global threshold = max(1/sqrt(iteration+1), 0.3) # ensure that we have enough exploration
            else
                global threshold = max(1/sqrt(iteration+1), gap_norm)
            end
        end

        global threshold = min(threhold_upper, threshold)
        # global threshold = 1
        push!(cumul_reward_list, cumul_reward)

        global cumul_time += time_direct

        # elseif warm_start_method == "OMDT"
        #     global threshold = 0.3
        # end
        # global threshold = 0.2
    else
        # the whole data are saved in root and scattered to other processes, in other processes, data input are nothing
        tree_sg = nothing
        objv_sg = nothing
        calc = nothing
        UB_sg = nothing
        threshold = nothing
    end

    
    # Synchronize all processes before continuing to the next iteration
    parallel.barrier()
    
    # Broadcast Q_value and optimal values from root to all processes for next iteration
    Q_value = parallel.bcast(Q_value)
    last_a = parallel.bcast(last_a)
    last_b = parallel.bcast(last_b)
    last_c = parallel.bcast(last_c)
    Pi_opt = parallel.bcast(Pi_opt)
    # last_d = parallel.bcast(last_d)
    threshold = parallel.bcast(threshold)

    iteration += 1

end
# find the best cumul_reward
best_cumul_reward = maximum(cumul_reward_list)
best_iteration = findfirst(cumul_reward_list .== best_cumul_reward) - 1
println("Best cumul_reward: $best_cumul_reward at iteration $best_iteration")
# find the best tree

if parallel.is_root()
    # if gap_sg === nothing && calc !== nothing && !isempty(calc)
    #     gap_sg = round(calc[end][end], digits=4)
    # end
    gap_sg = 0.0
    # Log final results
    Logger.log_final_results(best_cumul_reward, best_iteration, gap_sg, cumul_return_upper)
    log_file = Logger.save_log_json()
    
    println("Final gap: $gap_sg")
    println("Log saved to: $log_file")

    ##################### Tree structure plot #####################
    # plt = true
    # if plt
    #     tree_plot(tree_w, "CART", dataname)
    #     tree_plot(tree_sg, "sglb", dataname)
    # end
end

parallel.finalize()
