using Random
using LinearAlgebra
using Test

# ========== 原始版本 ==========

function policy_evaluation(policy, transition_prob, rewards, initial_state_p, gamma, num_states; threshold=1e-5, max_iterations=100000)
    V_old = zeros(Float64, num_states)
    V_new = similar(V_old)

    fixed_trans_probs = zeros(Float64, num_states, num_states)
    fixed_rewards = zeros(Float64, num_states, num_states)
    for i in 1:num_states
        action = policy[i]
        fixed_trans_probs[i, :] = transition_prob[i, :, action]
        fixed_rewards[i, :] = rewards[i, :, action]
    end

    for iter in 1:max_iterations
        for i in 1:num_states
            V_new[i] = sum(fixed_trans_probs[i, :] .* (fixed_rewards[i, :] .+ gamma .* V_old))
        end
        error = maximum(abs.(V_new - V_old))
        println("Iteration $iter, error: $error")
        if error < threshold
            break
        end
        V_old = copy(V_new)
    end

    expected_return = sum(V_new .* initial_state_p)
    return V_new, expected_return
end

function policy_evaluation_with_epsilon(policy, transition_prob, rewards, initial_state_p, gamma, num_states, epsilon; threshold=1e-10, max_iterations=100000)
    num_actions = size(transition_prob, 3)
    V_old = zeros(Float64, num_states)
    V_new = similar(V_old)

    fixed_trans_probs = zeros(Float64, num_states, num_states)
    fixed_rewards = zeros(Float64, num_states, num_states)

    for i in 1:num_states
        for a in 1:num_actions
            π = (a == policy[i]) ? (1 - epsilon + epsilon / num_actions) : (epsilon / num_actions)
            fixed_trans_probs[i, :] .+= π .* transition_prob[i, :, a]
            fixed_rewards[i, :] .+= π .* rewards[i, :, a]
        end
    end

    for iter in 1:max_iterations
        for i in 1:num_states
            V_new[i] = sum(fixed_trans_probs[i, :] .* (fixed_rewards[i, :] .+ gamma .* V_old))
        end
        error = maximum(abs.(V_new - V_old))
        println("Iteration $iter, error: $error")
        if error < threshold
            break
        end
        V_old = copy(V_new)
    end

    expected_return = sum(V_new .* initial_state_p)
    return V_new, expected_return
end

function policy_evaluation_vectorized(
    policy::Vector{Int}, 
    transition_prob::Array{Float64,3}, 
    rewards::Array{Float64,3}, 
    initial_state_p::Vector{Float64}, 
    gamma::Float64, 
    num_states::Int; 
    threshold::Float64=1e-5, 
    max_iterations::Int=100000
)
    V_old = zeros(num_states)

    fixed_trans_probs = reduce(vcat, @views [transition_prob[i, :, policy[i]]' for i in 1:num_states])
    fixed_rewards     = reduce(vcat, @views [rewards[i, :, policy[i]]' for i in 1:num_states])
    
    fixed_trans_probs = reshape(fixed_trans_probs, num_states, num_states)
    fixed_rewards     = reshape(fixed_rewards, num_states, num_states)

    for iter in 1:max_iterations
        V_new = sum(fixed_trans_probs .* (fixed_rewards .+ gamma .* V_old'), dims=2)[:]
        error = maximum(abs.(V_new - V_old))
        println("Iteration $iter, error: $error")
        if error < threshold
            break
        end
        V_old = V_new
    end

    expected_return = sum(V_old .* initial_state_p)
    return V_old, expected_return
end

function policy_evaluation_with_epsilon_vectorized(
    policy::Vector{Int}, 
    transition_prob::Array{Float64,3}, 
    rewards::Array{Float64,3}, 
    initial_state_p::Vector{Float64}, 
    gamma::Float64, 
    num_states::Int, 
    epsilon::Float64; 
    threshold::Float64=1e-10, 
    max_iterations::Int=1000
)
    num_actions = size(transition_prob, 3)
    V_old = zeros(num_states)

    fixed_trans_probs = zeros(num_states, num_states)
    fixed_rewards = zeros(num_states, num_states)

    # 构造 ε-greedy 策略加权的 transition & reward
    for i in 1:num_states
        for a in 1:num_actions
            π = (a == policy[i]) ? (1 - epsilon + epsilon / num_actions) : (epsilon / num_actions)
            fixed_trans_probs[i, :] .+= π .* transition_prob[i, :, a]
            fixed_rewards[i, :] .+= π .* rewards[i, :, a]
        end
    end

    for _ in 1:max_iterations
        V_new = sum(fixed_trans_probs .* (fixed_rewards .+ gamma .* V_old'), dims=2)[:]
        if maximum(abs.(V_new - V_old)) < threshold
            break
        end
        V_old = V_new
    end

    expected_return = sum(V_old .* initial_state_p)
    return V_old, expected_return
end

# ========== 向量化版本 ==========

# include("policy_eval_vectorized.jl")  # 你可以把我之前写的两个向量化函数复制到这个文件中

# ========== 随机测试 ==========

function random_test(num_states::Int, num_actions::Int, epsilon::Float64)
    Random.seed!(1234)
    policy = rand(1:num_actions, num_states)
    transition_prob = rand(num_states, num_states, num_actions)
    for i in 1:num_states, a in 1:num_actions
        transition_prob[i, :, a] ./= sum(transition_prob[i, :, a])
    end
    rewards = randn(num_states, num_states, num_actions)
    initial_state_p = rand(num_states)
    initial_state_p ./= sum(initial_state_p)
    gamma = 0.95

    # baseline
    t1 = @elapsed V1, R1 = policy_evaluation(policy, transition_prob, rewards, initial_state_p, gamma, num_states)
    t2 = @elapsed V2, R2 = policy_evaluation_vectorized(policy, transition_prob, rewards, initial_state_p, gamma, num_states)

    t3 = @elapsed V3, R3 = policy_evaluation_with_epsilon(policy, transition_prob, rewards, initial_state_p, gamma, num_states, epsilon)
    t4 = @elapsed V4, R4 = policy_evaluation_with_epsilon_vectorized(policy, transition_prob, rewards, initial_state_p, gamma, num_states, epsilon)

    println("Standard policy evaluation time: $t1 seconds")
    println("Vectorized policy evaluation time: $t2 seconds")
    println("Standard epsilon policy evaluation time: $t3 seconds") 
    println("Vectorized epsilon policy evaluation time: $t4 seconds")

    @test isapprox(V1, V2; atol=1e-8)
    @test isapprox(R1, R2; atol=1e-8)

    @test isapprox(V3, V4; atol=1e-8)
    @test isapprox(R3, R4; atol=1e-8)

    println("✅ All tests passed!")
end

random_test(4000, 4, 0.1)
