using Test, Random, LinearAlgebra, Compat

include("ellipse_core.jl");


# This file considers the problem
# minimize
#   w₁ KL(P₁, Q_₁) + w₂ KL(P₂, Q_₂)
# such that
#   Q₁, Q₂ have 1+ϵ moment ≤ B
#   α mean(Q₁) + β CVaR(Q₁, θ)  ≤  α mean(Q₂) + β CVaR(Q₂, θ)
#
# TODO: actually implement α, β params

struct ProblemParams
    ϵ  # moment power
    B  # moment bound
    θ  # quantile
    α  # scaling of mean in objective
    β  # scaling of cvar

    function ProblemParams(ϵ, B, θ, α=0, β=1)
        @assert B > 0;
        @assert 0 < θ < 1;
        @assert ϵ > 0;
        @assert α ≥ 0;
        @assert β ≥ 0;
        @assert α + β > 0;
        new(ϵ, B, θ, α, β)
    end
end

function zbds(p::ProblemParams)
    zlu = ((1-p.θ)*(p.α+p.β/(1-p.θ))^(1+1/p.ϵ) +
              p.θ *(p.α            )^(1+1/p.ϵ))^(p.ϵ/(1+p.ϵ))
    Be = p.B^(1/(1+p.ϵ))

    # high limit is where the region becomes bounded
    # lower limit is where z cannot be the VaR in class L
    -(p.B/p.θ)^(1/(1+p.ϵ)), Be*(zlu+p.α)/p.β
end

function isfeasible(p::ProblemParams, z)
    l, u = zbds(p)
    l < z < u
end



# sum in the objective for arm 1
function f1(ν₁, ρ, λ₂, z, p::ProblemParams, xs, ηs)
    @assert ρ ≥ 0;
    @assert λ₂ ≥ 0;
    @assert length(ηs) == length(xs)

    v, d_ν₁, d_ρ, d_λ₂ = 0., 0., 0., 0., 0.;
    for (x, η) in zip(xs, ηs)
        q = p.β*z + p.α*x + p.β/(1-p.θ)*max(x-z,0);
        f = abs(x)^(1+p.ϵ)-p.B
        a = 1 - ν₁ + λ₂*f + ρ*q;
        @assert a > 0 "ln($a), ν₁: $ν₁, ρ:$ρ, λ₂:$λ₂, θ:$p.θ, z:$z, B:$p.B, ϵ:$p.ϵ, f:$f, q:$q, x:$x";
        v    += η*log(a);
        d_ν₁ += η*-1/a;
        d_ρ  += η*q/a;
        d_λ₂ += η*f/a;
    end
    v, (d_ν₁, d_ρ, d_λ₂)
end

# sum in the objective for arm 2 
function f2(ν₁, ρ, ν₂, κ, p::ProblemParams, xs, ηs)
    @assert ρ ≥ 0;
    @assert ν₂ ≥ 0;
    @assert length(ηs) == length(xs)

    v, d_ν₁, d_ρ, d_ν₂, d_κ = 0., 0., 0., 0., 0.;
    for (x, η) in zip(xs, ηs)
        f = abs(x)^(1+p.ϵ)-p.B
        a = 1 + ν₁ + ν₂*f - p.β*κ - p.α*ρ*x - p.β/(1-p.θ)*max(ρ*x-κ,0);
        @assert a > 0 "ln($a), ν₁:$ν₁, ν₂:$ν₂, f:$f, κ:$κ, x:$x, η:$η";
        q = p.β*(ρ*x≥κ)/(1-p.θ)/a
        v    += η*log(a);
        d_ν₁ += η*1/a;
        d_ρ  += η*(-p.α*x/a-x*q);
        d_ν₂ += η*f/a;
        d_κ  += η*(-p.β/a + q);
    end
    v, (d_ν₁, d_ρ, d_ν₂, d_κ)
end



# compute objective, sub-gradient and the two separate KLs
#
# computes the lagrange dual function to
# \min_{Q in alt} \sum_i w_i KL(P̂ᵢ‖Qᵢ)

