using ArgParse, JLD2, Printf, JSON, Dates, IterTools, Random;
using Distributed;

@everywhere include("runiteps.jl");
@everywhere include("helpers.jl");
@everywhere include("../binary_search.jl");
@everywhere include("../expfam.jl");
include("helpers_experiments.jl");

function parse_commandline()
    s = ArgParseSettings();

    @add_arg_table! s begin
        "--save_dir"
            help = "Directory for saving the experiment's data."
            arg_type = String
            default = "experiments/"
        "--data_dir"
            help = "Directory for loading the data."
            arg_type = String
            default = "data/"
        "--seed"
            help = "Seed."
            arg_type = Int64
            default = 42
        "--inst"
            help = "Instance considered."
            arg_type = String
            default = "uniform"
        "--K"
            help = "Number of arms."
            arg_type = Int64
            default = 5
        "--mu1"
            help = "Best arm."
            arg_type = Float64
            default = 1.
        "--expe"
            help = "Experiment considered."
            arg_type = String
            default = "epsTest"
        "--Nruns"
            help = "Number of runs of the experiment."
            arg_type = Int64
            default = 100
        "--history"
            help = "History of recommendation to keep: 'none' or 'partial'."
            arg_type = String
            default = "partial"
        "--freqHist"
            help = "Frequence of storing recommendations."
            arg_type = Int64
            default = 50
        "--eps"
            help = "Approximation parameter."
            arg_type = Float64
            default = 0.1
        "--reps"
            help = "Proportion of ϵ-opt arms."
            arg_type = Float64
            default = 0.25
        "--opt"
            help = "ϵ-optimality considered: `mul` or `add`."
            arg_type = String
            default = "add"
    end

    parse_args(s);
end

@everywhere function get_rand_instance(param_inst, rng)
    nK = param_inst["nK"];
    ϵ = param_inst["eps"];
    opt = param_inst["opt"];
    Kϵ = Int64(ceil(nK * param_inst["reps"]));

    μs = param_inst["mu1"] * ones(nK);
    for a in 2:nK
        if param_inst["inst"] == "uniform"
            if a <= Kϵ
                μs[a] -= ϵ * rand(rng);
            else
                μs[a] = (1 - ϵ) * rand(rng);
            end
        elseif param_inst["inst"] == "BAIuniform"
            μs[a] = 0.5 + 0.3 * rand(rng);
        else
            @error "Not implemented"
        end
    end

    dists = [Gaussian() for μ in μs];
    return μs, dists;
end

@everywhere function run_inst(seed, iss, δs, param_inst, Tau_max, history, freqHist)
    rng = MersenneTwister(seed);

    # Random instance
    μs, dists = get_rand_instance(param_inst, rng);

    # Pure exploration problem
    pep = EspilonBestArm(dists, param_inst["eps"], param_inst["opt"]);
    Tstar, wstar = oracle(pep, μs);
    Tstar_beta, wstar_beta = oracle_beta_half(pep, μs);
    instance = (μs, pep.ϵ, pep.opt, Tstar, wstar, Tstar_beta, wstar_beta);

    # Stored results
    R = Tuple{Any, Tuple{Int64, Array{Int64,1}, UInt64, Array{Int64,1}}, Tuple{Array{Float64,1}, Float64, String, Float64, Array{Float64,1}, Float64, Array{Float64,1}}}[];

    # Evaluating iss
    for (i, is) in enumerate(iss)
        results = runit(seed, is, pep, μs, δs, Tau_max, history, freqHist);
        for result in results
            push!(R, (is, result, instance));
        end
    end

    R;
end

# Parameters
parsed_args = parse_commandline();
save_dir = parsed_args["save_dir"];
data_dir = parsed_args["data_dir"];
seed = parsed_args["seed"];
inst = parsed_args["inst"];
nK = parsed_args["K"];
mu1 = parsed_args["mu1"];
expe = parsed_args["expe"];
Nruns = parsed_args["Nruns"];
history = parsed_args["history"];
freqHist = parsed_args["freqHist"];
ϵ = parsed_args["eps"];
opt = parsed_args["opt"];
reps = parsed_args["reps"];

# Get Tau_max
Tau_max = 1e6;

# Storing parameters defining the instance
param_inst = Dict("inst" => inst, "nK" => nK, "mu1" => mu1,
                  "eps" => ϵ, "reps" => reps, "opt" => opt);

# Associated β functions
δs = [0.01];

# Naming files and folder
now_str = Dates.format(now(), "dd-mm_HHhMM");
_ϵ = split(string(ϵ), ".")[2];
_rϵ = split(string(reps), ".")[2];
experiment_name = "exp_eps_random_" * expe * "_" * opt * "_e" * _ϵ * "_re" * _rϵ * "_" * history * (history == "partial" ? string(freqHist) : "") * inst * "_K" * string(nK) * "_N" * string(Nruns);
experiment_dir = save_dir * now_str * ":" * experiment_name * "/";
mkdir(experiment_dir);
open("$(experiment_dir)parsed_args.json","w") do f
    JSON.print(f, parsed_args)
end

# Identification strategy used used on this instance: tuple (sr, rsp)
iss = everybody(expe, ones(nK) / nK);
iss_counters = Dict(is => 1 for is in iss);
iss_index = Dict(is => i for (i, is) in enumerate(iss));

# Run the experiments in parallel
@time _data = pmap(
    (i,) -> run_inst(seed + i, iss, δs, param_inst, Tau_max, history, freqHist),
    1:Nruns
);

data = Array{Tuple{Tuple{Int64, Array{Int64,1}, UInt64, Array{Int64,1}}, Tuple{Array{Float64,1}, Float64, String, Float64, Array{Float64,1}, Float64, Array{Float64,1}}}}(undef, length(iss), Nruns * length(δs));
for chunk in _data
    for (is, res, inst) in chunk
        data[iss_index[is], iss_counters[is]] = (res, inst);
        iss_counters[is] += 1;
    end
end

# Save everything using JLD2.
@save "$(experiment_dir)$(experiment_name).dat" iss data iss_index δs param_inst Nruns seed;

# Print a summary of the problem we considered
file = "$(experiment_dir)summary_$(experiment_name).txt";
print_eps_rand_summary(δs, iss, data, iss_index, param_inst, Nruns, file);
