using DataFrames, CSV
using Random, Distributions, StatsBase
using Plots

using MLDataUtils, Clustering
using Distributed, SharedArrays

# load functions for branch&bound and data preprocess from self-created module
@everywhere begin
    # Get the project root directory
    project_dir = dirname(dirname(@__FILE__))  # or use the absolute path
    
    # Add source directories to load path
    src_path = joinpath(project_dir, "src")
    test_path = joinpath(project_dir, "test")
    
    if !(src_path in LOAD_PATH)
        push!(LOAD_PATH, src_path)
    end
    
    if !(test_path in LOAD_PATH)
        push!(LOAD_PATH, test_path)
    end
end


using TimerOutputs: @timeit, get_timer

using Trees, bound_rl, parallel, Nodes_rl, Logger, best_return
using ub_func_rl, lb_func_rl, bb_func_rl, data_process_rl, utils_rl, opt_func_rl

# julia .\test\test_rl.jl 2 0.99 CF+MILP+SG 42 par frozenlake_4x4 4
# Parse command line arguments
D = parse(Int, ARGS[1])  # Depth of tree
gamma = parse(Float64, ARGS[2])  # Discount factor
LB_method = ARGS[3]  # Lower bound method
seed = parse(Int, ARGS[4])  # Random seed
scheme = ARGS[5]  # Scheme (sl/par)
dataname = ARGS[6]  # Dataset name
Iteration_num = parse(Int, ARGS[7])  # Number of iterations

# Optional parameters with defaults
mingap = length(ARGS) >= 8 ? parse(Float64, ARGS[8]) : 0.0005  # Minimum gap
start_action_idx = length(ARGS) >= 9 ? parse(Int, ARGS[9]) : 1  # Starting action index
warm_start_method = length(ARGS) >= 10 ? ARGS[10] : "CART"  # Warm start method
coffcient_method = length(ARGS) >= 11 ? ARGS[11] : "exp"  # Coefficient method (linear/exp)

# set dataset to string
println("gamma is set to $gamma")

if scheme == "par"
    using MPI
    parallel.init()
end

parallel.create_world()
parallel.root_println("Running $(parallel.nprocs()) processes.")
parallel.root_println("Start training $dataname with seed $seed.")


#############################################################
################# Main Process Program Body #################
#############################################################
# function initialize_data(dataname, seed, D)
if parallel.is_root()
    states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim = read_data(dataname)
    # construct a new varaivble V_old and initialize it with all zero
    # V_old = zeros(num_states)
    if dataname in keys(best_return.return_upper)
        cumul_return_upper = best_return.return_upper[dataname]
    else
        cumul_return_upper, x= opt_func_rl.get_optimal_cumul_return(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions)
    end

    log_data = Logger.init_log(dataname, seed, gamma, D, LB_method, scheme, Iteration_num, cumul_return_upper, "8Direct", start_action_idx, mingap, warm_start_method, coffcient_method)

    Random.seed!(seed)
    if warm_start_method == "CART" 
        init_policy = ones(Int, num_states) .* start_action_idx
        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 100000)
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        # shape of discounted_occupancy_rewards: (num_states, num_actions)
        alp = 0.00
        L_hat = 1
        # get initial tree
        # time_w = @elapsed tree_w, objv_w = lb_func_rl.warm_start(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        # if warm_start_method == "CART"
        time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        a_init = tree_w.a
        b_init = tree_w.b
        c_init = tree_w.c
        gap_warm = 0.0
        # d_init = tree_w.d
    elseif warm_start_method == "Random"
        a_init, b_init, c_init, d_init, init_policy = opt_func_rl.init_random_policy(states, D, feature_dim, num_actions)
        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 100000)
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        gap_warm = 0.0
        alp = 0.00
        L_hat = 1

    elseif warm_start_method == "OMDT"
        time_w = @elapsed a_init, b_init, c_init, objv_w, gap_warm, Pi_init = opt_func_rl.OMDT(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions, D; time_limit=300, seed=seed, mingap=0.0005)
        a_init = a_init
        b_init = b_init  
        c_init = c_init
        alp = 0.00
        L_hat = 1

        # d_init = d_init
        # cumul_reward_init = objv_w
        init_policy = Pi_init
        println("objv_w: $objv_w")
        println("gap_warm: $gap_warm")
        println("Pi_init: $Pi_init")
        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        println("discounted_occupancy: $discounted_occupancy")
    end
    ### NOTE: need to change L_hat
    println("V:_init $(V_old)")
    println("cumul_reward_init: $cumul_reward_init")
    Logger.log_warm_start_results(warm_start_method, time_w, cumul_reward_init, gap_warm)

    if warm_start_method == "OMDT"
        if abs(cumul_reward_init - objv_w) >= 1e-6
            @warn "cumul_reward_init and objv_w are not equal"
        end
    end    

    # shape of discounted_occupancy: (num_states,)

    # multiply discounted_occupancy with rewards
    for i in 1:num_states
        if coffcient_method == "exp"
            Q_value[i, :] = Q_value[i, :] .* exp(discounted_occupancy[i])
            # println("Q_value[i, :] $(Q_value[i, :])")
        elseif coffcient_method == "linear"
            Q_value[i, :] = Q_value[i, :] .* discounted_occupancy[i]
        end
    end


    # println("a_init: $(a_init)")
    # println("b_init: $(b_init)")
    # println("c_init: $(c_init)")
    # println("d_init: $(d_init)")
