# If not launching julia with multiple threads, remove @everywhere command from the below lines
@everywhere include("../Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

#include("../Top2Algos.jl");
#beta(t,n,d)=log((log(t)+1)/d);

##########################################################################################
# TO RUN THIS FILE, state the output path and mean vector, or uncomment below
# Set the parameters below appropriately
##########################################################################################


using Random, Distributions
using Distributed;
using DataFrames;

output = "Experiment_6/output/"
# Vector of means
#mu = [7.25, 7.05, 7, 7.1] # Instance from Figure 10
mu = [7.25, 7, 7, 7] # Instance from Figure 11


plotall = false; # flag for plotting all algos vs just AT2 and IAT2 in one plot
comparewithTaS = false;
plotAT2 = true;
plothistograms = false;

typeDistribution = "Gaussian"
niter = 100;  # Number of simulation runs to perform
println("Started $typeDistribution for $niter runs")

if length(ARGS) > 1
    typeDistribution= ARGS[1];
    niter = parse.(Int64, ARGS[2]); 
end


# Create default MAB instance
arm1 = Normal(mu[1],1);
arm2 = Normal(mu[2],1);
arm3 = Normal(mu[3],1);
arm4 = Normal(mu[4],1);


if typeDistribution=="Gaussian"
    arm1 = Normal(mu[1],1);
    arm2 = Normal(mu[2],1);
    arm3 = Normal(mu[3],1);
    arm4 = Normal(mu[4],1);
end
    
if typeDistribution=="Bernoulli"
    scale = maximum(mu)*1.01;
    mu = round.(mu./scale, digits = 3);
    arm1 = Binomial(1,mu[1]);
    arm2 = Binomial(1,mu[2]);
    arm3 = Binomial(1,mu[3]);
    arm4 = Binomial(1,mu[4]);
end

global MAB = (arm1, arm2, arm3, arm4);
println(MAB);

K = length(mu);
mu = mean.(MAB);
mu = [m for m in mu]
best = argmax(mu);

# Set problem parameters
δ = 0.001;
seed = 1123;
αs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

println("mu=$(mu), δ=$δ");

# Create table with data to store. 
# Define columns
Alg = Vector{String}();
MeanST = Vector{Float64}();
Sd_ST = Vector{Float64}();
RunTime = Vector{Float64}();
Sd_RT = Vector{Float64}();

T = 1000000; # Max time for which the dynamics should evolve

if plotall == true 
    # BETA-EB-TCB
    mean_st_BetaEBTCB = Vector{Float64}();
    std_st_BetaEBTCB = Vector{Float64}();

    for a in 1:length(αs)
        println("\n ~~~", αs[a], "-EB-TCB ~~~");
        data_BetaEBTCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop = true, b=beta, β = 0.5, α = αs[a], dist=typeDistribution), 1:niter);
        m = mean(length.(getindex.(getindex.(data_BetaEBTCB,1),1)));
        s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCB,1),1)))/niter);
        run_time_l = getindex.(getindex.(data_BetaEBTCB,1),5);
        mean_rt_l = mean(run_time_l);
        std_rt_l = sqrt(var(run_time_l)/niter);

        push!(mean_st_BetaEBTCB, round(m, digits=2));
        push!(std_st_BetaEBTCB, round(s, digits = 2));
        push!(Alg, string(0.5,"-EB-TCB"));
        push!(MeanST, round(m, digits = 2));
        push!(Sd_ST, round(s, digits = 2));
        push!(RunTime, round(mean_rt_l, digits = 2));
        push!(Sd_RT, round(std_rt_l, digits = 2));
    end

    # BETA-EB-TCBI
    mean_st_BetaEBTCBI = Vector{Float64}();
    std_st_BetaEBTCBI = Vector{Float64}();

    for a in 1:length(αs)
        println("\n ~~~", αs[a], "-EB-TCBI ~~~");
        data_BetaEBTCBI = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCBI"; check_stop = true, b=beta, β = 0.5, α = αs[a], dist=typeDistribution), 1:niter);
        m = mean(length.(getindex.(getindex.(data_BetaEBTCBI,1),1)));
        s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCBI,1),1)))/niter);
        run_time_l = getindex.(getindex.(data_BetaEBTCBI,1),5);
        mean_rt_l = mean(run_time_l);
        std_rt_l = sqrt(var(run_time_l)/niter);

        push!(mean_st_BetaEBTCBI, round(m, digits = 2));
        push!(std_st_BetaEBTCBI, round(s, digits = 2));
        push!(Alg, string(0.5,"-EB-TCBI"));
        push!(MeanST, round(m, digits = 2));
        push!(Sd_ST, round(s, digits = 2));
        push!(RunTime, round(mean_rt_l, digits=2));
        push!(Sd_RT, round(std_rt_l, digits = 2));
    end
    
    # TCB
    mean_st_TCB = Vector{Float64}();
    std_st_TCB = Vector{Float64}();

    for a in 1:length(αs)
        println("\n", αs[a]," ~~~ TCB ~~~")
        data_TCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop = true, b=beta, α = αs[a], dist=typeDistribution), 1:niter);
        mean_st_TCB_local = mean(length.(getindex.(getindex.(data_TCB,1),1)));
        std_st_TCB_local = sqrt(var(length.(getindex.(getindex.(data_TCB,1),1)))/niter);
        estimated_BA_TCB = getindex.(getindex.(data_TCB,1),4);
        error_frac_TCB = sum(estimated_BA_TCB .!= best)/T;
        run_time_TCB = getindex.(getindex.(data_TCB,1),5);
        mean_rt_TCB = mean(run_time_TCB);
        std_rt_TCB = sqrt(var(run_time_TCB)/niter);

        push!(mean_st_TCB, round(mean_st_TCB_local, digits=2));
        push!(std_st_TCB, round(std_st_TCB_local, digits = 2));
        push!(Alg, "TCB");
        push!(MeanST, round(mean_st_TCB_local, digits=2));
        push!(Sd_ST, round(std_st_TCB_local, digits=2));
        push!(RunTime, round(mean_rt_TCB, digits =2 ));
        push!(Sd_RT, round(std_rt_TCB, digits = 2));
    end

    # TCBI
    mean_st_TCBI = Vector{Float64}();
    std_st_TCBI = Vector{Float64}();

    for a in 1:length(αs)
        println("\n", αs[a]," ~~~ TCBI ~~~")
        data_TCBI =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCBI"; check_stop = true, b=beta, α = αs[a], dist=typeDistribution), 1:niter);
        mean_st_TCBI_local = mean(length.(getindex.(getindex.(data_TCBI,1),1)));
        std_st_TCBI_local = sqrt(var(length.(getindex.(getindex.(data_TCBI,1),1)))/niter);
        estimated_BA_TCBI = getindex.(getindex.(data_TCBI,1),4);
        error_frac_TCBI = sum(estimated_BA_TCBI .!= best)/T;
        run_time_TCBI = getindex.(getindex.(data_TCBI,1),5);
        mean_rt_TCBI = mean(run_time_TCBI);
        std_rt_TCBI = sqrt(var(run_time_TCBI)/niter);

        push!(mean_st_TCBI, round(mean_st_TCBI_local, digits=2));
        push!(std_st_TCBI, round(std_st_TCBI_local, digits = 2));
        push!(Alg, "TCBI");
        push!(MeanST, round(mean_st_TCBI_local, digits =2 ));
        push!(Sd_ST, round(std_st_TCBI_local, digits = 2));
        push!(RunTime, round(mean_rt_TCBI, digits =2 ));
        push!(Sd_RT, round(std_rt_TCBI, digits = 2));
    end
