export apply_dense, apply_conv

function extend_weight_with_theta(param::AbstractArray{Interval{T}, K}, u_type::Type) where {T, K}
    if K == 4
        n = reduce(+, param .!=0, dims=[1,2,3]) |> vec
        t_s = (1, 1, 1, size(param, 4))
    elseif K == 2
        n = sum(param .!= 0, dims=2) |> vec
        t_s = (size(param, 1), 1)
    else
        throw(ArgumentError("Unsupported array dimension: $K"))
    end

    # Here we use n, instead of n-1, since the bias addition is not considered here.
    n = interval.(T, n)
    u = interval(eps(u_type))
    _theta = (n*u) ./ (interval(T, 1) .- n*u)

    #TODO: Assert n*u < 1

    theta = interval(T, -1, 1) * interval.(sup.(_theta))

    theta = reshape(theta, t_s)

    widening_theta = theta .* param

    param = param + widening_theta

    return param, vec(theta)
end


function extend_bias_with_theta(param::AbstractArray{Interval{T}, 1}, theta::AbstractArray{Interval{T}, 1}) where T
    @assert length(param) == length(theta) "Bias and theta must have the same length."

    widening_theta = theta .* param
    param = param + widening_theta

    return param
end


function fp_sound_dense_flux(x::AbstractArray{T}, layer::Flux.Dense, u_type::Type)  where T <: Union{IntervalAffExpr, Interval}
    interval_weights, interval_bias = bound_parameters(layer, u_type)

    return interval_weights * x .+ interval_bias
end


function fp_sound_conv_flux(x::AbstractArray{T, 4}, layer::Flux.Conv, u_type::Type) where T <: Union{IntervalAffExpr, Interval}
    interval_weights, interval_bias = bound_parameters(layer, u_type)
    
    reshaped_bias = reshape(interval_bias, (1, 1, length(interval_bias)))

    res = conv(x, interval_weights, stride=layer.stride, pad=layer.pad, dilation=layer.dilation, groups=layer.groups)

    res = res .+ reshaped_bias

    return res
end


function textbook_conv(x::AbstractArray{T, 4}, layer::Flux.Conv) where T <: Union{IntervalAffExpr, Interval}
    return nnlib_conv(layer, x)
end


function textbook_fp_sound_conv(x::AbstractArray{T, 4}, layer::Flux.Conv, u_type::Type) where T <: Union{IntervalAffExpr, Interval}
    interval_weights, interval_bias = bound_parameters(layer, u_type)
    return nnlib_conv(layer, x; s_kernel=interval_weights, s_bias=interval_bias)
end


function textbook_dense(x::AbstractArray{T}, layer::Flux.Dense) where T <: Union{IntervalAffExpr, Interval}
    return naive_matmul(layer.weight, x) .+ layer.bias
end


function textbook_fp_sound_dense(x::AbstractArray{T}, layer::Flux.Dense, u_type::Type) where T <: Union{IntervalAffExpr, Interval}
    interval_weights, interval_bias = bound_parameters(layer, u_type)
    return naive_matmul(interval_weights, x) .+ interval_bias
end


function bound_parameters(layer::Union{Flux.Dense, Flux.Conv}, u_type::Type)
    interval_weights, theta = extend_weight_with_theta(layer.weight .|> interval, u_type)
    interval_bias = extend_bias_with_theta(layer.bias .|> interval, theta)
    
    return interval_weights, interval_bias
end


function apply_dense(x::AbstractArray{T}, layer::Flux.Dense, u_type::Type, fp_sound::Bool; implementation::String="textbook")  where T <: Union{IntervalAffExpr, Interval}
    if implementation == "textbook"
        if fp_sound
            return textbook_fp_sound_dense(x, layer, u_type)
        else
            return textbook_dense(x, layer)
        end
    elseif implementation == "flux"
        if fp_sound
            return fp_sound_dense_flux(x, layer, u_type)
        else
            return layer(x)
        end
    else
        throw(ArgumentError("Unsupported backend implementation: $implementation"))
    end

end


function apply_conv(x::AbstractArray{T}, layer::Flux.Conv, u_type::Type, fp_sound::Bool; implementation::String="textbook")  where T <: Union{IntervalAffExpr, Interval}
    if implementation == "textbook"
        if fp_sound
            return textbook_fp_sound_conv(x, layer, u_type)
        else
            return textbook_conv(x, layer)
        end
    elseif implementation == "flux"
        if fp_sound
            return fp_sound_conv_flux(x, layer, u_type)
        else
            return layer(x)
        end
    else
        throw(ArgumentError("Unsupported backend implementation: $implementation"))
    end
end