@enum GEVP_TYPE begin
    gevp_gaussian_pd = 1
    gevp_polynomial_pd = 2
    gevp_exponential_pd = 3
    gevp_gaussian = 4
    gevp_polynomial = 5
    gevp_exponential = 6
end

export gevp_gaussian, gevp_polynomial, gevp_exponential, gevp_gaussian_pd, gevp_polynomial_pd, gevp_exponential_pd 

function generate_gevp_data(n, gevp_type :: GEVP_TYPE)
    # Generate B
    Q = LA.qr(Random.rand(n, n)).Q
    B = Q' * LA.diagm(ones(n)) * Q

    # Generate C
    Q = LA.qr(Random.rand(n, n)).Q
    if gevp_type === gevp_gaussian || gevp_type === gevp_gaussian_pd
        C = 0.1 * Random.randn(n, n)
        C = 0.5 * (C + C')

    elseif gevp_type === gevp_polynomial || gevp_type === gevp_polynomial_pd
        p = 1.
        
        C = Q' * LA.diagm(Random.shuffle([
            i^(-p) for i = 1:n
        ])) * Q

    elseif gevp_type === gevp_exponential || gevp_type === gevp_exponential_pd
        p = 0.0025
        
        C = Q' * LA.diagm(Random.shuffle([
            10^(-i*p) for i = 1:n
        ])) * Q

    else
        error("GEVP type not recognized.")
    end

    if gevp_type === gevp_gaussian_pd || gevp_type === gevp_polynomial_pd || gevp_type === gevp_exponential_pd
        # C += (-2*LA.eigmin(C)) * LA.diagm(ones(n))
        C += 0.5 * LA.diagm(ones(n))
    end

    B12 = B^0.5
    B12inv = B12^(-1)
    S = B12inv * C * B12
    S = 0.5 * (S + S')
    S = LA.real(S)

    # eigvals = LA.real(LA.eigvals(S))
    eigvec = LA.eigvecs(S)[:, 1]
    # eigvec = Random.rand(n)
    # eigvec = eigvec / LA.norm(eigvec)
    # for k = 1:100
    #     eigvec = S * eigvec / LA.norm(S * eigvec)
    # end
    res = LA.real(B12inv * eigvec)

    opt_val = 0.5 * res' * C * res

    return GEVP(C, B), opt_val, res
end

struct GEVP{TC}
    C :: TC
    B :: TC
    function GEVP{TC}(C :: TC, B :: TC) where {TC}
        new(C, B)
    end
end

GEVP(C :: TC, B :: TC) where {TC} = GEVP{TC}(C, B)

function A_eval(gevp :: GEVP, x)
    C, B = gevp.C, gevp.B
    
   return [x' * B * x - 1.,]
end

function JA_eval(gevp :: GEVP, x)
    C, B = gevp.C, gevp.B

    return 2 * reshape(B * x, (1, length(x))), [x' * B * x - 1.,]
end

Base.@kwdef struct GEVPSolver{Tf, Tg, TA, TJA, Tx, Ty, R, Tβ}
    problem :: GEVP
    f :: Tf = ProximalOperators.Quadratic(problem.C, zeros(size(problem.C)[1]))
    g :: Tg = ProximalOperators.Zero()
    A :: TA = x -> A_eval(problem, x)
    JA :: TJA = x -> JA_eval(problem, x)
    x0 :: Tx
    y0 :: Ty
    σ0 :: R
    initial_inner_tol :: R
    λ :: R
    β :: Tβ
end

Base.@kwdef mutable struct ProximalOperatorsGEVP
    solver :: GEVPSolver
    state :: ALMState
    k
end

(f::ProximalOperatorsGEVP)(x) = (al(f.solver, f.state, x, f.state.y, f.k))

function ProximalOperators.:gradient!(grad, gevp :: ProximalOperatorsGEVP, x)
    solver = gevp.solver
    state = gevp.state
    k = gevp.k

    return al_gradx!(grad, solver, state, x, state.y, k)
end

function AdaProx.eval_with_pullback(f :: ProximalOperatorsGEVP, x)
    y = copy(x)

    fx = ProximalOperators.gradient!(y, f, x)
    return fx, () -> y
end

function AdaProx.eval_with_pullback(f :: ProximalOperators.Regularize, x)
    y = copy(x)

    fx = ProximalOperators.gradient!(y, f, x)
    return fx, () -> y
end

function solve_primal!(solver :: GEVPSolver, state :: ALMState, k)
    gevp = ProximalOperatorsGEVP(solver, state, k)

    if state.logging && k == 1
        state.hist.objective[1] = solver.f(state.x) + solver.g(state.x)
        state.hist.feasibility[1] = LA.norm(solver.A(state.x))
        state.hist.nit[1] = 0
    end

    if state.triple_loop
        if !(state.inner_solver == AdaProx.adaptive_proxgrad)
            if state.adaptive_gamma
                ffb = state.inner_solver(
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                    minimum_gamma = state.minimum_gamma,
                    stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) / ffb_state.gamma < state.inner_tol),
                    adaptive = state.adaptive_gamma,
                )
            else        
                # ffb = ProximalAlgorithms.SFISTA(
                    # tol = state.inner_tol,
                    # maxit = round(state.inner_max_it),
                # )
                ffb = state.inner_solver(
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                    minimum_gamma = state.minimum_gamma,
                    stop=(ffb_iter, ffb_state) -> begin 
                        temp = LA.norm(ffb_state.res, Inf) / ffb_state.gamma
                        return temp < state.inner_tol
                    end,
                    adaptive = state.adaptive_gamma,
                    gamma = state.γ_rule(solver.β(k), state.y)
                )
            end
        end

        x_curr = copy(state.x)
        x_new = copy(state.x)
        iters = 0.
        ρ = state.ρ_rule(solver.β(k), state.y)

        for k = 1:state.ippm_max_it
            F = ProximalOperators.Regularize(
                gevp,
                2 * ρ,
                x_curr
            )
            if state.inner_solver == AdaProx.adaptive_proxgrad
                # x_ffb, iters_ffb = state.inner_solver(
                #     state.x,
                #     f = F,
                #     g = solver.g,
                #     rule = AdaProx.OurRule(gamma = state.γ_rule(solver.β(k), state.y)),
                #     tol = state.inner_tol,
                #     maxit = round(state.inner_max_it),
                # )
                x_ffb, iters_ffb = run_adapg(
                    F,
                    solver.g,
                    x_curr,
                    (x, grad_x, temp, tol) -> begin
                        return LA.norm(grad_x) <= tol
                    end,
                    state.inner_tol / 4.,
                    state.inner_max_it,
                    state.γ_rule(solver.β(k), state.y)
                )
            else
                # x_ffb, iters_ffb = ffb(x0 = x_curr, f = F, g = solver.g, Lf = 2. / (state.γ_rule(solver.β(k), state.y)), mf = ρ)
                x_ffb, iters_ffb = ffb(x0 = x_curr, f = F, g = solver.g)
            end
            copyto!(x_new, x_ffb)
            iters += iters_ffb

            if 2 * ρ * LA.norm(x_new - x_curr, Inf) <= 2 * state.inner_tol #|| iters >= state.inner_max_it
                break
            end  
        end

        copyto!(state.x, x_new)
    else
        if state.adaptive_gamma
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) / ffb_state.gamma < state.inner_tol),
                adaptive = state.adaptive_gamma,
            )
        else
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> begin 
                    temp = LA.norm(ffb_state.res, Inf) / ffb_state.gamma
                    # println(temp)
                    return temp < state.inner_tol
                end,
                adaptive = state.adaptive_gamma,
                gamma = state.γ_rule(solver.β(k), state.y)
            )
        end

        x_curr, iters = ffb(x0 = state.x, f = gevp, g = solver.g)
        copyto!(state.x, x_curr)
    end
    
    if state.verbose
        println("-----------------\tk=$(k)")
        println("Inner iterations: \t$(iters)")
        println("Objective: \t\t$(solver.f(state.x) + solver.g(state.x))")
        println("Constraint violation: \t$(LA.norm(solver.A(state.x)))")
        println()
    end

    if state.logging
        state.hist.objective[k+1] = solver.f(state.x) + solver.g(state.x)
        state.hist.feasibility[k+1] = LA.norm(solver.A(state.x))
        state.hist.nit[k+1] = state.nit + iters
    end

    return iters
end

export generate_gevp_data, GEVPSolver