end


# AT2
mean_st_AT2 = Vector{Float64}();
std_st_AT2 = Vector{Float64}();
st_AT2 = Vector{Vector{Float64}}();

for a in 1:length(αs)
    println("\n", αs[a]," ~~~ AT2 ~~~")
    data_AT2 = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop = true, b=beta, α = αs[a], dist=typeDistribution), 1:niter);
    stop_time_AT2 = length.(getindex.(getindex.(data_AT2,1),1));
    mean_st_AT2_local = mean(stop_time_AT2);
    std_st_AT2_local = sqrt(var(stop_time_AT2)/niter);
    estimated_BA_AT2 = getindex.(getindex.(data_AT2,1),4);
    error_frac_AT2 = sum(estimated_BA_AT2 .!= best)/T;
    run_time_AT2 = getindex.(getindex.(data_AT2,1),5);
    mean_rt_AT2 = mean(run_time_AT2);
    std_rt_AT2 = sqrt(var(run_time_AT2)/niter);

    index_AT2 = Vector{Vector{Float64}}();
    for l = 1:K
        push!(index_AT2, getindex.(getindex.(getindex.(data_AT2,1),2)[1],l));
    end

    push!(st_AT2, stop_time_AT2);
    push!(mean_st_AT2, round(mean_st_AT2_local, digits=2));
    push!(std_st_AT2, round(std_st_AT2_local, digits = 2));
    push!(Alg, string("AT2"));
    push!(MeanST, round(mean_st_AT2_local, digits=2));
    push!(Sd_ST, round(std_st_AT2_local, digits = 2));
    push!(RunTime, round(mean_rt_AT2, digits = 2));
    push!(Sd_RT, round(std_rt_AT2, digits=2));
