using Random
using Printf
using LinearAlgebra
using DifferentialEquations
using Statistics

const SPACE = " "^200

"""
    ODEParameters(diagonal, risk)

An instance of the ODE parameters passed to the ODE solver.
"""
struct ODEParameters
    # The diagonal of the game's payoff matrix.
    diagonal
    # A constant that is added to entries of the game's payoff matrix to introduce risk to
    #   the optimal outcomes.
    risk
end

"""
    ODEParameters(m; risk=0.0, ϵ=1.0e-12)

Sample a random instance of the ODE parameters.

The sample satisfies the following properties:
- The entries of the `diagonal` are in ascending order and range between `ϵ`
and `m`.
- The last entry of the `diagonal` is equal to `m`. 

# Arguments:
- `m`: The size of the diagonal of the game's payoff matrix.
- `risk`: A constant that is added to entries of the game's payoff matrix to
introduce risk to the optimal outcomes.
- `ϵ`: The minimum value of each non-zero entry in the game's payoff matrix.
"""
function ODEParameters(m; risk=0., ϵ=1e-12)

    # Sample random `m` pair-wise deviations for the payoff matrix's diagonal.
    diagonal = rand(m)

    # Ensure that the diagonal's minimum entry is at least `ϵ`.
    diagonal[1] += ϵ

    # Construct the diagonal based on the sampled pair-wise deviations.
    for i ∈ 2:m
        diagonal[i] += diagonal[i - 1]
    end

    # Return the ODE parameters parametrized by the normalized diagonal and the given risk.
    return ODEParameters(diagonal ./ diagonal[m] .* m, risk)

end

"""
    payoffs(x, y, p; no_assert=false)

Compute the expected payoff of each of the game's outcomes in the given state.
"""
function payoffs(x, y, p::ODEParameters)

    # Compute the size of the game.
    m = length(p.diagonal)

    # compute the size of the given vectors.
    m₁ = length(x)
    m₂ = length(y)

    # Compute the expected payoffs of each outcome when no risk is applied.
    x_payoffs = Vector{Float64}(undef, m)
    x_payoffs[1:m₂] = (@view p.diagonal[1:m₂]) .* y
    if (m₂ < m) x_payoffs[m] = p.diagonal[m] * (1. - sum(y)) end

    y_payoffs = Vector{Float64}(undef, m)
    y_payoffs[1:m₁] = (@view p.diagonal[1:m₁]) .* x
    if (m₁ < m) y_payoffs[m] = p.diagonal[m] * (1. - sum(x)) end

    # Update the expected payoffs of each outcome according to the given risk.
    if (p.risk ≠ 0.)

        x_payoffs[2:m₂] += (1. .- (@view y[2:m₂])) * p.risk
        if (m₂ < m) x_payoffs[m] += sum(y) * p.risk end

        y_payoffs[2:m₁] += (1. .- (@view x[2:m₁])) * p.risk
        if (m₁ < m) y_payoffs[m] += sum(x) * p.risk end

    end

    # Return the expected payoffs. 
    return x_payoffs, y_payoffs

end

"""
    project!(x)

Project the vector `x` to the simplex and return the projected vector's support.

The simplex's dimensionality is inferred by the dimensionality of `x`. The function is an
implementation of the well-known procedure described in https://arxiv.org/abs/1309.1541.
"""
@inline function project!(x)

    # Compute the vector's dimensionality.
    n = length(x)

    # Initialize the amount of unnormalized correction required for the projection and the
    #   projected vector's support.
    unormalized_correction = 1. - sum(x)
    support = n

    # Sort the vector's coordinates in ascending order.
    sorted_x = sort(x)

    # Compute the amount of correction required for the projection and the support of the
    #   projected vector.
    correction = 0.

    for xᵢ ∈ sorted_x

        correction = unormalized_correction / support
        if (xᵢ + correction > 0) break end

        unormalized_correction += xᵢ
        support -= 1
    end

    # Perform the projection. To improve accuracy, set the final coordinate such that the
    #   ℓ₁-norm of the vector is 1.0.
    ℓ₁_norm = 0.

    for i ∈ 1:(n - 1)

        x[i] += correction;
        if (x[i] < 0) x[i] = 0. end

        ℓ₁_norm += x[i];
    end

    x[n] = 1. - ℓ₁_norm

    # Return the projected vector's support.
    return support

