module Logger

    using Dates
    using CSV
    using DataFrames
    using Printf
    using JSON
    
    export init_log, log_iteration, save_log, log_final_results
    using parallel
    
    # Store logs in a dictionary
    const logs = Dict{String, Any}()
    
    function init_log(dataname, seed, gamma, D, LB_method, scheme, Iteration_num, cumul_return_upper, method_name, start_action_idx, mingap, warm_start_method, coffcient_method)
        logs["metadata"] = Dict(
            "dataname" => dataname,
            "seed" => seed,
            "gamma" => gamma,
            "D" => D,
            "LB_method" => string(LB_method),
            "scheme" => scheme,
            "Iteration_num" => Iteration_num,
            "timestamp" => Dates.format(now(), "dd_HH-MM-SS"),
            "method_name" => method_name,
            "cumul_return_upper" => cumul_return_upper,
            "start_action_idx" => start_action_idx,
            "mingap" => mingap,
            "warm_start_method" => warm_start_method,
            "coffcient_method" => coffcient_method
        )
        logs["iterations"] = []
        logs["final_results"] = Dict()
        logs["total_time"] = 0.0  # Initialize total time counter
        
        # Create log directory if it doesn't exist
        if parallel.is_root()
            log_dir = joinpath(dirname(dirname(@__FILE__)), "logs")
            if !isdir(log_dir)
                mkdir(log_dir)
            end
        end
        
        return logs
    end

    function init_log_timeLimit(dataname, seed, gamma, D, LB_method, scheme, total_time_limit, cumul_return_upper, method_name, start_action_idx, mingap, warm_start_method, coffcient_method, iteration_time_limit, threhold_upper, decay_method)
        logs["metadata"] = Dict(
            "dataname" => dataname,
            "seed" => seed,
            "gamma" => gamma,
            "D" => D,
            "LB_method" => string(LB_method),
            "scheme" => scheme,
            "total_time_limit" => total_time_limit,
            "timestamp" => Dates.format(now(), "dd_HH-MM-SS"),
            "cumul_return_upper" => cumul_return_upper,
            "method_name" => method_name,
            "start_action_idx" => start_action_idx,
            "mingap" => mingap,
            "warm_start_method" => warm_start_method,
            "coffcient_method" => coffcient_method,
            "iteration_time_limit" => iteration_time_limit,
            "threhold_upper" => threhold_upper,
            "decay_method" => decay_method
        )
        logs["iterations"] = []
        logs["final_results"] = Dict()
        logs["total_time"] = 0.0  # Initialize total time counter
        
        logs_dir = joinpath(dirname(dirname(@__FILE__)), "logs", dataname)
        if !isdir(logs_dir)
            mkdir(logs_dir)
        end
        
        return logs
    end
    
    
    
    function log_warm_start_results(method_name, time_w, cumul_reward_init, gap)
        logs["warm_start_results"] = Dict(
            "time_w" => time_w,
            "cumul_reward_init" => cumul_reward_init,
        )
        # logs["total_time"] += time_w
        # logs['total_time']
        if method_name == "OMDT"
            logs["warm_start_results"]["gap"] = gap
        end
    end

    
    function log_iteration(iteration, time_w, objv_w, time_sg, objv_sg, UB_sg, gap, cumul_return, V, threshold)
        # Update total time
        logs["total_time"] += time_sg

        iter_log = Dict(
            "iteration" => iteration,
            "time_warmstart" => time_w,
            "objv_warmstart" => objv_w,
            "time_sg" => time_sg,
            "objv_sg" => objv_sg,
            "UB_sg" => UB_sg,
            "gap" => gap,
            "cumul_return" => cumul_return,
            "cumulative_time" => logs["total_time"],
            "threshold" => threshold
        )
                
        push!(logs["iterations"], iter_log)
        return iter_log
    end
    
    function log_final_results(best_cumul_return, best_iteration, gap_sg, cumul_return_upper)
        logs["final_results"]["best_cumul_return"] = best_cumul_return
        logs["final_results"]["best_iteration"] = best_iteration
        logs["final_results"]["gap_sg"] = gap_sg
        logs["final_results"]["cumul_return_upper"] = cumul_return_upper
        logs["final_results"]["normalized_return"] = best_cumul_return / cumul_return_upper
        logs["final_results"]["total_time"] = logs["total_time"]  # Add total time to final results
        if haskey(logs["metadata"], "Iteration_num")
            logs["final_results"]["aver_time"] = logs["total_time"] / Int(logs["metadata"]["Iteration_num"])
        end
        return logs["final_results"]
    end
    using JSON

    function save_log_json()
        meta = logs["metadata"]
        timestamp = meta["timestamp"]
        if haskey(meta, "threhold_upper")
            filename = "$(meta["method_name"])_D$(meta["D"])_War-$(meta["warm_start_method"])_Coff-$(meta["coffcient_method"])_$(meta["LB_method"])_threU$(meta["threhold_upper"])_g$(meta["gamma"])_start$(meta["start_action_idx"])_mingap$(meta["mingap"])_$timestamp.json"
        else
            filename = "$(meta["method_name"])_D$(meta["D"])_War-$(meta["warm_start_method"])_Coff-$(meta["coffcient_method"])_$(meta["LB_method"])_g$(meta["gamma"])_start$(meta["start_action_idx"])_mingap$(meta["mingap"])_$timestamp.json"
        end
        
        log_dir = joinpath(dirname(dirname(@__FILE__)), "logs", meta["dataname"])
        if !isdir(log_dir)
            mkdir(log_dir)
        end
        filepath = joinpath(log_dir, filename)
        # if the directory does not exist, create it
        # 保存整个 logs 字典为 JSON 文件
        open(filepath, "w") do io
            JSON.print(io, logs)
        end

        return filepath
    end

    # function save_log()
    #     # Create a filename based on metadata
    #     meta = logs["metadata"]

    #     filename = "$(meta["dataname"])_D$(meta["D"])_g$(meta["gamma"])_$(meta["LB_method"])_$(meta["timestamp"]).csv"
        
    #     # Create log directory if it doesn't exist
    #     log_dir = joinpath(dirname(dirname(@__FILE__)), "logs")
    #     filepath = joinpath(log_dir, filename)
        
    #     # Prepare data for DataFrame
    #     df_data = Dict{Symbol, Vector{Any}}()
        
    #     # Initialize with iteration data columns
    #     iter_columns = [:iteration, :time_warmstart, :objv_warmstart, :time_sg, :objv_sg, :UB_sg, :gap, :cumul_reward]
    #     for col in iter_columns
    #         df_data[col] = Any[]
    #     end
        
    #     # Add metadata columns
    #     for k in keys(meta)
    #         df_data[Symbol(k)] = Any[]
    #     end
        
    #     # Add final results columns
    #     for k in keys(logs["final_results"])
    #         df_data[Symbol(k)] = Any[]
    #     end
        
    #     # Fill in the data
    #     for (i, iter) in enumerate(logs["iterations"])
    #         # Add iteration data
    #         push!(df_data[:iteration], iter["iteration"])
    #         push!(df_data[:time_warmstart], iter["time_warmstart"])
    #         push!(df_data[:objv_warmstart], iter["objv_warmstart"])
    #         push!(df_data[:time_sg], iter["time_sg"])
    #         push!(df_data[:objv_sg], iter["objv_sg"])
    #         push!(df_data[:UB_sg], iter["UB_sg"])
    #         push!(df_data[:gap], iter["gap"])
    #         push!(df_data[:cumul_reward], iter["cumul_reward"])
            
    #         # Add metadata
    #         for (k, v) in meta
    #             push!(df_data[Symbol(k)], v)
    #         end
            
    #         # Add final results (with missing values for non-final rows)
    #         for k in keys(logs["final_results"])
    #             is_final = i == length(logs["iterations"])
    #             push!(df_data[Symbol(k)], is_final ? logs["final_results"][k] : missing)
    #         end
    #     end
        
    #     # Create DataFrame from collected data
    #     iter_df = DataFrame(df_data)
        
    #     # Save to CSV with explicit format
    #     CSV.write(filepath, iter_df; delim=',', quotechar='"', escapechar='\\', 
    #              missingstring="missing", dateformat="yyyy-mm-dd HH:MM:SS")
        
    #     return filepath
    # end
end
