import LinearAlgebra
import Random
import SparseArrays
import DataFrames
import CSV

include("../examples/plot_utils.jl")
include("../src/FrankWolfe.jl")
include("../cc/bpg.jl")

n = Int(1e3)
m = Int(1e2)
k = 1000
eps = 1e-8

rng = Random.MersenneTwister()

Random.seed!(1234);
xpi = rand(n);
total = sum(xpi);
const xp = xpi ./ total * 0.8;

A = randn(rng, Float64, (m, n))
A = abs.(A)
A = A ./ sum.(eachrow(A))
b = A * xp

xlogx(x) = x > zero(x) ? x .* log.(x) : zero(x)
φ(x) = sum(xlogx.(x))
∇φ(x) = log.(x .+ eps) .+ 1.0

φb = sum(xlogx.(b))
∇φb = log.(b) .+ 1.0

function f(x)
    Ax = A*x
    return sum(xlogx.(Ax)) - φb - ∇φb'*(Ax - b)
end

function grad!(storage, x)
    Ax = A*x
    result = A' * (log.(Ax .+ eps) .+ 1.0 - ∇φb)
    @. storage = result
end

grad = x -> A' * (log.(A*x .+ eps) .+ 1.0 - ∇φb)

Alogb = A' * log.(b)

function subproblem(x)
    Ax = A*x
    y = (-A' * log.(Ax) + Alogb) / L + log.(x .+ eps) .+ 1.0
    t = maximum(y)
    ey = exp.(y .- t)
    if sum(ey) <= exp(1 - t)
        return ey ./ exp(1 - t)
    end
    return ey ./ sum(ey)
end

lmo = FrankWolfe.UnitSimplexOracle(1.0);
x00 = ones(n)/n
results = []
FrankWolfe.benchmark_oracles(x -> f(x), (str, x) -> grad!(str, x), () -> rand(n), lmo; k=100)

println("\n==> Short Step rule - if you know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("ShortFW", trajectory_shortstep[end]...))

println("\n==> Adaptive if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("EucFW", trajectory_adaptive[end]...))

println("\n==> Adaptive Bregman if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("BregFW", trajectory_bregman[end]...))

println("\n==> Agnostic if function is too expensive for adaptive.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);
x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

push!(results, ("OpenFW", trajectory_agnostic[end]...))

x0 = deepcopy(x00)

@time x, trajectory_bpg = bpg(
    f,
    grad!,
    subproblem,
    x0,
    lmo,
    k,
    true,
    true,
);

x0 = deepcopy(x00)

@time x, trajectory_bpg = bpg(
    f,
    grad!,
    subproblem,
    x0,
    lmo,
    k,
    true,
    true,
);

push!(results, ("MD", trajectory_bpg[end]...))

data = [trajectory_bregman, trajectory_adaptive, trajectory_shortstep, trajectory_agnostic, trajectory_bpg,]
label = ["BregFW", "EucFW", "ShortFW", "OpenFW", "MD"]


plot_trajectories(data, label, xscalelog=true, primal_offset=1e-20)
