include("src/algos.jl")
include("utils/ImageNet.jl")
include("utils/utils.jl")

using TimerOutputs
using LogExpFunctions

Random.seed!(42) # Set the random seed for reproducibility.

train_features, test_features, train_labels, test_labels, cls = create_meta_learning_tasks()

function sample_ξ()
    # Randomly sample a task
    return rand(1:NUM_TASKS, 1)[1]
end

function sample_η_up(
    ξ::Int;
    num::Int=10,
    features::AbstractArray=test_features,
    labels::AbstractArray=test_labels,
)
    # Sample `num` data points from the test set of the task ξ.
    data_idx = randperm(NUM_TEST_PER_TASK)[1:num]
    # Return the features and labels as a vector of matrices.
    sampled_data = Vector{Matrix}(undef, 2)
    # The first matrix is the features, the second is the labels.
    sampled_data[1] = features[:, data_idx, ξ]
    # The labels are one-hot encoded.
    sampled_data[2] = labels[:, data_idx, ξ]
    return sampled_data
end

function sample_η_low(
    ξ::Int;
    num::Int=10,
    features::AbstractArray=train_features,
    labels::AbstractArray=train_labels,
)
    # Sample `num` data points from the training set of the task ξ.
    data_idx = randperm(NUM_TRAIN_PER_TASK)[1:num]
    # Return the features and labels as a vector of matrices.
    sampled_data = Vector{Matrix}(undef, 2)
    # The first matrix is the features, the second is the labels.
    sampled_data[1] = features[:, data_idx, ξ]
    # The labels are one-hot encoded.
    sampled_data[2] = labels[:, data_idx, ξ]
    return sampled_data
end

function f(
    x::Vector{<:Real},
    y::Vector{<:Real},
    η::Vector{<:AbstractMatrix},
    ξ::Real,
)
    linear_prod_ = reshape(y, DIM, :)' * η[1]
    first_term = -eachcol(η[2]) .⋅ eachcol(linear_prod_)
    second_term = LogExpFunctions.logsumexp(linear_prod_; dims=1) |> vec
    return (first_term + second_term) |> mean
end

function g(
    x::Vector{<:Real},
    y::Vector{<:Real},
    η::Vector{<:AbstractMatrix},
    ξ::Real;
    λ::Real=2.0,
)
    return f(x, y, η, ξ) + λ / 2 * norm(y - x)^2
end

dx = DIM * CLASSES_PER_TASK

function callback_err(
    f::Function,
    g::Function,
    sample_ξ::Function,
    sample_η_up::Function,
    sample_η_low::Function,
    lf₁::Real,
    μg::Real,
    x::Vector{<:Real},
    y::Matrix{<:Real},
    z::Matrix{<:Real},
    ξs_::Vector{<:Union{Real, AbstractArray}},
    ηs_::Vector{AbstractVecOrMat},
    λₖ::Real;
    h::Union{Function, Nothing}=nothing,
    Dh::Union{Function, Nothing}=nothing,
    A::Union{Matrix{<:Real}, Nothing}=nothing,
)
    # Number of tasks, since we know all_ξ and used it in computing y_ests/z_ests.
    num_tasks = size(y, 2)

    # Errors.
    err = 0
    # For each ξ.
    for (j, ξ) in enumerate(ξs_)
        # The true labels.
        true_lbls = onecold(ηs_[j][2])

        # Reshape to compute easier.
        y_ = reshape(y[:, j], DIM, :)
        # The linear product.
        linear_prod_ = y_' * ηs_[j][1]
        # Using combination of map and argmax can be faster.
        pred_ = map(argmax, eachcol(linear_prod_))

        # Misclassification, averaged over ηs.
        err += mean(pred_ .!= true_lbls)
    end

    # Averaged over ξs.
    err /= num_tasks

    return err
end

x₀ = randn(dx)
lf₁ = 1000.0 # TODO
μg = 1000.0 # TODO
ϵ = 1e-4

all_ξ = 1:NUM_TASKS |> collect .|> Int

# The number of iterations of inner loop for performance evaluation.
inner_iter = 100
# The number of minibatch in RT-MLMC algorithm.
minibatch = 10

