using Random;
using CPUTime;

include("peps.jl");
include("../expfam.jl");
include("samplingrules.jl");
include("reco_stop_pairs.jl");

# Run the learning algorithm, paramterised by a sampling rule
# The stopping and recommendation rules are common
#
# βs must be a list of thresholds *in increasing order*
# For non-anytime algorithms, βs should be a singleton

function runit(seed, is, pep, μs, δs, Tau_max, history, freqHist)
    sr, rsp = is;
    gap_sr = typeof(sr) == LUCB || typeof(sr) == UGapEc || typeof(sr) == LUCBhalf;
    elim_sr = typeof(sr) == DSR || typeof(sr) == DSH;
    K = nanswers(pep, μs);

    # Get thresholds
    βs = get_threshold(rsp.threshold, δs, 2, K, 2);

    rng = MersenneTwister(seed);

    N = zeros(Int64, K);        # counts
    S = zeros(K);               # sum of samples
    recommendations = Int64[];
    Xs = [[] for k in 1:K];

    baseline = CPUtime_us();

    # pull each arm once
    for k in 1:K
        _X = sample(rng, getexpfam(pep, k), μs[k]);
        S[k] += _X;
        N[k] += 1;
        append!(Xs[k], [_X]);
        if history == "partial" && k >= freqHist * length(recommendations)
            push!(recommendations, Int64(rand(rng, 1:K)));
        end
    end

    state = start(sr, N);
    R = Tuple{Int64, Array{Int64,1}, UInt64, Array{Int64,1}}[]; # collect return values

    while true
        t = sum(N);

        # emp. estimates
        hμ = S ./ N;

        if gap_sr
            # invoke sampling rule
            astar, k, ucb = nextsample(state, pep, hμ, N, βs[1](t), rng);
            if typeof(sr) == LUCB
                ks = [astar, k];
            elseif typeof(sr) == UGapEc || typeof(sr) == LUCBhalf
                ks = [k];
            end

            if history == "partial" && (t >= freqHist * length(recommendations) || t >= Tau_max)
                push!(recommendations, Int64(astar));
            end

            # test stopping criterion
            while ucb <= pep.ϵ
                popfirst!(βs);
                push!(R, (astar, copy(N), CPUtime_us() - baseline, copy(recommendations)));
                if isempty(βs)
                    return R;
                end
            end
        elseif elim_sr
            ks, astar = nextsample(state, pep, S, N);

            if history == "partial" && (t >= freqHist * length(recommendations) || t >= Tau_max)
                push!(recommendations, Int64(astar));
            end

            # This algorithm does not stop because it is an anytime algo
            # We could add on top some GLR stopping as well.
        else
            if typeof(rsp) == GLRT
                # test stopping criterion
                Zs, (aalt, _), (astar, ξ) = glrt(pep, N, hμ);
            else
                @error "Recommendation stopping pair undefined"
            end

            if history == "partial" && (t >= freqHist * length(recommendations) || t >= Tau_max)
                push!(recommendations, Int64(astar));
            end

            while stopping_criterion(Zs, βs[1], N, astar, aalt)
                popfirst!(βs);
                push!(R, (astar, copy(N), CPUtime_us() - baseline, copy(recommendations)));
                if isempty(βs)
                    return R;
                end
            end
            k = nextsample(state, pep, rsp, astar, aalt, ξ, N, Zs, rng);
            ks = [k];
        end

        # invoke sampling rule
        for k in ks
            _X = sample(rng, getexpfam(pep, k), μs[k]);
            S[k] += _X;
            N[k] += 1;
            append!(Xs[k], [_X]);
            t += 1;
        end

        if t > Tau_max
            #@warn "Finite-time horizon Tau_max = $(Tau_max) met without stopping for " * abbrev(sr) * ". Increase this if hard problem.";
            push!(R, (astar, copy(N), CPUtime_us() - baseline, copy(recommendations)));
            return R;
        end
    end
end

function stopping_criterion(Zs, β, N, astar, aalt)
    if abbrev(β) == "FTT"
        K = length(N);
        stop = true;
        for a in 1:K
            if a != astar
                cdt = Zs[a] > β(N[astar], N[a]);
                stop = stop && cdt;
            end
        end
    else
        t = sum(N);
        stop = Zs[aalt] > β(t);
    end
    return stop;
end