end

"""
    QRD0(u, p, t)

Compute the equations of motion of the 0-Replicator dynamics, i.e., Projection dynamics.

Following the FTRL representation, the equations of motion describe the evolution of the
accumulated payoff of the two players for each possible outcome.
"""
@inline function QRD0(u, p::ODEParameters, t)

    # Compute the game's size.
    m  = length(p.diagonal)

    # Unpack each player's accumulated payoffs from the state space.
    x = u[1:m]
    y = u[(m + 1):(2m)]

    # Compute each player's best response by projecting the respective accumulated payoffs
    #   to the simplex.
    project!(x)
    project!(y)

    # Compute the expected payoffs for each possible outcome.
    x_payoffs, y_payoffs = payoffs(x, y, p)

    # Return the time-derivatives.
    return vcat(x_payoffs, y_payoffs)

end

"""
    QRD1(u, p, t)

Compute the equations of motion of the Replicator dynamics.
"""
@inline function QRD1(u, p::ODEParameters, t)
    
    # Compute the game's size.
    m  = length(p.diagonal)

    # Unpack each player's compact state.
    u₁ = @view u[1:(m - 1)]
    u₂ = @view u[m:2(m - 1)]

    # Compute the expected payoffs for each possible outcome.
    x_payoffs, y_payoffs = payoffs(u₁, u₂, p)

    # Compute the value of the game.
    x_value = u₁ ⋅ (@view x_payoffs[1:(m - 1)]) + (1. - sum(u₁)) * x_payoffs[m]
    y_value = u₂ ⋅ (@view y_payoffs[1:(m - 1)]) + (1. - sum(u₂)) * y_payoffs[m]

    # Return the time-derivatives.
    return vcat(
        u₁ .* ((@view x_payoffs[1:(m - 1)]) .- x_value),
        u₂ .* ((@view y_payoffs[1:(m - 1)]) .- y_value)
    );

end

"""
    distribution(n)

    Sample a probability distribution over `n` random variables.
"""
function distribution(n::Int)

    # Sample a random (`n`+1)-dimensional vector of the form [0, X₁, …, Xₙ₋₁, 1].
    u = vcat([0.0], rand(n - 1), [1.0])

    # Sort the entries of th random vector in ascending order.
    sort!(u)

    # Compute the corresponding uniform distribution from the random pair-wise deviations of
    #    sample's entries.
    x = zeros(n)
    for i ∈ eachindex(x)
        x[i] = u[i + 1] - u[i];
    end

    # Returned the sampled distribution.
    return x

end

