using Random, CairoMakie, Dates

include("../utils/utils.jl")
include("../utils/runner.jl")
include("../utils/methods.jl")

include("../problems/problems.jl")

const OUTDIR = normpath(joinpath(@__DIR__, "..", "figs"))
mkpath(OUTDIR)

function run_lax_single()
    Random.seed!(1234)
    VI, params = lax(2_000_000, 1, 2π/3, true)
    η = 1.35 * params.ρ
    algs = [
        algorithm(Alg1, "Alg.1", (; η=η, ρ=params.ρ, coef=5.9)),
        algorithm(eg,   "EG",    (; γ=1.0)),
    ]
    cbs = Callback[]
    for a in algs
        cb = Callback(label=a.label)
        a.method(VI, params, cb; a.params...)
        push!(cbs, cb)
    end

    fig_norm = norm_grad_plot(cbs; nfig=1)
    fig_traj = plot_iterates(cbs, VI, params.path)
    timestamp = Dates.format(now(), "yyyymmdd_HHMMSS")
    save(joinpath(OUTDIR, "lax_alg1_norm_$(timestamp).png"),  fig_norm)
    timestamp = Dates.format(now(), "yyyymmdd_HHMMSS")
    save(joinpath(OUTDIR, "lax_alg1_traj_$(timestamp).png"),  fig_traj)
end

function run_lax_grid()
    Random.seed!(1234)
    T = 1_000_000
    angles = [1.7π/3, 2π/3, 2.2π/3, 3π/3]
    labels = ["ρ ≈ 1/5L","ρ = 1/2L","ρ ≈ 1/1.5L","ρ = 1/L"]
    coef = [2.0, 3.5, 4.0, 1.0]
    cbs = Callback[]
    for (θ, lab, c) in zip(angles, labels, coef)
        VI, params = lax(T, 1, θ, true)
        η = 0.25*params.ρ + 0.75/params.L
        a = algorithm(Alg1, "Alg.1", (; η=η, ρ=params.ρ, coef=c))
        cb = Callback(label=lab)
        a.method(VI, params, cb; a.params...)
        push!(cbs, cb)
    end
    fig = norm_grad_plot(cbs; nfig=2)
    timestamp = Dates.format(now(), "yyyymmdd_HHMMSS")
    save(joinpath(OUTDIR, "lax_alg1_norm_regime_$(timestamp).png"), fig)
end

function run_lax_three_regime()
    Random.seed!(1234)
    noise_model = :gaussian

    VI, params = lax(1_000_000, 1, 2π/3, true)
    algs = [
        algorithm(ftd, "KLL+22", (;noise_std=0.03, noise_model=noise_model)),
        algorithm(iusem, "IJO+17", (; noise_std=0.03, noise_model=noise_model)),
        algorithm(halpern_storm_eg, "AMW+25", (; noise_std=0.03, noise_model=noise_model)),
    ]
    cbs = Callback[]
    for a in algs
        cb = Callback(label=a.label)
        a.method(VI, params, cb; a.params...)
        push!(cbs, cb)
    end

    fig_norm = norm_grad_plot(cbs; nfig=3)
    timestamp = Dates.format(now(), "yyyymmdd_HHMMSS")
    save(joinpath(OUTDIR, "lax_alg1_norm_$(timestamp).png"),  fig_norm)
end