function obj(ν₁, ρ, λ₂, ν₂, κ, z, p::ProblemParams, w₁, xs₁, ηs₁, w₂, xs₂, ηs₂)
    v₁, (d1_ν₁, d1_ρ, d1_λ₂)       = f1(ν₁, ρ, λ₂,    z, p, xs₁, ηs₁);
    v₂, (d2_ν₁, d2_ρ, d2_ν₂, d2_κ) = f2(ν₁, ρ, ν₂, κ,    p, xs₂, ηs₂);

    @assert sum(ηs₁) ≈ 1
    @assert sum(ηs₂) ≈ 1

    kl₁ = v₁+log((w₁+w₂)/2w₁);
    kl₂ = v₂+log((w₁+w₂)/2w₂);

    # objective value and sub-gradient
    w₁*kl₁ + w₂*kl₂,
        w₁ .* (d1_ν₁, d1_ρ, d1_λ₂, 0.,    0.  ) .+
        w₂ .* (d2_ν₁, d2_ρ, 0.,    d2_ν₂, d2_κ),
    (kl₁, kl₂)
end












function inRegion1(ν₁, ρ, λ₂, z, p::ProblemParams)
    @assert isfeasible(p, z) "Bad input for Region 1: B $p.B  θ $p.θ  z $z"
    @assert ρ ≥ 0;
    @assert λ₂ ≥ 0;

    # TODO: make this case distinction work for λ₂ = 0
    1 - ν₁ - λ₂*p.B + if z + (p.α*ρ/λ₂/(1+p.ϵ))^(1/p.ϵ) ≥ 0
        # function decreasing already from zero. Give value at 0
        z*p.β*ρ - p.ϵ*λ₂^(-1/p.ϵ)*(p.α*ρ/(1+p.ϵ))^(1+1/p.ϵ)
    elseif z + ((p.α+p.β/(1-p.θ))*ρ/λ₂/(1+p.ϵ))^(1/p.ϵ) ≤ 0
        # function increasing still at 1. Give value at 1
        -z*p.β*ρ*p.θ/(1-p.θ) - p.ϵ*λ₂^(-1/p.ϵ)*((p.α+p.β/(1-p.θ))*ρ/(1+p.ϵ))^(1+1/p.ϵ)
    else
        # proper maximum on (0,1).
        @assert z < 0 "got $z for $((ν₁, ρ, λ₂, z, p))"
        λ₂*(-z)^(1+p.ϵ) + z*ρ*(p.α+p.β)
    end ≥ 0
end

function inRegion2(ν₁, ρ, ν₂, κ, p::ProblemParams)
    @assert ρ ≥ 0;
    @assert ν₂ ≥ 0;

    1 + ν₁ - ν₂*p.B ≥ max(
    p.β*κ               + ν₂^(-1/p.ϵ)*p.ϵ*(ρ*p.α/(1+p.ϵ))^(1+1/p.ϵ)
    ,
    - p.β*κ*p.θ/(1-p.θ) + ν₂^(-1/p.ϵ)*p.ϵ*(ρ*(p.α+p.β/(1-p.θ))/(1+p.ϵ))^(1+1/p.ϵ)
    )
end



# are parameters in region? (for testing)
function inRegion(ν₁, ρ, λ₂, ν₂, κ, z, p::ProblemParams)
    ρ ≥ 0 && λ₂ ≥ 0 && ν₂ ≥ 0 && inRegion1(ν₁, ρ, λ₂, z, p) && inRegion2(ν₁, ρ, ν₂, κ, p)
end

# get a random problem (for testing)
function randPrbl(rng)
    θ = rand(rng)
    ϵ = .05 + .95*rand(rng) # don't make it overly small
    B = 4*rand(rng)
    α = rand();
    β = rand();

    p = ProblemParams(ϵ, B, θ, α, β)

    z = let α = rand(rng), (lbd, ubd) = zbds(p)
        α*lbd + (1-α)*ubd
    end

    @assert isfeasible(p, z)

    z, p
end

# get a feasible point by rejection sampling (brutal!, used in testing)
function randPoint(rng, z, p::ProblemParams)

    if false
        # this is what we want to do:
        # sample uniformly from the region
        # but it accepts with much too small probability
        mins, maxs = enclosing_box(z, p)
        while true
            ν₁, ρ, λ₂, ν₂, κ = mins .+ rand(rng, 5).*(maxs .- mins);
            if inRegion(ν₁, ρ, λ₂, ν₂, κ, z, p)
                println("Got one")
                return  ν₁, ρ, λ₂, ν₂, κ
            end
        end
    else
        while true
            # TODO: I don't trust this covers the region well
            # see above
            ν₁ = randn(rng);
            ρ = rand(rng)
            λ₂ = rand(rng)
            ν₂ = rand(rng)
            κ = randn(rng);

            if inRegion(ν₁, ρ, λ₂, ν₂, κ, z, p)
                return  ν₁, ρ, λ₂, ν₂, κ
            end
        end
    end
