using LBFGSB

function generate_bpd_data(m, n, k; noise = 1e-3)
    B = Random.randn(m, n)
    z0 = zeros(n)
    support_b = Random.randperm(n)[1:k]
    z0[support_b] = Random.randn(k) 
    b0 = B * z0
    b = b0 + noise * Random.randn(m)

    @assert LA.norm(b - B * z0) <= (noise * 10 + 1e-12) * m

    return BPD(b, B), z0
end

struct BPD{Tb, TB}
    b :: Tb
    B :: TB
    f_verif
    function BPD{Tb, TB}(b :: Tb, B :: TB) where {Tb, TB}
        new(b, B, ProximalOperators.NormL1())
    end
end

BPD(b :: Tb, B :: TB) where {Tb, TB} = BPD{Tb, TB}(b, B)

function A_eval(bpd :: BPD, x)
    b, B = bpd.b, bpd.B
    _, n = size(B)
    
   return bpd.B * x[1:n].^2 - bpd.B * x[n+1:2*n].^2 - b
end

function JA_eval(bpd :: BPD, x)
    b, B = bpd.b, bpd.B
    _, n = size(B)

    return 2 * hcat(bpd.B .* x[1:n]', - bpd.B .* x[n+1:2*n]'), bpd.B * x[1:n].^2 - bpd.B * x[n+1:2*n].^2 - b
end

Base.@kwdef struct BPDSolver{Tf, Tg, TA, TJA, Tx, Ty, R, Tβ}
    problem :: BPD
    f :: Tf = ProximalOperators.SqrNormL2()
    g :: Tg = ProximalOperators.Zero()
    A :: TA = x -> A_eval(problem, x)
    JA :: TJA = x -> JA_eval(problem, x)
    x0 :: Tx
    y0 :: Ty
    σ0 :: R
    initial_inner_tol :: R
    λ :: R
    β :: Tβ
end

Base.@kwdef mutable struct ProximalOperatorsBPD
    solver :: BPDSolver
    state :: ALMState
    k
end

(f::ProximalOperatorsBPD)(x) = (al(f.solver, f.state, x, f.state.y, f.k))

function ProximalOperators.:gradient!(grad, bpd :: ProximalOperatorsBPD, x)
    solver = bpd.solver
    state = bpd.state
    k = bpd.k

    return al_gradx!(grad, solver, state, x, state.y, k)
end

function solve_primal!(solver :: BPDSolver, state :: ALMState, k)
    bpd = ProximalOperatorsBPD(solver, state, k)

    if state.logging && k == 1
        n = size(state.x)[1]
        x = state.x; n ÷= 2
        z = x[1:n].^2 - x[n+1:2*n].^2

        state.hist.objective[1] = bpd.solver.problem.f_verif(z)
        state.hist.feasibility[1] = LA.norm(solver.A(state.x))
        state.hist.nit[1] = 0
    end

    if state.inner_solver == "LBFGSB"
        counter = 0
        n = size(state.x)[1]
        optimizer = L_BFGS_B(n, 20)
        bounds = zeros(3, n)
        x0 = fill(Cdouble(3e0), n)
        copyto!(x0, state.x)
        fout, xout = optimizer(
            bpd, 
            (z, x) -> begin
                counter += 1
                ProximalOperators.gradient!(z, bpd, x)
            end, 
            x0, 
            bounds, 
            m=5, 
            # factr=1e7, 
            factr=1e1, 
            pgtol=state.inner_tol, 
            # iprint=0, 
            iprint=-1, 
            maxfun=state.inner_max_it, 
            maxiter=state.inner_max_it
        )
        copyto!(state.x, xout)
        iters_ffb = counter

        ProximalOperators.gradient!(state.res, bpd, state.x)
        println("LBFGSB Grad norm = $(LA.norm(state.res))")
    else

        if state.adaptive_gamma
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) < 1e-10),
                adaptive = state.adaptive_gamma,
            )
        else
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) < 1e-10),
                adaptive = state.adaptive_gamma,
                gamma = state.γ_rule(solver.β(k), state.y)
            )
        end

        x_ffb, iters_ffb = ffb(x0 = state.x, f = bpd, g = solver.g)
        copyto!(state.x, x_ffb)
    end

    x = state.x; n ÷= 2
    z = x[1:n].^2 - x[n+1:2*n].^2

    if state.verbose
        println("-----------------\tk=$(k)")
        println("Inner iterations: \t$(iters_ffb)")
        println("Objective: \t\t$(bpd.solver.problem.f_verif(z))")
        println("Constraint violation: \t$(LA.norm(solver.A(state.x)))")
        println()
    end

    if state.logging
        state.hist.objective[k+1] = bpd.solver.problem.f_verif(z)
        state.hist.feasibility[k+1] = LA.norm(solver.A(state.x))
        state.hist.nit[k+1] = state.nit + iters_ffb
    end
    
    return iters_ffb
end

export generate_bpd_data, BPDSolver