using Printf
using Random
using Distributions
using LinearAlgebra
using Plots
using ProgressMeter
using Zygote
using SparseDiffTools

"""
    n = drawtrunGeoDist(N)

Draw a sample from the truncated geometric distribution with support {1, 2, ..., N}.

# Arguments
- Input:
    - `N::Integer`: the upper bound of the support of the truncated geometric distribution.
- Output:
    - `n::Integer`: a sample from the truncated geometric distribution with support {1, 2, ..., N}.

# Notes
The typical way to sample from a truncated distribution is to use the `truncated` function in the `Distributions` package, i.e.,
```julia
    trunGeoDist = truncated(Geometric(); lower=1, upper=N)
    n = rand(trunGeoDist, 1)[1]
```
However, this may be slower than the give-up strategy implemented in this function.
"""
function drawtrunGeoDist(N::Integer)
    if N < 1
        error("N must be a positive integer.")
    end

    n = rand(Geometric(), 1)[1]
    while n > N
        n = rand(Geometric(), 1)[1]
    end

    return n
end

"""
    yₖ, zₖ = estimate_yz(∇₂g, ∇₂L, sample_η_low, x, y, ξ, λₖ, β₀, μg, Tₖ; max_iter=Inf, num_η=10)

Estimate y and z using the inner loop of Algorithm 1.

This function will be reused for performance evaluation.

# Arguments
- Input:
    - `∇₂g::Function`: a function that computes the gradient with respect to the second argument of g.
    - `∇₂L::Function`: a function that computes the gradient with respect to the second argument of L.
    - `sample_η_low::Function`: a function that samples η given ξ, `η = sample_η_low(ξ)`.
    - `x::Vector{<:Real}`: the current x vector.
    - `y::Vector{<:Real}`: the initialization of y.
    - `ξ::Union{Real, AbstractArray}`: the random variable ξ.
    - `λₖ::Real`: the current λ value.
    - `β₀::Real`: the constant for O(1) in the step size.
    - `μg::Real`: the modulus of strong convexity of g.
    - `Tₖ::Integer`: the number of iterations in the inner loop.
    - `max_iter::Real=Inf`: the maximum number of iterations in the inner loop (default is `Inf`).
        Normally it should be an integer, here we use `Real` to allow `Inf`.
    - `num_η::Integer=10`: the number of sampled `η` when calling sample_η_low.

- Output:
    - `yₖ::Vector{<:Real}`: the estimated y vector.
    - `zₖ::Vector{<:Real}`: the estimated z vector.
"""
function estimate_yz(
    ∇₂g::Function,
    ∇₂L::Function,
    sample_η_low::Function,
    x::Vector{<:Real},
    y::Vector{<:Real},
    ξ::Union{Real, AbstractArray},
    λₖ::Real,
    β₀::Real,
    μg::Real,
    Tₖ::Integer;
    max_iter::Real=Inf,
    num_η::Integer=10,
)
    # Initialize y and z.
    yₖ = y
    zₖ = yₖ

    # The inner loop.
    for t = 0:min(Tₖ-1, max_iter)
        # Compute the inner step size, where β₀ is the constant for O(1).
        βₜ = β₀ / (μg * (t + 1))

        # Sample η given ξ.
        ηₜ = sample_η_low(ξ; num=num_η)

        # Update yₖ.
        yₖ = yₖ - βₜ * ∇₂g(x, yₖ, ηₜ, ξ)

        # Update zₖ.
        zₖ = zₖ - βₜ / λₖ * ∇₂L(x, zₖ, yₖ, λₖ, ηₜ, ξ)
    end

    return yₖ, zₖ
end

"""
    ∇F = estimate_∇F(∇₁L, xₖ, yₖ, zₖ, ηₖ, ξₖ, λₖ)

Estimate ∇F using step 15 in algorithm 1.

This function will be reused in performance evaluation.

# Arguments
- Input:
    - `∇₁L::Function`: a function that computes the gradient with respect to the second argument of L.
    - `xₖ::Vector{<:Real}`: the current `x` vector.
    - `yₖ::Vector{<:Real}`: the current `y` vector.
    - `zₖ::Vector{<:Real}`: the current `z` vector.
    - `ηₖ::AbstractVecOrMat`: the current `η`.
    - `ξₖ::Union{Real, AbstractArray}`: the current `ξ`.
    - `λₖ::Real`: the current `λ`.

- Output:
    - `∇F::Vector{<:Real}`: the estimation of `∇F`.
"""
function estimate_∇F(
    ∇₁L::Function,
    xₖ::Vector{<:Real},
    yₖ::Vector{<:Real},
    zₖ::Vector{<:Real},
    ηₖ::AbstractVecOrMat,
    ξₖ::Union{Real, AbstractArray},
    λₖ::Real,
)
    ∇F = ∇₁L(xₖ, zₖ, yₖ, λₖ, ηₖ, ξₖ)
    return ∇F