end

function randRegion(rng)
    pbl = randPrbl(rng)
    randPoint(rng, pbl...)..., pbl...
end






# These are the dual region constraints.

function c1(ν₁, ρ, λ₂, ν₂, κ, z, p::ProblemParams)
    if λ₂ < 0
        -λ₂, # necessary to get feasible
        (0., 0., -1., 0., 0.)
    elseif ρ < 0
        -ρ,
        (0., -1., 0., 0., 0.)
    else
        # TODO: as for Region1, try to make this case distinction work for λ₂ = 0
        if z + (p.α*ρ/λ₂/(1+p.ϵ))^(1/p.ϵ) ≥ 0
            - 1 + ν₁ + λ₂*p.B - z*p.β*ρ + p.ϵ*λ₂^(-1/p.ϵ)*(p.α*ρ/(1+p.ϵ))^(1+1/p.ϵ), (
                1.,
                -z*p.β + p.α*(p.α*ρ/λ₂/(1+p.ϵ))^(1/p.ϵ),
                p.B - (p.α*ρ/λ₂/(1+p.ϵ))^(1+1/p.ϵ),
                0.,
                0.
            )
        elseif z + ((p.α+p.β/(1-p.θ))*ρ/λ₂/(1+p.ϵ))^(1/p.ϵ) ≤ 0
            - 1 + ν₁ + λ₂*p.B + z*p.β*ρ*p.θ/(1-p.θ) + p.ϵ*λ₂^(-1/p.ϵ)*((p.α+p.β/(1-p.θ))*ρ/(1+p.ϵ))^(1+1/p.ϵ), (
                1.,
                z*p.β*p.θ/(1-p.θ) + (p.α+p.β/(1-p.θ))*((p.α+p.β/(1-p.θ))*ρ/λ₂/(1+p.ϵ))^(1/p.ϵ),
                p.B - ((p.α+p.β/(1-p.θ))*ρ/λ₂/(1+p.ϵ))^(1+1/p.ϵ),
                0.,
                0.
            )
        else
            @assert z < 0 "got $z with $((ν₁, ρ, λ₂, ν₂, κ, z, p))"
            - 1 + ν₁ + λ₂*p.B - λ₂*(-z)^(1+p.ϵ) - z*ρ*(p.α+p.β), (
                1.,
                -z*(p.α+p.β),
                p.B - (-z)^(1+p.ϵ),
                0.,
                0.
            )
        end
    end
end


function c2(ν₁, ρ, λ₂, ν₂, κ, z, p::ProblemParams)
    if ν₂ < 0
        -ν₂, # necessary to get feasible
        (0., 0., 0., -1., 0.)
    elseif ρ < 0
        -ρ,
        (0., -1., 0., 0., 0.)
    else
        -1 - ν₁ + ν₂*p.B - p.β*κ*p.θ/(1-p.θ) + ν₂^(-1/p.ϵ)*p.ϵ*(ρ*(p.α+p.β/(1-p.θ))/(1+p.ϵ))^(1+1/p.ϵ), (
            -1.,
            (p.α+p.β/(1-p.θ))*(ρ/ν₂*(p.α+p.β/(1-p.θ))/(1+p.ϵ))^(1/p.ϵ),
            0.,
            p.B - (ρ/ν₂*(p.α+p.β/(1-p.θ))/(1+p.ϵ))^(1+1/p.ϵ),
            - p.β*p.θ/(1-p.θ)
        )
    end
end

# variable order is ν₁, ρ, λ₂, ν₂, κ
function c3(ν₁, ρ, λ₂, ν₂, κ, z, p::ProblemParams)
    if ν₂ < 0
        -ν₂, # necessary to get feasible
        (0., 0., 0., -1., 0.)
    elseif ρ < 0
        -ρ,
        (0., -1., 0., 0., 0.)
    else
        -1 - ν₁ + ν₂*p.B + p.β*κ + ν₂^(-1/p.ϵ)*p.ϵ*(ρ*p.α/(1+p.ϵ))^(1+1/p.ϵ), (
            -1.,
            p.α*(ρ/ν₂*p.α/(1+p.ϵ))^(1/p.ϵ),
            0.,
            p.B - (ρ/ν₂*p.α/(1+p.ϵ))^(1+1/p.ϵ),
            p.β
        )
    end