else
    L_hat = 0
    num_actions = 0
    alp = 0
    objv_w = 0
    tree_w = nothing
    a_init = nothing
    b_init = nothing
    c_init = nothing
    d_init = nothing
end

println("initialize data")
println("size of states: $num_states")
println("size of actions: $num_actions")
println("size of features: $feature_dim")
println("size of states: $(size(states))")
# states, transition_prob, rewards, initial_state_p, Q_value, num_states, num_actions, feature_dim, alp, tree_w, L_hat, time_w, objv_w = initialize_data(dataname, seed, D)
# println("Q_value: $(Q_value)")
##################### Testing of different ODT methods #####################
# D and LB_method are already in all processes at the beginning.
alp = parallel.bcast(alp)
L_hat = parallel.bcast(L_hat)
objv_w = parallel.bcast(objv_w)
# tree_w = parallel.bcast(tree_w)
num_states = parallel.bcast(num_states)

LB_method = split(LB_method, "+") # CF must in the method
gap_sg = nothing
cumul_reward_list = [cumul_reward_init]
tree_sg_list = []
last_a = a_init
last_b = b_init
last_c = c_init
# last_d = d_init

# Initialize threshold as a global variable
if warm_start_method == "OMDT"
    if gap_warm < 0.2
        global threshold = 0.2
    else
        global threshold = 0.4
    end
else
    global threshold = 1
end

global threshold = 1

