using Plots, StatsPlots
include("libs/tidnabbil/binary_search.jl")
include("runit.jl")


# visualise the MAB
function viz_mab(MAB)
    xs = range(-5, 5, length=1000);
    maxpdf = maximum(pdf.(w,x) for w in MAB, x in xs)
    plot((
        begin
        p = plot(xs, pdf.(w, xs), label="pdf", title="$i", ylims=(0, maxpdf))
        vline!(p, [mean(w)], label="Mean");
        vline!(p, [quantile(w, θ)], label="VaR");
        vline!(p, [cvar(w, θ)], label="CVaR");
        p
        end
        for (i,w) in enumerate(MAB))...,
         #dpi=360,
         size= (2560, 1440) .* (1/2) # looks nice when zoomed up
         )

    gui()
end

# visualise the MAB
function viz_mab_joint(MAB)
    xs = range(-5, 5, length=1000);
    p = plot(
        #title="Pdf and CVaR at $θ",
        xlabel="x",
        legendtitle="Arm"
        #size = (2560, 1440) .* (1/2) # looks nice when zoomed up
    );
    for (i,w) in enumerate(MAB)
        plot!(p, xs, pdf.(w, xs), label="$i", color=i)
        #vline!(p, [mean(w)], label="Mean");
        #vline!(p, [quantile(w, θ)], label="VaR");
        #vline!(p, [cvar(w, θ)], color=i, label="");
    end
    p
end




# visualise the stopping situation
function viz(xs)
    K = length(xs)
    cvrs = [cvar(ones(length(xs[k]))./length(xs[k]), xs[k], θ) for k in 1:K];
    ⋆ = argmin(cvrs);

    zs = range(-sqrt(B/θ)+1e-1, sqrt(B/(1-θ))-1e-1, length=200);
    Z = hcat([
        first.(ellipse.(length(xs[k]), Ref(xs[k]),
                        length(xs[⋆]), Ref(xs[⋆]),
                        zs, θ, B))
        for k in 1:K
        if k ≠ ⋆]...)

    println("minimum is ", minimum(Z))
    println("minima are ", minimum(Z, dims=1))

    srcmin = [
    begin
        (v, kl, p), z̃ = ϕsect(z -> ellipse(length(xs[k]), xs[k],
                                           length(xs[⋆]), xs[⋆],
                                           z, θ, B),
                          -sqrt(B/θ),
                          sqrt(B/(1-θ)));
        v, z̃
        end
        for k in 1:K
        if k ≠ ⋆
    ];

    println("ϕsect gave ", first.(srcmin))
    lbl = [k for k in 1:K if k ≠ ⋆];

    grow(lo,hi) = (lo+hi)/2 .+ (1.1*(hi-lo)/2) .* (-1,+1)

    p = plot(zs, Z, label=lbl',
             ylims=grow(sort(Z[:])[[1 div(end, 3)]]...),
             marker=true,
             color=lbl',
             xlabel="z"
             )
    hline!(p, first.(srcmin), color=lbl, label="")
    scatter!(p, getindex.(srcmin,2), getindex.(srcmin,1), marker=true, color=lbl, label="")
    p
end






# practical implication of lower bound
function solve(Tstar, δ, β)
    @assert δ == β.δ
    # lower bound
    kl = (1-2δ)*log((1-δ)/δ);
    lbd = Tstar*kl;

    # more practical lower bound with the employed threshold β
    # the +1 ensures t and Tstar*β(1+t) start ≤ for small t, and cross once
    practical = binary_search(t -> t-Tstar*β(1+t), 0, 1e10); #This gives errors

    lbd, practical;
end


function crack(MAB, B, θ, ϵ; alot=1000)
    K = length(MAB)

    # discrete approximation of distribution by inverse cdf
    xs = [quantile.(m, range(0,1,length=alot+2)[2:end-1]) for m in MAB]

    cvrs = cvar.(MAB, θ)
    cvrs_emp = [cvar(ones(length(xs[k]))./length(xs[k]), xs[k], θ) for k in 1:K];

    mcstrts = mean.(x -> abs(x)^(1+ϵ), xs);
    @assert all(mcstrts .≤ B) "The input distributions have $ϵ-moments $(mcstrts) exceeding bound B=$B. If you're close to the boundary, perhaps increasing 'alot' would help."

    @assert argmin(cvrs) == argmin(cvrs_emp) "CVaR has not concentrated; definitely not enough samples. Increase 'alot'"

    ⋆ = argmin(cvrs); # ← NOTE: we're looking for minimal CVaR
    
    
    print("Computing wstar\r")
    D, w = wstar_ellipse(⋆, xs, ProblemParams(ϵ, B, θ))

    ⋆, D, w
end


function inspect(MAB, B, θ, δs, βs; seed = 1, alot=1000)
    ⋆, D, w = crack(MAB, B, θ)
    table_dump(MAB, B, θ, ⋆, D, w, δs, βs)
end

function table_dump(MAB, B, θ, ⋆, D, w, δs, βs)
    lbd_pract = solve.(1/D, δs, βs) #This gives errors due to binary_search


    for (i,m) in enumerate(MAB)
        println(" Arm $i is ", short(m))
    end

    println("B           ", B);
    println("θ           ", θ);
    println("Best Arm ⋆  ", ⋆)
    println("T^*          ", 1/D);

    ltab(list) = join(map(x->@sprintf("%7.4f", x), list), " & ")

    println();
    println("Mean        ", ltab(mean.(MAB)))
    println("Variance    ", ltab(var.(MAB)))
    #println("Ex²         ", ltab(Ex².(MAB)))
    println("VaR         ", ltab(quantile.(MAB, θ)))
    println("CVaR        ", ltab(cvar.(MAB, θ)))
    println("w^*         ", ltab(w))

    println(); # useful to see if should run this experiment
    println("Lower bd (th)  ", getindex.(lbd_pract,1))
    println("Lower bd (pr)  ", getindex.(lbd_pract,2))
    #(getindex.(lbd_pract,1), getindex.(lbd_pract,2))
end


# takes model and data from global scope
function mkplot(MAB, B, θ, ⋆, D, w, δs, βs, data)
    lbd_pract = solve.(1/D, δs, βs) #This gives errors due to binary_search

    # we make that plot
    p = density([sum.(getindex.(getindex.(data, i),2)) for i in eachindex(δs)],
                label=[δs...]',
                normalize=:pdf,
                xlabel="T",
                ylabel="fraction of runs",
                legendtitle="δ",
		left_margin=10Plots.mm,
		bottom_margin=5Plots.mm
                )


    vline!(p, [getindex.(lbd_pract,1)...], color=1:length(δs), label="")
    vline!(p, [getindex.(lbd_pract,2)...], color=1:length(δs), label="", linestyle=:dash)
    p
end


function short(w::GeneralizedExtremeValue)
    "F$(params(w))";
end

function short(w::MixtureModel)
    join(map((w,n) -> "$w $(short(n))", w.prior.p, w.components), " + ");
end

function short(w::GeneralizedPareto)
    "P$(params(w))";
end

function reorder(data, ix) #This gives errors
    map.(((istar, N, time),) -> (istar, N[ix], time), data)
end
