using LinearAlgebra

# minimise convex function subject to convex ≤ constraints
function ellipse_minimise(cin, Pin, obj, cons; feasible_init=nothing, ubd=Inf, verbose=false, tol = 1e-5, maxiter=5000)
    n = length(cin);

    c = copy(cin);  # wreck a copy
    Pg = copy(cin); # preallocate
    g = copy(cin);  # gradient buffer
    zq = copy(cin); # gradient buffer

    if feasible_init === nothing
        # storage for best iterate
        bestpt = copy(cin); # answer buffer
        bestinfo = nothing # apparently user has no clue :) WARNING:
        # we may return with 'nothing' if we never
        # encounter a feasible point.
        @info "Seriously consider passing in an initial feasible point. Please." maxlog=1
    else
        # Hint: use this to avoid problem in previous 'if' part!!!!!
        bestpt = [feasible_init...] # make a copy
        bestinfo = obj(bestpt)
        ubd = min(ubd, bestinfo[1])
    end

    C = cholesky(Matrix(Pin), check=true);

    lbd = -Inf # best lower bound

    for it in 1:maxiter
        # check any constraint violation first
        viol, cg = cons(c)

        # cutting plane direction
        g .= if viol > 0 # strict violation
            @assert norm(cg) > 0;
            cg
        else
            info = obj(c) # (value, gradient, extra...)
            val, grad = info;

            ubd = min(ubd, val)

            # store best point so far. (NB: may still be worse than ubd)
            if bestinfo === nothing || val < bestinfo[1]
                bestpt .= c;
                bestinfo = info
            end

            grad
        end

        # cutting plane offset (i.e. minimum progress necessary from
        # here toward optimum)
        h = if viol > 0
            if isfinite(viol)
                viol
            else
                0
            end
        else
            val - ubd
        end

        @assert h ≥ 0

        zq .= C.U*g;
        sqgPg = norm(zq)

        if viol ≤ 0
            lbd = max(lbd, val - sqgPg)
        end

        if verbose
            if viol ≤ 0
                println("$it  Objective round  ‖g‖_P is $(sqgPg), current gap is $ubd - $lbd = $(ubd-lbd), best so far $(ubd)");
            else
                println("$it  Constraint round ‖g‖_P is $(sqgPg), current gap is $ubd - $lbd = $(ubd-lbd), current viol is $viol")
            end
        end


        if ubd - lbd ≤ tol
            # √gPg ≥ sub-optimality of bestpt. We are done!
            @assert bestinfo !== nothing
            return bestpt, bestinfo
        end

        if Inf > viol > sqgPg
            @warn "Remaining ellipse is infeasible since $(sqgPg) < $viol. Mathematically this can only happen if your problem is infeasible. Anway, bailing now. Bye!"
            return bestpt, bestinfo
        end

        if sqgPg ≤ 0
            @warn "$g has zero norm in $P at $c. Infeasible problem? Bailing now. Bye!"
            return bestpt, bestinfo
        end

        Pg .= C.L*zq./sqgPg
        α = h/sqgPg

        # update center
        c .-= (1+n*α)/(n+1) .* Pg

        # update radii
        if n == 1
            # TODO: deep cuts for n=1
            C.UL ./= sqrt(2)
        else
            lowrankdowndate!(C, sqrt(2*(1+n*α)/((n+1)*(1+α))) .* Pg);
            C.UL .*= sqrt((1-α^2)*n^2/(n^2-1))
        end
    end

    if !verbose
        println(" ---- ERROR ---- ")
        ellipse_minimise(cin, Pin, obj, cons; verbose=true, tol = tol, maxiter=maxiter)
    end

    @assert false "Ellipse ran out of iteration budget $maxiter. Current volume is $(sqrt(det(C)))"
end



function neg_val_grad(v, g, info...)
    -v, .-g, info...
end


function neg_val(v, g, info...)
    -v, g, info...
end



# maximise concave function subject to convex ≤ constraints
function ellipse_maximise(c, P, obj, cons; lbd=-Inf, kwargs...)
    pt, vginfo = ellipse_minimise(
        c, P,
        c -> neg_val_grad(obj(c)...),
        cons; ubd=-lbd, kwargs...)
    pt, neg_val(vginfo...)
end






# wrapper to deal with linear equality constraints
function equality_wrap(f, cons, c, P, A, b)

    q = svd(A, full=true);
    m = length(b);

    Q = q.V[:,m+1:end]
    z = q.V[:,1:m]*(q.U'b ./ q.S);

    w(x) = Q*x .+ z
    x(w) = Q'w

    @assert isapprox(x(w(x(c))), x(c), atol=1e-5)

    warp∇((v,∇)) = v, vec(Q'∇)
    fA(x)    = warp∇(f(w(x)))
    consA(x) = warp∇(cons(w(x)))

    cA = x(c)
    PA = Matrix(Symmetric(Q'*P*Q))

    fA, consA, cA, PA, w, x
end


# wrapper to deal with unit sum constraints
function unitsum_wrap(f, cons, c, P)
    equality_wrap(f, cons, c, P, ones(1, length(c)), 1.)
end
