# If not launching julia with multiple threads, remove @everywhere command from the below lines
@everywhere include("../Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

using Distributed;
using DataFrames;

typeDistribution = "Gaussian"
niter = 32;  # Number of simulation runs to perform

if length(ARGS) > 1
    typeDistribution= ARGS[1];
    niter = parse.(Int64, ARGS[2]); # niter to be around 5000 for these experiments
end

output = "Experiment_2/output/"

# Create default MAB instance
mu = [7.25, 7.05, 7, 7.1]
arm1 = Normal(mu[1],1);
arm2 = Normal(mu[2],1);
arm3 = Normal(mu[3],1);
arm4 = Normal(mu[4],1);


if typeDistribution=="Gaussian"
    arm1 = Normal(mu[1],1);
    arm2 = Normal(mu[2],1);
    arm3 = Normal(mu[3],1);
    arm4 = Normal(mu[4],1);
end
    
if typeDistribution=="Bernoulli"
    scale = maximum(mu)*1.01;
    mu = round.(mu./scale, digits = 3);
    arm1 = Bernoulli(mu[1]);
    arm2 = Bernoulli(mu[2]);
    arm3 = Bernoulli(mu[3]);
    arm4 = Bernoulli(mu[4]);
end

global MAB = (arm1, arm2, arm3, arm4);
println(MAB);

K = length(mu);
mu = mean.(MAB);
mu = [m for m in mu]
best = argmax(mu);

# Set problem parameters
α_local = 0.05
δ = 0.001;
seed = 1123;
βs = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

println("mu=$(mu), δ=$δ");

# Create table with data to store. 
# Define columns
Alg = Vector{String}();
MeanST = Vector{Float64}();
Sd_ST = Vector{Float64}();
RunTime = Vector{Float64}();
Sd_RT = Vector{Float64}();

T = 10000; # Max time for which the dynamics should evolve

# BETA-EB-TCB
mean_st_BetaEBTCB = Vector{Float64}();
std_st_BetaEBTCB = Vector{Float64}();

for a in 1:length(βs)
    println("\n ~~~", βs[a], "-EB-TCB ~~~");
    data_BetaEBTCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop = true, b=beta, β = βs[a], α = α_local, dist=typeDistribution), 1:niter);
    m = mean(length.(getindex.(getindex.(data_BetaEBTCB,1),1)));
    s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCB,1),1)))/niter);
    run_time_l = getindex.(getindex.(data_BetaEBTCB,1),5);
    mean_rt_l = mean(run_time_l);
    std_rt_l = sqrt(var(run_time_l)/niter);

    push!(mean_st_BetaEBTCB, round(m, digits=2));
    push!(std_st_BetaEBTCB, round(s, digits = 2));
    push!(Alg, string(βs[a],"-EB-TCB"));
    push!(MeanST, round(m, digits = 2));
    push!(Sd_ST, round(s, digits = 2));
    push!(RunTime, round(mean_rt_l, digits = 2));
    push!(Sd_RT, round(std_rt_l, digits = 2));
end

# BETA-EB-TCBI
mean_st_BetaEBTCBI = Vector{Float64}();
std_st_BetaEBTCBI = Vector{Float64}();

for a in 1:length(βs)
    println("\n ~~~", βs[a], "-EB-TCBI ~~~");
    data_BetaEBTCBI = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCBI"; check_stop = true, b=beta, β = βs[a], α = α_local, dist=typeDistribution), 1:niter);
    m = mean(length.(getindex.(getindex.(data_BetaEBTCBI,1),1)));
    s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCBI,1),1)))/niter);
    run_time_l = getindex.(getindex.(data_BetaEBTCBI,1),5);
    mean_rt_l = mean(run_time_l);
    std_rt_l = sqrt(var(run_time_l)/niter);

    push!(mean_st_BetaEBTCBI, round(m, digits = 2));
    push!(std_st_BetaEBTCBI, round(s, digits = 2));
    push!(Alg, string(βs[a],"-EB-TCBI"));
    push!(MeanST, round(m, digits = 2));
    push!(Sd_ST, round(s, digits = 2));
    push!(RunTime, round(mean_rt_l, digits=2));
    push!(Sd_RT, round(std_rt_l, digits = 2));
