import LinearAlgebra
import Random
using CSV, DataFrames

include("../examples/plot_utils.jl")
include("../src/FrankWolfe.jl")

n = Int(1e2)
m = Int(1e2)
k = 1000

rng = Random.MersenneTwister(1234)
Random.seed!(1234);
xpi = rand(n);
const xp = xpi;

A = randn(rng, Float64, (m, n))
A = A ./ LinearAlgebra.norm.(eachcol(A))'
b = abs.(A * xp).^2

f(x) = LinearAlgebra.norm(abs.(A*x).^2 - b)^2/4

function grad!(storage, x)
    Axb = A'*(A*x - b)
    @. storage = Axb
end

φ(x) = LinearAlgebra.norm(x)^4/4 + LinearAlgebra.norm(x)^2/2
∇φ(x) = (LinearAlgebra.norm(x)^2 + 1)*x

lmo = FrankWolfe.ScaledBoundLInfNormBall(zeros(n), ones(n))
x00 = FrankWolfe.compute_extreme_point(lmo, zeros(n));
results = []

FrankWolfe.benchmark_oracles(x -> f(x), (str, x) -> grad!(str, x), () -> randn(n), lmo; k=100)

println("\n==> Short Step rule - if you know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("ShortAFW", trajectory_shortstep[end]...))

println("\n==> Adaptive if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("EucAFW", trajectory_adaptive[end]...))

println("\n==> Adaptive Bregman if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("BregAFW", trajectory_bregman[end]...))

println("\n==> Agnostic if function is too expensive for adaptive.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.away_frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("OpenAFW", trajectory_agnostic[end]...))

data = [trajectory_bregman, trajectory_adaptive, trajectory_shortstep, trajectory_agnostic]
label = ["BregAFW", "EucAFW", "ShortAFW", "OpenAFW"]

plot_trajectories(data, label, xscalelog=false, yscalelog=true)
