# Codes from implementation of Top Two Algos. Revisited, Jourdan et al., 2022

using Statistics;
import Optim;

rel_entr(x, y) = x == 0 ? 0. : x * log(x / y);
function Kinf_emp(_isBer, samples, μ, u, B, is_Kinfp)
    @assert u <= B && u >= 0 "Domain violation for u: $(u) ∈ [0, $(B)]"
    if _isBer
        if is_Kinfp
            if u <= μ
                return 0;
            else
                return max(0, rel_entr(μ, u) + rel_entr(1 - μ, 1 - u));
            end
        else
            if u >= μ
                return 0;
            else
                return max(0, rel_entr(μ, u) + rel_entr(1 - μ, 1 - u));
            end
        end
    else
        if is_Kinfp
            if u <= μ
                return 0;
            elseif maximum(samples) < B && mean((B - u) ./ (B .- samples)) <= 1
                return mean(log.((B .- samples) / (B - u)));
            end
        else
            if u >= μ
                return 0;
            elseif minimum(samples) > 0 && mean(u ./ samples) <= 1
                return mean(log.(samples / u));
            end
        end

        try
            Y = is_Kinfp ? (samples .- u) / (B - u) : (u .- samples) / u;
            res = Optim.optimize(x -> -mean(log.(1 .- x * Y)),
                                 0., 1.);

            if Optim.converged(res)
                -Optim.minimum(res);
            else
                println("Optimization failed");
                Inf;
            end
        catch e
            println(e);
            Inf;
        end
    end
end

function alt_λ(samples1, μ1, w1, samplesa, μa, wa, B, _isBer)
    @assert μa <= B && μa >= 0 "Domain violation for μa: $(μa) ∈ [0, $(B)]"
    @assert μ1 <= B && μ1 >= 0 "Domain violation for μ1: $(μ1) ∈ [0, $(B)]"
    @assert μa < μ1 "1 should be the best arm: $(μa) < $(μ1)"
    if _isBer
        return (w1 * μ1 + wa * μa) / (w1 + wa);
    else
        try
            res = Optim.optimize(u -> w1 * Kinf_emp(_isBer, samples1, μ1, u, B, false) + wa * Kinf_emp(_isBer, samplesa, μa, u, B, true),
                                 μa, μ1);

            if Optim.converged(res)
                Optim.minimizer(res);
            else
                println("Optimization failed");
                Inf;
            end
        catch e
            println(e);
            Inf;
        end
    end
end


Kinf_dn(_isBer, samples, μ, v, B) = μ == 0 ? 0. : binary_search(x -> v - Kinf_emp(_isBer, samples, μ, x, B, false), 0, μ);
Kinf_up(_isBer, samples, μ, v, B) = μ == B ? B : binary_search(x -> Kinf_emp(_isBer, samples, μ, x, B, true) - v, μ, B);