# The number of experiments.
num_exps = 10

# ============= Set up experiments. ============
# For uncons Algorithm 1.
K, eval_gap = 1_000, 10
α₀, β₀ = 25, 500
eval_iters₁ = get_eval_iters(eval_gap, K)

# For uncons Algorithm 2.
K₂, eval_gap₂ = 14_000, 140
α₀₂, β₀₂, β₀_perf = 1e16, 25, 500
eval_iters₂ = get_eval_iters(eval_gap₂, K₂)

# For Hessian-based Algorithm.
K_Hes, eval_gap_Hes = 10_000, 100
Lg₁ = 10
ss_Hes(t::Int; bp::Int=1000, init::Real=5e-1) = t <= bp ? init / sqrt(t) : init / t
β_Hes = 70
# Lg₁, β_Hes = 100, 8 # init = 1e-2
eval_iters_Hes = get_eval_iters(eval_gap_Hes, K_Hes)

# ============= End experiments setup. ============

# Declare variables for storing results.
# x path for algorithm 1.
x₁ = Vector{Matrix}(undef, num_exps)
# x path for algorithm 2.
x₂ = Vector{Matrix}(undef, num_exps)
# x path for algorithm 2 with adaptive stepsize.
x₂_ = Vector{Matrix}(undef, num_exps)
# x path for Hessian-based algorithm.
x_Hes = Vector{Matrix}(undef, num_exps)
# objective function value for algorithm 1.
obj₁ = zeros(length(eval_iters₁), num_exps)
# objective function value for algorithm 2.
obj₂ = zeros(length(eval_iters₂), num_exps)
# objective function value for algorithm 2 with adaptive stepsize.
obj₂_ = zeros(length(eval_iters₂), num_exps)
# objective function value for Hessian-based algorithm.
obj_Hes = zeros(length(eval_iters_Hes), num_exps)
# stationarity for algorithm 1.
sta₁ = zeros(length(eval_iters₁), num_exps)
# stationarity for algorithm 2.
sta₂ = zeros(length(eval_iters₂), num_exps)
# stationarity for algorithm 2 with adaptive stepsize.
sta₂_ = zeros(length(eval_iters₂), num_exps)
# stationarity for Hessian-based algorithm.
sta_Hes = zeros(length(eval_iters_Hes), num_exps)
# err for algorithm 1.
err₁ = zeros(length(eval_iters₁), num_exps)
# err for algorithm 2.
err₂ = zeros(length(eval_iters₂), num_exps)
# err for algorithm 2 with adaptive stepsize.
err₂_ = zeros(length(eval_iters₂), num_exps)
# err for Hessian-based algorithm.
err_Hes = zeros(length(eval_iters_Hes), num_exps)
# Times for algorithm 1.
times₁ = zeros(K, num_exps)
# Times for algorithm 2.
times₂ = zeros(K₂, num_exps)
# Times for algorithm 2 with adaptive stepsize.
times₂_ = zeros(K₂, num_exps)
# Times for Hessian-based algorithm.
times_Hes = zeros(K_Hes, num_exps)

# To store the nₖs generated in Algorithm 2 and Hessian-based Algorithm.
nₖs₂ = zeros(K₂, num_exps)
nₖs₂_ = zeros(K₂, num_exps)
nₖs_Hes = zeros(K_Hes, num_exps)

# Algorithm 1.
Random.seed!(42)
for i = 1:num_exps
    x₁[i], times₁[:, i] = bilevel_solver(f, g, sample_ξ, sample_η_up,
        sample_η_low, lf₁, μg, x₀, α₀, β₀, K; num_η_up=100,
        num_η_low=100)

    # Obtain res for algorithm 1.
    # We do not count the time and memory for evaluation.
    res₁ = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low,
        lf₁, μg, β₀, x₁[i]; all_ξ=all_ξ, eval_gap=eval_gap,
        inner_iter=inner_iter, num_η_up=NUM_TEST_PER_TASK,
        num_η_low=NUM_TRAIN_PER_TASK, callback=callback_err)
    # Obj, sta, and err.
    obj₁[:, i] = res₁[:, 1]
    sta₁[:, i] = res₁[:, 2]
    err₁[:, i] = res₁[:, 3]
