# a Pure Exploration problem (pep)
# - a query, as embodied by a correct-answer function istar
# - nanswers: number of possible answers
# - istar: correct answer for feasible μ
# - glrt: value and best response (λ and ξ) to (N, μ) or (w, μ)
# - oracle: characteristic time and oracle weights at μ

```
Best Arm
```

struct BestArm
    expfams;    # Exponential family
end

nanswers(pep::BestArm, μ) = length(pep.expfams);
istar(pep::BestArm, μ) = argmax(μ);
getexpfam(pep::BestArm, k) = pep.expfams[k];
long(pep::BestArm) = "BAI for " * (typeof(getexpfam(pep, 1)) == Bernoulli ? "Bernoulli" : "Gaussian") * " bandits";

# Alternative parameter
function alt_λ(μ1, w1, μa, wa)
    if w1 == 0
        return μa;
    end
    if wa == 0 || μ1 == μa
        return μ1;
    end
    x = wa / w1;
    return (μ1 + x * μa) / (1 + x);
end

function glrt(pep::BestArm, w, μ)
    @assert length(size(μ)) == 1
    expfam = getexpfam(pep, 1);
    K = length(μ);
    astar = argmax(μ); # index of best arm among μ

    vals = Inf * ones(K);
    θs = zeros(K);
    for a in 1:K
        if μ[a] < μ[astar]
            θs[a] = alt_λ(μ[astar], w[astar], μ[a], w[a]);
            vals[a] = w[astar] * d(expfam, μ[astar], θs[a]) + w[a] * d(expfam, μ[a], θs[a]);
        elseif a != astar
            θs[a] = μ[a];
            vals[a] = 0;
        end
    end
    k = argmin(vals);

    λ = copy(μ);
    λ[astar] = θs[k];
    λ[k] = θs[k];

    vals, (k, λ), (astar, μ);
end

# Solve for x such that d1(μx) + x*da(μx) == v
function X(expfam, μ1, μa, v)
    upd_a = d(expfam, μ1, μa); # range of V(x) is [0, upd_a]
    @assert 0 ≤ v ≤ upd_a "0 ≤ $v ≤ $upd_a";
    α = binary_search(
        z -> let uz = alt_λ(μ1, 1 - z, μa, z)
        (1 - z) * d(expfam, μ1, uz) + z * d(expfam, μa, uz) - (1 - z) * v
        end,
        0, 1, ϵ = upd_a*1e-10);
    α/(1-α), alt_λ(μ1, 1 - α, μa, α);
end

# Oracle solution
function oracle(pep::BestArm, μs)
    μstar = maximum(μs);
    expfam = getexpfam(pep, 1);

    if all(μs .== μstar) # yes, this happens
        return Inf, ones(length(μs))/length(μs);
    end

    astar = argmax(μs);

    # determine upper range for subsequent binary search
    hi = minimum(
        d(expfam, μs[astar], μs[k])
        for k in eachindex(μs)
        if k != astar
    );

    val = binary_search(
        z -> sum(
            let ux = X(expfam, μs[astar], μs[k], z)[2];
            d(expfam, μs[astar], ux) / d(expfam, μs[k], ux)
            end
            for k in eachindex(μs)
            if k != astar
            ) - 1.0,
        0, hi);

    ws = [(k == astar) ? 1. : X(expfam, μs[astar], μs[k], val)[1] for k in eachindex(μs)];
    Σ = sum(ws);
    Σ / val, ws ./ Σ;
end

