# For all the adaptive methods we use the step size parameter only to compute an initial forward step and then use this new point to estimate the local lipschitz constant
using LinearAlgebra: norm

include("utils.jl")


function eg(VI::AbstractVI, params::ProblemParams, cb::Callback; γ::Float64=1., adaptive=false, universal=false, anchoring=missing, ada_heuristic=false)
    """Extra-gradient method for VI given by F with option for adaptive step size."""

    F, Π = VI.F, VI.Π
    z = z0 = params.z0
    lr = 1/params.L

    if adaptive
        # more practical way to get a good initial learning rate
        u = Π(z -  lr * F(z))
        lr = norm(z-u)/ norm(F(z)-F(u))
    end
    sumofsquares = 0
    cb(VI, z, lr, F(z), γ, 0.)
    get_anch_coeff = AnchCoeff(type=anchoring, VI=VI, normalizing_factor=norm(F(z0))^2)

    k = 0
    while k <= params.n_grad_eval
        # extrapolation step
        Fz = F(z)
        anch_coeff = get_anch_coeff(Fz, z)
        u = Π(z - lr*Fz + anch_coeff * (z0-z + lr*Fz))

        # update step
        Fu = F(u)
        z_new = Π(z - γ*lr*Fu + anch_coeff * (z0-z))

        # update stepsize
        if adaptive
            lr = min(lr, norm(u-z)/norm(Fu-Fz))
            if ada_heuristic
                lr = min(lr, norm(u-z_new)/norm(Fu-F(z_new)))
            end
        end
        sumofsquares += norm(Fz-Fu)^2
        if universal; lr = 1/sqrt(1+sumofsquares) end

        z = z_new
        k+=2
        cb(VI, u, lr, Fu, γ, k)
    end
end

function Alg1(VI::AbstractVI, params::ProblemParams, cb::Callback; η::Float64, ρ::Float64, coef::Float64)
    F, Π = VI.F, VI.Π
    L = params.L
    α = 1 - ρ/η
    β(k) = 1/(k + 2)

    JτA = z -> Π(z)
    x = params.z0
    eval_count = 0
    k = 0

    cb(VI, x, η, F(x), α, 0)

    LB  = 1 + η*L
    mu = 1 - η*L

    mu = (mu == 0) ? 1e-3 : mu
    α  = (α  == 0) ? 1e-3 : α

    while eval_count < params.n_grad_eval
        αk = coef * α / (sqrt(k+2) * log(k+3))
        
        # Alg.6: N_k = ceil( 96/(1-ηL)^2  /  min{ α_k/(120 α (k+1)), 1/120 } )
        Nk = ceil(Int, 1 / (mu^2) / min(αk/(1*α*(k+1)), 1/1)) 
        # Nk = ceil(Int, 1*(k+1)*sqrt(k+2)*log(k+3)/ (mu^2)) # 11520

        # Alg.6: M_k = ceil( 672*120*log2(N_k) / (1-ηL)^2 )
        Mk = ceil(Int, 0.1*log(2, Nk) / (mu^2)) # 80640

        B  = z -> z + η*F(z)

        total = zero.(x)
        count = 0
        for _ in 1:Mk
            a, c = mlmc_fbf(x, Nk, JτA, B, LB, mu)
            total .+= a
            count += c
        end
        J̃ = total ./ Mk

        x = (1-αk)*x + αk*J̃

        Fx = F(x)
        for n in 1:count
            count_total = n+eval_count
            if  count_total > params.n_grad_eval
                return x
            end
            cb(VI, x, η, Fx, αk, count_total)
        end
        eval_count += count
        k += 1
    end
    return x
end

function fbf(z0, N::Int, A::Function, Bin::Function, LB::Float64, μ::Float64)
    B = z -> Bin(z) - z0
    z = z0
    for t in 0:N-1
        τ  = 2 / ((t+1)*μ + 6*LB)
        Bz = B(z) + randn(eltype(z), size(z))

        z_half = A(z - τ*Bz)
        z = z_half + τ*(Bz - B(z_half))
    end
    return z
end

function mlmc_fbf(z0, N::Int, A::Function, B::Function, LB::Float64, μ::Float64)
    I = geom_half()

    y0 = fbf(z0, 1, A, B, LB, μ)

    if (1 << I) > N
        return y0, 2
    end

    yIm1 = fbf(z0, 1 << (I-1), A, B, LB, μ)
    yI = fbf(z0, 1 << I, A, B, LB, μ)

    return y0 .+ (2.0^I) .* (yI .- yIm1), 3*(1 << I)+2
end

function geom_half()
    i = 1
    while rand() ≥ 0.5
        i += 1
    end
    return i
end


# Kotsalis et al., 2022
function ftd(VI::AbstractVI, params::ProblemParams, cb::Callback; α::Float64 = 0.25, λ::Float64 = 1.0,
    γ::Float64=1.0, noise_std::Float64=1.0, noise_model::Symbol=:gaussian, t_df::Float64=2.0)

    F, Π = VI.F, VI.Π
    L = params.L
    x = params.z0
    lr = α / L

    stochF = make_stochF(F; noise_std=noise_std, noise_model=noise_model, t_df=t_df)

    function batch_grad(z, mbsize)
        g = zero(z)
        for _ in 1:mbsize
            g .+= stochF(z)
        end
        g ./= mbsize
        return g
    end

    eval_count = 0
    k = 0

    cb(VI, x, lr, F(x), α, eval_count)

    B = params.n_grad_eval
    kmax = floor(Int, sqrt(B)) - 1
    mbsize = kmax + 1

    g_curr = batch_grad(x, mbsize)
    eval_count += mbsize

    g_prev = g_curr
    x_out = x
    out_count = 0

    while (k < kmax) && (eval_count + mbsize <= B)
        λt = λ

        G = g_curr + λt * (g_curr - g_prev)
        x_new = Π(x - lr * G)

        if k >= 1
            out_count += 1
            if rand() < 1 / out_count
                x_out = x_new
            end
        end

        g_new = batch_grad(x_new, mbsize)
        eval_count += mbsize

        g_prev, g_curr = g_curr, g_new
        x = x_new

        k += 1
        cb(VI, x, lr, F(x), α, eval_count)
    end

    return x_out