end

# Save results for algorithm 1.
# JLD2.save("results/algo1.jld2", Dict(
#     "x₀" => x₀,
#     "x₁" => x₁,
#     "obj₁" => obj₁,
#     "sta₁" => sta₁,
#     "err₁" => err₁,
#     "times₁" => times₁,
#     "α₀" => α₀,
#     "β₀" => β₀))

# Algorithm 2 with adaptive stepsize.
Random.seed!(42)
for i = 1:num_exps
    # For algorithm 2.
    x₂[i], times₂[:, i], nₖs₂[:, i] = RT_MLMC_solver(f, g, sample_ξ, sample_η_up,
        sample_η_low, lf₁, μg, x₀, α₀₂, β₀₂, ϵ, K₂; num_η_up=100,
        num_η_low=100, minibatch=minibatch, a₁=0.05, nN=4, c₀=0.28)

    # Obtain res for algorithm 2.
    # We do not count the time and memory for evaluation.
    res₂ = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁,
        μg, β₀, x₂[i]; all_ξ=all_ξ, eval_gap=eval_gap₂,
        inner_iter=inner_iter, num_η_up=NUM_TEST_PER_TASK,
        num_η_low=NUM_TRAIN_PER_TASK, callback=callback_err)
    # Obj, sta, and err.
    obj₂[:, i] = res₂[:, 1]
    sta₂[:, i] = res₂[:, 2]
    err₂[:, i] = res₂[:, 3]
end

# Save results for algorithm 2.
# JLD2.save("results/algo2_14000it_seed42_tuness_nN4_c0.28.jld2", Dict(
#     "x₀" => x₀,
#     "x₂" => x₂,
#     "obj₂" => obj₂,
#     "sta₂" => sta₂,
#     "err₂" => err₂,
#     "times₂" => times₂,
#     "nₖs₂" => nₖs₂,
#     "α₀₂" => α₀₂,
#     "β₀₂" => β₀₂,
#     "a₁" => 0.05,
#     "nN" => 4,
#     "c₀" => 0.28))

# Algorithm 2 without adaptive stepsize.
Random.seed!(42)
for i = 1:num_exps
    # For algorithm 2.
    x₂_[i], times₂_[:, i], nₖs₂_[:, i] = RT_MLMC_solver(f, g, sample_ξ, sample_η_up,
        sample_η_low, lf₁, μg, x₀, α₀₂, β₀₂, ϵ, K₂; num_η_up=100,
        num_η_low=100, minibatch=minibatch, a₁=1, nN=4, c₀=0.28)

    # Obtain res for algorithm 2.
    # We do not count the time and memory for evaluation.
    res₂ = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁,
        μg, β₀, x₂_[i]; all_ξ=all_ξ, eval_gap=eval_gap₂,
        inner_iter=inner_iter, num_η_up=NUM_TEST_PER_TASK,
        num_η_low=NUM_TRAIN_PER_TASK, callback=callback_err)
    # Obj, sta, and err.
    obj₂_[:, i] = res₂[:, 1]
    sta₂_[:, i] = res₂[:, 2]
    err₂_[:, i] = res₂[:, 3]
end

# # Save results for algorithm 2.
# JLD2.save("results/algo2_14000it_seed42_nN4_c0.28.jld2", Dict(
#     "x₀" => x₀,
#     "x₂" => x₂_,
#     "obj₂" => obj₂_,
#     "sta₂" => sta₂_,
#     "err₂" => err₂_,
#     "times₂" => times₂_,
#     "nₖs₂" => nₖs₂_,
#     "α₀₂" => α₀₂,
#     "β₀₂" => β₀₂,
#     "a₁" => 1,
#     "nN" => 4,
#     "c₀" => 0.28))

