# Top2 algorithms and their helper functions

using Random, Distributions
include("KL.jl");
include("boundeddists.jl")

# Compute GLRT stopping statistic
function GLRT_Stop(index_values, δ, t; beta = -1)
    glr = minimum(index_values);
    K = length(index_values);
    thresh = beta == -1 ? (log((K-1)/δ) + 4log(log(t+1)))/t : beta(t,0,δ)/t;
    st = glr >= thresh
    return st
end

# Compute Arm Indexes (KLinf and optimization over x)
function arm_indexes(samples, wt, dist)
    mu = mean.(samples);
    K = length(mu);
    res = zeros(length(mu));
    
    if dist == "Gaussian" || dist == "Bernoulli"
        res = arm_indexes_exp(mu, wt, dist);
    else
        a_star = argmax(mu);        
        xs = zeros(K);
        for a in 1:K
            if mu[a] < mu[a_star]
                xs[a] = alt_λ(samples[a_star], mu[a_star], wt[a_star], samples[a], mu[a], wt[a], 1, false); # infimized in index 
                res[a] = wt[a_star] * Kinf_emp(false, samples[a_star], mu[a_star], xs[a], 1, false) + wt[a] * Kinf_emp(false, samples[a], mu[a], xs[a], 1, true); # index for arm a
            elseif a != a_star
                xs[a] = mu[a];
                res[a] = 0;
            end
        end
    end
    return res
end

function arm_indexes_exp(mu, wt, dist)
    res = zeros(length(mu));
    a_star = argmax(mu);
    Delta = mu[a_star] .- mu
    x = ( wt[a_star]*mu[a_star] .+ wt .* mu ) ./ (wt[a_star] .+ wt);
    
    d = dGaussian;
    if dist == "Gaussian"
        d = dGaussian; 
    else 
        if dist == "Bernoulli"
            d = dBernoulli;
        end
    end
    
    res = wt[a_star]*d.(mu[a_star], x) .+ wt .* d.(mu, x)
    return res
end

# Compute the ratio of KLinf - 1
function anchor(samples, wt, dist)
    mu = mean.(samples);
    K = length(mu);
    res = 0; 

    if dist == "Gaussian" || dist == "Bernoulli"
        res = anchor_exp(mu, wt, dist);
    else
        a_star = argmax(mu);    
        xs = zeros(K);
        kl_ratio = ones(K);
        for a in 1:K
            if mu[a] < mu[a_star]
                xs[a] = alt_λ(samples[a_star], mu[a_star], wt[a_star], samples[a], mu[a], wt[a], 1, false); # infimized in index 
                kl_ratio[a] = Kinf_emp(false, samples[a_star], mu[a_star], xs[a], 1, false)/Kinf_emp(false, samples[a], mu[a], xs[a], 1, true);
            elseif a != a_star
                xs[a] = mu[a];
                kl_ratio[a] = 1;
            end
        end 
        kl_ratio[a_star] = 0; 
        res = sum(kl_ratio[.!isnan.(kl_ratio)]) - 1;
    end
    
    return res
end


function anchor_exp(mu, wt, dist)
    res = 0; 
    a_star = argmax(mu);
    Delta = mu[a_star] .- mu
    x = ( wt[a_star]*mu[a_star] .+ wt .* mu ) ./ (wt[a_star] .+ wt);
    
    d = dGaussian;
    
    if dist == "Gaussian"
        d = dGaussian; 
    else 
        if dist == "Bernoulli"
            d = dBernoulli;
        end
    end
    res_old = sum(wt.^2 ./ (wt[a_star])^2) - 2;
    ratio = d.(mu[a_star], x) ./ d.(mu,x);
    ratio[a_star] = 0;
    res = sum(ratio[.!isnan.(ratio)]) - 1;

    return res
end


# Different top2 algorithms for Gaussian, Bernoulli, and bounded-support distributions: 
# TCB, AT2, BetaEBTCB, and their I versions
# α is the forced-exploration parameter
# α-TCB is the exponent for r in TCB guess index

