export VariableManager, register_new_var!, get_next_var_sym
export Decision, Unknown, ProvenSafe, PossiblySafe, PossiblyAdversarial, ProvenAdversarial
export normalize_network_layers, DotFunctor, f_interval_max
export FlattenType, LinearLayer, ReLUType
export get_input_domain
export assert_type!

import Setfield: @set

# Struct for managing variables
mutable struct VariableManager{S, T<:Real}
    vars::OrderedDict{S, Interval{T}}
    next_var_ID::Int64
end

#VariableManager{T}() = VariableManager{Symbol, T}(OrderedDict{Symbol, T}(), 0)
#VariableManager{T}(first_var_id::Int) = VariableManager{Symbol, T}(OrderedDict{Symbol, T}(), first_var_id)
VariableManager(first_var_id::Int, ::Type{T}) where T = VariableManager{Symbol, T}(OrderedDict{Symbol, T}(), first_var_id)

function register_new_var!(vm::VariableManager{S, T}, domain::Interval{T}) where {S, T<:Real}
    new_var = Symbol("x_", vm.next_var_ID)
    vm.vars[new_var] = domain
    vm.next_var_ID += 1
    return new_var
end

get_next_var_sym(vm::VariableManager{S, T}) where {S, T<:Real} = Symbol("x_", vm.next_var_ID)


# Decision outputs
@enum Decision begin
    ProvenSafe          # the result is safe with mathematical certainty
    ProvenAdversarial   # the result is adversarial with certainty (e.g. an example is given)
    Unknown             # the result is unknown
    PossiblySafe        # the result is safe, but it is not rigorously proven
    PossiblyAdversarial # the result is adversarial, but it is not rigorously proven
end



# From SoundVerificationSuite.jl - Utils.jl
"""
    Interval ordering mapper (for map-reduce) on (iv::Interval, is_disjoint::Bool) pairs

Example use:
```
    max_interval, is_disjoint = reduce(f_interval_max, intervals, init=(intervals[1], true))
```

Where `max_interval` is only valid if it is disjoint. Otherwise it is a union of intervals.
"""
function f_interval_max(lhs::Interval, rhs::Interval)::Tuple{Interval, Bool}
    # transfer inputs to magical monad-land
    return f_interval_max((lhs, true), rhs)
end

function f_interval_max(monad::Tuple{Interval, Bool}, iv::Interval)::Tuple{Interval, Bool}
    if intersect(monad[1],iv) == ∅
        if monad[1] > iv
            monad
        else
            return (iv, true)
        end
    else
        return (union(monad[1], iv), false)
    end
end

function flatten_network(netlayer::Flux.Chain)
    sublayers = [flatten_network(sublayer) for sublayer in netlayer]
    flat_sublayers = reduce((arr,v) -> (typeof(v) <: Vector) ? push!(arr, v...) : push!(arr, v), sublayers, init=[])
    return flat_sublayers
end

# nullcase, an unknown layer is left alone
function flatten_network(netlayer)
    return netlayer
end


struct DotFunctor{T <: Function}
    f::T
end
(ftor::DotFunctor{T})(arg) where T = ftor.f.(arg)


function functorize_layer(layer)
    non_vectorized_functions = [typeof(Flux.flatten)]
    if typeof(layer) <: Function && !(typeof(layer) in non_vectorized_functions)
        # `x -> f(x)` type functions must be vectorized
        # (currently flatten is the only function that does not need vectorization)
        return DotFunctor(layer)
    else
        return layer
    end
end


function unwrap_embedded_activation!(unwrapped_layers::Vector{Any}, layer::Union{Flux.Dense, Flux.Conv})
    if typeof(layer.σ) == typeof(Flux.identity)
        push!(unwrapped_layers, layer)
    else
        activation = layer.σ
        layer = @set layer.σ = Flux.identity # creates a copy of the immutable struct, but with the given change
        push!(unwrapped_layers, layer)
        push!(unwrapped_layers, activation)
    end
end

function unwrap_embedded_activation!(unwrapped_layers::Vector{Any}, layer)
    push!(unwrapped_layers, layer)
end


"""
Flattens the network and applies function vectorization where needed.
"""
function normalize_network_layers(network::Flux.Chain)
    flat_layers = flatten_network(network)
    unwrapped_layers = []
    foreach(layer -> unwrap_embedded_activation!(unwrapped_layers, layer), flat_layers)
    functored_layers = map(functorize_layer, unwrapped_layers)
    return functored_layers
end

const LinearLayer = Union{Flux.Dense, Flux.Conv} # others can be added, but algorithms might not support them
const ReLUType = typeof(Flux.relu)
const FlattenType = typeof(Flux.flatten)


function get_input_domain(input_point::AbstractArray{T}, perturbation_radius::T, input_range::Tuple{T,T}) where T<:Real
    # c_radius = convert(T, perturbation_radius)

    input_lower_domain = max.(input_point .- perturbation_radius, input_range[1])
    input_upper_domain = min.(input_point .+ perturbation_radius, input_range[2])
    
    return interval.(input_lower_domain, input_upper_domain)
end


function assert_type(x::Flux.Chain, ::Type{T}) where T <: Union{Float32, Float64}
    conversion = T == Float32 ? f32 : f64

    for l in x
        if (isa(x, Flux.Dense) || isa(x, Flux.Conv)) && (eltype(l.weight) != T || eltype(l.bias) != T)
            @warn "The network is not of type $T. Converting layer weights and biases to $T. You should consider defining the network with $T from the start!"
            return conversion(x)
        end
    end

    return x
end


function assert_type(x::AbstractArray{<:Real}, ::Type{T}) where T <: Union{Float32, Float64}

    if eltype(x) != T
        @warn "The input is not of type $T. Converting input to $T. You should consider defining the input with $T from the start!"
        return convert.(T, x)
    end

    return x
end


function assert_type(x::K, ::Type{T}) where {T <: Union{Float32, Float64}, K <: Real}
    if isa(x, T)
        return x
    else
        @warn "The perturbation radius is not of type $T. Converting perturbation radius to $T. You should consider defining the perturbation radius with $T from the start!"
        return convert(T, x)
    end
end


function assert_type!(alg)
    alg.network = assert_type(alg.network, alg.solve_options.u_type)
    alg.input_point = assert_type(alg.input_point, alg.solve_options.u_type)
    alg.perturbation_radius = assert_type(alg.perturbation_radius, alg.solve_options.u_type)

    return 
end