import LinearAlgebra
import Random
import SparseArrays
using DataFrames, CSV
using DelimitedFiles

# include("../src/FrankWolfe.jl")
include("../examples/plot_utils.jl")

data_dir = "./cc/gas+sensor+array+drift+dataset+at+different+concentrations/"

batch_files = filter(f->endswith(f, ".dat"), readdir(data_dir, join=true))

all_lines = String[]
for file in batch_files
    append!(all_lines, readlines(file))
end

m = length(all_lines)         
n = 128      
A = zeros(Float64, m, n)    
b = zeros(Float64, m)
t = 100

for (i, line) in enumerate(all_lines)
    tokens = split(line)
    cls_conc = split(tokens[1], ';')
    b[i] = parse(Float64, cls_conc[2])

    for feat in tokens[2:end]
        idx_val = split(feat, ':')
        idx = parse(Int, idx_val[1])
        val = parse(Float64, idx_val[2])
        A[i, idx] = val
    end
end

k = 1000
p = 1.1

function f(x)
    Ax = A*x
    return sum(abs.(Ax - b).^p)
end

φ(x) = f(x)
function ∇φ(x)
    Ax = A*x
    Axb = Ax - b
    return p * A' * (abs.(Axb).^(p-1) .* sign.(Axb))
end

function grad!(storage, x)
    result = ∇φ(x)
    @. storage = result
end

lmo = FrankWolfe.LpNormLMO{Float64, 2}(130); 
x00 = ones(n);
results = [];
FrankWolfe.benchmark_oracles(x -> f(x), (str, x) -> grad!(str, x), () -> rand(n), lmo; k=100);

println("\n==> Short Step rule - if you know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
    timeout = t,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
    timeout = t,
);

push!(results, ("ShortFW", trajectory_shortstep[end]...))

println("\n==> Adaptive if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
    timeout = t,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
    timeout = t,
);

push!(results, ("EucFW", trajectory_adaptive[end]...))

println("\n==> Adaptive Bregman if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
    timeout = t,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
    timeout = t,
);

push!(results, ("BregFW", trajectory_bregman[end]...))

println("\n==> Agnostic if function is too expensive for adaptive.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
    timeout = t,
);
x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k*5,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
    timeout = t,
);

push!(results, ("OpenFW", trajectory_agnostic[end]...))

data = [trajectory_bregman, trajectory_adaptive, trajectory_agnostic,]
label = ["BregFW", "EucFW", "OpenFW",]

plot_trajectories(data, label, xscalelog=false, yscalelog=true, primal_offset=1e-25, legend_position=:topright)
