# We organise them in two levels
# - sampling rule; a factory for sampling rule states
# - sampling rule state; keeps track of i.e. tracking information etc.

import Distributions;

include("../regret.jl");
include("../tracking.jl");
include("../expfam.jl");
include("helpers.jl");
include("envelope.jl");

```
Uniform sampling
```

struct RoundRobin # used as factory and state
end

long(sr::RoundRobin) = "Uniform";
abbrev(sr::RoundRobin) = "RR";

function start(sr::RoundRobin, N)
    return sr;
end

function nextsample(sr::RoundRobin, pep, astar, aalt, ξ, N, Zs, rng)
    return 1 + (sum(N) % length(N));
end

```
Oracle sampling
```

struct FixedWeights # used as factory and state
    w;
    function FixedWeights(w)
        @assert all(w .≥ 0) && sum(w) ≈ 1 "$w not in simplex";
        new(w)
    end
end

long(sr::FixedWeights) = "Oracle Weigths";
abbrev(sr::FixedWeights) = "opt";

function start(sr::FixedWeights, N)
    return sr;
end

function nextsample(sr::FixedWeights, pep, astar, aalt, ξ, N, Zs, rng)
    argmin(N .- sum(N).*sr.w);
end


```
TaS
```

struct TaS
    TrackingRule;
end

long(sr::TaS) = "TaS " * abbrev(sr.TrackingRule);
abbrev(sr::TaS) = "TaS-" * abbrev(sr.TrackingRule);

struct TaSState
    t;
    TaSState(TrackingRule, N) = new(ForcedExploration(TrackingRule(N)));
end

function start(sr::TaS, N)
    TaSState(sr.TrackingRule, N);
end

function nextsample(sr::TaSState, pep, astar, aalt, ξ, N, Zs, rng)
    _, w = oracle(pep, ξ);

    # tracking
    return track(sr.t, N, w);
end


```
LUCB
```

struct LUCB
end

long(sr::LUCB) = "LUCB";
abbrev(sr::LUCB) = "LUCB";

function start(sr::LUCB, N)
    sr;
end

function nextsample(sr::LUCB, pep, ξ, N, β, rng)
    K = nanswers(pep, ξ);
    expfam = getexpfam(pep, 1);
    astar = argmax(ξ);

    UCBs = zeros(K);
    LCB = ddn(expfam, ξ[astar], β / N[astar]);
    for a in 1:K
        if a != astar
            UCBs[a] = dup(expfam, ξ[a], β / N[a]);
        else
            UCBs[a] = -Inf;
        end
    end

    # Best challenger to astar
    k = argmax(UCBs);
    @assert k != astar "Problem"

    # Compute gap
    ucb = UCBs[k] - LCB;
    return astar, k, ucb;
end

```
LUCBhalf
```

struct LUCBhalf
end

long(sr::LUCBhalf) = "LUCBhalf";
abbrev(sr::LUCBhalf) = "LUCBhalf";

function start(sr::LUCBhalf, N)
    sr;
end

function nextsample(sr::LUCBhalf, pep, ξ, N, β, rng)
    K = nanswers(pep, ξ);
    expfam = getexpfam(pep, 1);
    astar = argmax(ξ);

    UCBs = zeros(K);
    LCB = ddn(expfam, ξ[astar], β / N[astar]);
    for a in 1:K
        if a != astar
            UCBs[a] = dup(expfam, ξ[a], β / N[a]);
        else
            UCBs[a] = -Inf;
        end
    end

    # Best challenger to astar
    challenger = argmax(UCBs);
    @assert challenger != astar "Problem"

    # Compute gap
    ucb = UCBs[challenger] - LCB;
    u = rand(rng);
    if u <= 0.5
        k = astar;
    else
        k = challenger;
    end
    return astar, k, ucb;
end


```
DKM
```

struct DKM
    TrackingRule;
end

long(sr::DKM) = "DKM " * abbrev(sr.TrackingRule);
abbrev(sr::DKM) = "DKM-" * abbrev(sr.TrackingRule);

struct DKMState
    h; # one online learner in total
    t;
    DKMState(TrackingRule, N) = new(AdaHedge(length(N)), TrackingRule(N));
end

function start(sr::DKM, N)
    DKMState(sr.TrackingRule, N);
end

function optimistic_gradient(expfam, hμ, t, N, λs)
    [let ↑ = dup(expfam, hμ[k], log(t)/N[k]),
    ↓ = ddn(expfam, hμ[k], log(t)/N[k])
    max(d(expfam, ↑, λs[k]), d(expfam, ↓, λs[k]), log(t)/N[k])
    end
    for k in eachindex(hμ)];