# Hessian-based Algorithm.
Random.seed!(42)
for i = 1:num_exps
    # Hessian-based algorithm.
    x_Hes[i], times_Hes[:, i], nₖs_Hes[:, i] = RT_MLMC_Hessian_solver(f, g, sample_ξ, sample_η_up,
        sample_η_low, Lg₁, x₀, ss_Hes, β_Hes, K_Hes, 12;
        num_η_up=100, num_η_low=100, minibatch=minibatch)

    # Obtain res for Hessian-based algorithm.
    # We do not count the time and memory for evaluation.
    res_Hes = perf_evaluation(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg,
        β₀, x_Hes[i]; all_ξ=all_ξ, eval_gap=eval_gap_Hes, inner_iter=inner_iter,
        num_η_up=NUM_TEST_PER_TASK, num_η_low=NUM_TRAIN_PER_TASK, callback=callback_err)
    # Obj, sta, and err.
    obj_Hes[:, i] = res_Hes[:, 1]
    sta_Hes[:, i] = res_Hes[:, 2]
    err_Hes[:, i] = res_Hes[:, 3]
end

# Save results for Hessian-based algorithm.
# JLD2.save("results/Hes_seed42_it10000.jld2", Dict(
#     "x₀" => x₀,
#     "x_Hes" => x_Hes,
#     "obj_Hes" => obj_Hes,
#     "sta_Hes" => sta_Hes,
#     "err_Hes" => err_Hes,
#     "times_Hes" => times_Hes,
#     "nₖs_Hes" => nₖs_Hes,
#     "Lg₁" => Lg₁,
#     "β_Hes" => β_Hes))

# # Load results from saved files.
# data_ = JLD2.load("results/algo1.jld2")
# obj₁ = data_["obj₁"]
# sta₁ = data_["sta₁"]
# err₁ = data_["err₁"]
# times₁ = data_["times₁"]
# data_ = JLD2.load("results/Hes_seed42_it10000.jld2")
# obj_Hes = data_["obj_Hes"]
# sta_Hes = data_["sta_Hes"]
# err_Hes = data_["err_Hes"]
# times_Hes = data_["times_Hes"]
# nₖs_Hes = data_["nₖs_Hes"]
# data_ = JLD2.load("results/algo2_14000it_seed42_tuness_nN4_c0.28.jld2")
# obj₂ = data_["obj₂"]
# sta₂ = data_["sta₂"]
# err₂ = data_["err₂"]
# times₂ = data_["times₂"]
# nₖs₂ = data_["nₖs₂"]
# data_ = JLD2.load("results/algo2_14000it_seed42_nN4_c0.28.jld2")
# obj₂_ = data_["obj₂"]
# sta₂_ = data_["sta₂"]
# err₂_ = data_["err₂"]
# times₂_ = data_["times₂"]
# nₖs₂_ = data_["nₖs₂"]

# ============ Plot. ============
# Objective function value.
obj₁_mean = mean(obj₁, dims=2)[:]
ls_obj₁ = std(log10.(obj₁), dims=2)[:]
lower_obj₁ = 10 .^(log10.(obj₁_mean) .- ls_obj₁)
upper_obj₁ = 10 .^(log10.(obj₁_mean) .+ ls_obj₁)

obj₂_mean = mean(obj₂, dims=2)[:]
ls_obj₂ = std(log10.(obj₂), dims=2)[:]
lower_obj₂ = 10 .^(log10.(obj₂_mean) .- ls_obj₂)
upper_obj₂ = 10 .^(log10.(obj₂_mean) .+ ls_obj₂)

obj₂_mean_ = mean(obj₂_, dims=2)[:]
ls_obj₂_ = std(log10.(obj₂_), dims=2)[:]
lower_obj₂_ = 10 .^(log10.(obj₂_mean_) .- ls_obj₂_)
upper_obj₂_ = 10 .^(log10.(obj₂_mean_) .+ ls_obj₂_)

obj_Hes_mean = mean(obj_Hes, dims=2)[:]
ls_obj_Hes = std(log10.(obj_Hes), dims=2)[:]
lower_obj_Hes = 10 .^(log10.(obj_Hes_mean) .- ls_obj_Hes)
upper_obj_Hes = 10 .^(log10.(obj_Hes_mean) .+ ls_obj_Hes)