"""
    estimate_APoA(p; sample_size=100, Δt=10., ϵ=1.0e-2, max_deviation=1.0e-6, silent=true)

Estimate the APoA of the Replicator and Projection dynamics in the given game.

The APoA is estimated by bootstrap estimates of the expected social welfare of the game at
the points of convergence of uniformly-distributed initial conditions.

# Arguments:
- `p`::ODEParameters: The ODE parameters that define a two-player game.
- `sample_size`: The size of the sample of initial conditions.
- `Δt`: The time horizon over which the ODE solver is called.
- `ϵ`: The distance between the point-of-convergence and a reference point used to determine
the convergence rate.
- `max_deviation`: The maximum allowed deviation between the point-of-convergence and the
reference point.
- `silent`: Indicates whether the verbose output should be suppressed.
"""
function estimate_APoA(p::ODEParameters; sample_size=100, Δt=10., ϵ=1e-2,
    max_deviation=1e-6, silent=true)

    # Compute the game's size.
    m = length(p.diagonal)

    # Compute the optimal of the game.
    optimal_value = 2 * p.diagonal[m]

    # Sample `sample_size` i.i.d. initial conditions.
    sample = [vcat(distribution(m), distribution(m)) for _ ∈ 1:sample_size]

    # Signal the beginning of the computations.
    if (!silent) println("Computing APoA...") end

    # Estimate the APoA of the Replicator dynamics in the given game.
    QRD1_valid_points = 0
    QRD1_accumulated_value = 0.
    QRD1_APoA = 1.

    for i ∈ eachindex(sample)

        # Set the initial condition of the current iteration.
        u₀ = vcat(sample[i][1:(m - 1)], sample[i][(m + 1):(2m - 1)])

        # Estimate the point-of-convergence for the current initial condition.
        ode_problem  = ODEProblem(QRD1, u₀, (0.0, Δt), p)
        ode_solution = solve(ode_problem, saveat=[Δt - ϵ, Δt])

        # Compute the actual state of the game at times `Δt` and `Δt` - `ϵ`.
        x_ϵ = @view ode_solution[1:(m - 1), 1]
        x_ϵ = vcat(x_ϵ, 1. - sum(x_ϵ))
        x   = @view ode_solution[1:(m - 1), 2]
        x   = vcat(x, 1. - sum(x))

        y_ϵ = @view ode_solution[m:2(m - 1), 1]
        y_ϵ = vcat(y_ϵ, 1. - sum(y_ϵ))
        y = @view ode_solution[m:2(m - 1), 2]
        y = vcat(y, 1. - sum(y))

        # Compute the maximum absolute deviation between the two states and skip iteration
        #   if the threshold is exceeded.
        deviation = norm(vcat(x - x_ϵ, y - y_ϵ), Inf)
        if deviation > max_deviation
            println("\r$(SPACE)\r Warning! No point-wise convergence: deviation=$(deviation).")
            continue
        end

        # Compute the social welfare at the estimated point-of-convergence.
        x_payoffs, y_payoffs = payoffs(x, y, p)
        value = x ⋅ x_payoffs + y ⋅ y_payoffs

        # Update the estimated optimal social welfare of the game if needed.
        if (value > optimal_value) optimal_value = value; end

        # Compute information about the APoA.
        QRD1_valid_points += 1
        QRD1_accumulated_value += value

        # Display the current progress to the user.
        if (!silent)
            QRD1_APoA =  optimal_value / (QRD1_accumulated_value / QRD1_valid_points)
            stats = @sprintf(
                "RD: APoA=%.5f | Progress=%.0f%%",
                QRD1_APoA, i * 100. / sample_size
            )
            print("\r $(stats)" * SPACE)
        end
    end

    if (!silent) println(); end

    # Estimate the APoA of the Projection dynamics in the given game.
    QRD0_valid_points = 0
    QRD0_accumulated_value = 0.
    QRD0_APoA = 1.

    for i ∈ eachindex(sample)

        # Set the initial condition of the current iteration.
        u₀ = sample[i]

        # Estimate the point-of-convergence for the current initial condition.
        ode_problem  = ODEProblem(QRD0, u₀, (0.0, Δt), p)
        ode_solution = solve(ode_problem, saveat=[Δt - ϵ, Δt])

        # Compute the actual state of the game at times `Δt` and `Δt` - `ϵ`.
        x_ϵ = ode_solution[1:m, 1]
        project!(x_ϵ)
        x = ode_solution[1:m, 2]
        project!(x)

        y_ϵ = ode_solution[(m + 1):(2m), 1]
        project!(y_ϵ)
        y = ode_solution[(m + 1):(2m), 2]
        project!(y)

        # Compute the maximum absolute deviation between the two states and skip iteration
        #   if the threshold is exceeded.
        deviation = norm(vcat(x - x_ϵ, y - y_ϵ), Inf)
        if deviation > max_deviation
            println("\r$(SPACE)\r Warning! No point-wise convergence: deviation=$(deviation).")
            continue
        end

        # Compute the social welfare at the estimated point-of-convergence.
        x_payoffs, y_payoffs = payoffs(x, y, p)
        value = x ⋅ x_payoffs + y ⋅ y_payoffs

        # Update the estimated optimal social welfare of the game if needed.
        if (value > optimal_value) optimal_value = value end

        # Compute information about the APoA.
        QRD0_valid_points += 1
        QRD0_accumulated_value += value

        # Display the current progress to the user.
        if (!silent)
            QRD0_APoA =  optimal_value / (QRD0_accumulated_value / QRD0_valid_points)

            stats = @sprintf(
                "PD: APoA=%.5f | Progress=%.0f%%",
                QRD0_APoA, i * 100. / sample_size
            )
            print("\r$(SPACE)\r $(stats)")
        end
    end

    if (!silent) println() end

    # Estimate the APoA.
    QRD0_APoA =  optimal_value / (QRD0_accumulated_value / QRD0_valid_points)
    QRD1_APoA =  optimal_value / (QRD1_accumulated_value / QRD1_valid_points)

    if (!silent)
        stats = @sprintf("APoA: RD=%.5f | PD=%.5f", QRD1_APoA, QRD0_APoA)
        println(" $(stats)")
    end

    # Return the estimated APoA.
    return QRD1_APoA, QRD0_APoA