end

function nextsample(sr::DKMState, pep, astar, aalt, ξ, N, Zs, rng)
    expfam = getexpfam(pep, 1);

    # query the learner
    w = act(sr.h);

    # best response λ-player to w
    _, (k, λs), (_, _) = glrt(pep, w, ξ);

    # optimistic gradient
    ∇ = optimistic_gradient(expfam, ξ, sum(N), N, λs);
    incur!(sr.h, -∇);

    # tracking
    return track(sr.t, N, w);
end

```
AdaTopTwo
```

struct AdaTopTwo
    βtype;
    leader;
    challenger;
    is_tracking;
end

long(sr::AdaTopTwo) = sr.βtype * "-" * sr.leader * "-" * sr.challenger * (sr.is_tracking ? "-T" : "");
abbrev(sr::AdaTopTwo) = sr.βtype * "-" * sr.leader * "-" * sr.challenger * (sr.is_tracking ? "-T" : "");

struct AdaTopTwoState
    βtype;
    leader;
    challenger;
    is_tracking;
    counts;     # pulling counts when leader
    βs;         # cumulative sums of β depending on leader
    AdaTopTwoState(βtype, leader, challenger, is_tracking, K) = new(βtype, leader, challenger, is_tracking, zeros(K), zeros(K));
end

function start(sr::AdaTopTwo, N)
    AdaTopTwoState(sr.βtype, sr.leader, sr.challenger, sr.is_tracking, length(N));
end