# Oracle solution for β = 1/2
function oracle_beta_half(pep::BestArm, μs)
    μstar = maximum(μs);
    expfam = getexpfam(pep, 1);

    if all(μs .== μstar) # yes, this happens
        return Inf, ones(length(μs))/length(μs);
    end

    astar = argmax(μs);

    # determine upper range for subsequent binary search
    hi = minimum([(μs[astar] - μs[k])^2 for k in eachindex(μs) if k != astar]);

    # Binary search
    val = binary_search(
        z -> sum([1 / ((μs[astar] - μs[k])^2 / z - 1) for k in eachindex(μs) if k != astar])  - 1.0,
        0, hi);
    inv_val = 1 / val;

    ws = [(k == astar) ? 1/2 : 0.5 / (inv_val * (μs[astar] - μs[k])^2 - 1) for k in eachindex(μs)];
    4 * inv_val, ws;
end


```
Epsilon Best Arm
Note: This is only for Gaussian distributions
```


struct EspilonBestArm
    expfams;    # Exponential family
    ϵ;
    opt;
    BAIpep;
    EspilonBestArm(expfams, ϵ, opt) = new(expfams, ϵ, opt, BestArm(expfams));
end

nanswers(pep::EspilonBestArm, μ) = length(pep.expfams);
istar(pep::EspilonBestArm, μ) = argmax(μ);
iF(pep::EspilonBestArm, μ) = argmax(μ);
iepss(pep::EspilonBestArm, μ) = filter(x -> μ[x] >= (pep.opt == "mul" ? maximum(μ) * (1 - pep.ϵ) : maximum(μ) - pep.ϵ), 1:length(μ));
getexpfam(pep::EspilonBestArm, k) = pep.expfams[k];
long(pep::EspilonBestArm) = (pep.opt == "mul" ? "Multiplicative" : "Additive") *" ϵ-BAI for Gaussian bandits with ϵ=" * string(ϵ);


# Alternative parameter
function alt_eps_λ(μ1, w1, μa, wa, ϵ, opt)
    if opt == "add"
        λ = alt_λ(μ1, w1, μa - ϵ, wa);
        return λ, λ + ϵ;
    else
        λ = alt_λ(μ1, w1, μa * (1 - ϵ), wa / (1 - ϵ)^2);
        return λ, λ / (1 - ϵ);
    end
end

# This computes the highest GLR
# Note that we cannot re-use previous computations since the alternative parameter is different than for BAI
function higher_glrt(pep::EspilonBestArm, w, μ)
    @assert length(size(μ)) == 1
    ieps = iepss(pep, μ);
    K = length(μ);

    val, vals, iF, jF, λs = maximum(
        begin
            vals = Inf * ones(K);
            θs = [(0,0) for k in 1:K];
            for j in 1:K
                if pep.opt == "mul"
                    if μ[i] == μ[j] * (1 - pep.ϵ)
                        θs[j] = μ[i], μ[j];
                        vals[j] = 0;
                    elseif j != i
                        θs[j] = alt_eps_λ(μ[i], w[i], μ[j], w[j], pep.ϵ, "mul");
                        vals[j] = (μ[i] - μ[j] * (1 - pep.ϵ))^2 / (1 / w[i] + (1 - pep.ϵ)^2 / w[j]);
                    end
                elseif pep.opt == "add"
                    if μ[i] == μ[j] - pep.ϵ
                        θs[j] = μ[i], μ[j];
                        vals[j] = 0;
                    elseif j != i
                        θs[j] = alt_eps_λ(μ[i], w[i], μ[j], w[j], pep.ϵ, "add");
                        vals[j] = (μ[i] - μ[j] + pep.ϵ)^2 / (1 / w[i] + 1 / w[j]);
                    end
                end
            end
            jalt = argmin(vals)
            vals[jalt], vals, i, jalt, θs[jalt]
        end
        for i in ieps
    );

    λ = copy(μ);
    λ[iF] = λs[1];
    λ[jF] = λs[2];

    0.5 * vals, (jF, λ), (iF, μ);
end