end


# IAT2
mean_st_AT2I = Vector{Float64}();
std_st_AT2I = Vector{Float64}();
st_AT2I = Vector{Vector{Float64}}();

for a in 1:length(αs)
    println("\n", αs[a]," ~~~ IAT2 ~~~")
    data_AT2I =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2I"; check_stop = true, b=beta, α = αs[a], dist=typeDistribution), 1:niter);
    stop_time_AT2I = length.(getindex.(getindex.(data_AT2I,1),1));
    mean_st_AT2I_local = mean(length.(getindex.(getindex.(data_AT2I,1),1)));
    std_st_AT2I_local = sqrt(var(length.(getindex.(getindex.(data_AT2I,1),1)))/niter);
    estimated_BA_AT2I = getindex.(getindex.(data_AT2I,1),4);
    error_frac_AT2I = sum(estimated_BA_AT2I .!= best)/T;
    run_time_AT2I = getindex.(getindex.(data_AT2I,1),5);
    mean_rt_AT2I = mean(run_time_AT2I);
    std_rt_AT2I = sqrt(var(run_time_AT2I)/niter);

    push!(st_AT2I, stop_time_AT2I);
    push!(mean_st_AT2I, round(mean_st_AT2I_local, digits=2));
    push!(std_st_AT2I, round(std_st_AT2I_local, digits = 2));
    push!(Alg, "AT2I");
    push!(MeanST, round(mean_st_AT2I_local, digits = 2));
    push!(Sd_ST, round(std_st_AT2I_local, digits=2));
    push!(RunTime, round(mean_rt_AT2I, digits =2));
    push!(Sd_RT, round(std_rt_AT2I, digits=2));
end


if comparewithTaS == true
    # TaS
    mean_st_TaS = Vector{Float64}();
    std_st_TaS = Vector{Float64}();
    st_TaS = Vector{Vector{Float64}}();

    for a in 1:length(αs)
        println("\n", αs[a]," ~~~ TaS ~~~")
        data_TaS = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TaS"; check_stop = true, b=beta, α = αs[a], dist=typeDistribution), 1:niter);
        stop_time_TaS = length.(getindex.(getindex.(data_TaS,1),1));
        mean_st_TaS_local = mean(stop_time_TaS);
        std_st_TaS_local = sqrt(var(stop_time_TaS)/niter);
        estimated_BA_TaS = getindex.(getindex.(data_TaS,1),4);
        error_frac_TaS = sum(estimated_BA_TaS .!= best)/T;
        run_time_TaS = getindex.(getindex.(data_TaS,1),5);
        mean_rt_TaS = mean(run_time_TaS);
        std_rt_TaS = sqrt(var(run_time_TaS)/niter);

        index_TaS = Vector{Vector{Float64}}();
        for l = 1:K
            push!(index_TaS, getindex.(getindex.(getindex.(data_TaS,1),2)[1],l));
        end

        push!(st_TaS, stop_time_TaS);
        push!(mean_st_TaS, round(mean_st_TaS_local, digits=2));
        push!(std_st_TaS, round(std_st_TaS_local, digits = 2));
        push!(Alg, string("TaS"));
        push!(MeanST, round(mean_st_TaS_local, digits=2));
        push!(Sd_ST, round(std_st_TaS_local, digits = 2));
        push!(RunTime, round(mean_rt_TaS, digits = 2));
        push!(Sd_RT, round(std_rt_TaS, digits=2));
    end
end

################## PRINT DATA ##################
df = DataFrame(;AlgoName=Alg, AvgST = MeanST, StdDevST = Sd_ST, RunTime=RunTime, StdDevRT = Sd_RT);
print(df)

