# If not launching julia with multiple threads, remove @everywhere command from the below lines
@everywhere include("../Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

using Distributed;
using DataFrames;

typeDistribution = "Bdd"
niter = 32;  # Number of simulation runs to perform

if length(ARGS) > 1
    typeDistribution= ARGS[1];
    niter = parse.(Int64, ARGS[2]); # niter to be around 5000 for these experiments
end

output = "Experiment_7/output/"

# Create default MAB instance
arm1 = Beta(1.5,1);
arm2 = Beta(2,6);
arm3 = Beta(1,1.5);
arm4 = Beta(1,7);

global MAB = (arm1, arm2, arm3, arm4);
println(MAB);

mu = mean.(MAB);
mu = [m for m in mu]
K = length(mu);
best = argmax(mu);

# Set problem parameters
α_local = 0.05
δs = [0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001, 0.0000001, 0.00000001, 0.000000001, 0.0000000001]
seed = 1123;
β = 0.5

# Create table with data to store. 
# Define columns
δVec = Vector{Float64}();
Alg = Vector{String}();
MeanST = Vector{Float64}();
Sd_ST = Vector{Float64}();
RunTime = Vector{Float64}();
Sd_RT = Vector{Float64}();

for δ in δs
    
    T = 10000; # Max time for which the dynamics should evolve

    # BETA-EB-TCB

    data_BetaEBTCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop = true, b=beta, β = β, α = α_local, dist=typeDistribution), 1:niter);
    m = mean(length.(getindex.(getindex.(data_BetaEBTCB,1),1)));
    s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCB,1),1)))/niter);
    run_time_l = getindex.(getindex.(data_BetaEBTCB,1),5);
    mean_rt_l = mean(run_time_l);
    std_rt_l = sqrt(var(run_time_l)/niter);

    push!(Alg, string(β,"-EB-TCB"));
    push!(MeanST, round(m, digits = 2));
    push!(Sd_ST, round(s, digits = 2));
    push!(RunTime, round(mean_rt_l, digits = 2));
    push!(Sd_RT, round(std_rt_l, digits = 2));
    push!(δVec, δ);

    # AT2
    println("\n ~~~ AT2 ~~~")
    data_AT2 = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
    stop_time_AT2 = length.(getindex.(getindex.(data_AT2,1),1));
    mean_st_AT2 = mean(stop_time_AT2);
    std_st_AT2 = sqrt(var(stop_time_AT2)/niter);
    estimated_BA_AT2 = getindex.(getindex.(data_AT2,1),4);
    error_frac_AT2 = sum(estimated_BA_AT2 .!= best)/T;
    run_time_AT2 = getindex.(getindex.(data_AT2,1),5);
    mean_rt_AT2 = mean(run_time_AT2);
    std_rt_AT2 = sqrt(var(run_time_AT2)/niter);

    push!(Alg, string("AT2"));
    push!(MeanST, round(mean_st_AT2, digits=2));
    push!(Sd_ST, round(std_st_AT2, digits = 2));
    push!(RunTime, round(mean_rt_AT2, digits = 2));
    push!(Sd_RT, round(std_rt_AT2, digits=2));
    push!(δVec, δ);

    # TCB
    println("\n ~~~ TCB ~~~")
    data_TCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
    mean_st_TCB = mean(length.(getindex.(getindex.(data_TCB,1),1)));
    std_st_TCB = sqrt(var(length.(getindex.(getindex.(data_TCB,1),1)))/niter);
    estimated_BA_TCB = getindex.(getindex.(data_TCB,1),4);
    error_frac_TCB = sum(estimated_BA_TCB .!= best)/T;
    run_time_TCB = getindex.(getindex.(data_TCB,1),5);
    mean_rt_TCB = mean(run_time_TCB);
    std_rt_TCB = sqrt(var(run_time_TCB)/niter);

    push!(Alg, "TCB");
    push!(MeanST, round(mean_st_TCB, digits=2));
    push!(Sd_ST, round(std_st_TCB, digits=2));
    push!(RunTime, round(mean_rt_TCB, digits =2 ));
    push!(Sd_RT, round(std_rt_TCB, digits = 2));
    push!(δVec, δ);
    
end

################## PRINT DATA ##################
df = DataFrame(;AlgoName=Alg, AvgST = MeanST, StdDevST = Sd_ST, RunTime=RunTime, StdDevRT = Sd_RT, δ=δVec);
print(df)
println("");

AT2data = df[df.AlgoName .== "AT2", :];
AT2sc = AT2data[!,2];
AT2std = AT2data[!,3];

TCBdata = df[df.AlgoName .== "TCB", :];
TCBsc = TCBdata[!,2];
TCBstd = TCBdata[!,3];

BetaTCBdata = df[df.AlgoName .== "0.5-EB-TCB", :];
BetaTCBsc = BetaTCBdata[!,2];
BetaTCBstd = BetaTCBdata[!,3];

using Plots;
colors = [ :red, :green,   :purple, :blue, :orange, :black, :brown, :yellow, :cyan, :pink];
colors2 = [:pink, :orange, :blue]

plot(title = "Sample Complexity", ylabel="Number of Samples", xlabel = "1/delta", legend=:topleft);
plot!(1 ./ [δs...], AT2sc, ribbon =(2 .* AT2std,2 .* AT2std), fillalpha= 0.05, color=colors[1], line=(:solid, 1), marker=(:star6,4), label = string("AT2 "), xaxis=:log);
plot!(1 ./ [δs...], TCBsc, ribbon =(2 .* TCBstd,2 .* TCBstd), fillalpha= 0.05, color=colors[2], line=(:dashdot, 1), marker=(:diamond,4), label = string("TCB"), xaxis=:log);
p_SC = plot!(1 ./ [δs...], BetaTCBsc, ribbon =(2 .* BetaTCBstd,2 .* BetaTCBstd), fillalpha= 0.05, color=colors[3], marker=(:utriangle,4), line=(:dash, .5), label = string("0.5-EB-TCB"), xaxis=:log);

display(p_SC);
savefig(p_SC, string(output,"SC_Function_delta_",typeDistribution,"_",niter,".pdf"));