using DataFrames, CSV
using Random, Distributions, StatsBase
using Plots

using MLDataUtils, Clustering
using Distributed, SharedArrays

# load functions for branch&bound and data preprocess from self-created module
@everywhere begin
    # Get the project root directory
    project_dir = dirname(dirname(@__FILE__))  # or use the absolute path
    
    # Add source directories to load path
    src_path = joinpath(project_dir, "src")
    test_path = joinpath(project_dir, "test")
    
    if !(src_path in LOAD_PATH)
        push!(LOAD_PATH, src_path)
    end
    
    if !(test_path in LOAD_PATH)
        push!(LOAD_PATH, test_path)
    end
end


using TimerOutputs: @timeit, get_timer

using Trees, bound_rl, parallel, Nodes_rl, Logger
using ub_func_rl, lb_func_rl, bb_func_rl, data_process_rl, utils_rl, opt_func_rl

# julia .\test\test_rl.jl 2 0.99 CF+MILP+SG 42 par frozenlake_4x4 4

D = parse(Int, ARGS[1]) # 2 # 
gamma = parse(Float64, ARGS[2]) # 0.99
LB_method = ARGS[3] # "CF+MILP+SG" #
seed = parse(Int, ARGS[4]) # 1 # 
scheme = ARGS[5] # sl # par #  
dataname = ARGS[6] # 
Iteration_num = parse(Int, ARGS[7]) # 4
mingap = length(ARGS) >= 8 ? parse(Float64, ARGS[8]) : 0.0001 # 0.0001
start_action_idx = length(ARGS) >= 9 ? parse(Int, ARGS[9]) : 1 # default is 1
warm_start_method = length(ARGS) >= 10 ? ARGS[10] : "CART"  # Warm start method
coffcient_method = length(ARGS) >= 11 ? ARGS[11] : "exp"  # Coefficient method (linear/exp)

# set dataset to string
using MPI
if scheme == "par"
    parallel.init()
end

println("mpi_init: $(MPI.Initialized())")

