module data_process_rl

using DataFrames, CSV
using CategoricalArrays
using Random, Distributions
using LinearAlgebra
using MLDataUtils, StatsBase
using ub_func_rl
using NPZ

export read_data

function convert_to_float(x)
    if x isa Number
        return Float64(x)
    elseif x isa String
        return tryparse(Float64, x)
    else
        return x 
    end
end

function deep_float_convert(data)
    if data isa AbstractArray
       
        return map(deep_float_convert, data)
    elseif data isa AbstractDict
        return Dict(k => deep_float_convert(v) for (k, v) in data)
    else
        return convert_to_float(data)
    end
end

function read_data(dataname)
    # read the data from the file
    file_dir = joinpath(@__DIR__, "..", "RL_data")
    # println("file_dir: $file_dir")
    file_path = joinpath(file_dir, string(dataname) * ".npz")
    # println("file_path: $file_path")
    data = npzread(file_path)

    # 强制转换成指定维度和类型的 Array
    states      = Array{Float64,2}(data["states"])
    transition_prob  = Array{Float64,3}(data["transition_prob"])
    rewards     = Array{Float64,3}(data["rewards"])
    initial_state_p= Array{Float64,1}(data["initial_state_p"])

    # 简单查看维度
    # println("states size: ",      size(states))       # (n_states, n_features)
    # println("transition size: ",  size(transition_prob))   # (n_states, n_states, n_actions)
    # println("rewards size: ",     size(rewards))
    # println("init_state_p size: ",size(initial_state_p)) # (n_states,)
    num_states = size(states, 1)
    num_actions = size(transition_prob, 3)
    feature_dim = size(states, 2)
    return states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim
end
#     # Extract number of states and actions from transition probability matrix
#     num_states = Int(data["num_states"])
#     num_actions = Int(data["num_actions"])
#     feature_dim = Int(data["num_features"])

#     # Get the dimensions of the MDP
#     # states = convert_nested_array(data["states"])
#     # Get the dimensions of the MDP
#     # println("data['states']: $(data["states"])")
#     # states = reshape(Float64.(data["states"]), :, feature_dim)
#     states = deep_float_convert(data["states"])
#     array_states = vcat(states...)
#     states = reshape(array_states, num_states, feature_dim)
#     # println("states: $(size(states))")
#     # println("states: $(states[1,:])")
#     rewards = deep_float_convert(data["rewards"]["data"])
#     rewards = reshape(rewards, num_states, num_states, num_actions)
#     transition_prob = deep_float_convert(data["transition_prob"]["data"])
#     transition_prob = reshape(transition_prob, num_states, num_states, num_actions)
#     println("transition_prob: $(transition_prob[1,:,1])")
#     initial_state_p = deep_float_convert(data["initial_state_p"])
#     # println("states: $(size(states))")
#     # println("rewards: $(size(rewards))")
#     # println("transition_prob: $(size(transition_prob))")
#     # println("initial_state_p: $(size(initial_state_p))")
#     # 转换 transition_prob
#     # tp_data = Float64.(data["transition_prob"]["data"])
#     # tp_shape = Tuple(Float64.(data["transition_prob"]["shape"]))
#     # transition_prob = reshape(tp_data, tp_shape)


#     # # 转换 rewards
#     # r_data = Float64.(data["rewards"]["data"])
#     # r_shape = Tuple(Float64.(data["rewards"]["shape"]))
#     # rewards = reshape(r_data, r_shape)

#     # 初始状态分布
#     # initial_state_p = Float64.(data["initial_state_p"])

# end

# function read_data(dataname)
#     # read the data from the file
#     file_dir = joinpath(@__DIR__, "..", "RL_data")
#     # println("file_dir: $file_dir")
#     file_path = joinpath(file_dir, string(dataname) * ".json")
#     # println("file_path: $file_path")
#     data = JSON.parsefile(file_path)
    
#     # Extract number of states and actions from transition probability matrix
#     num_states = Int(data["num_states"])
#     num_actions = Int(data["num_actions"])
#     feature_dim = Int(data["num_features"])

#     # Get the dimensions of the MDP
#     # states = convert_nested_array(data["states"])
#     # Get the dimensions of the MDP
#     # println("data['states']: $(data["states"])")
#     # states = reshape(Float64.(data["states"]), :, feature_dim)
#     states = deep_float_convert(data["states"])
#     array_states = vcat(states...)
#     states = reshape(array_states, num_states, feature_dim)
#     # println("states: $(size(states))")
#     # println("states: $(states[1,:])")
#     rewards = deep_float_convert(data["rewards"]["data"])
#     rewards = reshape(rewards, num_states, num_states, num_actions)
#     transition_prob = deep_float_convert(data["transition_prob"]["data"])
#     transition_prob = reshape(transition_prob, num_states, num_states, num_actions)
#     println("transition_prob: $(transition_prob[1,:,1])")
#     initial_state_p = deep_float_convert(data["initial_state_p"])
#     # println("states: $(size(states))")
#     # println("rewards: $(size(rewards))")
#     # println("transition_prob: $(size(transition_prob))")
#     # println("initial_state_p: $(size(initial_state_p))")
#     # 转换 transition_prob
#     # tp_data = Float64.(data["transition_prob"]["data"])
#     # tp_shape = Tuple(Float64.(data["transition_prob"]["shape"]))
#     # transition_prob = reshape(tp_data, tp_shape)


#     # # 转换 rewards
#     # r_data = Float64.(data["rewards"]["data"])
#     # r_shape = Tuple(Float64.(data["rewards"]["shape"]))
#     # rewards = reshape(r_data, r_shape)

#     # 初始状态分布
#     # initial_state_p = Float64.(data["initial_state_p"])

#     return states, transition_prob, rewards, initial_state_p, num_states, num_actions, feature_dim
# end

end