# Stationarity.
sta₁_mean = mean(sta₁, dims=2)[:]
ls_sta₁ = std(log10.(sta₁), dims=2)[:]
lower_sta₁ = 10 .^(log10.(sta₁_mean) .- ls_sta₁)
upper_sta₁ = 10 .^(log10.(sta₁_mean) .+ ls_sta₁)

sta₂_mean = mean(sta₂, dims=2)[:]
ls_sta₂ = std(log10.(sta₂), dims=2)[:]
lower_sta₂ = 10 .^(log10.(sta₂_mean) .- ls_sta₂)
upper_sta₂ = 10 .^(log10.(sta₂_mean) .+ ls_sta₂)

sta₂_mean_ = mean(sta₂_, dims=2)[:]
ls_sta₂_ = std(log10.(sta₂_), dims=2)[:]
lower_sta₂_ = 10 .^(log10.(sta₂_mean_) .- ls_sta₂_)
upper_sta₂_ = 10 .^(log10.(sta₂_mean_) .+ ls_sta₂_)

sta_Hes_mean = mean(sta_Hes, dims=2)[:] 
ls_sta_Hes = std(log10.(sta_Hes), dims=2)[:]
lower_sta_Hes = 10 .^(log10.(sta_Hes_mean) .- ls_sta_Hes)
upper_sta_Hes = 10 .^(log10.(sta_Hes_mean) .+ ls_sta_Hes)

# Error.
err₁_mean = mean(err₁, dims=2)[:]
ls_err₁ = std(log10.(err₁), dims=2)[:]
lower_err₁ = 10 .^(log10.(err₁_mean) .- ls_err₁)
upper_err₁ = 10 .^(log10.(err₁_mean) .+ ls_err₁)

err₂_mean = mean(err₂, dims=2)[:]
ls_err₂ = std(log10.(err₂), dims=2)[:]
lower_err₂ = 10 .^(log10.(err₂_mean) .- ls_err₂)
upper_err₂ = 10 .^(log10.(err₂_mean) .+ ls_err₂)

err₂_mean_ = mean(err₂_, dims=2)[:]
ls_err₂_ = std(log10.(err₂_), dims=2)[:]
lower_err₂_ = 10 .^(log10.(err₂_mean_) .- ls_err₂_)
upper_err₂_ = 10 .^(log10.(err₂_mean_) .+ ls_err₂_)

err_Hes_mean = mean(err_Hes, dims=2)[:]
ls_err_Hes = std(log10.(err_Hes), dims=2)[:]
lower_err_Hes = 10 .^(log10.(err_Hes_mean) .- ls_err_Hes)
upper_err_Hes = 10 .^(log10.(err_Hes_mean) .+ ls_err_Hes)

