include("../src/algos.jl")

using ProjectRoot
using Printf
using Dates

"""
    eval_iters = get_eval_iters(eval_gap::Integer, K::Integer)

Get the list of iterations that need to evaluate the performance.
"""
function get_eval_iters(eval_gap::Integer, K::Integer)
    eval_iters = collect(1:eval_gap:K)
    if eval_iters[end] < K
        push!(eval_iters, K)
    end
    return eval_iters
end

"""
    res = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg, β₀,
        x_path; all_ξ=nothing, eval_gap=100, inner_iter=100, num_sample=1_000,
        dy=0, num_η_up=10, num_η_low=10, callback=nothing)

Given the x_path and functions, evaluate the performance of the algorithm.

After obtaining the path of x by calling the algorithm, the performance will be
evaluated separately.
The reason to separate the algorithm execution and the performance evaluation is
that the y*(x; ξ) is always not estimated accurately during the algorithm, then
we cannot estimate the objective function value and the stationarity during the
algorithm execution.
Moreover, it is also for a more precise estimation of time complexity of the
algorithm.

To evaluate the performance of the algorithm, we receive the path of xₖ obtained
by algorithm 1 or 2, then call inner loop of algorithm 1 to estimate y*(xₖ, ξ)
for ξ (it depends on if we know all possible values of ξ or not), and compute
the corresponding objective function value and the estimated expectation, i.e.,
the true objective function value.
If ξ is infinite or we don't know all possible values of ξ, we estimate the
expectation by computing the average over a large number of samples, such as
10_000 ξ's.

We also need to compute the stationarity path.
Here, the stationarity is defined as 1/K ∑ₖ₌₁ᴷ E[‖ ∇ F(xₖ) ‖²] ≤ ϵ, where 
F(x) = E_{ξ, η}[f(x, y*(x, ξ); η, ξ)].
This can be done by computing the average of the gradients of f at x, y*, η and
ξ. Similar to the evaluation of the objective function values, we estimate y* and
then compute the average value over a large number of samples.

# Arguments
- Input:
    - `f::Function`: the function inside the expectation of the objective function.
    - `g::Function`: the function inside the expectation of the objective function
        of the lower level optimization problem;
    - `sample_ξ::Function`: the function that samples ξ.
    - `sample_η_up::Function`: the function that samples η given ξ, it is used in
        the upper level optimization.
    - `sample_η_low::Function`: the function that samples η given ξ, it is used in
        the lower level optimization.
    - `lf₁::Real`: the Lipschitz constant of `f`.
    - `μg::Real`: the modulus of strong convexity of `g`.
    - `β₀::Real=1.`: the constant for O(1) in the upper stepsize.
    - `x_path::Matrix{<:Real}`: the path of x, stored as a matrix, where k-th
        column is a point at the k-th iteration.
    - `all_ξ::Union{Vector, Nothing}=nothing`: all possible ξ's values, by
        default it is nothing, in this case we don't know all possible ξ's values.
    - `eval_gap::Integer=100`: the gap to evaluate the objective function value
        and stationarity.
    - `inner_iter::Integer=100`: the number of inner iteration to estimate y and z.
    - `num_sample::Integer=1_000`: the number to sample ξ if all_ξ is unknown.
    - `dy::Integer=0`: the dimension of `y`. By default it is set to 0, in this
        case `y` has the same dimension as `x`.
    - `num_η_low::Integer=10`: the number of sampled η in each iteration in the
        lower loop to estimate y and z.
    - `num_η_up::Integer=10`: the number of sampled η in each iteration in the
        upper loop to estimate ∇F.
    - `callback::Union{Function, Nothing}=nothing`: a callback function for
        computing extra information, such as prediction error, prediction accuracy,
        function value of g, and so on. The default value is `nothing`, meaning
        doing nothing.
        `res = callback(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg, xₖ,
            y_ests, z_ests, ξs_, ηs_, λₖ)`

- Output:
    - `res::Matrix`: a matrix containing 2 or 3 columns (depending on `callback`):
        - `f_path::Vector{<:Real}`: the path of the objective function value.
        - `stationarity_path::Vector{<:Real}`: the path of the stationarity.
        - `res_callback::Vector`: the path of results of callback function.
"""
function perf_evaluation(
    f::Function,
    g::Function,
    sample_ξ::Function,
    sample_η_up::Function,
    sample_η_low::Function,
    lf₁::Real,
    μg::Real,
    β₀::Real,
    x_path::Matrix{<:Real};
    all_ξ::Union{Vector, Nothing}=nothing,
    eval_gap::Integer=100,
    inner_iter::Integer=100,
    num_sample::Integer=1_000,
    dy::Integer=0,
    num_η_up::Integer=10,
    num_η_low::Integer=10,
    callback::Union{Function, Nothing}=nothing,
)
    # The number of outer iterations.
    K = size(x_path, 2)

    # The dimension of x.
    dx = size(x_path, 1)

    # The gradient of f, g and L.
    ∇₁f(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(f, x, y, η, ξ)[1]) === nothing ? zero(x) : ∇

    ∇₂f(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(f, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇

    ∇₁g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(g, x, y, η, ξ)[1]) === nothing ? zero(x) : ∇

    ∇₂g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(g, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇

    ∇₁L(x::Vector{<:Real}, z::Vector{<:Real}, y::Vector{<:Real}, λ::Real,
        η::AbstractVecOrMat, ξ::Real) = ∇₁f(x, z, η, ξ) +
            λ * (∇₁g(x, z, η, ξ) - ∇₁g(x, y, η, ξ))

    ∇₂L(x::Vector{<:Real}, z::Vector{<:Real}, y::Vector{<:Real}, λ::Real,
        η::AbstractVecOrMat, ξ::Real) = ∇₂f(x, z, η, ξ) + λ * ∇₂g(x, z, η, ξ)

    # The list of iterations that I need to use to evaluate, as I will not
    # evaluate the performance at every iteration.
    eval_iters = get_eval_iters(eval_gap, K)
    # The number of evaluation.
    num_eval = length(eval_iters)

    # For every iteration in eval_iters, we compute the estimated paths of
    # objective values and stationarities.
    f_path = zeros(num_eval)
    stationarity_path = zeros(num_eval)

    if dy == 0
        dy = dx
    end

    # Compute λₖ.
    λₖ = 2 * lf₁ / μg * (inner_iter + 1)^(1 / 4)

    # Declare possible callback_res.
    if ~isnothing(callback)
        callback_res = Vector(undef, num_eval)
    end

    # If `all_ξ` is nothing, that means we don't know all possible ξ, then we need
    # to estimate the expectation by sampling `num_sample` samples.
    if isnothing(all_ξ)
        # For each evaluation iteration.
        @showprogress desc="[perf]" Threads.@threads for i in eachindex(eval_iters)
            k = eval_iters[i]

            # Get the current xₖ.
            xₖ = vec(mean(x_path[:, 1:k], dims=2))

            # Estimate y and z for `num_sample` sampled ξ's.
            y_ests = zeros(dy, num_sample)
            z_ests = zeros(dy, num_sample)

            # Estimate ∇F for `num_sample` sampled ξ's.
            ∇F_ests = zeros(dx, num_sample)

            # The sampled ξ.
            ξs_ = Vector{Union{Real, AbstractArray}}(undef, num_sample)

            # The sampled η.
            ηs_ = Vector{AbstractVecOrMat}(undef, num_sample)

            for j in 1:num_sample
                # Sample ξ.
                ξs_[j] = sample_ξ()

                # Estiamte y and z.
                # Note that here we use `max(k, min_iter)` as the inner loop
                # iterations to estimate y and z.
                # This is because we need to avoid some extreme cases, for example,
                # when k = 1, the inner loop will not return any useful information.
                y_ests[:, j], z_ests[:, j] = estimate_yz(∇₂g, ∇₂L, sample_η_low,
                    xₖ, y_ests[:, j], ξs_[j], λₖ, β₀, μg, inner_iter; num_η=num_η_low)

                # Sample η.
                ηs_[j] = sample_η_up(ξ; num=num_η_up)

                # Estiamte ∇F
                ∇F_ests[:, j] = estimate_∇F(∇₁L, xₖ, y_ests[:, j], z_ests[:, j],
                                            ηs_[j], ξs_[j], λₖ)
            end

            # If callback function is given.
            if ~isnothing(callback)
                callback_res[i] = callback(f, g, sample_ξ, sample_η_up, sample_η_low,
                    lf₁, μg, xₖ, y_ests, z_ests, ξs_, ηs_, λₖ)
            end

            # Compute the objective function value.
            f_path[i] = mean([f(xₖ, y_ests[:, j], ηs_[j], ξs_[j])
                                            for j in 1:num_sample])

            # Compute the stationarity path.
            stationarity_path[i] = norm(mean(∇F_ests, dims=2)) ^ 2
        end
    else
        # If `all_ξ` is given.
        # For each evaluation iteration.
        @showprogress desc="[perf]" Threads.@threads for i in eachindex(eval_iters)
            k = eval_iters[i]

            # Get the current xₖ.
            xₖ = vec(mean(x_path[:, 1:k], dims=2))

            # Estimate y and z for all ξ.
            y_ests = zeros(dy, length(all_ξ))
            z_ests = zeros(dy, length(all_ξ))

            # Estimate ∇F for all ξ.
            ∇F_ests = zeros(dx, length(all_ξ))

            # The sampled η.
            ηs_ = Vector{AbstractVecOrMat}(undef, length(all_ξ))

            for (j, ξ) in enumerate(all_ξ)
                # Estimate y and z.
                # Note that here we use `max(k, min_iter)` as the inner loop
                # iterations to estimate y and z.
                # This is because we need to avoid some extreme cases, for example,
                # when k = 1, the inner loop will not return any useful information.
                y_ests[:, j], z_ests[:, j] = estimate_yz(∇₂g, ∇₂L, sample_η_low,
                    xₖ, y_ests[:, j], ξ, λₖ, β₀, μg, inner_iter; num_η=num_η_low)

                # Sample η.
                ηs_[j] = sample_η_up(ξ; num=num_η_up)

                # Estiamte ∇F
                ∇F_ests[:, j] = estimate_∇F(∇₁L, xₖ, y_ests[:, j], z_ests[:, j],
                                            ηs_[j], ξ, λₖ)
            end

            # If callback function is given.
            if ~isnothing(callback)
                callback_res[i] = callback(f, g, sample_ξ, sample_η_up, sample_η_low,
                    lf₁, μg, xₖ, y_ests, z_ests, all_ξ, ηs_, λₖ)
            end

            # Compute the objective function value.
            f_path[i] = mean([f(xₖ, y_ests[:, j], ηs_[j], ξ)
                                            for (j, ξ) in enumerate(all_ξ)])

            # Compute the stationarity path.
            stationarity_path[i] = mean(norm.(eachcol(∇F_ests)) .^ 2)# norm(mean(∇F_ests, dims=2)) ^ 2
        end
    end

    if isnothing(callback)
        return hcat(f_path, stationarity_path)
    else
        return hcat(f_path, stationarity_path, callback_res)
    end
end

"""
    x_path, res = eval_algo1(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg,
        x₀, α₀, β₀, K, eval_gap; Pₓ=identity, all_ξ=nothing, dy=0, inner_iter=100,
        num_η_up=10, num_η_low=10, num_η_perf_up=10, num_η_perf_low=10,
        callback=nothing)

Evaluate the performance of algorithm 1.

# Arguments
- Input:
    - `f::Function`: the function `f` in the objective function.
    - `g::Function`: the objective function `g` in the constraint.
    - `sample_ξ::Function`: a function that samples `ξ`.
    - `sample_η_up::Function`: a function that samples `η` in the upper level.
    - `sample_η_low::Function`: a function that samples `η` in the lower level.
    - `lf₁::Real`: the Lipschitz constant of f.
    - `μg::Real`: the modulus of strong convexity of g.
    - `x₀::Vector{<:Real}`: the initial point.
    - `α₀::Real`: the constant for O(1) in the lower stepsize.
    - `β₀::Real`: the constant for O(1) in the upper stepsize.
    - `K::Integer`: the number of upper loop.
    - `eval_gap::Integer`: the gap to evaluate the objective function value and
        stationarity.
    - `Pₓ::Function=identity`: the function that computes the projection onto the upper
        feasible set.
    - `all_ξ::Union{Vector, Nothing}=nothing`: all possible ξ's values, by
        default it is nothing, in this case we don't know all possible ξ's values.
    - `dy::Integer=0`: the dimension of `y`. By default it is set to 0, in this
        case `y` has the same dimension as `x`.
    - `inner_iter::Integer=100`: the number of inner iteration to estimate y and z.
    - `num_η_up::Integer = 10`: the number of sampled η in each iteration in the
        upper loop.
    - `num_η_low::Integer = 10`: the number of sampled η in each iteration in the
        lower loop.
    - `num_η_perf_up::Integer = 10`: the number of sampled η in the upper loop when
        evaluate the performnace.
    - `num_η_perf_low::Integer = 10`: the number of sampled η in the lower loop when
        evaluate the performnace.
    - `callback::Union{Function, Nothing}=nothing`: a callback function for
        computing extra information, such as prediction error, prediction accuracy,
        function value of g, and so on. The default value is `nothing`, meaning
        doing nothing.
        `res = callback(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg, xₖ,
            y_ests, z_ests, ξs_, ηs_, λₖ; h=h, Dh=Dh_,A=A)`

- Output:
    - `x_path::Matrix{<:Real}`: the path of x, stored as a matrix, where k-th
        column is a point at the k-th iteration.
    - `res::Matrix`: a Matrix containing 2 or 3 columns (depending on `callback`):
        - `f_path::Vector{<:Real}`: the path of the objective function values, stored
            as a Vector, where k-th item is the objective function of the k-th evaluated
            iteration.
            Here, we note that we do not evaluate the objective function value at
            every iteration, we only evaluate the objective function value per `eval_gap`
            iterations. This is to save computation resource and time.
        - `stationarity_path::Vector{<:Real}`: the path of the stationarity.
        Similar to `f_path`, we do not evaulate the stationarity at every iterations.
        - `res_callback::Vector`: the path of results of callback function.

# Notes
If all possible ξ's are know previously, make sure `all_ξ` is provided correctly.
"""
function eval_algo1(
    f::Function,
    g::Function,
    sample_ξ::Function,
    sample_η_up::Function,
    sample_η_low::Function,
    lf₁::Real,
    μg::Real,
    x₀::Vector{<:Real},
    α₀::Real,
    β₀::Real,
    K::Integer,
    eval_gap::Integer;
    Pₓ::Function=identity,
    all_ξ::Union{Vector, Nothing}=nothing,
    dy::Integer=0,
    inner_iter::Integer=100,
    num_η_up::Integer=10,
    num_η_low::Integer=10,
    num_η_perf_up::Integer=10,
    num_η_perf_low::Integer=10,
    callback::Union{Function, Nothing}=nothing,
)
    # Get the list of iterations that need to evaluate performance.
    eval_iters = get_eval_iters(eval_gap, K)

    # The file names.
    figname_fmt = @sprintf("alg1_%s_%s_%d_%s.svg", α₀, β₀, K, 
                            # the last one is the timestamp.
                            Dates.format(now(), "yymmddHHMMSS"))

    # Run the solver.
    x_path = bilevel_solver(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁,
        μg, x₀, α₀, β₀, K; dy=dy, Pₓ=Pₓ, num_η_up=num_η_up, num_η_low=num_η_low)

    # Evaluate the performance.
    res = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg, β₀,
        x_path; all_ξ=all_ξ, eval_gap=eval_gap, dy=dy, inner_iter=inner_iter,
        num_η_up=num_η_perf_up, num_η_low=num_η_perf_low, callback=callback)
    return x_path, res
end

"""
    x_path, res = eval_algo2(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg,
        x₀, α₀, β₀, ϵ, K, eval_gap; Pₓ=identity, all_ξ=nothing, dy=0,
        inner_iter=100, num_η_up=10, num_η_low=10, num_η_perf_up=10,
        num_η_perf_low=10, callback=nothing)

Evaluate the performance of algorithm 2.

# Arguments
- Input:
    - `f::Function`: the function `f` in the objective function.
    - `g::Function`: the objective function `g` in the constraint.
    - `sample_ξ::Function`: a function that samples `ξ`.
    - `sample_η_up::Function`: a function that samples `η` in the upper level.
    - `sample_η_low::Function`: a function that samples `η` in the lower level.
    - `lf₁::Real`: the Lipschitz constant of f.
    - `μg::Real`: the modulus of strong convexity of g.
    - `x₀::Vector{<:Real}`: the initial point.
    - `α₀::Real`: the constant for O(1) in the lower stepsize.
    - `β₀::Real`: the constant for O(1) in the upper stepsize.
    - `ϵ::Real`: the tolerance.
    - `K::Integer`: the number of upper loop.
    - `eval_gap::Integer`: the gap to evaluate the objective function value and
        stationarity.
    - `Pₓ::Function=identity`: the function that computes the projection onto the upper
        feasible set.
    - `all_ξ::Union{Vector, Nothing}=nothing`: all possible ξ's values, by
        default it is nothing, in this case we don't know all possible ξ's values.
    - `dy::Integer=0`: the dimension of `y`. By default it is set to 0, in this
        case `y` has the same dimension as `x`.
    - `inner_iter::Integer=100`: the number of inner iteration to estimate y and z.
    - `num_η_up::Integer = 10`: the number of sampled η in each iteration in the
        upper loop.
    - `num_η_low::Integer = 10`: the number of sampled η in each iteration in the
        lower loop.
    - `num_η_perf_up::Integer = 10`: the number of sampled η in the upper loop when
        evaluate the performnace.
    - `num_η_perf_low::Integer = 10`: the number of sampled η in the lower loop when
        evaluate the performnace.
    - `callback::Union{Function, Nothing}=nothing`: a callback function for
        computing extra information, such as prediction error, prediction accuracy,
        function value of g, and so on. The default value is `nothing`, meaning
        doing nothing.
        `res = callback(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg, xₖ,
            y_ests, z_ests, ξs_, ηs_, λₖ; h=h, Dh=Dh_,A=A)`

- Output:
    - `x_path::Matrix{<:Real}`: the path of x, stored as a matrix, where k-th
        column is a point at the k-th iteration.
    - `res::Vector`: a vector containing 2 or 3 elements (depending on `callback`):
        - `f_path::Vector{<:Real}`: the path of the objective function values, stored
            as a Vector, where k-th item is the objective function of the k-th evaluated
            iteration.
            Here, we note that we do not evaluate the objective function value at
            every iteration, we only evaluate the objective function value per `eval_gap`
            iterations. This is to save computation resource and time.
        - `stationarity_path::Vector{<:Real}`: the path of the stationarity.
        Similar to `f_path`, we do not evaulate the stationarity at every iterations.
        - `res_callback::Vector`: the path of results of callback function.

# Notes
If all possible ξ's are know previously, make sure `all_ξ` is provided correctly.
"""
function eval_algo2(
    f::Function,
    g::Function,
    sample_ξ::Function,
    sample_η_up::Function,
    sample_η_low::Function,
    lf₁::Real,
    μg::Real,
    x₀::Vector{<:Real},
    α₀::Real,
    β₀::Real,
    ϵ::Real,
    K::Integer,
    eval_gap::Integer;
    Pₓ::Function=identity,
    all_ξ::Union{Vector, Nothing}=nothing,
    dy::Integer=0,
    inner_iter::Integer=100,
    num_η_up::Integer=10,
    num_η_low::Integer=10,
    num_η_perf_up::Integer=10,
    num_η_perf_low::Integer=10,
    callback::Union{Function, Nothing}=nothing,
)
    # Get the list of iterations that need to evaluate performance.
    eval_iters = get_eval_iters(eval_gap, K)

    # The file names.
    figname_fmt = @sprintf("alg1_%s_%s_%d_%s.svg", α₀, β₀, K, 
                            # the last one is the timestamp.
                            Dates.format(now(), "yymmddHHMMSS"))

    # Run the solver.
    x_path = RT_MLMC_solver(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁,
        μg, x₀, α₀, β₀, ϵ, K; dy=dy, Pₓ=Pₓ, num_η_up=num_η_up, num_η_low=num_η_low)

    # Evaluate the performance.
    res = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg, β₀,
        x_path; all_ξ=all_ξ, eval_gap=eval_gap, dy=dy, inner_iter=inner_iter,
        num_η_up=num_η_perf_up, num_η_low=num_η_perf_low, callback=callback)
    return x_path, res
end