# Run the learning algorithm, paramterised by a sampling rule
# The stopping and recommendation rules are common
#

using CPUTime
using Printf
using Random
using SparseArrays, LinearAlgebra
include("cvar.jl");
include("goldrat_search.jl");
include("libs/tidnabbil/regret.jl");
include("libs/tidnabbil/thresholds.jl");
include("libs/tidnabbil/tracking.jl");
include("myellipse.jl")
include("kolm_proj.jl")




# concave objective on the simplex
# minimise over z and over the KLInf problem
function glrt(w, ⋆, xs, pp::ProblemParams; kwargs...)
    glrt(w, ⋆, xs, ones.(length.(xs))./length.(xs), pp; kwargs...)
end

function glrt(w, ⋆, xs, ηs, pp::ProblemParams; maxiter=64, tol=1e-10, kwargs...)
    K = length(xs)
    minimum(
        begin
        # Here we use golden ratio search (version of bisection for
        # quasiconvex functions that works without gradients). The "z
        # objective" is not quite quasiconvex for small sample
        # sizes. The hope is that the "z objective" becomes
        # quasiconvex rather quickly due to forced exploration.
        # We won't worry too much about it :)
        (v, kl, p), z̃ = ϕsect(z -> ellipse(w[k], xs[k], ηs[k],
                                           w[⋆], xs[⋆], ηs[⋆],
                                           z, pp; tol, kwargs...),
                              zbds(pp)...; maxiter=maxiter);
        v, sparsevec([k, ⋆], [kl...], K)
        end
        for k in 1:K
        if k ≠ ⋆
    )
end



function wstar_ellipse(⋆, xs, pp::ProblemParams; kwargs...)
    K = length(xs);

    u = ones(K)./K;   # uniform
    ei(i) = I(K)[:,i] # i-th basis vector

    P = (1-1/K)*I(K);
    for i in 1:K
        @assert (ei(i)-u)'inv(P)*(ei(i)-u) ≈ 1
    end

    # non-negativity constraints
    cons(w) = maximum((-w[i], .- ei(i)) for i in 1:K)
    
    print("Calling unitsum_wrap \r");
    fA, consA, cA, PA, w, x = unitsum_wrap(
        w -> glrt(w, ⋆, xs, pp),
        cons, u, P)

    for i in 1:K
        e = x(ei(i));
        @assert (e-cA)'inv(PA)*(e-cA) ≈ 1 "trouble for $i at $e, value $((e-cA)'inv(PA)*(e-cA))"
    end
    print("maximizing ellipse\r")
    pt, (v, g) = ellipse_maximise(cA, PA, fA, consA; kwargs...)

    v, w(pt)
end



# reconstruct the final state of samples the algoritm took
# from the initial seed 'seed' and the per-arm sampe counts 'N'
function recover_xs(seed, MAB, N)
    K = length(N)
    rngs  = MersenneTwister.(rand(MersenneTwister(seed), UInt64, K));
    xs = [Array{Float64,1}() for k in 1:K] # all samples
    for k in 1:K
        # apparently the batch command
        # rand(rngs[k], MAB[k], N[k])
        # does not give the same answer :(
        # so we sample exactly as the algorithm did
        for i in 1:N[k]
            push!(xs[k], rand(rngs[k], MAB[k]));
        end
        sort!(xs[k])
    end
    xs
end

# old interface: wrap in ProblemParams and forward
function runit(seed, MAB, βs, B, θ, ϵ; kwargs...)
    runit(seed, MAB, βs, ProblemParams(ϵ, B, θ); kwargs...)
end

# βs must be a list of thresholds *in increasing order*
function runit(seed, MAB, βs, pp::ProblemParams; verbose=false)
    #@assert all(Ex².(MAB) .≤ B) "second-moment constraint"

    println("starting run $seed")

    βs = collect(βs); # mutable copy
    K = length(MAB);

    # instantiate one rng per arm, so that given seed and N we can
    # reconstruct all samples
    rngs  = MersenneTwister.(rand(MersenneTwister(seed), UInt64, K));
    xs = [Array{Float64,1}() for k in 1:K] # all samples

    # one learner per estimate for best cvar
    ahs = [AdaHedge(K) for k in 1:K];
    for ah in ahs
        ah.Δ = 1e-4 # lower bound guess for the scale of the gradients
    end

    tracking = ForcedExploration(CTracking(zeros(K)))

    baseline = CPUtime_us();
    R = Tuple{Int64, Array{Int64,1}, UInt64}[]; # collect return values

    while true
        t = sum(length.(xs))+1 # compute index of this round

        # sampling rule
        I = if t ≤ K
            track(tracking, length.(xs), ones(K)./K);
        else
            # compute empirical best looking arm
            cvrs = [cvar(ones(length(xs[k]))./length(xs[k]), xs[k], pp.θ) for k in 1:K];
            ⋆ = argmin(cvrs); # ← NOTE: we're looking for minimal CVaR

            # compute stopping statistic
            GLRT, _ = glrt(length.(xs), ⋆, xs, pp);
            @assert GLRT ≥ -1e-4 "GLRT is $GLRT"


            while GLRT > βs[1](t)
                popfirst!(βs);
                push!(R, (⋆, length.(xs), CPUtime_us()-baseline));
                if isempty(βs)
                    println("finished run $seed in ", (CPUtime_us()-baseline)/1e6, " seconds")
                    return R;
                end
            end


            # Idea: we are executing a saddle point learner, but feed this
            # learner the projected empirical distribution. This allows our
            # continuity result to kick in, so that the learner converges to
            # w^*(μ)
            # TODO: steps are
            # Step 1: compute all projected empirical distributions

            # compute projected (in sup-norm) empirical distributions
            emps = [kolm_proj(pp.ϵ, pp.B, xs[k], ones(length(xs[k]))./length(xs[k])) for k in 1:K];
            # compute their cvars
            empcvrs = [cvar(ηs, xs, pp.θ) for (xs,ηs) in emps];
            # and get the  best
            empstar = argmin(empcvrs); # ← NOTE: we're looking for minimal CVaR

            # Step 2: run a saddle point iteration
            w = act(ahs[empstar]);

            # Step 3: update with gradient from best response
            v, ∇ = glrt(w, empstar, getindex.(emps,1), getindex.(emps,2), pp)


            #∇ /= max(1,sum(∇)); # Menard clipping

            if verbose
                println("\n");
                println("t ", @sprintf("%6d", t), "  ",
                        "⋆ $(⋆)", "  ",
                        "empstar $(empstar)", "  ",
                        "GLRT ", @sprintf("%.4f", GLRT), "  ",
                        "proj ", @sprintf("%.0f", βs[1](t)/(GLRT/t)), "  ",
                        "proj' ", @sprintf("%.0f", βs[1](t)/v))
                println("N     ", join(map(x->@sprintf("%8d", x), length.(xs)), ", "));
                println("CVaR  ", join(map(x->@sprintf("%8.2f", x), empcvrs), ", "));
                println("ah w  ",
                        join(map(x->@sprintf("%8.5f", x), act(ahs[empstar])), ", ")
                        )
                println("ah L  ",
                        join(map(x->@sprintf("%8.5f", x), ahs[empstar].L), ", ")
                        )

                println("∇     ",
                        join(map(x->@sprintf("%8.5f", x), Vector(∇)), ", ")
                        )
            end

            # Step 4: update.
            incur!(ahs[empstar], -∇);

            track(tracking, length.(xs), w);
        end

        # sample
        push!(xs[I], rand(rngs[I], MAB[I]))
        sort!(xs[I]) # restore invariant for cvar, TODO: avoid (or make faster)
    end
end
