module FPSoundSymbolic
import ..FPSoundVerification as FPSV
using ..FPSoundVerification

using Flux
using IntervalArithmetic


@option struct Option <: AbstractOptions
    input_range::Tuple
    u_type::Type = Float64
    fp_sound::Bool = true
    backend::String = "textbook"
end


mutable struct Algorithm{T}
    solve_options::Option
    perturbation_radius::Real
    network::Flux.Chain
    flattened_network::Vector
    vars::VariableManager{Symbol, <:Real}

    input_point::AbstractArray{<:Real}
    prop_expr::AbstractArray{IntervalAffExpr{T}}

    true_index::Int

    results::Dict{Symbol, Any}

    function Algorithm{T}(network::Flux.Chain, input_point::AbstractArray{<:Real}, true_index::Int, perturbation_radius::Real, option::Option) where T
        new(
            option,
            perturbation_radius,
            network,
            [],
            VariableManager(0, option.u_type),
            input_point,
            [],
            true_index,
            Dict{Symbol, Any}()
        )
    end

end


function register_var_and_create_input_expr!(input_domain::Interval{T}, vm::VariableManager{S, T}) where {S, T<:Real}
    new_var = register_new_var!(vm, input_domain)
    return IntervalExpr{S, T}(interval(T, 0), OrderedDict{S, Interval{T}}(new_var => interval(T, 1)))
end


function FPSV.init_bounding_algorithm!(algorithm::Algorithm)
    @assert algorithm.solve_options.backend in ["textbook", "flux"] "Unsupported backend: $(algorithm.solve_options.backend), only 'textbook' and 'flux' are supported."

    assert_type!(algorithm)

    input_domains = get_input_domain(algorithm.input_point, algorithm.perturbation_radius, algorithm.solve_options.input_range)

    algorithm.prop_expr = register_var_and_create_input_expr!.(input_domains, Ref(algorithm.vars))

    algorithm.flattened_network = normalize_network_layers(algorithm.network)
end


function apply_relu_on_exp(exp::IntervalExpr{S, T}, vars::VariableManager{S, T}) where {S, T<:Real}
    conc_bounds = exp(vars)

    if sup(conc_bounds) <= 0
        res = zero(exp)
    elseif  inf(conc_bounds) >= 0
        res = exp
    else
        iu = interval(sup(conc_bounds))
        il = interval(inf(conc_bounds))
        
        lambda = iu / (iu - il)
        
        res = exp * lambda

        M = lambda * interval(abs(inf(conc_bounds)))

        M_domain = interval(T, 0, sup(M))
        new_var_sym = get_next_var_sym(vars)

        extend_exp_with_new_var!(res, new_var_sym, interval(T, 1))
        register_new_var!(vars, M_domain)
    end

    return res
end


function model_network_layer!(layer::DotFunctor{ReLUType}, algorithm::Algorithm)
    res = similar(algorithm.prop_expr)

    for i in eachindex(algorithm.prop_expr)
        res[i] = apply_relu_on_exp(algorithm.prop_expr[i], algorithm.vars)
    end

    algorithm.prop_expr = res
end


function model_network_layer!(layer::Flux.Dense, algorithm::Algorithm)
    algorithm.prop_expr = apply_dense(algorithm.prop_expr, layer, algorithm.solve_options.u_type, algorithm.solve_options.fp_sound; implementation=algorithm.solve_options.backend)
end


function model_network_layer!(layer::Flux.Conv, algorithm::Algorithm)
    algorithm.prop_expr = apply_conv(algorithm.prop_expr, layer, algorithm.solve_options.u_type, algorithm.solve_options.fp_sound; implementation=algorithm.solve_options.backend)
end


function model_network_layer!(layer::FlattenType, algorithm::Algorithm)
    algorithm.prop_expr = Flux.flatten(algorithm.prop_expr)
end


function FPSV.model_network_layers!(algorithm::Algorithm)
    for layer in algorithm.flattened_network
        model_network_layer!(layer, algorithm)
    end

    return algorithm.prop_expr
end


concretize(x::IntervalAffExpr, vars::VariableManager) = x(vars)


function FPSV.decide_label_uniformity!(algorithm::Algorithm)
    out_bounds = concretize.(algorithm.prop_expr |> vec, Ref(algorithm.vars))

    algorithm.results[:OutputBounds] = out_bounds

    out_bounds = out_bounds .- out_bounds[algorithm.true_index]

    algorithm.results[:HostileIndices] = []

    for i in eachindex(out_bounds)
        if i != algorithm.true_index && sup(out_bounds[i]) > 0
            # It is just PossiblyAdversarial, since overestimation is significant in this algorithm.
            algorithm.results[:Status] = PossiblyAdversarial
            push!(algorithm.results[:HostileIndices], i)
        end
    end

    if isempty(algorithm.results[:HostileIndices])
        algorithm.results[:Status] = ProvenSafe
    end
    
    return
end

end # module FPSoundSymbolic