end


"""
    x_path = bilevel_solver(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg,
        x₀, α₀, β₀, K; dy=0, Pₓ=identity, num_η_up=10, num_η_low=10)

Solve CSBO using Algorithm 1.

# Arguments
- Input:
    - `f::Function`: the function inside the expectation of the objective function.
    - `g::Function`: the function inside the expectation of the objective function
        of the lower level optimization problem;
    - `sample_ξ::Function`: the function that samples ξ.
    - `sample_η_up::Function`: the function that samples η given ξ, it is used in
        the upper level optimization.
    - `sample_η_low::Function`: the function that samples η given ξ, it is used in
        the lower level optimization.
    - `lf₁::Real`: the Lipschitz constant of `f`.
    - `μg::Real`: the modulus of strong convexity of `g`.
    - `x₀::Vector{<:Real}`: the initial point of the algorithm.
    - `α₀::Real.`: the constant for O(1) in the lower stepsize.
    - `β₀::Real.`: the constant for O(1) in the upper stepsize.
    - `K::Integer`: the number of upper loop.
    - `dy::Integer=0`: the dimension of y, it is 0 if x and y share the same shape.
    - `Pₓ::Function=identity`: the projection onto the upper constraint set.
    - `num_η_up::Integer=10`: the number of sampled η in each iteration in the
        upper loop.
    - `num_η_low::Integer=10`: the number of sampled η in each iteration in the
        lower loop.

- Output:
    - `x_path::Matrix{<:Real}`: the path of x, stored as a matrix, where k-th
        column is a point at the k-th iteration.

# Notes
    - Any variable (i.e., x, y, z) should be restricted to be a Vector{<:Real},
        and every function (e.g., f, g, L, h) should receive Vector{<:Real}.
    - Since η can be data, it should be `AbstractVecOrMat`, so `sample_η_up` and
        `sample_η_low` should return `AbstractVecOrMat`.
    - `sample_η_up` and `sample_η_low` should have a keyword argument named by
       `num::Integer`, which is the sample number. This is used for setting up
       mini-batch when performing SGD-type descent, and the two keyword arguments
       `num_η_up` and `num_η_low` will be used in these two functions.
"""
function bilevel_solver(
    f::Function,
	g::Function,
	sample_ξ::Function,
	sample_η_up::Function,
	sample_η_low::Function,
	lf₁::Real,
	μg::Real,
	x₀::Vector{<:Real},
	α₀::Real,
	β₀::Real,
	K::Integer;
    dy::Integer=0,
	Pₓ::Function=identity,
    num_η_up::Integer=10,
    num_η_low::Integer=10,
)
    # A vector storing all times.
    times = zeros(Float64, K)

    # Dimension of x.
    dx = length(x₀)
    # Dimension of y.
    dy = dy == 0 ? dx : dy

    #=
            Use Zygote to compute the gradient of f, g, L, and h.
    If the function does not involve the input variable, the corresponding
    part of the output of Zygote.gradient will be `nothing`.

    For example, if f(x, y, η, ξ) = \|y - η\|_2^2, then the output will be 
        (nothing, ∇_y f, ∇_η f, nothing).

    This is made by design, but may cause some type-related issues.
    Here, we add some trick to avoid these issues.

    Specifically, we call Zygote.gradient and name it by ∇, then check if it
    is nothing, if so, then return `zero(x)`, otherwise return ∇ itself.
    =#
    ∇₁f(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) = 
        			(∇ = gradient(f, x, y, η, ξ)[1]) === nothing ? zero(x) : ∇ 

    ∇₂f(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) = 
        			(∇ = gradient(f, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇ 

    ∇₁g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) = 
        			(∇ = gradient(g, x, y, η, ξ)[1]) === nothing ? zero(x) : ∇ 

    ∇₂g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) = 
        			(∇ = gradient(g, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇ 

    ∇₁L(x::Vector{<:Real}, z::Vector{<:Real}, y::Vector{<:Real}, λ::Real,
        η::AbstractVecOrMat, ξ::Real) = ∇₁f(x, z, η, ξ) + 
            λ * (∇₁g(x, z, η, ξ) - ∇₁g(x, y, η, ξ))

    ∇₂L(x::Vector{<:Real}, z::Vector{<:Real}, y::Vector{<:Real}, λ::Real,
        η::AbstractVecOrMat, ξ::Real) = ∇₂f(x, z, η, ξ) + λ * ∇₂g(x, z, η, ξ)

    # The path of xₖ.
    x_path = zeros(dx, K)

    # Initialize the outer loop.
    xₖ = x₀
    yₖ = zeros(dy)

    # The outer loop.
    @showprogress desc="[alg1]" for k = 1:K
        # The start time of iteration k.
        start_time = time()

        # Set up the parameters for the inner loop.
        λₖ = 2 * lf₁ / μg * (k + 1)^(1 / 4)
        # α₀ is the constant for O(1).
        αₖ = α₀ / sqrt(k + 1)
        Tₖ = k

        # Sample ξ.
        ξₖ = sample_ξ()

        # Inner loop, extimating yₖ and zₖ.
        yₖ, zₖ = estimate_yz(∇₂g, ∇₂L, sample_η_low, xₖ, yₖ, ξₖ, λₖ, β₀, μg, Tₖ;
                            num_η=num_η_low)

        # Sample η given ξₖ.
        ηₖ = sample_η_up(ξₖ; num=num_η_up)

        # Estimate ∇F.
        ∇F = estimate_∇F(∇₁L, xₖ, yₖ, zₖ, ηₖ, ξₖ, λₖ)

        # Update xₖ.
        xₖ = Pₓ(xₖ - αₖ * ∇F)

        # Save xₖ to the path.
        x_path[:, k] = xₖ

        # The time duration of iteration k.
        times[k] = time() - start_time
    end

    return x_path, times
end

"""
    x_path = RT_MLMC_solver(f, g, sample_ξ, sample_η_up, sample_η_low, lf₁, μg,
        x₀, α₀, β₀, ϵ, K; a₁=0.05, nN=4, c₀=0.3, dy=0, Pₓ=identity, num_η_up=10,
        num_η_low=10, minibatch=20)

Solve CSBO using Algorithm 2.

# Arguments
- Input:
    - `f::Function`: the function inside the expectation of the objective function.
    - `g::Function`: the function inside the expectation of the objective function
        of the lower level optimization problem;
    - `sample_ξ::Function`: the function that samples ξ.
    - `sample_η_up::Function`: the function that samples η given ξ, it is used in
        the upper level optimization.
    - `sample_η_low::Function`: the function that samples η given ξ, it is used in
        the lower level optimization.
    - `lf₁::Real`: the Lipschitz constant of `f`.
    - `μg::Real`: the modulus of strong convexity of `g`.
    - `x₀::Vector{<:Real}`: the initial point of the algorithm.
    - `α₀::Real`: the constant for O(1) in the lower stepsize.
    - `β₀::Real`: the constant for O(1) in the upper stepsize.
    - `ϵ::Real`: the tolerance.
    - `K::Integer`: the number of upper loop.
    - `a₁::Real=0.05`: the constant to scale the stepsize.
    - `nN::Integer=4`: the parameter to determine N, where N = nN * log(1/ϵ).
    - `c₀::Real=0.3`: the constant to determine the nₖ to truncate stepsize.
    - `dy::Integer=0`: the dimension of y, it is 0 if x and y share the same shape.
    - `Pₓ::Function=identity`: the projection onto the upper constraint set.
    - `num_η_up::Integer=10`: the number of sampled η in each iteration in the
        upper loop.
    - `num_η_low::Integer=10`: the number of sampled η in each iteration in the
        lower loop.
    - `minibatch::Integer=20`: the number of minibatch of ξ in each iteration.
        That is, in each iteration, we sample `minibatch` ξ's, and solve each
        corresponding inner loop individually, then take the average of the
        obtained estimated gradients to reduce variance.

- Output:
    - `x_path::Matrix{<:Real}`: the path of x, stored as a matrix, where k-th
        column is a point at the k-th iteration.

# Notes
    - Any variable (i.e., x, y, z) should be restricted to be a Vector{<:Real},
        and every function (e.g., f, g, L, h) should receive Vector{<:Real}.
    - Since η can be data, it should be `AbstractVecOrMat`, so `sample_η_up` and
        `sample_η_low` should return `AbstractVecOrMat`.
    - `sample_η_up` and `sample_η_low` should have a keyword argument named by
       `num::Integer`, which is the sample number. This is used for setting up
       mini-batch when performing SGD-type descent, and the two keyword arguments
       `num_η_up` and `num_η_low` will be used in these two functions.
"""
function RT_MLMC_solver(
    f::Function,
	g::Function,
	sample_ξ::Function,
	sample_η_up::Function,
	sample_η_low::Function,
	lf₁::Real,
	μg::Real,
	x₀::Vector{<:Real},
	α₀::Real,
	β₀::Real,
	ϵ::Real,
	K::Integer;
    a₁::Real=0.05,
    nN::Integer=4,
    c₀::Real=0.3,
    dy::Integer=0,
	Pₓ::Function=identity,
    num_η_up::Integer=10,
    num_η_low::Integer=10,
    minibatch::Integer=10,
)
    # A vector storing all times.
    times = zeros(Float64, K)

    # Dimension of x.
    dx = length(x₀)

    # Dimension of y.
    dy = dy == 0 ? dx : dy

    #=
            Use Zygote to compute the gradient of f, g, L, and h.
    If the function does not involve the input variable, the corresponding
    part of the output of Zygote.gradient will be `nothing`.

    For example, if f(x, y, η, ξ) = \|y - η\|_2^2, then the output will be 
        (nothing, ∇_y f, ∇_η f, nothing).

    This is made by design, but may cause some type-related issues.
    Here, we add some trick to avoid these issues.

    Specifically, we call Zygote.gradient and name it by ∇, then check if it
    is nothing, if so, then return `zero(x)`, otherwise return ∇ itself.
    =#
    ∇₁f(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(f, x, y, η, ξ)[1]) === nothing ? zero(x) : ∇ 

    ∇₂f(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(f, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇ 

    ∇₁g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(g, x, y, η, ξ)[1]) === nothing ? zero(x) : ∇ 

    ∇₂g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(g, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇ 

    ∇₁L(x::Vector{<:Real}, z::Vector{<:Real}, y::Vector{<:Real}, λ::Real,
        η::AbstractVecOrMat, ξ::Real) = ∇₁f(x, z, η, ξ) + 
            λ * (∇₁g(x, z, η, ξ) - ∇₁g(x, y, η, ξ))

    ∇₂L(x::Vector{<:Real}, z::Vector{<:Real}, y::Vector{<:Real}, λ::Real,
        η::AbstractVecOrMat, ξ::Real) = ∇₂f(x, z, η, ξ) + λ * ∇₂g(x, z, η, ξ)

    # The path of xₖ.
    x_path = zeros(dx, K)

    # Initialize the outer loop.
    xₖ = x₀
    yₖ = zeros(dy)

    # The N and α.
    N = nN * log(1/ϵ) |> ceil |> Int
    # α₀ is the O(1) part.
    α = α₀ * ϵ ^ 4

    # λ₀
    λ₀ = 2 * lf₁ / μg

    # Minibatch of ξ.
    # - The estimated gradients.
    ∇s = zeros(dx, minibatch)
    
    # nₖs = zeros(K)
    nₖs = [drawtrunGeoDist(N) for _ in 1:K]

    # The outer loop.
    @showprogress desc="[alg2]" for k = 1:K
        # The start time of iteration k.
        start_time = time()

        # Sample nₖ from the truncated geometric distribution, here we use a
        # give-up strategy, i.e., if nₖ > N, then we redraw nₖ until nₖ ≤ N.
        # nₖ = drawtrunGeoDist(N)
        # nₖs[k] = nₖ
        nₖ = nₖs[k]
        # Compute the corresponding pₖ.
        pₖ = 2.0 ^ (-nₖ)

        # Compute λₖ.
        λₙₖ = 2 * lf₁ / μg * 2^(nₖ / 4)
        λₙₖ₋₁ = 2 * lf₁ / μg * 2^((nₖ - 1) / 4)

        # If nₖ == 0, theoretically speaking, we do not perform any updates,
        # but in practice, we perform 4 steps.
        if nₖ == 0
            nₖ = 2
        end

        # Initialize the inner loop.
        # yₖ is yₖ²ⁿᵏ
        yₖs = repeat(yₖ, 1, minibatch)
        # zₙₖ is zₖ²ⁿᵏ(λₙₖ)
        zₙₖ = yₖs |> copy
        # zₙₖ₋₁ is zₖ²ⁿᵏ(λₙₖ₋₁)
        zₙₖ₋₁ = yₖs |> copy
        # Declare the two intermediate variables need to be saved.
        yₖ²⁽ⁿᵏ⁻¹⁾ = yₖs |> copy
        # zₙₖ₋₁²⁽ⁿᵏ⁻¹⁾ is zₖ²⁽ⁿᵏ⁻¹⁾(λₙₖ₋₁)
        zₙₖ₋₁²⁽ⁿᵏ⁻¹⁾ = zₙₖ₋₁ |> copy
        # yₖ⁰ is the initial value of yₖ in iteration k.
        yₖ⁰ = yₖs |> copy
        # zₖ⁰ is zₖ⁰(λ₀), which is the initial value of zₖ in iteration k.
        zₖ⁰ = zₙₖ |> copy

        # Run in paralel the minibatch steps.
        Threads.@threads for i = 1:minibatch
            # Sample ξₖ.
            ξₖ = sample_ξ()

            # The inner loop.
            for t = 0:2^nₖ-1
                # Save yₖ²⁽ⁿᵏ⁻¹⁾ and zₙₖ₋₁²⁽ⁿᵏ⁻¹⁾.
                if t == 2^(nₖ - 1)
                    yₖ²⁽ⁿᵏ⁻¹⁾[:, i] = yₖs[:, i]
                    zₙₖ₋₁²⁽ⁿᵏ⁻¹⁾[:, i] = zₙₖ₋₁[:, i]
                end

                # Compute the inner step size, where β₀ is the constant for O(1).
                βₜ = β₀ / (μg * (t + 1))

                # Sample η given ξₖ.
                ηₜ = sample_η_low(ξₖ; num=num_η_low)

                # Update yₖ.
                yₖs[:, i] = yₖs[:, i] - βₜ * ∇₂g(xₖ, yₖs[:, i], ηₜ, ξₖ)

                # Update zₙₖ.
                zₙₖ[:, i] = zₙₖ[:, i] - βₜ / λₙₖ * ∇₂L(xₖ, zₙₖ[:, i], yₖs[:, i], λₙₖ, ηₜ, ξₖ)

                # Update zₙₖ₋₁.
                zₙₖ₋₁[:, i] = zₙₖ₋₁[:, i] - βₜ / λₙₖ₋₁ * ∇₂L(xₖ, zₙₖ₋₁[:, i], yₖs[:, i], λₙₖ₋₁, ηₜ, ξₖ)
            end
            
            # Sample η given ξₖ. 
            ηₖ = sample_η_up(ξₖ; num=num_η_up)

            # uₖ⁰ is uₖ(0, λ₀).
            uₖ⁰ = ∇₁L(xₖ, zₖ⁰[:, i], yₖ⁰[:, i], λ₀, ηₖ, ξₖ)

            # Compute uₖ's.
            # uₖⁿᵏ is uₖ(nₖ, λₙₖ).
            uₖⁿᵏ = ∇₁L(xₖ, zₙₖ[:, i], yₖs[:, i], λₙₖ, ηₖ, ξₖ)
            # uₖⁿᵏ⁻¹ is uₖ(nₖ - 1, λₙₖ₋₁).
            uₖⁿᵏ⁻¹ = ∇₁L(xₖ, zₙₖ₋₁²⁽ⁿᵏ⁻¹⁾[:, i], yₖ²⁽ⁿᵏ⁻¹⁾[:, i], λₙₖ₋₁, ηₖ, ξₖ)

            # Compute estimated gradient.
            ∇s[:, i] = uₖ⁰ + (uₖⁿᵏ - uₖⁿᵏ⁻¹) / pₖ
        end

        # Update xₖ using the averaged estimated gradient.
        if nₖ >= c₀ * N
            xₖ = Pₓ(xₖ - a₁ * α * mean(∇s; dims=2) |> vec)
        else
            xₖ = Pₓ(xₖ - α * mean(∇s; dims=2) |> vec)
        end

        # Compute the mean of yₖs as the initial yₖ in the next iteration.
        yₖ = mean(yₖs, dims=2) |> vec

        # Save xₖ to the path.
        x_path[:, k] = xₖ

        # The time duration of iteration k.
        times[k] = time() - start_time
    end

    return x_path, times, nₖs
end

function EpochSGD(
    ∇₂g::Function,
    sample_η_low::Function,
    ξ::Union{Real, AbstractArray},
    x::Vector{<:Real},
    y::Vector{<:Real},
    β::Real,
    K::Integer;
    num_η::Integer=10,
)
    # Save the initial point.
    y₁⁰ = deepcopy(y)
    # Declare yₖ₊₁⁰ for k = 0.
    yₖ₊₁⁰ = y
    yₖ⁰ = y

    for k = 1:K
        # For reusing.
        inv2k = 2.0^(-k)

        # Update yold for yₖ₊₁⁰.
        yold = deepcopy(y)

        # Update y.
        # y is indeed yₖ⁰ for current k.
        y = yₖ₊₁⁰

        # Declare yₖ₊₁⁰, initialized by 2^{-k} * yₖ⁰.
        yₖ₊₁⁰ = inv2k * yold

        for j = 0:(2^k - 1)
            # Sample ηₖʲ from P(η|ξ).
            ηₖʲ = sample_η_low(ξ; num=num_η)
            
            # Update y.
            # y -= β * inv2k * ∇₂g(x, y, ηₖʲ, ξ)
            y -= β / (j + 1) * ∇₂g(x, y, ηₖʲ, ξ)

            # Update yₖ₊₁⁰
            yₖ₊₁⁰ += inv2k * y
        end
        
        if k == K - 1
            yₖ⁰ = yₖ₊₁⁰
        end
    end
    return y₁⁰, yₖ⁰, yₖ₊₁⁰
end

function Hes⁻¹Vec(
    g::Function,
    x::Vector{<:Real},
    y::Vector{<:Real},
    r₀::Vector{<:Real},
    ∇₁f_::Vector{<:Real},
    ξ::Union{Real, AbstractArray},
    η′::AbstractVecOrMat,
    ηₙ::Vector,
    N::Integer,
    Lg₁::Real,
)
    # Obtain the number of η's in ηₙ, i.e., the N̂.
    N̂ = length(ηₙ)
    # Declare two reference variables, to be replaced in the Jacobian/Hessian-vector
    # multiplication.
    ξ_ref = Ref(ξ)
    η_ref = Ref(ηₙ[1])

    # Get the gradient of g with respect to z.
    # Notice that here the η and ξ are replaced by the reference variables.
    gy(y_) = gradient(z -> g(x, z, η_ref[], ξ_ref[]), y_)[1]
    # Get the gradient of g with respect to x, but the input is y.
    # This used to get the Hessian of g with respect to x and y.
    gx_of_y(y_) = Zygote.gradient(x_ -> g(x_, y_, η_ref[], ξ_ref[]), x)[1]
    # Get the Jacobian-vector multiplication operator.
    Jxy = JacVec(gx_of_y, y)

    # Get the Hessian-vector multiplication operator.
    Hyy = HesVecGrad(gy, y)

    # Initialize the r.
    r = r₀

    # For each η in ηₙ.
    for n in 1:N̂
        # Replace the reference variable, this means we don't need to redeclare
        # the gradient, the Jacobian-vector, and the Hessian-vector multiplication
        # operator, but reuse them.
        η_ref[] = ηₙ[n]
        # Use the Hessian-vector multiplication operator to get vn.
        vn = Hyy * r
        # Update r.
        @. r = r - (1 / Lg₁) * vn
    end

    # Compute r′.
    rprime = (N / Lg₁) * r
    # Update reference of η.
    η_ref[] = η′
    # Compute c using the Jacobian-vector multiplication operator.
    c = Jxy * rprime

    # return.
    return ∇₁f_ - c
end

function ∇RTMLMC(
    f::Function,
    g::Function,
    sample_ξ::Function,
    sample_η_up::Function,
    sample_η_low::Function,
    x::Vector{<:Real},
    y::Vector{<:Real};
    β::Real=1,
    N::Integer=10,
    num_η_up::Integer=10,
    num_η_low::Integer=10,
    Lg₁::Real=10,
    minibatch::Integer=1,
)
    # The function computing gradient of g with respect to the second argument.
    ∇₂g(x::Vector{<:Real}, y::Vector{<:Real}, η::AbstractVecOrMat, ξ::Real) =
        			(∇ = gradient(g, x, y, η, ξ)[2]) === nothing ? zero(y) : ∇ 

    # Sample N̂ uniformly from {0, ..., N - 1}.
    N̂ = rand(1:N, 1)[1]
    # The estimated gradients.
    v̂ = zeros(length(x), minibatch)

    # Sample k̂ from truncated geometric distribution.
    k̂ = drawtrunGeoDist(N)
    # Compute the corresponding probability.
    pₖ = 2. ^ (-k̂)

    Threads.@threads for i = 1:minibatch
        # Sample ξ from P(ξ).
        ξ = sample_ξ()

        # Sample η′, η′ from P(η|ξ).
        # Notice that η′, η′′ are used for computing gradient of f, that means P(η|ξ)
        # is the lower-level one.
        η′ = sample_η_low(ξ; num=num_η_low)
        η′′ = sample_η_low(ξ; num=num_η_low)
        # Sample ηₙ from P(η|ξ).
        # Notice that ηₙ are used for computing gradient of g, then P(η|ξ) is the
        # upper-level one.
        ηₙ = [sample_η_up(ξ; num=num_η_up) for _ = 1:N̂]

        # Run EpochSGD to obtain y₁⁰, yₖ⁰, yₖ₊₁⁰, with the number of epochs being k̂.
        y₁⁰, yₖ⁰, yₖ₊₁⁰ = EpochSGD(∇₂g, sample_η_low, ξ, x, y, β, k̂; num_η=num_η_low)

        # Compute r₀ used in Hes⁻¹Vec.
        r₀ = (temp = gradient(f, x, y, η′′, ξ)[2]) === nothing ? zero(y) : temp |> Vector
        # Compute the first term in (6).
        ∇₁f¹ = (temp = gradient(f, x, y₁⁰, η′′, ξ)[1]) === nothing ? zero(x) : temp
        ∇₁fᵏ = (temp = gradient(f, x, yₖ⁰, η′′, ξ)[1]) === nothing ? zero(x) : temp
        ∇₁fᵏ⁺¹ = (temp = gradient(f, x, yₖ₊₁⁰, η′′, ξ)[1]) === nothing ? zero(x) : temp

        # Use Hessian-Vector products to compute v̂¹, v̂ᵏ, v̂ᵏ⁺¹.
        v̂¹ = Hes⁻¹Vec(g, x, y₁⁰, r₀, ∇₁f¹, ξ, η′, ηₙ, N, Lg₁)
        v̂ᵏ = Hes⁻¹Vec(g, x, yₖ⁰, r₀, ∇₁fᵏ, ξ, η′, ηₙ, N, Lg₁)
        v̂ᵏ⁺¹ = Hes⁻¹Vec(g, x, yₖ₊₁⁰, r₀, ∇₁fᵏ⁺¹, ξ, η′, ηₙ, N, Lg₁)
        
        # Compute v̂, i.e., the RT-MLMC gradient estimator.
        v̂[:, i] = v̂¹ + (v̂ᵏ⁺¹ - v̂ᵏ) / pₖ
    end

    return mean(v̂; dims=2) |> vec, k̂
end

function RT_MLMC_Hessian_solver(
    f::Function,
    g::Function,
    sample_ξ::Function,
    sample_η_up::Function,
    sample_η_low::Function,
    Lg₁::Real,
    x₀::Vector{<:Real},
    αₜ::Function,
    β₀::Real,
    K::Integer,
    N::Integer;
    num_η_up::Integer=10,
    num_η_low::Integer=10,
    minibatch::Integer=1,
)
    # A vector storing all times.
    times = zeros(Float64, K)

    # Dimension of x.
    dx = length(x₀)

    # The path of x.
    x = zeros(dx, K)
    # Initialize the outer loop.
    xₖ = x₀

    nₖs = zeros(K)

    # The outer loop.
    @showprogress desc="[Hess]" for k = 1:K
        # The start time of iteration k.
        start_time = time()
        
        # Estimate gradient.
        v, nₖs[k] = ∇RTMLMC(f, g, sample_ξ, sample_η_up, sample_η_low, xₖ, xₖ; N=N,
                num_η_up=num_η_up, num_η_low=num_η_low, Lg₁=Lg₁, β=β₀,
                minibatch=minibatch)
        
        # Gradient descent?
        xₖ += αₜ(k) * v

        x[:, k] = xₖ

        # The time duration of iteration k.
        times[k] = time() - start_time
    end

    return x, times, nₖs
end