# If not launching julia with multiple threads, remove @everywhere command from the below lines

@everywhere include("../Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

using Distributed;
using DataFrames;

typeDistribution = "Gaussian"
niter = 4000;  # Number of simulation runs to perform
trunc_digits = 2; # Number of digits after dec. 
T = 500; # Time for which the dynamics should evolve

if length(ARGS) > 1
    typeDistribution= ARGS[1];
    niter = parse.(Int64, ARGS[2]);
    T = parse.(Int64, ARGS[3]);
end

# Create default MAB instance

dist = "Gaussian"
mu = [10, 8, 7, 6.5]
arm1 = Normal(mu[1],1);
arm2 = Normal(mu[2],1);
arm3 = Normal(mu[3],1);
arm4 = Normal(mu[4],1);


if typeDistribution == "Bernoulli"
    dist = "Bernoulli"
    scale = maximum(mu)*1.1;
    mu = round.(mu./scale, digits = 2);
    arm1 = Bernoulli(mu[1]);
    arm2 = Bernoulli(mu[2]);
    arm3 = Bernoulli(mu[3]);
    arm4 = Bernoulli(mu[4]);
    
    trunc_digits = 5;
end

global MAB = (arm1, arm2, arm3, arm4);
println(MAB)

a_star = argmax(mu);
K = length(mu);

# Set problem parameters
δ = 0.0001;
N = [1, 1, 1, 1]; # Initial number of samples to each arm

seed = 1123;
quant = 0.001;

# Collect data

# Running niter iterations of non-fluid AT2 algorithm
data_non_fluid = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop=false, b=-1, InitialSamples = N, β = 0.5, α=0.5, αTCB=1, dist = dist), 1:niter);

# Running niter iterations of non-fluid TCB algorithm
data_TCB = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop=false, b=-1, InitialSamples = N, β = 0.5, α=0.5, αTCB=1, dist = dist), 1:niter);

# Running niter iterations of non-fluid BetaEBTCB algorithm
data_BetaEBTCB = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop=false, b=-1, InitialSamples = N, β = 0.5, α=0.5, αTCB=1, dist = dist), 1:niter);


# Visualize the collected data
# collect return values: 
# 1 - number of samples till time t to each arm, 
# 2 - index for each arm at time t, 
# 3 - sum_ratio - 1 at time t.

using Plots;
I=1 # Pick a sample path to plot
save_dir = "Experiment_1/output/"

################## CLEAN DATA FOR ALGOS. ##################

# Get time-series of index values of each arm in non-fluid AT2 across niter runs
# Compute time-series of average and std. of index values of each arm, # of samples for each arm, in non-fluid AT2
index_nfl_exp = Vector{Vector{Float64}}();
mean_index = Vector{Vector{Float64}}();
std_index = Vector{Vector{Float64}}();
samples_nfl = Vector{Vector{Int64}}();
mean_samples= Vector{Vector{Float64}}();
std_samples = Vector{Vector{Float64}}();

for l = 1:K
    push!(index_nfl_exp, getindex.(getindex.(getindex.(data_non_fluid,1),2)[I],l)); # time series of index values of arm l in fixed samplepath I of the experiment
    push!(mean_index, round.(mean([getindex.(getindex.(getindex.(data_non_fluid,1),2)[i],l) for i in 1:niter]), digits=trunc_digits)); #time series of mean (across niter iterations) index of arm l
    push!(std_index, round.(sqrt.(var([getindex.(getindex.(getindex.(data_non_fluid,1),2)[i],l) for i in 1:niter])./niter), digits=trunc_digits)); #time series of std. dev of (across niter iterations) index of arm l
    push!(samples_nfl, getindex.(getindex.(getindex.(data_non_fluid,1),1)[I],l)); #nsamples per arm
    push!(mean_samples, round.(mean([getindex.(getindex.(getindex.(data_non_fluid,1),1)[i],l) for i in 1:niter]), digits=trunc_digits)); #mean nsamples per a.
    push!(std_samples, round.(sqrt.(var([getindex.(getindex.(getindex.(data_non_fluid,1),1)[i],l) for i in 1:niter])./niter), digits=trunc_digits));
end

# get time series of mean and std. ratio-sum - 1 value in non-fluid AT2 #
mean_g_nfl = round.(mean([getindex.(getindex.(data_non_fluid,1),3)[i] for i in 1:niter]), digits = trunc_digits);
std_g_nfl = round.(sqrt.(var([getindex.(getindex.(data_non_fluid,1),3)[i] for i in 1:niter])./niter), digits = trunc_digits);


# COLLECT ALL ABOVE FOR OTHER ALGORITHMS: TCB
# Compute time-series of average and std. of index values of each arm, # of samples for each arm, in non-fluid TCB
mean_index_TCB = Vector{Vector{Float64}}();
std_index_TCB = Vector{Vector{Float64}}();
for l = 1:K
    push!(mean_index_TCB, round.(mean([getindex.(getindex.(getindex.(data_TCB,1),2)[i],l) for i in 1:niter]), digits=trunc_digits));
    push!(std_index_TCB, round.(sqrt.(var([getindex.(getindex.(getindex.(data_TCB,1),2)[i],l) for i in 1:niter])./niter), digits=trunc_digits));
