module utils_rl

    export compute_Q_value, policy_evaluation

    # compute Q_value(i, k) = sum_{j,a} Pr(i,j,a) * (r(s,j,a) + gamma * V(j))
    function compute_Q_value(transition_prob, rewards, V_old, gamma, num_states, num_actions)
        Q_value = zeros(num_states, num_actions)
        for i in 1:num_states, k in 1:num_actions
            Q_value[i,k] = sum(transition_prob[i,:,k] .* (rewards[i,:,k] .+ gamma .* V_old))
        end
        return Q_value
    end

    # function policy_evaluation(policy, transition_prob, rewards, initial_state_p, gamma, num_states, threshold=1e-10, max_iterations=1000)
    #     V_new = zeros(Float64, num_states)
    #     V_old = zeros(Float64, num_states)  # Initialize V_old
    #     # println("policy: $policy")
    #     # println("gamma: $gamma")
    #     fixed_trans_probs = zeros(Float64, num_states, num_states)
    #     fixed_rewards = zeros(Float64, num_states, num_states)
    #     for i in 1:num_states
    #         action = policy[i]
    #         fixed_trans_probs[i,:] = transition_prob[i,:,action]
    #         fixed_rewards[i,:] = rewards[i,:,action]
    #     end
    #     # Policy's shape is [num_states, num_actions]
    #     for sim in 1:max_iterations
    #         for i in 1:num_states
    #             # Get the action chosen by policy for state i
    #             # Calculate expected value for state i
    #             V_new[i] = sum(fixed_trans_probs[i,:] .* (fixed_rewards[i,:] .+ gamma .* V_old))
    #         end
    #         # if sim < 20
    #         #     println("V_new: $V_new")
    #         # end
    #         # Calculate the maximum difference between V_new and V_old
    #         max_diff = maximum(abs.(V_new - V_old))
    #         # println("V_new: $V_new")
    #         # Check if the difference is below threshold
    #         if max_diff < threshold
    #             break
    #         end
            
    #         # Update V_old for next iteration
    #         V_old = copy(V_new)
    #     end
        
    #     # Calculate expected return from initial state distribution
    #     expected_return = sum(V_new .* initial_state_p)
        
    #     return V_new, expected_return
    # end


    # function policy_evaluation_with_epsilon(
    #     policy, transition_prob, rewards, initial_state_p,
    #     gamma, num_states, epsilon, threshold=1e-10, max_iterations=1000)

    #     num_actions = size(transition_prob, 3)

    #     # Initialize V vectors
    #     V_new = zeros(Float64, num_states)
    #     V_old = zeros(Float64, num_states)

    #     # Build ε-greedy weighted fixed trans_probs and rewards
    #     fixed_trans_probs = zeros(Float64, num_states, num_states)
    #     fixed_rewards = zeros(Float64, num_states, num_states)

    #     for i in 1:num_states
    #         for a in 1:num_actions
    #             π = (a == policy[i]) ? (1 - epsilon + epsilon / num_actions) : (epsilon / num_actions)
    #             fixed_trans_probs[i, :] += π * transition_prob[i, :, a]
    #             fixed_rewards[i, :] += π * rewards[i, :, a]
    #         end
    #     end

    #     # Standard value iteration loop
    #     for sim in 1:max_iterations
    #         for i in 1:num_states
    #             V_new[i] = sum(fixed_trans_probs[i, :] .* (fixed_rewards[i, :] .+ gamma .* V_old))
    #         end

    #         max_diff = maximum(abs.(V_new - V_old))
    #         if max_diff < threshold
    #             break
    #         end

    #         V_old = copy(V_new)
    #     end

    #     expected_return = sum(V_new .* initial_state_p)
    #     return V_new, expected_return
    # end

    #### Vectorized Policy Evaluation ####
    # function policy_evaluation_vectorized(
    function policy_evaluation(
        policy,
        transition_prob, 
        rewards, 
        initial_state_p, 
        gamma, 
        num_states,
        threshold::Float64=1e-5, 
        max_iterations::Int=1000
    )
        V_old = zeros(num_states)

        # 构造 fixed_trans_probs 和 fixed_rewards（沿 action 索引选 slice）
        # fixed_trans_probs = transition_prob[CartesianIndex.(1:num_states, :, policy)]
        # fixed_rewards     = rewards[CartesianIndex.(1:num_states, :, policy)]
        fixed_trans_probs = reduce(vcat, @views [transition_prob[i, :, policy[i]]' for i in 1:num_states])
        fixed_rewards     = reduce(vcat, @views [rewards[i, :, policy[i]]' for i in 1:num_states])
        
        # 变为 Matrix 类型（num_states × num_states）
        fixed_trans_probs = reshape(fixed_trans_probs, num_states, num_states)
        fixed_rewards     = reshape(fixed_rewards, num_states, num_states)

        for _ in 1:max_iterations
            V_new = sum(fixed_trans_probs .* (fixed_rewards .+ gamma .* V_old'), dims=2)[:]
            if maximum(abs.(V_new - V_old)) < threshold
                break
            end
            V_old = V_new
        end

        expected_return = sum(V_old .* initial_state_p)
        return V_old, expected_return
    end


    #### Vectorized Policy Evaluation with Epsilon ####
    # function policy_evaluation_with_epsilon_vectorized(
    function policy_evaluation_with_epsilon(
        policy, 
        transition_prob, 
        rewards, 
        initial_state_p, 
        gamma, 
        num_states, 
        epsilon,
        threshold::Float64=1e-10, 
        max_iterations::Int=1000
    )
        num_actions = size(transition_prob, 3)
        V_old = zeros(num_states)

        fixed_trans_probs = zeros(num_states, num_states)
        fixed_rewards = zeros(num_states, num_states)

        # 构造 ε-greedy 策略加权的 transition & reward
        for i in 1:num_states
            for a in 1:num_actions
                π = (a == policy[i]) ? (1 - epsilon + epsilon / num_actions) : (epsilon / num_actions)
                fixed_trans_probs[i, :] .+= π .* transition_prob[i, :, a]
                fixed_rewards[i, :] .+= π .* rewards[i, :, a]
            end
        end

        for _ in 1:max_iterations
            V_new = sum(fixed_trans_probs .* (fixed_rewards .+ gamma .* V_old'), dims=2)[:]
            if maximum(abs.(V_new - V_old)) < threshold
                break
            end
            V_old = V_new
        end

        expected_return = sum(V_old .* initial_state_p)
        return V_old, expected_return
    end


    function get_discounted_occupancy(policy, transition_prob, initial_state_p, gamma, num_states, num_actions; threshold=1e-10, max_iterations=1000000, normalize=true)
        # Number of states and actions
        
        # Build the policy-induced transition matrix P^π of shape (n_states, n_states)
        P_pi = zeros(num_states, num_states)
        for i in 1:num_states
            action = policy[i]  # Get action from policy for state i
            P_pi[i, :] = transition_prob[i, :, action]  # Use transition probabilities under the chosen action
        end

        # Initialize the discounted occupancy sum and the initial distribution (μ₀)
        occupancy_sum = zeros(num_states)
        # mu = copy(initial_state_p)  # Make a copy to avoid modifying the input
        # println("shape of mu: $(size(mu))")
        mu = reshape(initial_state_p, 1, num_states) # Ensure mu is a row vector
        discount = 1.0
        
        # Initialize the previous occupancy sum for convergence check
        prev_occupancy_sum = zeros(num_states)

        # Iterate to approximate Σₜ γᵗ · μ_t, where μ_t = μ₀ · (P^π)^t
        for t in 1:max_iterations
            occupancy_sum .+= discount .* vec(mu)
            mu = mu * P_pi  # Correct dot product: mu is a vector and P_pi is a matrix
            discount *= gamma  # Discount factor decreases over time
            # occupancy_sum .+= discount .* vec(mu) # Add the current discounted state distribution
            # mu = mu * P_pi  # Matrix multiplication: (1 x n) * (n x n) -> (1 x n)
            # discount *= gamma  # Discount factor decreases over time
        
            # Early stopping based on change in occupancy sum
            if maximum(abs.(occupancy_sum - prev_occupancy_sum)) < 1e-12
                break
            end
            prev_occupancy_sum .= occupancy_sum  # Update previous occupancy sum
        end
        
        # Multiply by (1 - γ) to get the normalized discounted occupancy measure
        d_gamma = (1 - gamma) .* occupancy_sum
        
        return d_gamma
    end


    # extract policy from tree
    function extract_policy(tree, states, num_states, num_actions)
        T = 2^(tree.D+1)-1
        Tb = Int(floor(T/2))
        policy = zeros(num_states, num_actions)
        for i in 1:num_states
            t = 1
            while t in 1:Tb
                if tree.a[:,t]'*states[i,:] + 1e-12 >= tree.b[t]
                    t = 2*t+1
                else
                    t = 2*t
                end
            end
            # get the choosing action
            action = tree.c[:,t]
            policy[i,:] = action
        end
        return policy

    end



    end