end


# merge above constraints by taking max
function cons(ν₁, ρ, λ₂, ν₂, κ, z, p::ProblemParams)
    argmax(first, (c1(ν₁, ρ, λ₂, ν₂, κ, z, p),
                   c2(ν₁, ρ, λ₂, ν₂, κ, z, p),
                   c3(ν₁, ρ, λ₂, ν₂, κ, z, p)));
end




function enclosing_box(z, p::ProblemParams)
    # variable order is ν₁, ρ, λ₂, ν₂, κ

    @assert isfeasible(p, z)

    Be = p.B^(1/(1+p.ϵ));
    c1 = Be*(p.α + p.β/(1-p.θ));
    c2 = Be*p.α;
    c3 = -Be*p.α + p.β*z + p.β/(1-p.θ)*max(0, -(Be+z))


    Z = p.ϵ*((1-p.θ)*((p.α + p.β/(1-p.θ))/(1+p.ϵ))^(1+1/p.ϵ) +
                p.θ *( p.α               /(1+p.ϵ))^(1+1/p.ϵ))
    zlu = (1+p.ϵ)*(Z/p.ϵ)^(p.ϵ/(1+p.ϵ))
    zlu2 = ((1-p.θ)*(p.α+p.β/(1-p.θ))^(1+1/p.ϵ) +
               p.θ *(p.α            )^(1+1/p.ϵ))^(p.ϵ/(1+p.ϵ))
    @assert zlu ≈ zlu2

    BBλ = zlu*Be;

    @assert c3 < BBλ # or region unbounded

    offλ = if -BBλ*(1-p.θ)/p.θ ≤ z*p.β ≤ BBλ
        0
    elseif z*p.β > BBλ
        (-BBλ + z*p.β)/p.α
    else
        (-BBλ*(1-p.θ) - z*p.β*p.θ)/(p.α*(1-p.θ)+p.β)
    end
    @assert offλ ≥ 0 "hey! got $offλ"

    offν = if -Be*(p.α*(1-p.θ)+p.β)/p.θ ≤ z*p.β ≤ Be*p.α
        0
    elseif z*p.β > Be*p.α
        z*p.β - Be*p.α
    else
        (-z*p.β*p.θ - Be*(p.α*(1-p.θ)+p.β))/(1-p.θ)
    end
    @assert offν ≥ 0 "hey! got $offν"

    b = (max(0,c3/Be)^(1+1/p.ϵ) - p.α^(1+1/p.ϵ)) / ((p.α + p.β/(1-p.θ))^(1+1/p.ϵ) - p.α^(1+1/p.ϵ))

    @assert b < 1-p.θ
    # b < 1-p.θ is equivalent to max(0,c3) < BBλ

    mins = (-1.,
            0.,
            0.,
            0.,
            -2*(1-p.θ)/p.θ/p.β
            );

    maxs = (1 + max(2c3,0)/(BBλ-c3),
            2/(BBλ - c3),
            2/(p.B - offλ^(1+p.ϵ)),
            2/(p.B - (Z/p.ϵ)^(-p.ϵ)*(offν/(1+p.ϵ))^(1+p.ϵ)),
            2/p.β*(1-p.θ)/(1-p.θ-max(0,b))
    )

    
    @assert all(mins .≤ maxs) "not-a-box for $p, $z:\n$(join(
("$v: $l $u" for (v,l,u) in zip((:ν₁, :ρ, :λ₂, :ν₂, :κ), mins, maxs)),'\n'))"
    mins, maxs
end



function enclosing_ellipse(z, p::ProblemParams)
    mins, maxs = enclosing_box(z, p)

    @assert all(mins .≤ maxs)

    c = [(mins .+ maxs) ./ 2 ...]
    P = diagm([5 ./ 4 .* (maxs .- mins).^2 ...])

    @assert collect(c .- mins)'inv(P)*collect(c .- mins) ≈ 1
    @assert collect(c .- maxs)'inv(P)*collect(c .- maxs) ≈ 1
    
    c, P
end






