using ArgParse, JLD2, Printf, JSON, Dates, IterTools, Random;
using Distributed;

@everywhere include("runit.jl");
@everywhere include("helpers.jl");
@everywhere include("../binary_search.jl");
@everywhere include("../expfam.jl");
include("helpers_experiments.jl");

function parse_commandline()
    s = ArgParseSettings();

    @add_arg_table! s begin
        "--save_dir"
            help = "Directory for saving the experiment's data."
            arg_type = String
            default = "experiments/"
        "--data_dir"
            help = "Directory for loading the data."
            arg_type = String
            default = "data/"
        "--seed"
            help = "Seed."
            arg_type = Int64
            default = 42
        "--inst"
            help = "Instance considered."
            arg_type = String
            default = "uniform"
        "--K"
            help = "Number of arms."
            arg_type = Int64
            default = 5
        "--mu1"
            help = "Best arm."
            arg_type = Float64
            default = 0.6
        "--gapmin"
            help = "Min gap."
            arg_type = Float64
            default = 0.1
        "--gapmax"
            help = "Max gap."
            arg_type = Float64
            default = 0.4
        "--expe"
            help = "Experiment considered."
            arg_type = String
            default = "test"
        "--Nruns"
            help = "Number of runs of the experiment."
            arg_type = Int64
            default = 100
        "--isBer"
            help = "Bernoulli bandits."
            action = :store_true
        "--history"
            help = "History of recommendation to keep: 'none' or 'partial'."
            arg_type = String
            default = "partial"
        "--freqHist"
            help = "Frequence of storing recommendations."
            arg_type = Int64
            default = 200
    end

    parse_args(s);
end

@everywhere function get_rand_instance(param_inst, rng)
    nK = param_inst["nK"];

    μs = param_inst["mu1"] * ones(nK);
    for a in 2:nK
        if param_inst["inst"] == "uniform"
            μs[a] -= (param_inst["gapmax"] - param_inst["gapmin"]) * rand(rng) + param_inst["gapmin"];
        elseif param_inst["inst"] == "distinct"
            val = (param_inst["gapmax"] - param_inst["gapmin"]) * rand(rng) + param_inst["gapmin"];
            while minimum(abs.(μs .- val)) <= 0.001
                val = (param_inst["gapmax"] - param_inst["gapmin"]) * rand(rng) + param_inst["gapmin"];
            end
            μs[a] -= val;
        elseif param_inst["inst"] == "unif2G"
            for a in 2:6
                μs[a] -= 0.05 * (0.995 + 0.01 * rand(rng));
            end
            for a in 7:10
                μs[a] -= 0.1 * (0.995 + 0.01 * rand(rng));
            end
        else
            @error "Not implemented"
        end
    end

    dists = [param_inst["isBer"] ? Bernoulli() : Gaussian() for μ in μs];
    return μs, dists;
end

@everywhere function run_inst(seed, iss, δs, param_inst, Tau_max, history, freqHist)
    rng = MersenneTwister(seed);

    # Random instance
    μs, dists = get_rand_instance(param_inst, rng);

    # Pure exploration problem
    pep = BestArm(dists);
    Tstar, wstar = oracle(pep, μs);
    Tstar_beta, wstar_beta = oracle_beta_half(pep, μs);
    instance = (μs, Tstar, wstar, Tstar_beta, wstar_beta);

    # Stored results
    R = Tuple{Any, Tuple{Int64, Array{Int64,1}, UInt64, Array{Int64,1}}, Tuple{Array{Float64,1}, Float64, Array{Float64,1}, Float64, Array{Float64,1}}}[];

    # Evaluating iss
    for (i, is) in enumerate(iss)
        results = runit(seed, is, pep, μs, δs, Tau_max, history, freqHist);
        for result in results
            push!(R, (is, result, instance));
        end
    end

    R;
end

# Parameters
parsed_args = parse_commandline();
save_dir = parsed_args["save_dir"];
data_dir = parsed_args["data_dir"];
seed = parsed_args["seed"];
inst = parsed_args["inst"];
nK = inst == "unif2G" ? 10 : parsed_args["K"];
mu1 = parsed_args["mu1"];
gapmin = parsed_args["gapmin"];
gapmax = parsed_args["gapmax"];
expe = parsed_args["expe"];
Nruns = parsed_args["Nruns"];
isBer = parsed_args["isBer"];
history = parsed_args["history"];
freqHist = parsed_args["freqHist"];


# Get Tau_max
Tau_max = 1e6;

# Storing parameters defining the instance
param_inst = Dict("inst" => inst, "nK" => nK, "mu1" => mu1,
                  "gapmin" => gapmin, "gapmax" => gapmax,
                  "isBer" => isBer);

# Associated β functions
δs = [0.1];

# Naming files and folder
now_str = Dates.format(now(), "dd-mm_HHhMM");
experiment_name = "exp_random_" * (isBer ? "ber_" : "gau_") * expe * "_" * history * (history == "partial" ? string(freqHist) : "") * inst * "_K" * string(nK) * "_N" * string(Nruns);
experiment_dir = save_dir * now_str * ":" * experiment_name * "/";
mkdir(experiment_dir);
open("$(experiment_dir)parsed_args.json","w") do f
    JSON.print(f, parsed_args)
end

# Identification strategy used used on this instance: tuple (sr, rsp)
iss = everybody(expe, ones(nK) / nK);
iss_counters = Dict(is => 1 for is in iss);
iss_index = Dict(is => i for (i, is) in enumerate(iss));

# Run the experiments in parallel
@time _data = pmap(
    (i,) -> run_inst(seed + i, iss, δs, param_inst, Tau_max, history, freqHist),
    1:Nruns
);

data = Array{Tuple{Tuple{Int64, Array{Int64,1}, UInt64, Array{Int64,1}}, Tuple{Array{Float64,1}, Float64, Array{Float64,1}, Float64, Array{Float64,1}}}}(undef, length(iss), Nruns * length(δs));
for chunk in _data
    for (is, res, inst) in chunk
        data[iss_index[is], iss_counters[is]] = (res, inst);
        iss_counters[is] += 1;
    end
end

# Save everything using JLD2.
@save "$(experiment_dir)$(experiment_name).dat" iss data iss_index δs param_inst Nruns seed;

# Print a summary of the problem we considered
file = "$(experiment_dir)summary_$(experiment_name).txt";
print_rand_summary(δs, iss, data, iss_index, param_inst, Nruns, file);
