module FPSoundIBP

import ..FPSoundVerification as FPSV
using ..FPSoundVerification

using Flux
using IntervalArithmetic

@option struct Option <: AbstractOptions
    input_range::Tuple
    u_type::Type = Float64
    fp_sound::Bool = true
    backend::String = "textbook"
end


mutable struct Algorithm{T}
    solve_options::Option
    perturbation_radius::Real
    network::Flux.Chain
    flattened_network::Vector

    input_point::AbstractArray{<:Real}
    prop_expr::AbstractArray{Interval{T}}

    true_index::Int

    results::Dict{Symbol, Any}

    function Algorithm{T}(network::Flux.Chain, input_point::AbstractArray{<:Real}, true_index::Int, perturbation_radius::Real, option::Option) where T
        new(
            option,
            perturbation_radius,
            network,
            [],
            input_point,
            [],
            true_index,
            Dict{Symbol, Any}()
        )
    end

end


function FPSV.init_bounding_algorithm!(algorithm::Algorithm)
    @assert algorithm.solve_options.backend in ["textbook", "flux"] "Unsupported backend: $(algorithm.solve_options.backend), only 'textbook' and 'flux' are supported."

    assert_type!(algorithm)

    algorithm.prop_expr = get_input_domain(algorithm.input_point, algorithm.perturbation_radius, algorithm.solve_options.input_range)
    # algorithm.prop_expr = convert.(Interval{Float64}, input_domains)

    algorithm.flattened_network = normalize_network_layers(algorithm.network)
end


function model_network_layer!(layer::DotFunctor{ReLUType}, algorithm::Algorithm)
    res = similar(algorithm.prop_expr)

    for i in eachindex(algorithm.prop_expr)
        res[i] = Flux.relu(algorithm.prop_expr[i])
    end

    algorithm.prop_expr = res
end


function model_network_layer!(layer::Flux.Dense, algorithm::Algorithm)
    algorithm.prop_expr = apply_dense(algorithm.prop_expr, layer, algorithm.solve_options.u_type, algorithm.solve_options.fp_sound; implementation=algorithm.solve_options.backend)
end


function model_network_layer!(layer::Flux.Conv, algorithm::Algorithm)
    algorithm.prop_expr = apply_conv(algorithm.prop_expr, layer, algorithm.solve_options.u_type, algorithm.solve_options.fp_sound; implementation=algorithm.solve_options.backend)
end


function model_network_layer!(layer::FlattenType, algorithm::Algorithm)
    algorithm.prop_expr = Flux.flatten(algorithm.prop_expr)
end


function FPSV.model_network_layers!(algorithm::Algorithm)
    for layer in algorithm.flattened_network
        model_network_layer!(layer, algorithm)

        # # Debug
        # if layer isa Flux.Dense && length(layer.bias) == 101
        #     println(algorithm.prop_expr[end])
        # end
    end

    return algorithm.prop_expr
end


function FPSV.decide_label_uniformity!(algorithm::Algorithm)
    out_bounds = algorithm.prop_expr |> vec

    algorithm.results[:OutputBounds] = out_bounds

    out_bounds = out_bounds .- out_bounds[algorithm.true_index]
 
    algorithm.results[:HostileIndices] = []

    for i in eachindex(out_bounds)
        if i != algorithm.true_index && sup(out_bounds[i]) > 0
            # It is just PossiblyAdversarial, since overestimation is significant in this algorithm.
            algorithm.results[:Status] = PossiblyAdversarial
            push!(algorithm.results[:HostileIndices], i)
        end
    end

    if isempty(algorithm.results[:HostileIndices])
        algorithm.results[:Status] = ProvenSafe
    end

    return
end

end # module FPSoundIBP