using LinearAlgebra
using CairoMakie

include("utils.jl")

cycle = Cycle([:color, :marker], covary=true)
fontsize_theme = Theme(fontsize = 28)
set_theme!(fontsize_theme, Lines=(linewidth=4, cycle=cycle), Scatter=(cycle=cycle,))

function my_scatterline!(ax, x, y, k, num_algs, num_mark, label; linewidth=3, markersize=20)
    step = Int(floor(length(y) / num_mark))
    start = k * Int(ceil(step / num_algs))

    # Skip markers if no valid values to avoid error (DID NOT HELP)
    if length(y) == 0 || step < 1
        println("warning")
        return
    end

    try
        clamp!(y, -100000, 1000000)
        lines!(ax, x, y; label = label, linewidth=linewidth)
        scatter!(ax, x[start:step:length(x)], y[start:step:length(x)]; markersize = markersize, label = label)
    catch e
        println("Some problems with plotting")
    end

end


function norm_grad_plot(data; nfig=0, num_mark=6)
    fig = Figure()

    yticks = nothing
    if nfig == 1
        yticks = ([1e2, 1e0, 1e-2, 1e-4], ["10²", "10⁰", "10⁻²","10⁻⁴"])
        pos = :rt
    elseif nfig == 2 
        yticks = ([1e0, 1e-2, 1e-4, 1e-6], ["10⁰", "10⁻²","10⁻⁴", "10⁻⁶"])
        pos = :rt
    elseif nfig == 3
        yticks = ([1e2, 1e0, 1e-2, 1e-4], ["10²", "10⁰", "10⁻²","10⁻⁴"])
        pos = :rb
    end

    ax = Axis(fig[1, 1],
        xlabel = "operator evaluations",
        ylabel = "squared operator norm",
        xlabelsize = 24,
        ylabelsize = 26,
        xticklabelsize = 19,
        yticklabelsize = 19,
        yscale = log10,
        yticks = yticks
    )

    n = length(data)

    for (k, alg) in enumerate(data)
        x = alg.x_ticks
        y = alg.store_grad
        my_scatterline!(ax, x, y, k, n, num_mark, alg.label; linewidth=3, markersize=18)
    end

    if nfig == 1
        ylims!(0.8*1e-4, 1.3*100)
    elseif nfig == 2 
        ylims!(0.8*1e-6, 2.5*10)
    elseif nfig == 3
        ylims!(0.8*1e-4, 1.3*100)
    end

    axislegend(
        labelsize = 20,
        merge = true,
        position = pos
        )
    return fig
end


function plot_iterates(data, VI::AbstractVI, path, num_mark=10; val=3)
    """Plot trajectory of iterates (only possible if problem is 2-d)"""

    F = VI.F
    fig = Figure()
    ax = Axis(fig[1, 1],
    xticklabelsize = 18,
    yticklabelsize = 18
    )
    n = length(data)
    for (k, alg) in enumerate(data)
        mat = reshape(alg.iterates, 2, :)
        x = convert(Vector{Float64}, mat[1, :])
        y = convert(Vector{Float64}, mat[2, :])
        clamp!(x, -10, 10)
        clamp!(y, -10, 10)
        my_scatterline!(ax, x, y, k, n, num_mark, alg.label,
        linewidth=3, markersize=12
        )
    end

    minval = -val/2
    maxval = val/2
    if occursin("nonconvex-linear", path); maxval = 2 end
    if occursin("follow-the-ridge", path); minval, maxval = -5, 5 end
    f(u) = Point2f(-F(u))
    const_black(_) = colorant"black"
    streamplot!(ax, f, minval..maxval, minval..maxval, linewidth=0.4, arrow_size = 8)
    xlims!(minval, maxval)
    ylims!(minval, maxval)

    axislegend(labelsize = 20, 
        merge = true,
    )
    
    return fig
end