# pristine ellipsoid call without side effect exploitation
function nice_ellipse(w₁, xs₁, ηs₁, w₂, xs₂, ηs₂, z, p::ProblemParams; kwargs...)
    c, P = enclosing_ellipse(z, p);

    feasible_init = ((w₂-w₁)/(w₁+w₂), 0., 0., 0., 0.)
    
    pt, (val, g, kls) = ellipse_maximise(
        c, P,
        ((ν₁, ρ, λ₂, ν₂, κ),) -> obj(ν₁, ρ, λ₂, ν₂, κ, z, p, w₁, xs₁, ηs₁, w₂, xs₂, ηs₂),
        ((ν₁, ρ, λ₂, ν₂, κ),) -> cons(ν₁, ρ, λ₂, ν₂, κ, z, p);
        feasible_init, kwargs...
    )

    if any(kls .< 0)
        @warn "Got negative KLs $kls. Carrying on."
    end

    val, kls, pt
end


function ellipse(w₁, xs₁, ηs₁, w₂, xs₂, ηs₂, z, p::ProblemParams; kwargs...)

    @assert sum(ηs₁) ≈ 1
    @assert sum(ηs₂) ≈ 1

    c, P = enclosing_ellipse(z, p);

    # shared state between constraints and objective
    cache_pt = (0., 0., 0., 0., 0.);
    v₁, d1_ν₁, d1_ρ, d1_λ₂, v₂, d2_ν₁, d2_ρ, d2_ν₂, d2_κ, kl₁, kl₂ = zeros(11)


    function evil_cons((ν₁, ρ, λ₂, ν₂, κ))
        cache_pt = (ν₁, ρ, λ₂, ν₂, κ)
        viol, ∇ = cons(ν₁, ρ, λ₂, ν₂, κ, z, p);

        # Constraint evaluations are cheap. Always process them first.
        if viol > 0 return viol, ∇ end

        # we evaluate parts of the objective already here
        # THESE GET CACHED IN THE STATE ABOVE!!!!
        v₁, (d1_ν₁, d1_ρ, d1_λ₂)       = f1(ν₁, ρ, λ₂,    z, p, xs₁, ηs₁);
        v₂, (d2_ν₁, d2_ρ, d2_ν₂, d2_κ) = f2(ν₁, ρ, ν₂, κ,    p, xs₂, ηs₂);

        kl₁ = v₁+log((w₁+w₂)/2w₁);
        kl₂ = v₂+log((w₁+w₂)/2w₂);

        # constrain the individual KLs to be postive
        argmax(first, (
            (-kl₁, -1 .* (d1_ν₁, d1_ρ, d1_λ₂, 0.,    0.  )),
            (-kl₂, -1 .* (d2_ν₁, d2_ρ, 0.,    d2_ν₂, d2_κ)),
            (viol, ∇))) # i.e. no constraint violated.
    end

    function evil_obj((ν₁, ρ, λ₂, ν₂, κ))
        @assert cache_pt == (ν₁, ρ, λ₂, ν₂, κ)
        # if we get here, the above cons function has evaluated the
        # vs, kls and derivatives already AT OUR CURRENT ARGUMENT VALUE

        w₁*kl₁ + w₂*kl₂,
            w₁ .* (d1_ν₁, d1_ρ, d1_λ₂, 0.,    0.  ) .+
            w₂ .* (d2_ν₁, d2_ρ, 0.,    d2_ν₂, d2_κ),
        (kl₁, kl₂)
    end

    # TODO: I fudged this horribly :)
    feasible_init = ((w₂-w₁)/(w₁+w₂), eps(), eps(), eps(), 0.)
    evil_cons(feasible_init); # we need to make the first call to 'obj' by ellipsoid actually work!
    pt, (val, g, kls) = ellipse_maximise(c, P, evil_obj, evil_cons; feasible_init, kwargs...)

    if any(kls .< 0)
        @warn "Got negative KLs $kls. Carrying on."
    end

    val, kls, pt
end









# find the feasible point pt + λ (pt-c) of highest λ
function bastardize(pt, c, z, p::ProblemParams)
    @assert inRegion(pt..., z, p)

    # find initial bracket
    lo = 0;
    hi = 1;

    loc(λ) = c .+ λ .* (pt.-c)

    while inRegion(loc(hi)..., z, p)
        lo = hi;
        hi *= 2;
    end

    # binary search refine bracket
    # doing 64 steps here works too well; it gives λ₂ actually zero
    # which breaks the region partition check
    for it = 1:32
        mid = (lo+hi)/2;
        if inRegion(loc(mid)..., z, p)
            @assert cons(loc(mid)..., z, p)[1] ≤ 0
            lo = mid;
        else
            @assert cons(loc(mid)..., z, p)[1] > 0
            hi = mid;
        end
    end

    # lo is feasible
    loc(lo)
end