# This computes the GLR wrt the EB arm
# Note that we cannot re-use previous computations since the alternative parameter is different than for BAI
function glrt(pep::EspilonBestArm, w, μ)
    @assert length(size(μ)) == 1
    K = length(μ);
    astar = argmax(μ);

    vals = Inf * ones(K);
    θs = [(0.,0.) for k in 1:K];
    for j in 1:K
        if pep.opt == "mul"
            if j != astar
                θs[j] = alt_eps_λ(μ[astar], w[astar], μ[j], w[j], pep.ϵ, "mul");
                vals[j] = (μ[astar] - μ[j] * (1 - pep.ϵ))^2 / (1 / w[astar] + (1 - pep.ϵ)^2 / w[j]);
            end
        elseif pep.opt == "add"
            if j != astar
                θs[j] = alt_eps_λ(μ[astar], w[astar], μ[j], w[j], pep.ϵ, "add");
                vals[j] = (μ[astar] - μ[j] + pep.ϵ)^2 / (1 / w[astar] + 1 / w[j]);
            end
        end
    end
    jalt = argmin(vals);

    λ = copy(μ);
    λ[astar] = θs[jalt][1];
    λ[jalt] = θs[jalt][2];

    0.5 * vals, (jalt, λ), (astar, μ);
end

# Oracle solution
function oracle(pep::EspilonBestArm, μs)
    astar = argmax(μs);
    if pep.opt == "mul"
        μstar = maximum(μs);
        expfam = getexpfam(pep, 1);

        if all(μs .== μstar) # yes, this happens
            return Inf, ones(length(μs))/length(μs);
        end

        astar = argmax(μs);

        # determine upper range for subsequent binary search
        hi = minimum([(μs[astar] - (1 - pep.ϵ) * μs[k])^2 for k in eachindex(μs) if k != astar]);

        # Binary search
        val = binary_search(
            z -> sum([1 / ((μs[astar] - (1 - pep.ϵ) * μs[k])^2 / z - 1)^2 for k in eachindex(μs) if k != astar])  - 1 / (1 - pep.ϵ)^2,
            0, hi);
        inv_val = 1 / val;

        wstar = 1 / (1 + sum([(1 - pep.ϵ)^2 / (inv_val * (μs[astar] - (1 - pep.ϵ) * μs[k])^2 - 1) for k in eachindex(μs) if k != astar]));

        ws = [(k == astar) ? wstar : wstar * (1 - pep.ϵ)^2 / (inv_val * (μs[astar] - (1 - pep.ϵ) * μs[k])^2 - 1) for k in eachindex(μs)];
        2 * inv_val / wstar, ws;
    else
        _μs = [(k == astar) ? μs[k] : μs[k] - pep.ϵ for k in eachindex(μs)];
        return oracle(pep.BAIpep, _μs)
    end
end

# Oracle solution for β = 1/2
function oracle_beta_half(pep::EspilonBestArm, μs)
    if pep.opt == "mul"
        μstar = maximum(μs);
        expfam = getexpfam(pep, 1);

        if all(μs .== μstar) # yes, this happens
            return Inf, ones(length(μs))/length(μs);
        end

        astar = argmax(μs);

        # determine upper range for subsequent binary search
        hi = minimum([(μs[astar] - (1 - pep.ϵ) * μs[k])^2 for k in eachindex(μs) if k != astar]);

        # Binary search
        val = binary_search(
            z -> sum([1 / ((μs[astar] - (1 - pep.ϵ) * μs[k])^2 / z - 1) for k in eachindex(μs) if k != astar])  - 1 / (1 - pep.ϵ)^2,
            0, hi);
        inv_val = 1 / val;

        ws = [(k == astar) ? 0.5 : 0.5 * (1 - pep.ϵ)^2 / (inv_val * (μs[astar] - (1 - pep.ϵ) * μs[k])^2 - 1) for k in eachindex(μs)];
        4 * inv_val, ws;
    else
        astar = argmax(μs);
        _μs = [(k == astar) ? μs[k] : μs[k] - pep.ϵ for k in eachindex(μs)];
        return oracle_beta_half(pep.BAIpep, _μs)
    end
end
