# If not launching julia with multiple threads, remove @everywhere command from the below lines
@everywhere include("../Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

using Distributed;
using DataFrames;

typeDistribution = "Gaussian"
niter = 64;  # Number of simulation runs to perform

if length(ARGS) > 1
    typeDistribution= ARGS[1];
    niter = parse.(Int64, ARGS[2]); # niter to be around 5000 for these experiments
end

output = "Experiment_4/output/"

# Create default MAB instance
mu = [10, 8]
arm1 = Normal(mu[1],1);
arm2 = Normal(mu[2],1);


global MAB = (arm1, );

# Set problem parameters
α_local = 0.05
δ = 0.001
seed = 1123;
β = 0.5
narms = [2, 5, 10, 20, 30, 40, 60, 80]

# Create table with data to store. 
# Define columns
narmsvec = Vector{Float64}();
Alg = Vector{String}();
MeanST = Vector{Float64}();
Sd_ST = Vector{Float64}();
RunTime = Vector{Float64}();
Sd_RT = Vector{Float64}();

for K in narms
    
    while length(MAB) < K
        global MAB = (MAB ..., arm2);
    end
    mu = mean.(MAB);
    mu = [m for m in mu]
    best = argmax(mu);
    @assert length(MAB) == K
    println(MAB);

    T = 1000000; # Max time for which the dynamics should evolve

    # BETA-EB-TCB

    data_BetaEBTCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop = true, b=beta, β = β, α = α_local, dist=typeDistribution), 1:niter);
    m = mean(length.(getindex.(getindex.(data_BetaEBTCB,1),1)));
    s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCB,1),1)))/niter);
    run_time_l = getindex.(getindex.(data_BetaEBTCB,1),5);
    mean_rt_l = mean(run_time_l);
    std_rt_l = sqrt(var(run_time_l)/niter);

    push!(Alg, string(β,"-EB-TCB"));
    push!(MeanST, round(m, digits = 2));
    push!(Sd_ST, round(s, digits = 2));
    push!(RunTime, round(mean_rt_l, digits = 2));
    push!(Sd_RT, round(std_rt_l, digits = 2));
    push!(narmsvec, K);

    # AT2
    println("\n ~~~ AT2 ~~~")
    data_AT2 = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
    stop_time_AT2 = length.(getindex.(getindex.(data_AT2,1),1));
    mean_st_AT2 = mean(stop_time_AT2);
    std_st_AT2 = sqrt(var(stop_time_AT2)/niter);
    estimated_BA_AT2 = getindex.(getindex.(data_AT2,1),4);
    error_frac_AT2 = sum(estimated_BA_AT2 .!= best)/T;
    run_time_AT2 = getindex.(getindex.(data_AT2,1),5);
    mean_rt_AT2 = mean(run_time_AT2);
    std_rt_AT2 = sqrt(var(run_time_AT2)/niter);

    push!(Alg, string("AT2"));
    push!(MeanST, round(mean_st_AT2, digits=2));
    push!(Sd_ST, round(std_st_AT2, digits = 2));
    push!(RunTime, round(mean_rt_AT2, digits = 2));
    push!(Sd_RT, round(std_rt_AT2, digits=2));
    push!(narmsvec, K);

    # TCB
    println("\n ~~~ TCB ~~~")
    data_TCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
    mean_st_TCB = mean(length.(getindex.(getindex.(data_TCB,1),1)));
    std_st_TCB = sqrt(var(length.(getindex.(getindex.(data_TCB,1),1)))/niter);
    estimated_BA_TCB = getindex.(getindex.(data_TCB,1),4);
    error_frac_TCB = sum(estimated_BA_TCB .!= best)/T;
    run_time_TCB = getindex.(getindex.(data_TCB,1),5);
    mean_rt_TCB = mean(run_time_TCB);
    std_rt_TCB = sqrt(var(run_time_TCB)/niter);

    push!(Alg, "TCB");
    push!(MeanST, round(mean_st_TCB, digits=2));
    push!(Sd_ST, round(std_st_TCB, digits=2));
    push!(RunTime, round(mean_rt_TCB, digits =2 ));
    push!(Sd_RT, round(std_rt_TCB, digits = 2));
    push!(narmsvec, K);
    
end

################## PRINT DATA ##################
df = DataFrame(;AlgoName=Alg, AvgST = MeanST, StdDevST = Sd_ST, RunTime=RunTime, StdDevRT = Sd_RT, K=narmsvec);
print(df)
println("");

AT2data = df[df.AlgoName .== "AT2", :];
AT2sc = AT2data[!,2];
AT2std = AT2data[!,3];

TCBdata = df[df.AlgoName .== "TCB", :];
TCBsc = TCBdata[!,2];
TCBstd = TCBdata[!,3];

BetaTCBdata = df[df.AlgoName .== "0.5-EB-TCB", :];
BetaTCBsc = BetaTCBdata[!,2];
BetaTCBstd = BetaTCBdata[!,3];

using Plots;
colors = [ :red, :green,   :purple, :blue, :orange, :black, :brown, :yellow, :cyan, :pink];
colors2 = [:pink, :orange, :blue]

plot(title = "Sample Complexity", ylabel="Number of Samples", xlabel = "K (number of arms)", legend=:topleft);
plot!([narms ...], AT2sc, ribbon =(2 .* AT2std,2 .* AT2std), fillalpha= 0.05, color=colors[1], line=(:solid, 1), marker=(:star6,4), label = string("AT2 "));
plot!([narms ...], TCBsc, ribbon =(2 .* TCBstd,2 .* TCBstd), fillalpha= 0.05, color=colors[2], line=(:dashdot, 1), marker=(:diamond,4), label = string("TCB"));
p_SC = plot!([narms ...], BetaTCBsc, ribbon =(2 .* BetaTCBstd,2 .* BetaTCBstd), fillalpha= 0.05, color=colors[3], marker=(:utriangle,4), line=(:dash, .5), label = string("0.5-EB-TCB"));

display(p_SC);
savefig(p_SC, string(output,"SC_Function_K_",typeDistribution,"_",niter,".pdf"));