#=
# Print all stopping times for specific parameters.
println("");
println("");
for a in 1:length(αs)
    if (αs[a] == 0 || αs[a] == 0.5)
        println("Samples for α=",αs[a],": ",st_AT2I[a]);
        println("Aerage: ",mean(st_AT2I[a])," Maximum: ",maximum(st_AT2I[a]));
    end
    println("");
    println("");
end
println("");
println("");
=#

using Plots;
using StatsPlots;
colors = [ :red, :green,   :purple, :brown, :blue, :orange, :black, :yellow, :cyan, :pink];
colors2 = [:pink, :orange, :blue]

plot(title = "Sample Complexity", xlabel="Exploration exponent (alpha)", ylabel="Number of Samples");
if plotall == true 
    plot!([αs...],mean_st_TCB, ribbon =(2 .* std_st_TCB,2 .* std_st_TCB), fillalpha= 0.05, color=colors[2], line=(:dashdot, 1), marker=(:diamond,4), label = string("TCB"));
    plot!([αs...],mean_st_TCBI, ribbon =(2 .* std_st_TCBI,2 .* std_st_TCBI), fillalpha= 0.05, color=colors2[2], line=(:dashdot, 1), marker=(:diamond,4), label = string("ITCB"));
    plot!([αs...],mean_st_BetaEBTCB, ribbon =(2 .* std_st_BetaEBTCB,2 .* std_st_BetaEBTCB), fillalpha= 0.05, color=colors[3], marker=(:utriangle,4), line=(:dash, .5), label = string("Beta-EB-TCB"));
    plot!([αs...],mean_st_BetaEBTCBI, ribbon =(2 .* std_st_BetaEBTCBI,2 .* std_st_BetaEBTCBI), fillalpha= 0.05, color=colors2[3], marker=(:utriangle,4), line=(:dash, .5), label = string("Beta-EB-ITCB"));
end

if comparewithTaS == true
    plot!([αs...], mean_st_TaS, ribbon =(2 .* std_st_TaS,2 .* std_st_TaS), fillalpha= 0.05, color=colors[4], line=(:solid, 1), marker=(:circle,4), label = string("TaS "));
end
if plotAT2 == true
    plot!([αs...], mean_st_AT2, ribbon =(2 .* std_st_AT2,2 .* std_st_AT2), fillalpha= 0.05, color=colors[1], line=(:solid, 1), marker=(:star6,4), label = string("AT2 "));
end
p_SC = plot!([αs...],mean_st_AT2I, ribbon =(2 .* std_st_AT2I,2 .* std_st_AT2I), fillalpha= 0.05, color=colors2[1], line=(:solid, 1), marker=(:star6,4), label = string("IAT2"));

savefig(p_SC, string(output,"SC_alpha_",typeDistribution,"_same_",niter,".pdf"));
display(p_SC);

################## Make Histograms #################

if plothistograms == true 
    for i in eachindex(αs)
        q = density([st_AT2I[i]],
                    normalize=:pdf,
                    xlabel="T",
                    ylabel="fraction of runs",
                    label=["IAT2"],
                    title=string("α=",αs[i]),
                    left_margin=10Plots.mm,
                    bottom_margin=5Plots.mm
                    )
        savefig(q, string(output, "SC_hist_alpha_",i,"_",typeDistribution,"_same_",niter,".pdf"));
        # display(q);
    end



    #=
    for i in eachindex(αs)
        q = density([st_AT2[i], st_AT2I[i]],
                    normalize=:pdf,
                    xlabel="T",
                    ylabel="fraction of runs",
                    label=["AT2" "IAT2"],
                    title=string("α=",αs[i]),
                    left_margin=10Plots.mm,
                    bottom_margin=5Plots.mm
                    )
        savefig(q, string(output, "SC_hist_alpha_",i,"_",typeDistribution,"_same_",niter,".pdf"));
        display(q);
    end
    =#


    comb = density([st_AT2I[i] for i in eachindex(αs)],
                label=[αs...]',
                normalize=:pdf,
                xlabel="T",
                ylabel="fraction of runs",
                legendtitle="α",
                left_margin=10Plots.mm,
                bottom_margin=5Plots.mm
                )

    savefig(comb, string(output, "SC_combhist_alpha_",typeDistribution,"_same_",niter,".pdf"));
    display(comb);
end