end

# AT2
println("\n ~~~ AT2 ~~~")
data_AT2 = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
stop_time_AT2 = length.(getindex.(getindex.(data_AT2,1),1));
mean_st_AT2_local = mean(stop_time_AT2);
std_st_AT2_local = sqrt(var(stop_time_AT2)/niter);
estimated_BA_AT2 = getindex.(getindex.(data_AT2,1),4);
error_frac_AT2 = sum(estimated_BA_AT2 .!= best)/T;
run_time_AT2 = getindex.(getindex.(data_AT2,1),5);
mean_rt_AT2 = mean(run_time_AT2);
std_rt_AT2 = sqrt(var(run_time_AT2)/niter);

index_AT2 = Vector{Vector{Float64}}();
for l = 1:K
    push!(index_AT2, getindex.(getindex.(getindex.(data_AT2,1),2)[1],l));
end
push!(Alg, string("AT2"));
push!(MeanST, round(mean_st_AT2_local, digits=2));
push!(Sd_ST, round(std_st_AT2_local, digits = 2));
push!(RunTime, round(mean_rt_AT2, digits = 2));
push!(Sd_RT, round(std_rt_AT2, digits=2));

# IAT2
println("\n ~~~ IAT2 ~~~")
data_AT2I =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2I"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
mean_st_AT2I_local = mean(length.(getindex.(getindex.(data_AT2I,1),1)));
std_st_AT2I_local = sqrt(var(length.(getindex.(getindex.(data_AT2I,1),1)))/niter);
estimated_BA_AT2I = getindex.(getindex.(data_AT2I,1),4);
error_frac_AT2I = sum(estimated_BA_AT2I .!= best)/T;
run_time_AT2I = getindex.(getindex.(data_AT2I,1),5);
mean_rt_AT2I = mean(run_time_AT2I);
std_rt_AT2I = sqrt(var(run_time_AT2I)/niter);

push!(Alg, "AT2I");
push!(MeanST, round(mean_st_AT2I_local, digits = 2));
push!(Sd_ST, round(std_st_AT2I_local, digits=2));
push!(RunTime, round(mean_rt_AT2I, digits =2));
push!(Sd_RT, round(std_rt_AT2I, digits=2));

# TCB
println("\n ~~~ TCB ~~~")
data_TCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
mean_st_TCB_local = mean(length.(getindex.(getindex.(data_TCB,1),1)));
std_st_TCB_local = sqrt(var(length.(getindex.(getindex.(data_TCB,1),1)))/niter);
estimated_BA_TCB = getindex.(getindex.(data_TCB,1),4);
error_frac_TCB = sum(estimated_BA_TCB .!= best)/T;
run_time_TCB = getindex.(getindex.(data_TCB,1),5);
mean_rt_TCB = mean(run_time_TCB);
std_rt_TCB = sqrt(var(run_time_TCB)/niter);

push!(Alg, "TCB");
push!(MeanST, round(mean_st_TCB_local, digits=2));
push!(Sd_ST, round(std_st_TCB_local, digits=2));
push!(RunTime, round(mean_rt_TCB, digits =2 ));
push!(Sd_RT, round(std_rt_TCB, digits = 2));


# TCBI
println("\n ~~~ TCBI ~~~")
data_TCBI =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCBI"; check_stop = true, b=beta, α = α_local, dist=typeDistribution), 1:niter);
mean_st_TCBI_local = mean(length.(getindex.(getindex.(data_TCBI,1),1)));
std_st_TCBI_local = sqrt(var(length.(getindex.(getindex.(data_TCBI,1),1)))/niter);
estimated_BA_TCBI = getindex.(getindex.(data_TCBI,1),4);
error_frac_TCBI = sum(estimated_BA_TCBI .!= best)/T;
run_time_TCBI = getindex.(getindex.(data_TCBI,1),5);
mean_rt_TCBI = mean(run_time_TCBI);
std_rt_TCBI = sqrt(var(run_time_TCBI)/niter);

