using CPUTime
using MAT
using Dates
using Statistics

export run_on_sample, run_on_batch

function run_on_sample(
    network::Flux.Chain,
    input_point::AbstractArray{<:Real},
    true_index::Int,
    perturbation_radius::Real,
    main_algorithm, # constructor
    solve_options::AbstractOptions
)::Dict
    res = Dict()

    # Predicted index
    res[:PredictedIndex] = input_point |> network |> vec |> argmax
    
    # True index
    res[:TrueIndex] = true_index

    # Algorithm infos
    res[:MainAlgorithm] = main_algorithm
    res[:AlgorithmOptions] = solve_options

    # Get the algorithm
    algorithm = main_algorithm{solve_options.u_type}(network, input_point, true_index, perturbation_radius, solve_options)

    #CPUtic()

    # Run the algorithm
    # If the original input is adversarial, it is not necessary to run the verification
    f_core_ops = () -> begin
        if res[:TrueIndex] != res[:PredictedIndex]
            res[:Status] = ProvenAdversarial
        else
            init_bounding_algorithm!(algorithm)
            model_network_layers!(algorithm)
            decide_label_uniformity!(algorithm)
        end
    end

    # Mesure CPU time
    #solve_time = CPUtoq()

    solve_time = @elapsed f_core_ops()

    res = merge(res, algorithm.results)

    res[:SolveTime] = solve_time
    res[:Radius] = perturbation_radius
    res[:VerifiedPrecision] = solve_options.u_type

    return res
end


function run_on_batch(
    main_algorithm, # constructor
    network::Flux.Chain,
    dataset::AbstractArray{<:Real, 4},
    output_indices::AbstractVector{<:Int},
    perturbation_radius::Real;
    # Keyword arguments
    network_id::AbstractString = "",
    solve_options::AbstractOptions,
    output_path::AbstractString = "./Results"
)
    if !isdir(output_path) mkdir(output_path) end
    if !isdir(joinpath(output_path, network_id)) mkdir(joinpath(output_path, network_id)) end

    input_samples = size(dataset, 4)
    
    for input_idx in 1:input_samples
        println("Processing sample $input_idx / $input_samples ...")
        res = run_on_sample(
            network,
            dataset[:, :, :, input_idx:input_idx],
            output_indices[input_idx],
            perturbation_radius,
            main_algorithm,
            solve_options
        )

        # timestamp = Dates.format(now(), "yyyy-mm-dd_HH-MM-SS")
        alg_id = get_alg_id(main_algorithm, solve_options.fp_sound)

        output_file = joinpath(output_path, network_id, "$alg_id.csv")
        
        save_to_disk(
            input_idx,
            res,
            output_file
        )
    end

end


function save_to_disk(
    input_idx::Int,
    res::Dict,
    output_file::AbstractString
)
    header = ["SampleIndex","PredictedIndex","TrueIndex","Status", "Radius", "SolveTime", "MeanOutputDiameter", "MaxOutputDiameter", "MinOutputDiameter", "VerifiedPrecision"]
    
    out_diams = diam.(res[:OutputBounds])

    record = [
        input_idx,
        res[:PredictedIndex],
        res[:TrueIndex],
        res[:Status],
        res[:Radius],
        res[:SolveTime],
        mean(out_diams),
        maximum(out_diams),
        minimum(out_diams),
        res[:VerifiedPrecision]
    ]

    write_header = !isfile(output_file)

    header = join(header, ",")
    record = join(string.(record), ",")


    open(output_file, "a") do io
        if write_header
            write(io, header, "\n")
        end
        write(io, record, "\n")
    end
    
    return
end

function get_alg_id(algorithm, fp_sound::Bool)
    alg_name = algorithm == FPSoundIBP.Algorithm ? "IBP" : "Symbolic"
    prefix = fp_sound ? "FPSound" : ""

    return prefix * alg_name
end