for i in 1:Iteration_num
    println("Iteration $i")
    
    if parallel.is_root()
        global time_w, objv_w
        # for each node, there are 20% chance to be add to undtermined set
        a_udt = [t for t in 1:(2^D-1) if rand() <= threshold]
        a_dt = setdiff(1:(2^D-1), a_udt)
        b_udt = [t for t in 1:(2^D-1) if rand() <= threshold]
        b_dt = setdiff(1:(2^D-1), b_udt)
        c_udt = [t for t in (2^D):(2^(D+1)-1) if rand() <= threshold]
        c_dt = setdiff(1:(2^(D+1)-1), c_udt)
        d_udt = [t for t in 1:(2^D-1) if rand() <= threshold]
        d_dt = setdiff(1:(2^D-1), d_udt)
    
        # get ground truth optimal objective value
        # time_direct = @elapsed a_opt, b_opt, c_opt, objv_optimal, gap_optimal, Pi_opt = opt_func_rl.direct_iteration(states, Q_value, D, num_actions, last_a, last_b, last_c, nothing, a_udt, a_dt, b_udt, b_dt, c_udt, c_dt, d_udt, d_dt, alp; seed=seed, mingap=mingap, time_limit=300)
        time_direct = @elapsed a_opt, b_opt, c_opt, objv_optimal, gap_optimal, Pi_opt = opt_func_rl.direct_iteration(states, Q_value, D, num_actions, last_a, last_b, last_c, nothing, a_udt, a_dt, b_udt, b_dt, c_udt, c_dt, d_udt, d_dt, alp; seed=seed, mingap=mingap, time_limit=18000)
        println("objv_optimal: $objv_optimal")
        println("time_direct: $time_direct")

        println("pi_: $(Pi_opt)")
        V, cumul_reward = utils_rl.policy_evaluation(Pi_opt, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        println("V: $(V)")
        println("Iteration $i: cumul_reward: $cumul_reward")

        Logger.log_iteration(i, 0.0, nothing, time_direct, objv_optimal, objv_optimal, gap_optimal, cumul_reward, V, threshold)

        # Update Q_value for next iteration
        global Q_value = compute_Q_value(transition_prob, rewards, V, gamma, num_states, num_actions)

        
        global discounted_occupancy = utils_rl.get_discounted_occupancy(Pi_opt, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)

        # println("Original Q_value: $(Q_value)")
        # println("discounted_occupancy: $(discounted_occupancy)")
        for j in 1:num_states
            if coffcient_method == "exp"
                Q_value[j, :] = Q_value[j, :] .* exp(discounted_occupancy[j])
            elseif coffcient_method == "linear"
                Q_value[j, :] = Q_value[j, :] .* discounted_occupancy[j]
            end
        end
        

        if cumul_reward > cumul_reward_list[end] # its a nice update
            # Store optimal values for next iteration
            global last_a = a_opt
            global last_b = b_opt
            global last_c = c_opt
            # global last_d = d_opt
        else
            if warm_start_method == "OMDT"
                Init_temp = 0.02
            elseif warm_start_method == "CART"
                Init_temp = 0.05
            end
            temperature = Init_temp/log(i+1)
            # use heuristic from hill climbing
            # If the time is too short, and the result is not good, do not update
            if rand() < exp((cumul_reward - cumul_reward_list[end])/temperature) && time_direct > 2
            # Store optimal values for next iteration
                global last_a = a_opt
                global last_b = b_opt
                global last_c = c_opt
                # global last_d = d_opt
            end
        end
        gap_norm = abs(cumul_reward - cumul_return_upper)/abs(cumul_return_upper)
        # if gap_norm > 
        #     global threshold = 0.
        # else
        #     global threshold = 1 / sqrt(Iteration_num)
        # end
        # if warm_start_method == "CART"

        if gap_norm < 0.3
            global threshold = 0.3
        else
            global threshold = gap_norm
        end
        # if solving the problem is very quick, set threshold larger
        if time_direct < 5 
            global threshold = threshold + 0.1
        end

        # global threshold = 1
        push!(cumul_reward_list, cumul_reward)

        # elseif warm_start_method == "OMDT"
        #     global threshold = 0.3
        # end
        # global threshold = 0.2
    else
        # the whole data are saved in root and scattered to other processes, in other processes, data input are nothing
        tree_sg = nothing
        objv_sg = nothing
        calc = nothing
        UB_sg = nothing
        threshold = nothing
    end

    
    # Synchronize all processes before continuing to the next iteration
    parallel.barrier()
    
    # Broadcast Q_value and optimal values from root to all processes for next iteration
    Q_value = parallel.bcast(Q_value)
    last_a = parallel.bcast(last_a)
    last_b = parallel.bcast(last_b)
    last_c = parallel.bcast(last_c)
    Pi_opt = parallel.bcast(Pi_opt)
    # last_d = parallel.bcast(last_d)
    threshold = parallel.bcast(threshold)
end
# find the best cumul_reward
best_cumul_reward = maximum(cumul_reward_list)
best_iteration = findfirst(cumul_reward_list .== best_cumul_reward) - 1
println("Best cumul_reward: $best_cumul_reward at iteration $best_iteration")
# find the best tree

if parallel.is_root()
    # if gap_sg === nothing && calc !== nothing && !isempty(calc)
    #     gap_sg = round(calc[end][end], digits=4)
    # end
    gap_sg = 0.0
    # Log final results
    Logger.log_final_results(best_cumul_reward, best_iteration, gap_sg, cumul_return_upper)
    log_file = Logger.save_log_json()
    
    println("Final gap: $gap_sg")
    println("Log saved to: $log_file")

    ##################### Tree structure plot #####################
    # plt = true
    # if plt
    #     tree_plot(tree_w, "CART", dataname)
    #     tree_plot(tree_sg, "sglb", dataname)
    # end
end

parallel.finalize()
