export IntervalExpr, IntervalAffExpr
export extend_exp_with_new_var!

Flux.relu(iv::Interval{T}) where T = interval(iv |> inf |> Flux.relu, iv |> sup |> Flux.relu)

# IntervalExpr struct for interval linear expressions
struct IntervalExpr{S, T<:Real}
    constant::Interval{T}
    coeffs::OrderedDict{S, Interval{T}}
end

# The main type for interval linear expressions
IntervalAffExpr{T} = IntervalExpr{Symbol, T} where T<:Real

IntervalExpr(constant::Interval{T}) where T <: Real = IntervalExpr{Symbol, T}(convert(Interval{T}, constant), OrderedDict{Symbol, Interval{T}}())
IntervalExpr(constant::T) where T <: Real = IntervalExpr(interval(T, constant))

IntervalExpr{Symbol, T}(constant::Interval{T}) where T <: Real = IntervalAffExpr{T}(constant, OrderedDict{Symbol, Interval{T}}())
IntervalExpr{Symbol, T}(constant::T) where T <: Real = IntervalAffExpr{T}(constant, OrderedDict{Symbol, Interval{T}}())


function (exp::IntervalExpr{S, T})(x::OrderedDict{S, Interval{T}}) where {S, T<:Real}
    result = exp.constant
    for (key, coeff) in exp.coeffs
        if haskey(x, key)
            result += coeff * x[key]
        else
            throw(ArgumentError("Variable $key not found in the provided dictionary."))
        end
    end
    return result
end

(exp::IntervalExpr{S, T})(x::VariableManager{S, T}) where {S, T<:Real} = exp(x.vars)

# Operators for interval linear expressions
# Addition
function Base.:+(a::IntervalExpr{S, T}, b::IntervalExpr{S, T}) where {S, T<:Real}
    new_constant = a.constant + b.constant
    new_coeffs = OrderedDict(a.coeffs)
    for (key, coeff) in b.coeffs
        if haskey(new_coeffs, key)
            new_coeffs[key] += coeff
        else
            new_coeffs[key] = coeff
        end
    end
    return IntervalExpr{S, T}(new_constant, new_coeffs)
end

Base.:+(a::IntervalExpr{S, T}, b::Union{Interval{T}, T}) where {S, T<:Real} = IntervalExpr{S, T}(a.constant + interval(T, b), a.coeffs)
Base.:+(a::Union{Interval{T}, T}, b::IntervalExpr{S, T}) where {S, T<:Real} = b + a

# Multiplication
function Base.:*(a::IntervalExpr{S, T}, b::Union{Interval{T}, T}) where {S, T<:Real}
    b = interval(T, b)
    new_constant = a.constant * b
    new_coeffs = OrderedDict{S, Interval{T}}()

    if !iszero(b)
        for (key, coeff) in a.coeffs
            new_coeffs[key] = coeff * b
        end
    end

    return IntervalExpr{S, T}(new_constant, new_coeffs)
end

Base.:*(a::Union{Interval{T}, T}, b::IntervalExpr{S, T}) where {S, T<:Real} = b * a

function Base.:*(a::IntervalExpr{S,T}, b::IntervalExpr{S, T}) where {S, T <: Real}
    if !(is_constant(a) || is_constant(b))
        throw(ArgumentError("Multiplication of two IntervalExprs is not supported unless one is constant."))
    end

    if is_constant(a)
        return b * a.constant
    else
        return a * b.constant
    end

end

Base.:*(a::Bool, b::IntervalExpr{S, T}) where {S, T<:Real} = b * (a ? 1.0 : 0.0)

is_constant(exp::IntervalExpr{S, T}) where {S, T<:Real} = isempty(exp.coeffs)

# Additional methods for zero and one
Base.zero(::Type{IntervalExpr{S, T}}) where {S, T<:Real} = IntervalExpr{S, T}(interval(T, 0.0))
Base.zero(x::IntervalExpr{S, T}) where {S, T<:Real} = zero(typeof(x))

Base.one(::Type{IntervalExpr{S, T}}) where {S, T<:Real} = IntervalExpr{S, T}(interval(T, 1.0))
Base.one(x::IntervalExpr{S, T}) where {S, T<:Real} = one(typeof(x))

# Conversion and promotion rules
Base.convert(::Type{IntervalExpr{S, T}}, x::Interval{T}) where {S, T<:Real} = IntervalExpr{S, T}(x)
Base.convert(::Type{IntervalExpr{S, T}}, x::T) where {S, T<:Real} = IntervalExpr{S, T}(interval(T, x))

Base.promote_rule(::Type{IntervalExpr{S, T}}, ::Type{Interval{T}}) where {S, T<:Real} = IntervalExpr{S, T}
Base.promote_rule(::Type{IntervalExpr{S, T}}, ::Type{T}) where {S, T<:Real} = IntervalExpr{S, T}

Base.iszero(x::IntervalExpr{S, T}) where {S, T<:Real} = iszero(x.constant) && isempty(x.coeffs)

# Show method for printing
function Base.show(io::IO, exp::IntervalExpr{S, T}) where {S, T<:Real}
    print(io, "IntervalExpr(constant: ", exp.constant, ", coeffs: {")
    first = true
    for (key, coeff) in exp.coeffs
        if !first
            print(io, ", ")
        end
        print(io, "$key => $coeff")
        first = false
    end
    print(io, "})")
end


function extend_exp_with_new_var!(exp::IntervalExpr{S, T}, new_var::S, coeff::Interval{T}) where {S, T<:Real}
    if haskey(exp.coeffs, new_var)
        throw(ArgumentError("Variable $new_var already exists in the expression."))
    else
        exp.coeffs[new_var] = coeff
    end
    return exp
end

# function bound_forward_error(input::Array{Interval{Float64}}, layer::Union{Flux.Dense, Flux.Conv}, u_type::Type)
#     K = ndims(layer.weight)

#     # Here we use an upper bound on the exact n, since the zero input elements are not considered.
#     if K == 4
#         # We can compute theta with 'n' (instead of n-1), since the bias addition is not considered here.
#         n = reduce(+, layer.weight .!= 0, dims=[1,2,3]) |> vec
#         t_s = (1, 1, size(layer.weight, 4), 1)
#     elseif K == 2
#         n = sum(layer.weight .!= 0, dims=2) |> vec
#         t_s = (size(layer.weight, 1),)
#     else
#         throw(ArgumentError("Unsupported array dimension: $K"))
#     end

#     max_inp = map(x -> max(abs(inf(x)), abs(sup(x))), input) .|> interval
#     abs_weight = abs.(layer.weight)
#     abs_bias = abs.(layer.bias)

#     if layer isa Flux.Conv
#         reshaped_bias = reshape(abs_bias, (1, 1, length(abs_bias)))
#         out = conv(max_inp, abs_weight; stride=layer.stride, pad=layer.pad, dilation=layer.dilation, groups=layer.groups)
#         out = out .+ reshaped_bias
#     else
#         out = abs_weight * max_inp .+ abs_bias
#     end

#     # TODO: This can be skipped ... 
#     out_upper = out .|> sup .|> interval

#     # TODO: Write a separete function for this
#     n = interval.(n)
#     u = interval(eps(u_type))
#     theta = (n*u) ./ (interval(1) .- n*u)
#     theta = reshape(theta, t_s)

#     error_upper_bound = sup.(out_upper .* theta)

#     return error_upper_bound
# end