using DataFrames, CSV, JSON
using Random, Distributions, StatsBase
using Plots

using MLDataUtils, Clustering
using Distributed, SharedArrays

# load functions for branch&bound and data preprocess from self-created module
@everywhere begin
    # Get the project root directory
    project_dir = dirname(dirname(@__FILE__))  # or use the absolute path
    
    # Add source directories to load path
    src_path = joinpath(project_dir, "src")
    test_path = joinpath(project_dir, "test")
    
    if !(src_path in LOAD_PATH)
        push!(LOAD_PATH, src_path)
    end
    
    if !(test_path in LOAD_PATH)
        push!(LOAD_PATH, test_path)
    end
end

using TimerOutputs: @timeit, get_timer

using Trees, bound_rl, parallel, Nodes_rl, Logger, best_return, Dates
using ub_func_rl, lb_func_rl, bb_func_rl, data_process_rl, utils_rl, opt_func_rl

D = parse(Int, ARGS[1])  # Depth of tree
gamma = parse(Float64, ARGS[2])  # Discount factor
dataname = ARGS[3]  # Dataset name
time_limit = parse(Int, ARGS[4])  # Time limit for OMDT
seed = length(ARGS) > 4 ? parse(Int, ARGS[5]) : 42  # seed for OMDT
mingap = length(ARGS) > 5 ? parse(Float64, ARGS[6]) : 0.00001  # minimum gap for OMDT

println("seed is $seed")
println("Mingap is $mingap")
states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim = read_data(dataname)

time_w = @elapsed a_init, b_init, c_init, objv_w, gap_warm, Pi_init = opt_func_rl.OMDT(states, initial_state_p, transition_prob, rewards, gamma, num_states, num_actions, D; time_limit=time_limit, seed=seed, mingap=mingap)

# log the result
path = joinpath(project_dir, "logs", dataname)
if !isdir(path)
    mkdir(path)
end

timestamp = Dates.format(now(), "dd_HH-MM-SS")
# write the result to the log file
log_file = joinpath(path, "OMDT_base_$(dataname)_D$(D)_time_limit$(time_limit)_G$(gamma)_mingap$(mingap)_seed$(seed)_T$(timestamp).json")

results = Dict(
    "seed" => seed,
    "Mingap" => mingap,
    "Time Limit" => time_limit,
    "Objective Value" => objv_w,
    "Gap" => gap_warm,
    "Time" => time_w,
)
println("results: $results")
# write the result to the log file
open(log_file, "w") do io
    JSON.print(io, results, 4)
end



