using ArgParse, JLD2, Printf, JSON, Dates;
using Random, LinearAlgebra;
using Distributed;

@everywhere include("../binary_search.jl");
@everywhere include("peps.jl");
@everywhere include("../expfam.jl");

function parse_commandline()
    s = ArgParseSettings();
    @add_arg_table! s begin
        "--save_dir"
            help = "Directory for saving the experiment's data."
            arg_type = String
            default = "simulations/"
        "--seed"
            help = "Seed."
            arg_type = Int64
            default = 42
        "--N"
            help = "Number of computations."
            arg_type = Int64
            default = 1000
    end
    parse_args(s);
end

# Parse arguments
parsed_args = parse_commandline();
save_dir = parsed_args["save_dir"];
seed = parsed_args["seed"];
N = parsed_args["N"];

# Create experiment folder
now_str = Dates.format(now(), "dd-mm_HHhMM");
experiment_name = "sim_N$(N)";
experiment_dir = save_dir * now_str * ":" * experiment_name * "/";
mkdir(experiment_dir);
open("$(experiment_dir)parsed_args.json","w") do f
    JSON.print(f, parsed_args)
end

@everywhere function compute_characteristic_times_ratio_K3(rΔ)
    # 2-arms bandit
    μs = [0.0, -1, -rΔ];
    dists = [Gaussian() for μ in μs];
    pep = BestArm(dists);

    # Oracle computation
    Tstar, wstar = oracle(pep, μs);
    Tstar_beta, _ = oracle_beta_half(pep, μs);

    return Tstar_beta/Tstar, Tstar, wstar, Tstar_beta;
end

@everywhere function compute_characteristic_times_ratio_K4(rΔ1, rΔ2)
    # 2-arms bandit
    μs = [0.0, -1, -rΔ1, -rΔ2];
    dists = [Gaussian() for μ in μs];
    pep = BestArm(dists);

    # Oracle computation
    Tstar, wstar = oracle(pep, μs);
    Tstar_beta, _ = oracle_beta_half(pep, μs);

    return Tstar_beta/Tstar, Tstar, wstar, Tstar_beta;
end

# Varying rΔ
rΔs = range(1, stop=6, length=N);
_rΔs = collect(rΔs);

# Computations
@time vals =  pmap(
    (rΔ,) -> compute_characteristic_times_ratio_K3(rΔ),
    rΔs
);

# Save data
file_str = "beta_one_half_K3_N$(N)";
file_name = "$(experiment_dir)$(file_str).dat";
@save file_name vals _rΔs;


# Varying rΔ
rΔs = range(1, stop=4, length=Int64(N / 10));
_rΔs = collect(rΔs);

# Computations
@time vals =  pmap(
    ((rΔ1, rΔ2),) -> compute_characteristic_times_ratio_K4(rΔ1, rΔ2),
    Iterators.product(rΔs, rΔs)
);

# Save data
file_str = "beta_one_half_K4_N$(N)";
file_name = "$(experiment_dir)$(file_str).dat";
@save file_name vals _rΔs;
