import LinearAlgebra
import Random
import Plots
import Arpack
using CSV
using DataFrames

include("../examples/plot_utils.jl")
include("../src/FrankWolfe.jl")

file_path = "./cc/ratings.csv"
df_ratings = CSV.read(file_path, DataFrame)
X = Matrix(select(coalesce.(unstack(df_ratings, :userId, :movieId, :rating), 0.0), Not(:userId)))

m = X.size[1]
n = X.size[2]
r = 60
k = 1000
println(m, " ", n, " ", r)

dotwh(x) = x.blocks[1] * x.blocks[2]'
concat(x, y) = FrankWolfe.BlockVector([x, y], [(m, r), (n, r)], (m+n)*r)

f(x) = LinearAlgebra.norm(X - dotwh(x))^2/2

function grad!(storage, x)
    diff = X - dotwh(x)
    grad_x = FrankWolfe.BlockVector([-diff * x.blocks[2], -diff' * x.blocks[1]], [(m, r), (n, r)], (m + n)*r)
    return storage .= grad_x
end


normX = LinearAlgebra.norm(X)

φ(x) = 3*LinearAlgebra.norm(x)^4/4 + normX*LinearAlgebra.norm(x)^2/2
∇φ(x) = (3*LinearAlgebra.norm(x)^2 + normX)*x

est = Arpack.svds(X, nsv=1, ritzvec=false)[1].S[1]*10
lmo1 = FrankWolfe.ScaledBoundLInfNormBall(zeros(m, r), 5.0*ones(m, r))
lmo2 = FrankWolfe.FrankWolfe.NuclearNormLMO(est)
lmo = FrankWolfe.ProductLMO(lmo1, lmo2);
x00 = FrankWolfe.BlockVector([5.0*ones(m, r), 5.0*ones(n, r)], [(m, r), (n, r)], (m + n)*r);
println("\n==> Short Step rule - if you know L.\n")


x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_shortstep = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Shortstep(2.0),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);


println("\n==> Adaptive if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_adaptive = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveZerothOrder(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    trajectory=true,
);

println("\n==> Adaptive Bregman if you do not know L.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_bregman = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.AdaptiveBregmanZerothOrder(φ=φ, ∇φ=∇φ),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    # verbose=true,
    trajectory=true,
);


println("\n==> Agnostic if function is too expensive for adaptive.\n")

x0 = deepcopy(x00)

@time x, v, primal, dual_gap, trajectory_agnostic = FrankWolfe.frank_wolfe(
    f,
    grad!,
    lmo,
    x0,
    max_iteration=k,
    line_search=FrankWolfe.Agnostic(),
    print_iter=k / 10,
    memory_mode=FrankWolfe.InplaceEmphasis(),
    verbose=true,
    trajectory=true,
);

data = [trajectory_bregman, trajectory_adaptive]
label = ["BregFW", "EucFW"]

plot_trajectories(data, label, xscalelog=true)