end


## Self-Contained Tests ####################################################################

"""
    test₁(m=2; p=nothing, risk=0.0, sample_size=1000, Δt=100.0, ϵ=1.0e-2, max_deviation=1.0e-4)

Estimate the APoA of the Replicator and Projection dynamics in a single game.

If the argument `p` is an instance of the [`ODEParameters`](@ref) struct, then it provides
the game's description; otherwise, a game of size `m` and risk `risk` is sampled at random.
The rest of the parameters are passed unchanged to [`estimate_APoA`](@ref).
"""
function test₁(m=2; p=nothing, risk=0., sample_size=1000, Δt=100., ϵ=1e-2,
    max_deviation=1e-4, silent=false)

    # If ODE parameters are not provided, sample a game of the given size.
    if (!isa(p, ODEParameters)) p=ODEParameters(m, risk=risk) end

    # Estimate the APoA of the game.
    return estimate_APoA(
        p,
        sample_size=sample_size,
        Δt=Δt,
        ϵ=ϵ,
        max_deviation=max_deviation,
        silent=silent
    )
end

"""
    test₂(m=2; game_sample_size=100, risk=0.0, inital_points_sample_size=1000, Δt=100.0, ϵ=1.0e-2, max_deviation=1.0e-4, silent=true)

Compute statistics about the APoA of the Replicator and Projection dynamics in a random
sample of games of size `m` and risk `risk`.

The statistics are computed from a sample of `game_sample_size` random games of the given
size. The rest of the parameters are passed unchanged to [`estimate_APoA`](@ref).
To avoid confusion the `sample_size` argument of [`estimate_APoA`](@ref) is given by the
`inital_points_sample_size` argument.
"""
function test₂(m=2; game_sample_size=100, risk=0., initial_points_sample_size=250, Δt=30.,
    ϵ=1e-2, max_deviation=1e-4, silent=true)

    # Construct a sample of two-player games of size `m`.
    sample = [ODEParameters(m, risk=risk) for _ ∈ 1:game_sample_size]

    # Signal the beginning of the computations.
    println("Test Paramaters: m=$(m) | risk=$(risk). Commencing computations...")

    sample_QRD1_APoA = zeros(game_sample_size)
    sample_QRD0_APoA = zeros(game_sample_size)
    mean_QRD1_APoA = 1.
    std_QRD1_APoA = 0.
    mean_QRD0_APoA = 1.
    std_QRD0_APoA = 0.
    mean_QRD0_QRD1_APoA_Difference = 0.
    std_QRD0_QRD1_APoA_Difference = 0.
    min_QRD0_QRD1_APoA_Difference = 0.
    max_QRD0_QRD1_APoA_Difference = 0.
    count_QRD0_QRD1_APoA_Difference_NonPositive = 0
    max_QRD1_APoA = 1.
    max_QRD0_APoA = 1.
    

    for i ∈ eachindex(sample)

        # Set the game parameters of the current iteration.
        p = sample[i]

        # Estimate the APoA of the dynamics—Replicator dynamics and Projection dynamics.
        QRD1_APoA, QRD0_APoA = estimate_APoA(
            p,
            sample_size=initial_points_sample_size,
            Δt=Δt,
            ϵ=ϵ,
            max_deviation=max_deviation,
            silent=silent
        )

        sample_QRD1_APoA[i] = QRD1_APoA
        sample_QRD0_APoA[i] = QRD0_APoA

        mean_QRD1_APoA = mean(sample_QRD1_APoA[1:i])
        std_QRD1_APoA = std(sample_QRD1_APoA[1:i])
        mean_QRD0_APoA = mean(sample_QRD0_APoA[1:i])
        std_QRD0_APoA = std(sample_QRD0_APoA[1:i])
        mean_QRD0_QRD1_APoA_Difference = mean(sample_QRD0_APoA[1:i] - sample_QRD1_APoA[1:i])
        std_QRD0_QRD1_APoA_Difference = std(sample_QRD0_APoA[1:i] - sample_QRD1_APoA[1:i])
        min_QRD0_QRD1_APoA_Difference = min((sample_QRD0_APoA[1:i] - sample_QRD1_APoA[1:i])...)
        max_QRD0_QRD1_APoA_Difference = max((sample_QRD0_APoA[1:i] - sample_QRD1_APoA[1:i])...)
        count_QRD0_QRD1_APoA_Difference_NonPositive = sum(sample_QRD0_APoA[1:i] .≤ sample_QRD1_APoA[1:i])
        max_QRD1_APoA = max(sample_QRD1_APoA[1:i]...)
        max_QRD0_APoA = max(sample_QRD0_APoA[1:i]...)

        if (QRD1_APoA < QRD0_APoA) println("\r$(SPACE)\r Irregular Game: Diagonal=$(p.diagonal)") end

        # Display statistics progressively.
        stats = @sprintf(
            "APoA: RD=%.5f±%.5f  PD=%.5f±%.5f | (PD-RD)=%.5f±%.5f Min(PD-RD)=%.5f, Max(PD-RD)=%.5f, (PD≤RD)=%.0f%% Max(RD)=%.5f Max(PD)=%.5f | Progress=%.0f%%",
            mean_QRD1_APoA,
            std_QRD1_APoA,
            mean_QRD0_APoA,
            std_QRD0_APoA,
            mean_QRD0_QRD1_APoA_Difference,
            std_QRD0_QRD1_APoA_Difference,
            min_QRD0_QRD1_APoA_Difference,
            max_QRD0_QRD1_APoA_Difference,
            count_QRD0_QRD1_APoA_Difference_NonPositive * 100. / i,
            max_QRD1_APoA,
            max_QRD0_APoA,
            i * 100. / game_sample_size
        )
        print("\r$(SPACE)\r $(stats)")

    end

    println()

    return (
        mean_QRD1_APoA,
        std_QRD1_APoA,
        mean_QRD0_APoA,
        std_QRD0_APoA,
        max_QRD1_APoA,
        max_QRD0_APoA,
        mean_QRD0_QRD1_APoA_Difference,
        std_QRD0_QRD1_APoA_Difference,
        min_QRD0_QRD1_APoA_Difference,
        max_QRD0_QRD1_APoA_Difference,
        count_QRD0_QRD1_APoA_Difference_NonPositive
    )
    
end

## Examples ################################################################################

Random.seed!(1)

# Estimate the APoA of the Replicator and Projections dynamics in a Stag Hunt game.
# @time test₁(p=ODEParameters([1.0, 2.0], 0.0))
# Estimate the APoA of the Replicator and Projections dynamics in a variation of a Stag Hunt
#   game, where risk has been introduced.
# @time test₁(p=ODEParameters([1.0, 2.0], -10.0))

# Collect statisitcs about the APoA in "safe" games of size 5.
# @time test₂(5)
# Collect statisitcs about the APoA in "risky" games of size 5.
# @time test₂(10, risk=-5.)

# for m ∈ 2:20
#     Random.seed!(1)
#     @time println(test₂(m, Δt=100., initial_points_sample_size=1000))
# end

# for m ∈ [2, 3, 5, 10, 20]
#     for seed ∈ 1:10
#         Random.seed!(seed)
#         @time println(test₂(m, Δt=100., initial_points_sample_size=1000))
#     end
# end