end


# Iusem et al., 2017
function iusem(VI::AbstractVI, params::ProblemParams, cb::Callback; α::Float64 = 1.0,
    γ::Float64=1.0, noise_std::Float64=1.0, noise_model::Symbol=:gaussian, t_df::Float64=2.0)

    F, Π = VI.F, VI.Π
    x = params.z0
    L = params.L
    lr = 1 / (4*L) 

    stochF = make_stochF(F; noise_std=noise_std, noise_model=noise_model, t_df=t_df)

    eval_count = 0
    k = 0

    Fx = F(x)
    cb(VI, x, lr, Fx, α, eval_count)

    while eval_count < params.n_grad_eval
        Nk = ceil(4.0*(k+1)*log(k+2)^1.1)

        g = zero(x)
        for j in 1:Nk
            g .+= stochF(x)
        end
        eval_count += Nk
        g ./= Nk
        z = Π(x - lr * g)

        fill!(g, 0)
        for j in 1:Nk
            g .+= stochF(z)
        end
        eval_count += Nk
        g ./= Nk
        x = Π(x - lr * g)

        k += 1
        Fx = F(x)
        cb(VI, x, lr, Fx, α, eval_count)
    end

    return x
end


# Alacaoglu et al., 2025
function halpern_storm_eg(VI::AbstractVI, params::ProblemParams, cb::Callback; B::Union{Nothing,Float64}=nothing, γ::Float64=1.0,
    noise_std::Float64=1.0, noise_model::Symbol=:gaussian, t_df::Float64=2.0)

    F, Π = VI.F, VI.Π
    L = params.L
    z0 = params.z0
    z = z0

    B = (B === nothing) ? L : B
    τ_const = B == 0.0 ? 1/4 : min(1/4, (L^2) / (12.0 * B^2))

    stochF = make_stochF(F; noise_std=noise_std, noise_model=noise_model, t_df=t_df)

    function storm_pair(z1, z2)
        if noise_std == 0.0
            noise = zero(z1)
        elseif noise_model == :gaussian
            noise = noise_std .* randn(eltype(z1), size(z1))
        elseif noise_model == :laplace
            dist = Laplace(0.0, noise_std)
            noise = rand(dist, size(z1))
        elseif noise_model == :studentt
            dist = TDist(t_df)
            noise = noise_std .* rand(dist, size(z1))
        end
        return F(z1) + noise, F(z2) + noise
    end

    k   = 0
    βk  = 1.0 / (k + 3) 
    αk  = 2.0 * βk
    γk  = (1.0 - βk) / (6.0 * L)
    τk  = τ_const / sqrt(k + 3.0)

    cb(VI, z0, γk, F(z), αk, 0) 

    g = z0
    eval_count = 0

    while eval_count + 3 <= params.n_grad_eval
        βk  = 1.0 / (k + 3)
        αk  = 2.0 * βk
        γk  = (1.0 - βk) / (6.0 * L)
        τk  = τ_const / sqrt(k + 3.0)

        zbar = βk * z0 + (1.0 - βk) * z
        z_half = Π(zbar - γk * g)

        F_mid_tilde = stochF(z_half)
        eval_count += 1

        z_next = Π(zbar - (τk * γk) * F_mid_tilde)

        F_next_tilde, F_z_tilde = storm_pair(z_next, z)
        eval_count += 2

        g = F_next_tilde + (1.0 - αk) * (g - F_z_tilde)

        z = z_next
        k += 1

        cb(VI, z, γk, F(z), αk, eval_count)
    end

    return z
end

function make_stochF(F; noise_std=1.0, noise_model=:gaussian, t_df=2.0)
    if noise_std == 0.0
        return z -> F(z) 
    end

    if noise_model == :gaussian
        return z -> F(z) .+ noise_std .* randn(eltype(z), size(z))
    elseif noise_model == :laplace
        dist = Laplace(0.0, noise_std)
        return z -> F(z) .+ rand(dist, size(z))
    elseif noise_model == :studentt
        dist = TDist(t_df)
        return z -> F(z) .+ noise_std .* rand(dist, size(z))
    end
end

Base.@kwdef mutable struct AnchCoeff
    type
    VI::AbstractVI
    normalizing_factor::Float64
    G = 0.
    index = 1
end

function (o::AnchCoeff)(Fz, z)
    """I've been trying different ways of anchoring. This function unifies this across all methods """

    if ismissing(o.type)
        anch_coeff = 0
    elseif o.type == "normal"
        anch_coeff = 1/(o.index+1)
    elseif o.type == "acc"
        anch_coeff = 1/(o.index+10)
    elseif o.type == "adaptive"
        fixed_point_residual = norm(z - o.VI.Π(z - Fz))^2
        o.G += 1/fixed_point_residual
        anch_coeff = minimum([1/o.index, 1/o.normalizing_factor * 1/o.G])
        # if (0 == k % 100); println(anch_coeff) end
    end

    o.index += 1

    return anch_coeff
end