end

# get time series of mean and std. ratio-sum - 1 value in non-fluid AT2
mean_g_TCB = round.(mean([getindex.(getindex.(data_TCB,1),3)[i] for i in 1:niter]), digits = trunc_digits);
std_g_TCB = round.(sqrt.(var([getindex.(getindex.(data_TCB,1),3)[i] for i in 1:niter])./niter), digits = trunc_digits);


# COLLECT ALL ABOVE FOR OTHER ALGORITHMS: BetaEBTCB
# Compute time-series of average and std. of index values of each arm, # of samples for each arm, in non-fluid AT2
mean_index_BetaEBTCB = Vector{Vector{Float64}}();
std_index_BetaEBTCB = Vector{Vector{Float64}}();
for l = 1:K
    push!(mean_index_BetaEBTCB, round.(mean([getindex.(getindex.(getindex.(data_BetaEBTCB,1),2)[i],l) for i in 1:niter]), digits=trunc_digits));
    push!(std_index_BetaEBTCB, round.(sqrt.(var([getindex.(getindex.(getindex.(data_BetaEBTCB,1),2)[i],l) for i in 1:niter])./niter), digits=trunc_digits));
end

# get time series of mean and std. ratio-sum - 1 value in non-fluid AT2
mean_g_BetaEBTCB = round.(mean([getindex.(getindex.(data_BetaEBTCB,1),3)[i] for i in 1:niter]), digits = trunc_digits);
std_g_BetaEBTCB = round.(sqrt.(var([getindex.(getindex.(data_BetaEBTCB,1),3)[i] for i in 1:niter])./niter), digits = trunc_digits);

################## ALGOS COMPLETED ##################

################## FINALLY, GENERATE PLOTS ##################

colors = [ :red, :blue, :orange, :green, :purple, :black, :brown, :yellow, :cyan, :pink]

######### non-fluid AT2 indexes #########
p_alg_exp = plot(title = "AT2 Index Values", xlabel="Time", ylabel="Index Values");
for l in 1:K
    if l != a_star
        plot!(mean_index[l], ribbon =(2 .* std_index[l],2 .* std_index[l]), fillalpha= 0.2, color=colors[l], line=(:solid, 2), label = string("Arm ",l), xlabel="Time", ylabel="Index Values");
    end
end

savefig(p_alg_exp,string(save_dir, K, "AT2Indexes_",dist,"_",T,"_",niter,".pdf"));

######### non-fluid TCB indexes #########
p_TCB = plot(title = "TCB Index Values", xlabel="Time", ylabel="Index Values");
for l in 1:K
    if l != a_star
        plot!(mean_index_TCB[l], ribbon =(2 .* std_index_TCB[l],2 .* std_index_TCB[l]), fillalpha= 0.2, color=colors[l], line=(:dashdotdot, 2), label = string("Arm ",l), xlabel="Time", ylabel="Index Values");
    end
end

savefig(p_TCB,string(save_dir, K, "TCBIndexes_",dist,"_",T,"_",niter,".pdf"));

######### non-fluid BetaEBTCB indexes #########
p_BetaEBTCB = plot(title = "Beta-EB-TCB Index Values", xlabel="Time", ylabel="Index Values");
for l in 1:K
    if l != a_star
        plot!(mean_index_BetaEBTCB[l], ribbon =(2 .* std_index_BetaEBTCB[l],2 .* std_index_BetaEBTCB[l]), fillalpha= 0.2, color=colors[l], line=(:dashdotdot, 2), label = string("Arm ",l), xlabel="Time", ylabel="Index Values");
    end
end

savefig(p_BetaEBTCB,string(save_dir, K, "BetaEBTCBIndexes_",dist,"_",T,"_",niter,".pdf"));

######### Combined ratio-sum-1 condition for all algos #########
plot(mean_g_nfl, ribbon = (2 .* std_g_nfl, 2 .* std_g_nfl), fillalpha = 0.2, color=:green, line=(:dashdotdot,2),title="Anchor function value", label="AT2");

plot!(mean_g_TCB, ribbon = (2 .* std_g_TCB, 2 .* std_g_TCB), fillalpha = 0.2, color=:blue, line=(:dashdotdot,2),label="TCB");

g_comb = plot!(mean_g_BetaEBTCB, ribbon = (2 .* std_g_BetaEBTCB, 2 .* std_g_BetaEBTCB), fillalpha = 0.2, color=:red, line=(:dashdotdot,2),label="Beta-EB-TCB");

savefig(g_comb, string(save_dir, K, "Arms_Algos_combined_g_",dist,"_",T,"_",niter, ".pdf"));

g = plot(mean_g_nfl, ribbon = (2 .* std_g_nfl, 2 .* std_g_nfl), fillalpha = 0.2, color=:green, line=(:solid,2),title="Anchor function value", label="AT2");
savefig(g, string(save_dir, K, "AT2_g_",dist,"_",T,"_",niter, ".pdf"));