push!(Alg, "TCBI");
push!(MeanST, round(mean_st_TCBI_local, digits =2 ));
push!(Sd_ST, round(std_st_TCBI_local, digits = 2));
push!(RunTime, round(mean_rt_TCBI, digits =2 ));
push!(Sd_RT, round(std_rt_TCBI, digits = 2));


################## PRINT DATA ##################
df = DataFrame(;AlgoName=Alg, AvgST = MeanST, StdDevST = Sd_ST, RunTime=RunTime, StdDevRT = Sd_RT);
print(df)
println("");

#### Fill data for AT2 AT2I TCB TCBI for all for βs #### 
mean_st_AT2 = Vector{Float64}();
mean_st_AT2I = Vector{Float64}();
mean_st_TCB = Vector{Float64}();
mean_st_TCBI = Vector{Float64}();
std_st_AT2 = Vector{Float64}();
std_st_AT2I = Vector{Float64}();
std_st_TCB = Vector{Float64}();
std_st_TCBI = Vector{Float64}();

for a in 1:length(βs)
    push!(mean_st_AT2,mean_st_AT2_local);
    push!(mean_st_AT2I,mean_st_AT2I_local);
    push!(mean_st_TCB,mean_st_TCB_local);
    push!(mean_st_TCBI,mean_st_TCBI_local);
    push!(std_st_AT2,std_st_AT2_local);
    push!(std_st_AT2I,std_st_AT2I_local);
    push!(std_st_TCB,std_st_TCB_local);
    push!(std_st_TCBI,std_st_TCBI_local);
end

using Plots;
colors = [ :red, :green,   :purple, :blue, :orange, :black, :brown, :yellow, :cyan, :pink];
colors2 = [:pink, :orange, :blue]

plot(title = "Sample Complexity", xlabel="Beta", ylabel="Number of Samples");
plot!([βs...], mean_st_AT2, ribbon =(2 .* std_st_AT2,2 .* std_st_AT2), fillalpha= 0.05, color=colors[1], line=(:solid, 1), marker=(:star6,4), label = string("AT2 "));
plot!([βs...],mean_st_AT2I, ribbon =(2 .* std_st_AT2I,2 .* std_st_AT2I), fillalpha= 0.05, color=colors2[1], line=(:solid, 1), marker=(:star6,4), label = string("IAT2"));
plot!([βs...],mean_st_TCB, ribbon =(2 .* std_st_TCB,2 .* std_st_TCB), fillalpha= 0.05, color=colors[2], line=(:dashdot, 1), marker=(:diamond,4), label = string("TCB"));
plot!([βs...],mean_st_TCBI, ribbon =(2 .* std_st_TCBI,2 .* std_st_TCBI), fillalpha= 0.05, color=colors2[2], line=(:dashdot, 1), marker=(:diamond,4), label = string("ITCB"));
plot!([βs...],mean_st_BetaEBTCB, ribbon =(2 .* std_st_BetaEBTCB,2 .* std_st_BetaEBTCB), fillalpha= 0.05, color=colors[3], marker=(:utriangle,4), line=(:dash, .5), label = string("Beta-EB-TCB"));
p_SC = plot!([βs...],mean_st_BetaEBTCBI, ribbon =(2 .* std_st_BetaEBTCBI,2 .* std_st_BetaEBTCBI), fillalpha= 0.05, color=colors2[3], marker=(:utriangle,4), line=(:dash, .5), label = string("Beta-EB-ITCB"));
savefig(p_SC, string(output,"SC_",typeDistribution,"_",niter,".pdf"));