parallel.create_world()
parallel.root_println("Running $(parallel.nprocs()) processes.")
parallel.root_println("Start training $dataname with seed $seed.")
parallel.root_println("gamma is set to $gamma")
#############################################################
################# Main Process Program Body #################
#############################################################
# function initialize_data(dataname, seed, D)
if parallel.is_root()
    states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim = read_data(dataname)
    # Initialize logger
    # if dataname in keys(best_return.return_upper)
    #     cumul_return_upper = best_return.return_upper[dataname]
    # else
    cumul_return_upper, x= opt_func_rl.get_optimal_cumul_return(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions)
    # end

    log_data = Logger.init_log(dataname, seed, gamma, D, LB_method, scheme, Iteration_num, cumul_return_upper, "8SRBB", start_action_idx, mingap, warm_start_method, coffcient_method)
    Random.seed!(seed)
    if warm_start_method == "CART" 
    # construct a new varaivble V_old and initialize it with all zero
    # V_old = zeros(num_states)
        init_policy = ones(Int, num_states) .* start_action_idx
        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        #println("V:_init $(V_old)")
        println("cumul_reward_init: $cumul_reward_init")
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        alp = 0.00
        L_hat = 1
        mingap = mingap
        # get the initial tree
        # time_w = @elapsed tree_w, objv_w = lb_func_rl.warm_start(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        a_init = tree_w.a
        b_init = tree_w.b
        c_init = tree_w.c
        gap_warm = 0.0
        # d_init = tree_w.d
    elseif warm_start_method == "Random"
        a_init, b_init, c_init, d_init, init_policy = opt_func_rl.init_random_policy(states, D, feature_dim, num_actions)
        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
        gap_warm = 0.0
        alp = 0.00
        L_hat = 1
    elseif warm_start_method == "OMDT"
        time_w = @elapsed a_init, b_init, c_init, objv_w, gap_warm, Pi_init = opt_func_rl.OMDT(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions, D; time_limit=300, seed=seed, mingap=0.0005)
        a_init = a_init
        b_init = b_init  
        c_init = c_init
        alp = 0.00
        L_hat = 1

        init_policy = Pi_init
        println("objv_w: $objv_w")
        println("gap_warm: $gap_warm")
        println("Pi_init: $Pi_init")
        V_old, cumul_reward_init = utils_rl.policy_evaluation(init_policy, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        Q_value = utils_rl.compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        discounted_occupancy = utils_rl.get_discounted_occupancy(init_policy, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)
    end

    for i in 1:num_states
        if coffcient_method == "exp"
            Q_value[i, :] = Q_value[i, :] .* exp(discounted_occupancy[i])
        elseif coffcient_method == "linear"
            Q_value[i, :] = Q_value[i, :] .* discounted_occupancy[i]
        end
    end


    ### NOTE: need to change L_hat

    println("initialize data")
    println("size of states: $num_states")
    println("size of actions: $num_actions")
    println("size of features: $feature_dim")
    println("size of states: $(size(states))")
else
    println("rank: $(parallel.myid())")
    L_hat = 0
    num_actions = 0
    alp = 0
    mingap = 0
    objv_w = 0
    tree_w = nothing
    cumul_return_upper = 0
    num_states = 0
    feature_dim = 0
    states = nothing
    transition_prob = nothing
    rewards = nothing
    initial_state_p = nothing
    Q_value = nothing
    cumul_reward_init = 0
    log_data = nothing
    discounted_occupancy = nothing
    a_init = nothing
    b_init = nothing
    c_init = nothing
    d_init = nothing
end


# states, transition_prob, rewards, initial_state_p, Q_value, num_states, num_actions, feature_dim, alp, tree_w, L_hat, time_w, objv_w = initialize_data(dataname, seed, D)
# println("Q_value: $(Q_value)")
##################### Testing of different ODT methods #####################
# D and LB_method are already in all processes at the beginning.
alp = parallel.bcast(alp)
L_hat = parallel.bcast(L_hat)
mingap = parallel.bcast(mingap)
objv_w = parallel.bcast(objv_w)
tree_w = parallel.bcast(tree_w)
num_states = parallel.bcast(num_states)
num_actions = parallel.bcast(num_actions)
feature_dim = parallel.bcast(feature_dim)
states = parallel.bcast(states)
transition_prob = parallel.bcast(transition_prob)
rewards = parallel.bcast(rewards)
initial_state_p = parallel.bcast(initial_state_p)
Q_value = parallel.bcast(Q_value)
cumul_return_upper = parallel.bcast(cumul_return_upper)
discounted_occupancy = parallel.bcast(discounted_occupancy)
a_init = parallel.bcast(a_init)
b_init = parallel.bcast(b_init)
c_init = parallel.bcast(c_init)

LB_method = split(LB_method, "+") # CF must in the method
gap_sg = nothing

# we should record the cumul_reward for each iteration
if parallel.is_root()
    cumul_reward_list = []
    tree_w_list = []
    tree_sg_list = []
end 

for i in 1:Iteration_num
    parallel.root_println("Iteration $i")
    # if parallel.myid() == 0
    #     println("rank 0: objv_w: $objv_w, mingap: $mingap")
    # end
    # parallel.barrier()
    # if parallel.myid() == 1
    #     println("rank 1: objv_w: $objv_w, mingap: $mingap")
    # end
    # parallel.barrier()
    ##################### Start training and optimization #####################
    if parallel.is_root()
        global time_w, tree_w, objv_w
        if i != 1
            # warm start
            time_w = @elapsed tree_w, objv_w = lb_func_rl.CART_base(states, Q_value, num_states, num_actions, feature_dim, D, alp, L_hat)
        else
            tree_w = tree_w
            objv_w = objv_w
        end
        push!(tree_w_list, tree_w)
    else
        tree_w = nothing
        objv_w = nothing
    end
    tree_w = parallel.bcast(tree_w)
    objv_w = parallel.bcast(objv_w)
    parallel.barrier()
    
    if parallel.is_root()
        global time_sg
        time_sg = @elapsed begin
            global tree_sg, objv_sg, calc, UB_sg
            tree_sg, objv_sg, calc, UB_sg = bb_func_rl.branch_bound(states, Q_value, num_states, num_actions, feature_dim, D, tree_w, objv_w, alp, L_hat, LB_method, true, false; seed=seed, mingap=mingap)
        end
        push!(tree_sg_list, tree_sg)
        # Only perform policy evaluation in root process
        # extract policy from tree
        Pi_opt = utils_rl.extract_policy(tree_sg, states, num_states, num_actions)
        # println("Pi_opt: $(Pi_opt)")
        # update Q_value
        # change policy to be [num_states,1]
        pi_ = zeros(Int, num_states, 1)
        for j in 1:num_states
            pi_[j] = Int(argmax(Pi_opt[j,:]))
        end
        # println("pi_: $(pi_)")
        V, cumul_reward = utils_rl.policy_evaluation(pi_, transition_prob, rewards, initial_state_p, gamma, num_states, 1e-10, 1000000)
        # println("V: $(V)")
        println("Iteration $i: cumul_reward: $cumul_reward")
        push!(cumul_reward_list, cumul_reward)

        gap_ = !isempty(calc) ? calc[end][end] : nothing
        # Log iteration results
        Logger.log_iteration(i, time_w, objv_w, time_sg, objv_sg, UB_sg, gap_, cumul_reward, V, 1)
        
        # Update Q_value for next iteration
        global Q_value = compute_Q_value(transition_prob, rewards, V, gamma, num_states, num_actions)
        global discounted_occupancy = utils_rl.get_discounted_occupancy(pi_, transition_prob, initial_state_p, gamma, num_states, num_actions; normalize=true)

        for j in 1:num_states
            if coffcient_method == "exp"
                global Q_value[j, :] = Q_value[j, :] .* exp(discounted_occupancy[j])
            elseif coffcient_method == "linear"
                global Q_value[j, :] = Q_value[j, :] .* discounted_occupancy[j]
            end
        end
    else
        # the whole data are saved in root and scattered to other processes, in other processes, data input are nothing
        tree_sg, objv_sg, calc, UB_sg = bb_func_rl.branch_bound(nothing, nothing, num_states, num_actions, feature_dim, D, tree_w, objv_w, alp, L_hat, LB_method, true, false; seed=seed, mingap=mingap)
        Q_value = nothing
    end
    # Synchronize all processes before continuing to the next iteration
    parallel.barrier()
    
    # Broadcast Q_value from root to all processes for next iteration
    Q_value = parallel.bcast(Q_value)
end



if parallel.is_root()
    # find the best cumul_reward
    best_cumul_reward = maximum(cumul_reward_list)
    best_iteration = findfirst(cumul_reward_list .== best_cumul_reward)
    println("Best cumul_reward: $best_cumul_reward at iteration $best_iteration")
    # find the best tree
    best_tree_w = tree_w_list[best_iteration]
    best_tree_sg = tree_sg_list[best_iteration]
    println("rank: $(parallel.myid())")

    # Trees.printTree(best_tree_w)
    # Trees.printTree(best_tree_sg)
    if gap_sg === nothing && calc !== nothing && !isempty(calc)
        gap_sg = round(calc[end][end], digits=4)
    end
    
    # Log final results
    Logger.log_final_results(best_cumul_reward, best_iteration, gap_sg, cumul_return_upper)
    log_file = Logger.save_log_json()
    println("Final gap: $gap_sg")
    # println("Log saved to: $log_file")

end

parallel.finalize()