function nextsample(ssr::AdaTopTwoState, pep, astar, aalt, ξ, N, Zs, rng)
    K = nanswers(pep, ξ);
    expfam = getexpfam(pep, 1);

    # Leader
    if ssr.leader == "EB"
        B = astar;
    elseif ssr.leader == "TS"
        if typeof(expfam) == Gaussian
            hξ = [rand(rng, Distributions.Normal(ξ[k], 1. / sqrt(N[k]))) for k in 1:K];
        elseif typeof(expfam) == Bernoulli
            hξ = [rand(rng, Distributions.Beta(1 + ξ[k] * N[k], 1 + N[k] - ξ[k] * N[k])) for k in 1:K];
        else
            @error "Undefined Exfam for TS";
        end
        B = argmax(hξ);
    elseif ssr.leader == "UCB"
        t = sum(N);
        UCBs = [dup(expfam, ξ[k], 2 * 1.2 * (1 + 1.2) * log(t) / N[k]) for k in 1:K];
        B = argmax(UCBs);
    elseif ssr.leader == "UCBI"
        t = sum(N);
        UCBs = [dup(expfam, ξ[k], bonus_ucb(t, 1.2, 1.2) / N[k]) for k in 1:K];
        B = argmax(UCBs);
    elseif ssr.leader == "IMED"
        Kinfs = [N[k] * d(expfam, ξ[k], ξ[astar]) for k in 1:K];
        B = argmin(Kinfs .+ log.(N));
    else
        @error "Undefined Adaptive Top Two leader";
    end

    # Challenger
    if ssr.challenger == "TC"
        if B != astar
            # Sampling uniformly at random among the arms with highest mean, since they have null transporation cost
            # Most of the time, this will be a singleton, hence returning astar
            ks = [a for a in 1:K if ξ[a] >= ξ[B] && a != B];
            C = ks[rand(rng, 1:length(ks))];
        else
            # When a_1 == astar, the computations from the stopping rule can be used as they rely on the same transportation costs
            C = aalt;
        end
    elseif ssr.challenger == "TCI"
        if B != astar
            # I need to recompute the TC costs for this arm. It is not enough to sample among arms with higher mean.
            # Indeed under-sampled arms with lower mean can have lower indices compared to over-sampled arms with higher mean. There is a trade-off.
            Zbis = [a != B ? (ξ[a] >= ξ[B] ? 0 :
                    N[B] * d(expfam, ξ[B], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a])) +
                    N[a] * d(expfam, ξ[a], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a]))) : Inf for a in 1:K];
            C = argmin(Zbis .+ log.(N));
        else
            # When a_1 == astar, we can re-use the GLR stopping computations.
            C = argmin(Zs .+ log.(N));
        end
    elseif ssr.challenger == "TCIs"
        # Compared to TCI, we apply the log penalization to a squared root transformation, i.e. x → √2x
        if B != astar
            # I need to recompute the TC costs for this arm. It is not enough to sample among arms with higher mean.
            # Indeed under-sampled arms with lower mean can have lower indices compared to over-sampled arms with higher mean. There is a trade-off.
            Zbis = [a != B ? (ξ[a] >= ξ[B] ? 0 :
                    N[B] * d(expfam, ξ[B], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a])) +
                    N[a] * d(expfam, ξ[a], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a]))) : Inf for a in 1:K];
            C = argmin(sqrt.(2 * Zbis) .+ log.(N));
        else
            # When a_1 == astar, we can re-use the GLR stopping computations.
            C = argmin(sqrt.(2 * Zs) .+ log.(N));
        end
    elseif ssr.challenger == "TCIs+"
        # Compared to TCI, we apply the poly log penalization (with κ = 1.2) to a squared root transformation, i.e. x → √2x
        if B != astar
            # I need to recompute the TC costs for this arm. It is not enough to sample among arms with higher mean.
            # Indeed under-sampled arms with lower mean can have lower indices compared to over-sampled arms with higher mean. There is a trade-off.
            Zbis = [a != B ? (ξ[a] >= ξ[B] ? 0 :
                    N[B] * d(expfam, ξ[B], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a])) +
                    N[a] * d(expfam, ξ[a], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a]))) : Inf for a in 1:K];
            C = argmin(sqrt.(2 * Zbis) .+ log.(N) .^ (1.2 / 2));
        else
            # When a_1 == astar, we can re-use the GLR stopping computations.
            C = argmin(sqrt.(2 * Zs) .+ log.(N) .^ (1.2 / 2));
        end
    elseif ssr.challenger == "TCIsp"
        # Compared to TCIs, we apply a polynomial penalization (with α = 1.2) to a transformation by x → √2x
        if B != astar
            # I need to recompute the TC costs for this arm. It is not enough to sample among arms with higher mean.
            # Indeed under-sampled arms with lower mean can have lower indices compared to over-sampled arms with higher mean. There is a trade-off.
            Zbis = [a != B ? (ξ[a] >= ξ[B] ? 0 :
                    N[B] * d(expfam, ξ[B], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a])) +
                    N[a] * d(expfam, ξ[a], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a]))) : Inf for a in 1:K];
            C = argmin(sqrt.(2 * Zbis) .+ N .^ (1 / (2 * 1.2)));
        else
            # When a_1 == astar, we can re-use the GLR stopping computations.
            C = argmin(sqrt.(2 * Zs) .+  N .^ (1 / (2 * 1.2)));
        end
    elseif ssr.challenger == "TCIsp+"
        # Compared to TCIs, we apply a polynomial penalization (with α = 2) to a transformation by x → √2x
        if B != astar
            # I need to recompute the TC costs for this arm. It is not enough to sample among arms with higher mean.
            # Indeed under-sampled arms with lower mean can have lower indices compared to over-sampled arms with higher mean. There is a trade-off.
            Zbis = [a != B ? (ξ[a] >= ξ[B] ? 0 :
                    N[B] * d(expfam, ξ[B], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a])) +
                    N[a] * d(expfam, ξ[a], (N[B] * ξ[B] + N[a] * ξ[a]) / (N[B] + N[a]))) : Inf for a in 1:K];
            C = argmin(sqrt.(2 * Zbis) .+ N .^ (1 / (2 * 2)));
        else
            # When a_1 == astar, we can re-use the GLR stopping computations.
            C = argmin(sqrt.(2 * Zs) .+  N .^ (1 / (2 * 2)));
        end
    elseif ssr.challenger == "RS"
        ks = [B];
        count = 0;
        while B in ks
            if typeof(expfam) == Gaussian
                hξ = [rand(rng, Distributions.Normal(ξ[a], 1. / sqrt(N[a]))) for a in 1:K];
            elseif typeof(expfam) == Bernoulli
                hξ = [rand(rng, Distributions.Beta(1 + ξ[a] * N[a], 1 + N[a] - ξ[a] * N[a])) for a in 1:K];
            end
            max_hξ = maximum(hξ);
            ks = [a for a in 1:K if hξ[a] == max_hξ];
            count += 1;

            if count > 1e7
                @warn "RS challenger is taking too much time, hence sample uniformly.";
                return 1 + (sum(N) % length(N));
            end
        end
        C = ks[rand(rng, 1:length(ks))];
    else
        @error "Undefined Adaptive Top Two challenger";
    end

    # β choice
    if ssr.βtype == "cst"
        β = 0.5;
    elseif ssr.βtype == "BC"
        if ξ[B] > ξ[C]
            altλ = (N[B] * ξ[B] + N[C] * ξ[C]) / (N[B] + N[C]);
            ratio = N[C] * d(expfam, ξ[C], altλ) / (N[B] * d(expfam, ξ[B], altλ));
            β = 1 / (1 + ratio);
        else
            β = 0.5;
        end
    else
        @error "Undefined Adaptive Top Two β choice";
    end

    if ssr.is_tracking
        ssr.βs[B] += β;
        if ssr.counts[B] <= ssr.βs[B]
            k = B;
            ssr.counts[B] += 1;
        else
            k = C;
        end
    else
        u = rand(rng);
        if u <= β
            k = B;
        else
            k = C;
        end
    end

    return k;
