import LinearAlgebra
import Random
import Distributions
import Plots
import Arpack
import Noise

include("../examples/plot_utils.jl")
include("../src/FrankWolfe.jl")

n = Int(1e3)
r = 20
k = 1000

rngw = Random.MersenneTwister(42)
Random.seed!(42)
Y = rand(n, r)
Y = Y ./ LinearAlgebra.norm.(eachcol(Y))'
M = Y * Y'

f(x) = LinearAlgebra.norm(M - x * x')^2/2

function grad!(storage, x)
    grad_x = 2 * x * (x' * x) - 2 * M * x
    return storage .= grad_x
end

φ(x) = LinearAlgebra.norm(x)^4/4 + LinearAlgebra.norm(x)^2/2
∇φ(x) = (LinearAlgebra.norm(x)^2 + 1)*x

est = 10 * Arpack.svds(M, nsv=1, ritzvec=false)[1].S[1];
lmo = FrankWolfe.FrankWolfe.NuclearNormLMO(est);
x00 = rand(n, r);
results = []
println("\n==> Short Step rule - if you know L.\n")


x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("ShortFW", trajectory_shortstep[end]...))

println("\n==> Adaptive if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("EucFW", trajectory_adaptive[end]...))

println("\n==> Adaptive Bregman if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("BregFW", trajectory_bregman[end]...))

println("\n==> Agnostic if function is too expensive for adaptive.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("OpenFW", trajectory_agnostic[end]...))

data = [trajectory_bregman, trajectory_adaptive, trajectory_shortstep, trajectory_agnostic]
label = ["BregFW", "EucFW", "ShortFW", "OpenFW"]


plot_trajectories(data, label, xscalelog=true, primal_offset=1e-20)