# Plot comparison of algorithm 2 with and without adaptive stepsize.
# Objective function value.
plot(eval_iters₂, obj₂_mean, yscale=:log10, label="Algorithm 2 with adaptive stepsize",
        ribbon=(obj₂_mean .- lower_obj₂, upper_obj₂ .- obj₂_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=14, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:red, xlabel="Iteration", ylabel="Objective Function Value")
plot!(eval_iters₂, obj₂_mean_, label="Algorithm 2 without adaptive stepsize",
        ribbon=(obj₂_mean_ .- lower_obj₂_, upper_obj₂_ .- obj₂_mean_),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:orange)
title!("Objective function values v.s. iterations") |> display
savefig("obj_algo2_comp.png")

# Stationarity.
plot(eval_iters₂, sta₂_mean, yscale=:log10, label="Algorithm 2 with adaptive stepsize",
        ribbon=(sta₂_mean .- lower_sta₂, upper_sta₂ .- sta₂_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:red, xlabel="Iteration", ylabel="Stationarity")
plot!(eval_iters₂, sta₂_mean_, label="Algorithm 2 without adaptive stepsize",
        ribbon=(sta₂_mean_ .- lower_sta₂_, upper_sta₂_ .- sta₂_mean_),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:orange)
title!("Stationarities v.s iterations") |> display
savefig("sta_algo2_comp.png")

# Error.
plot(eval_iters₂, err₂_mean, label="Algorithm 2 with adaptive stepsize",
        ribbon=(err₂_mean .- lower_err₂, upper_err₂ .- err₂_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:red, xlabel="Iteration", ylabel="Error")
plot!(eval_iters₂, err₂_mean_, label="Algorithm 2 without adaptive stepsize",
        ribbon=(err₂_mean_ .- lower_err₂_, upper_err₂_ .- err₂_mean_),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:orange)
title!("Errors v.s. iterations") |> display
savefig("err_algo2_comp.png")

# Plot error bars for objective function value of problem.
plot(eval_iters₁, obj₁_mean, yscale=:log10, label="Algorithm 1",
        ribbon=(obj₁_mean .- lower_obj₁, upper_obj₁ .- obj₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=14, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Iteration", ylabel="Objective Function Value")
plot!(eval_iters₂, obj₂_mean, label="Algorithm 2",
        ribbon=(obj₂_mean .- lower_obj₂, upper_obj₂ .- obj₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(eval_iters_Hes, obj_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(obj_Hes_mean .- lower_obj_Hes, upper_obj_Hes .- obj_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Objective function values v.s. iterations") |> display
# savefig("obj.png")

# Plot error bars for stationarity of problem.
plot(eval_iters₁, sta₁_mean, yscale=:log10, label="Algorithm 1",
        ribbon=(sta₁_mean .- lower_sta₁, upper_sta₁ .- sta₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Iteration", ylabel="Stationarity")
plot!(eval_iters₂, sta₂_mean, label="Algorithm 2",
        ribbon=(sta₂_mean .- lower_sta₂, upper_sta₂ .- sta₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(eval_iters_Hes, sta_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(sta_Hes_mean .- lower_sta_Hes, upper_sta_Hes .- sta_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Stationarities v.s iterations") |> display
# savefig("sta.png")

# Plot error bars for err of problem.
plot(eval_iters₁, err₁_mean, label="Algorithm 1",
        ribbon=(err₁_mean .- lower_err₁, upper_err₁ .- err₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Iteration", ylabel="Error")
plot!(eval_iters₂, err₂_mean, label="Algorithm 2",
        ribbon=(err₂_mean .- lower_err₂, upper_err₂ .- err₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(eval_iters_Hes, err_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(err_Hes_mean .- lower_err_Hes, upper_err_Hes .- err_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Errors v.s. iterations") |> display
# savefig("err.png")


# ============ Plot inner-iter v.s. performance figures. ==============
# The cumsum of inner iterations for algorithm 1.
mean_initer₁ = cumsum(1:K)
# The cumsum of inner iterations for algorithm 2.
# Note that we used some tricks in the code.
nₖs₂[nₖs₂ .== 0] .= 2
# And we use minibatch.
mean_initer₂ = 10 * cumsum(2.0 .^ mean(nₖs₂, dims=2)[:])
# The cumsum of inner iterations for Hessian-based method.
mean_initer_Hes = 10 * cumsum(2.0 .^ mean(nₖs_Hes, dims=2)[:])

# Plot error bars for objective function value of problem.
plot(mean_initer₁[eval_iters₁], obj₁_mean, yscale=:log10, label="Algorithm 1",
        ribbon=(obj₁_mean .- lower_obj₁, upper_obj₁ .- obj₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=14, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Inner iterations", ylabel="Objective Function Value")
plot!(mean_initer₂[eval_iters₂], obj₂_mean, label="Algorithm 2",
        ribbon=(obj₂_mean .- lower_obj₂, upper_obj₂ .- obj₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(mean_initer_Hes[eval_iters_Hes], obj_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(obj_Hes_mean .- lower_obj_Hes, upper_obj_Hes .- obj_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Objective function values v.s. inner iterations") |> display
# savefig("obj_initer.png")

# Plot error bars for stationarity of problem.
plot(mean_initer₁[eval_iters₁], sta₁_mean, yscale=:log10, label="Algorithm 1",
        ribbon=(sta₁_mean .- lower_sta₁, upper_sta₁ .- sta₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Inner iterations", ylabel="Stationarity")
plot!(mean_initer₂[eval_iters₂], sta₂_mean, label="Algorithm 2",
        ribbon=(sta₂_mean .- lower_sta₂, upper_sta₂ .- sta₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(mean_initer_Hes[eval_iters_Hes], sta_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(sta_Hes_mean .- lower_sta_Hes, upper_sta_Hes .- sta_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Stationarities v.s. inner iterations") |> display
# savefig("sta_initer.png")

# Plot error bars for err of problem.
plot(mean_initer₁[eval_iters₁], err₁_mean, label="Algorithm 1",
        ribbon=(err₁_mean .- lower_err₁, upper_err₁ .- err₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Inner iterations", ylabel="Error")
plot!(mean_initer₂[eval_iters₂], err₂_mean, label="Algorithm 2",
        ribbon=(err₂_mean .- lower_err₂, upper_err₂ .- err₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(mean_initer_Hes[eval_iters_Hes], err_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(err_Hes_mean .- lower_err_Hes, upper_err_Hes .- err_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Errors v.s. inner iterations") |> display
# savefig("err_initer.png")

# ============ Plot time-performance figures. ============
# The cumsum of mean of times for algorithm 1.
mean_time₁ = cumsum(mean(times₁, dims=2)[:])
# The cumsum of mean of times for algorithm 2.
mean_time₂ = cumsum(mean(times₂, dims=2)[:])
# The cumsum of mean of times for Hessian-based algorithm.
mean_time_Hes = cumsum(mean(times_Hes, dims=2)[:])

# Plot error bars for objective function value of problem.
plot(mean_time₁[eval_iters₁], obj₁_mean, yscale=:log10, label="Algorithm 1",
        ribbon=(obj₁_mean .- lower_obj₁, upper_obj₁ .- obj₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=14, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Computational Time (s)", ylabel="Objective Function Value")
plot!(mean_time₂[eval_iters₂], obj₂_mean, label="Algorithm 2",
        ribbon=(obj₂_mean .- lower_obj₂, upper_obj₂ .- obj₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(mean_time_Hes[eval_iters_Hes], obj_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(obj_Hes_mean .- lower_obj_Hes, upper_obj_Hes .- obj_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Objective function values v.s. computational time") |> display
# savefig("obj_time.png")

# Plot error bars for stationarity of problem.
plot(mean_time₁[eval_iters₁], sta₁_mean, yscale=:log10, label="Algorithm 1",
        ribbon=(sta₁_mean .- lower_sta₁, upper_sta₁ .- sta₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Computational Time (s)", ylabel="Stationarity")
plot!(mean_time₂[eval_iters₂], sta₂_mean, label="Algorithm 2",
        ribbon=(sta₂_mean .- lower_sta₂, upper_sta₂ .- sta₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(mean_time_Hes[eval_iters_Hes], sta_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(sta_Hes_mean .- lower_sta_Hes, upper_sta_Hes .- sta_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Stationarities v.s. computational time") |> display
# savefig("sta_time.png")

# Plot error bars for err of problem.
plot(mean_time₁[eval_iters₁], err₁_mean, label="Algorithm 1",
        ribbon=(err₁_mean .- lower_err₁, upper_err₁ .- err₁_mean),
        fillalpha=0.15, dpi=300, legend=:best, tickfontsize=10, guidefontsize=14,
        titlefontsize=16, legendfontsize=12, linestyle=:solid, linewidth=1.5,
        color=:blue, xlabel="Computational Time (s)", ylabel="Error")
plot!(mean_time₂[eval_iters₂], err₂_mean, label="Algorithm 2",
        ribbon=(err₂_mean .- lower_err₂, upper_err₂ .- err₂_mean),
        fillalpha=0.15, linestyle=:dash, linewidth=1.5, color=:red)
plot!(mean_time_Hes[eval_iters_Hes], err_Hes_mean, label="Hessian-based Algorithm",
        ribbon=(err_Hes_mean .- lower_err_Hes, upper_err_Hes .- err_Hes_mean),
        fillalpha=0.15, linestyle=:dot, linewidth=1.5, color=:green)
title!("Errors v.s. computational time") |> display
# savefig("err_time.png")