end



```
Frank-Wolfe based Sampling
```

struct FWSampling
    TrackingRule;
end

long(sr::FWSampling) = "FW-Sampling " * abbrev(sr.TrackingRule);
abbrev(sr::FWSampling) = "FWS-" * abbrev(sr.TrackingRule);

mutable struct FWSamplingState
    x;
    t;
    FWSamplingState(TrackingRule, N) = new(ones(length(N)) / length(N), TrackingRule(N));
end

function start(sr::FWSampling, N)
    FWSamplingState(sr.TrackingRule, N);
end

# Computing f and ∇f for FWSampling
function compute_f_∇f_bai(expfam, hw, ξ, astar, r, K)
    # Alternative parameters
    λs = [(ξ[astar] * hw[astar] + ξ[k] * hw[k]) / (hw[astar] + hw[k]) for k=1:K] ;
    suboptimal = [k for k=1:K if k!=astar];

    # construct ∇f
    ∇f = [[0.0 for j=1:K] for i=1:K];
    for k in suboptimal
        ∇f[k][astar] = d(expfam, ξ[astar], λs[k]);
        ∇f[k][k] = d(expfam, ξ[k], λs[k]);
    end

    # construct f
    f = [hw'∇f[k] for k in suboptimal];
    fmin = minimum(f);
    if r > eps()
        fidx = [j for (idxj,j) in enumerate(suboptimal) if (f[idxj]<fmin+r)]
    elseif abs(r)<eps()
        fidx = [suboptimal[argmin(f)]];
    else
        fidx = suboptimal;
    end
    return f, ∇f, fidx;
end

function nextsample(ssr::FWSamplingState, pep, astar, aalt, ξ, N, Zs, rng)
    expfam = getexpfam(pep, 1);
    K, t = length(N), sum(N);
    r = t^(-9.0/10)/K;

    z = zeros(K);
    if !hμ_in_lambda(ξ, astar, K) || is_complete_square(floor(Int, t/K))
        z = ones(K) / K;
    else
        f, ∇f, fidx = compute_f_∇f_bai(expfam, ssr.x, ξ, astar, r, K);
        if length(fidx) == 1 # best challenger
            challenger_idx = argmax(∇f[fidx[1]]);
            z = [(challenger_idx==j) ? 1 : 0 for j=1:K];
        else # solve LP of the zero-sum matrix game
            Σ = [[(i==j) ? 1 : 0 for j=1:K] - ssr.x for i=1:K];
            A = [[Σ[i]'∇f[j] for i=1:K] for j in fidx]; # construct payoff matrix
            z = solveZeroSumGame(A, K, length(fidx));
        end
    end
    setfield!(ssr, :x, ssr.x*((t-1.0)/t) + z*1.0/t);
    return track(ssr.t, N, ssr.x);
end



```
UGapEc
```

struct UGapEc
end

long(sr::UGapEc) = "UGapEc";
abbrev(sr::UGapEc) = "UGapEc";

function start(sr::UGapEc, N)
    sr;
end

function nextsample(sr::UGapEc, pep, ξ, N, β, rng)
    K = nanswers(pep, ξ);
    expfam = getexpfam(pep, 1);
    astar = argmax(ξ);

    # UCB and LCB
    UCBs = zeros(K);
    LCBs = zeros(K);
    for a in 1:K
        UCBs[a] = dup(expfam, ξ[a], β / N[a]);
        LCBs[a] = ddn(expfam, ξ[a], β / N[a]);
    end

    # Compute gaps
    gaps = zeros(K);
    for a in 1:K
        idx = [k for k in 1:K if k != a];
        gaps[a] = maximum(UCBs[idx]) - LCBs[a];
    end

    # Compute leader
    gap = minimum(gaps);
    B = argmin(gaps);

    # Compute challenger
    UCBs[B] = 0;
    C = argmax(UCBs);

    # Choose arm to sample from
    k = (N[B] < N[C]) ? B : C;
    return astar, k, gap;
end
