import Pkg
Pkg.activate("./../")
# Pkg.status()

include("./../FPSoundVerification.jl")
using .FPSoundVerification
using Flux
using MLDatasets
using ArgParse


# Scripts to read in the neural networks
include("./../WK17_a/wk17a_flux.jl");
include("./../WK17_a/read_wk17a.jl");


# Database loading and preprocessing
x_test, y_test = MLDatasets.MNIST(split=:test)[:];
dataset_shape = (28, 28, 1, length(y_test));
dataset = reshape(x_test, dataset_shape) .|> Float64;
# Julia is 1 indexed, labels are the [0,9] integers
output_indices = y_test .+ 1;


# Parse the command line arguments
argparser = ArgParseSettings()

@add_arg_table argparser begin
    "--verifier", "-v"
        help = "Algorithm to use: IBP or Symbolic"
        arg_type = String
        default = "IBP"
    "--net_id", "-i"
        help = "Network ID"
        arg_type = String
        required = true
    "--num_instance", "-n"
        help = "How many instances to run the verification on"
        arg_type = Int
        default = 20
    "--perturbation_radius", "-p"
        help = "Perturbation radius"
        arg_type = Float64
        default = 0.01
    "--fp_sound", "-s"
        help = "Whether to use floating-point sound verification"
        action = :store_true
    "--backend", "-b"
        help = "Backend linear algebra implementation to use: textbook or flux"
        arg_type = String
        default = "textbook"
    "--verified_precision", "-r"
        help = "The verified precision."
        arg_type = String
        default = "Float64"

end

parsed_args = parse_args(argparser; as_symbols = true)
net_type = parsed_args[:verified_precision] == "Float32" ? Float32 : Float64

# Load the neural network based on the provided network ID
if parsed_args[:net_id] == "wk17a"
    nn = read_wk17a_flux(joinpath(".", "..", "WK17_a", "wk17a.mat"))
elseif parsed_args[:net_id] == "order3_64"
    nn = read_wk17a_from_mat(joinpath(".", "..", "WK17_a", "MAT", "wk17a_order_pattern_3_f64_adversary.mat"), net_type)
elseif parsed_args[:net_id] == "order2_64"
    nn = read_wk17a_from_mat(joinpath(".", "..", "WK17_a", "MAT", "wk17a_order_pattern_2_f64_adversary.mat"), net_type)
elseif parsed_args[:net_id] == "order1_64"
    nn = read_wk17a_from_mat(joinpath(".", "..", "WK17_a", "MAT", "wk17a_order_pattern_1_f64_adversary.mat"), net_type)
elseif parsed_args[:net_id] == "order3_32"
    nn = read_wk17a_from_mat(joinpath(".", "..", "WK17_a", "MAT", "wk17a_order_pattern_3_f32_adversary.mat"), net_type)
elseif parsed_args[:net_id] == "order2_32"
    nn = read_wk17a_from_mat(joinpath(".", "..", "WK17_a", "MAT", "wk17a_order_pattern_2_f32_adversary.mat"), net_type)
elseif parsed_args[:net_id] == "order1_32"
    nn = read_wk17a_from_mat(joinpath(".", "..", "WK17_a", "MAT", "wk17a_order_pattern_1_f32_adversary.mat"), net_type)
else
    error("Unknown network ID: $(parsed_args[:net_id])")
end


# Determine the verifier and its corresponding algorithm to use
if  parsed_args[:verifier] == "IBP"
    verifier = FPSoundIBP
elseif parsed_args[:verifier] == "Symbolic"
    verifier = FPSoundSymbolic
else
    error("Unknown algorithm: $(parsed_args[:alg])")
end

alg = verifier.Algorithm

solve_options = Dict{String, Any}(
    "input_range" => net_type.((0.0, 1.0)),
    "u_type" => net_type,
    "fp_sound" => parsed_args[:fp_sound],
    "backend" => parsed_args[:backend],
);


# Run the verification on a batch of instances
# The result is saved in the "Results" folder
run_on_batch(
    alg,
    nn,
    dataset[:,:,:,1:parsed_args[:num_instance]],
    output_indices[1:parsed_args[:num_instance]],
    parsed_args[:perturbation_radius];
    # Keyword arguments
    network_id = parsed_args[:net_id],
    solve_options = from_dict(verifier.Option, solve_options),
    output_path = "./Results"
)