function non_fluid_top2_algorithms_np(MAB, δ, T, sd, Alg; check_stop=false, b=-1, InitialSamples = zeros(length(MAB)), β = 0.5, α = 0.5, αTCB = 1, dist = "Gaussian")
    
    K = length(MAB);

    # collect return values: 
    # 1 - number of samples till time t to each arm, 
    # 2 - index for each arm at time t, 
    # 3 - sum_ratio - 1 at time t.

    rngs  = MersenneTwister.(rand(MersenneTwister(sd), UInt64, K+1));

    R = Tuple{Vector{Vector{Int64}}, Vector{Vector{Float64}}, Vector{Float64}, Int64, Float64}[];
    NSamples = Vector{Vector{Int64}}();
    index_t = Vector{Vector{Float64}}();
    opt_cond_t = Vector{Float64}(); # ratio_sum - 1
    xs = [Array{Float64,1}() for k in 1:K] # all samples
    wt = zeros(K);
    obj= zeros(K);
    
    # Pull each arm once or NSamples times
    for a in 1:K
        # Pull arm a NSamples[a] times
        dummy = [-Inf for k in 1:K];

        # push dummy index for each arm for analysis
        push!(index_t, dummy);
        push!(opt_cond_t, Inf);
        
        init_samples = max(1,InitialSamples[a]);
        
        for i in 1:init_samples
            push!(xs[a], rand(rngs[a], MAB[a])); 
        end
        push!(NSamples, [length.(xs) ...]);
    end

    tot_init_samples = sum(length.(xs));
    
    tbig = time()
    while true
        t = sum(length.(xs))+1; # total number of samples generated + 1
        t_eff = t - tot_init_samples + K;
        
        # Compute index for each arm 
        mu = mean.(xs);
        a_star = argmax(mu);
        Delta = mu[a_star] .- mu
        wt = NSamples[t_eff-1]./sum(NSamples[t_eff-1]);
        wt_guess = wt .- (1/(K-1))*(1/(t_eff-1)^αTCB);
        wt_guess[a_star] = wt[a_star]+1/(t_eff-1)^αTCB;
            
        obj = arm_indexes(xs, wt, dist);        
        obj_guess = arm_indexes(xs, wt_guess, dist);

        obj[a_star] = Inf; # o/w index for optimal arm will be minimum (0)
        obj_guess[a_star] = Inf; # o/w index for optimal arm will be minimum (0)
        
        sum_ratio = anchor(xs, wt, dist); 
        
        # Compute the condition for sampling best-arm
        cond = -1;
        if Alg == "AT2" || Alg == "AT2I"
            cond = sum_ratio; 
        else 
            if Alg == "BetaEBTCB" || Alg == "BetaEBTCBI"
                cond = β-rand(rngs[K+1],Uniform(0,1));
            else
                cond = minimum(obj_guess) - minimum(obj);
            end
        end
        
        # Check for stopping condition
        st = (check_stop && GLRT_Stop(obj, δ, t-1; beta = b)) || (T ≤ t ) #&& !check_stop)
        
        if st
            t_run = (time()-tbig)*100000; # Time in microseconds
            # println(glr," ", thresh)
            push!(R, (NSamples, index_t, opt_cond_t, a_star, t_run));
            return R
        end
        
        # push index for each arm for analysis
        push!(index_t, obj);
        # push!(opt_cond_t, cond); # --  original code
        push!(opt_cond_t, sum_ratio);

        # check for starvation
        min_samples = floor((t)^α) + 1;
        exp_vec = min_samples .- NSamples[t_eff-1];
        starved = argmax(exp_vec)
        I = if exp_vec[starved] > 0
            # some arm is starved, sample it
            starved
        else
            # No arm is starved
            # Implement the top 2 algorithm to decide which arm to sample
            # Decide which arm to sample
            I = if cond > 0
                a_star
            else
                # return a sub-optimal arm to sample
                obj_l = obj;
                
                if Alg == "AT2I" || Alg == "TCBI" || Alg == "BetaEBTCBI"
                    obj_l = obj_l .+ (log.(NSamples[t_eff-1]) ./(t_eff - 1));
                    # println("$Alg, Optimistic index at $t.")
                end

                minindex = argmin(obj_l);
                minindex
            end
            I
        end
        
        # Pull arm I
        push!(xs[I], rand(rngs[I], MAB[I])); 
        push!(NSamples, [length.(xs) ...]);
    end
end