# using Revise
# using ModelVerification
# using LazySets
using Flux
using LinearAlgebra
using Zygote
using ReverseDiff

# using Test
# using MLDatasets: CIFAR10, MNIST
# using MLUtils: splitobs, DataLoader
# import Random
# using JLD2
# using Plots
# using Profile

function f_batch(A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray)
    isa(A, AbstractMatrix) && return batched_mul(A, x)
    # @show size(x, 2), size(A, 3) 
    @assert size(x, 2) == size(A, 3) 
    @assert length(size(A)) == 3
    x = reshape(x, (size(x, 1), 1,size(x, 2)))
    return dropdims(A ⊠ x, dims=2)
    
end

function g_batch(B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray)
    # return batched_mul(B, u)
    isa(B, AbstractMatrix) && return batched_mul(B, u)
    # @show size(u, 2), size(B, 3) 
    @assert size(u, 2) == size(B, 3) 
    @assert length(size(B)) == 3
    u = reshape(u, (size(u, 1), 1,size(u, 2)))
    return dropdims(B ⊠ u, dims=2)
end

function affine_dyn_batch(A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray;Δ=nothing)
    f_x = f_batch(A, x)
    g_u = g_batch(B, u)
    ẋ = f_x + g_u
    isnothing(Δ) && (Δ = zeros(size(ẋ)))
    return ẋ + Δ 
end

function forward_invariance_func(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray; α=0,Δ=nothing)
    state_dim, batchsize = size(x)
    ẋ = affine_dyn_batch(A, x, B, u;Δ=Δ)
    ẋ = reshape(ẋ, (state_dim, 1, batchsize))
    # gradient(x -> sum(layer_output), x)[1]
    _, ∇ϕ = Zygote.pullback(ϕ, x)
    # @show ∇ϕ
    # g = ∇ϕ(ones(size(x)))[1]
    # val = ϕ(x)
    # @show g[:, 1]
    # x[1,1] = x[1,1] + 1e-6
    # @show (ϕ(x)[1] - val[1]) / 1e-6
    # @show size(ϕ(x)), size(x), size(ẋ)
    # @show size(∇ϕ(ẋ)[1])
    ∇ϕ_x = ∇ϕ(ones(size(x)))[1] ./ state_dim
    ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))
    ϕ̇ = reshape(batched_mul(∇ϕ_x, ẋ), size(ϕ(x)))
    l = ϕ̇ .+ α .* ϕ(x)
    return l
end

function forward_invariance_func_noAB(ϕ::Chain,  x::AbstractArray, ẋ::AbstractArray; α=0)
    state_dim, batchsize = size(x)
    # ẋ = dyn_model(x, u) # if support batchsize
    ẋ = reshape(ẋ, (state_dim, 1, batchsize))
    # gradient(x -> sum(layer_output), x)[1]
    _, ∇ϕ = Zygote.pullback(ϕ, x)
    ∇ϕ_x = ∇ϕ(ones(size(x)))[1] ./ state_dim
    ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))
    ϕ̇ = reshape(batched_mul(∇ϕ_x, ẋ), size(ϕ(x)))
    l = ϕ̇ .+ α .* ϕ(x)
    return l
end

function forward_invariance_func_ce(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray; α=0,Δ=nothing)
    state_dim, batchsize = size(x)
    # @show size(x)
    ẋ = affine_dyn_batch(A, x, B, u;Δ=Δ)
    # @show ẋ
    # @show size(ẋ)
    ẋ = reshape(ẋ, (state_dim, 1, batchsize))
    # gradient(x -> sum(layer_output), x)[1]
    y_m, ∇ϕ = Zygote.pullback(ϕ, x)
    # @show size(y_m)
    # @show ∇ϕ
    # g = ∇ϕ(ones(size(x)))[1]
    # val = ϕ(x)
    # @show g[:, 1]
    # x[1,1] = x[1,1] + 1e-6
    # @show (ϕ(x)[1] - val[1]) / 1e-6
    # @show size(ϕ(x)), size(x), size(ẋ)
    # @show size(∇ϕ(ẋ)[1])
    ∇ϕ_x_1 = ∇ϕ(vcat(ones(1,size(x')...), zeros(1,size(x')...)))[1] ./ state_dim
    # @show size(∇ϕ_x_1)
    ∇ϕ_x_2 = ∇ϕ(vcat(zeros(1,size(x')...), ones(1,size(x')...)))[1] ./ state_dim
    # ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))
    ∇ϕ_x = vcat(reshape(∇ϕ_x_1, (1, size(∇ϕ_x_1)...)), reshape(∇ϕ_x_2, (1, size(∇ϕ_x_2)...)))
    # @show ∇ϕ_x
    # @show ẋ
    # @show size(∇ϕ_x), size(ẋ)
    # @show size(batched_mul(∇ϕ_x, ẋ))
    ϕ̇ = reshape(batched_mul(∇ϕ_x, ẋ), size(ϕ(x)))
    # @show size(ϕ(x))
    l = ϕ̇ .+ α .* ϕ(x)
    return l
end

function loss_safe_set(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)
    return Flux.Losses.mse(max.(0, (2 .* y_init .- 1) .* ϕ(x)), 0)
end

function loss_safe_set_ce(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)
    pred_y = softmax(ϕ(x))
    # @show size(pred_y)
    y_init = y_init[1, :]
    # @show size(y_init)
    label_y = Flux.onehotbatch(y_init, 0:1)
    # @show size(label_y)
    loss = Flux.crossentropy(pred_y, label_y)
    return loss
end

function loss_naive_safeset(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)
    y_init = y_init[1, :] # safe: 1; unsafe: 0
    loss = relu((2 .* y_init .- 1) .* ϕ(x)[1, :] .+ 1e-6)
    return sum(loss) / size(loss)[end]
end

function loss_regularization(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)
    y_init = y_init[1, :] # safe: 1; unsafe: 0
    loss = sigmoid_fast((2 .* y_init .- 1) .* ϕ(x)[1, :])
    return sum(loss) / size(loss)[end]
end

function loss_forward_invariance(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray, y_cbf::AbstractArray; α=0,Δ=nothing)
    # ẋ = affine_dyn_batch(A, x, B, u)
    
    # state_dim, batchsize = size(x)
    # _, ∇ϕ = Zygote.pullback(ϕ, x)
    # ∇ϕ_x = ∇ϕ(x)[1]
    # ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))

    # ẋ = reshape(ẋ, (state_dim, 1, batchsize))

    # ϕ̇ = batched_mul(∇ϕ_x, ẋ)
    # @show y_cbf
    l = forward_invariance_func(ϕ, A, x, B, u; α,Δ=Δ)
    return Flux.Losses.mse(max.(0, l), 0)
end

function loss_naive_fi(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray, y_init::AbstractArray; use_pgd=false, use_adv = false, α=0, lr =1, num_iter=10,ϵ=0.1,Δ=nothing)
    y_init = y_init[1, :]
    index = findall(x->x==1, y_init)
    # @show index
    size(index)[1] == 0 && return 0

    x = x[:, index]
    u = u[:, index]

    if !isnothing(Δ)
        A = A[:,:, index]
        B = B[:,:, index]
        Δ = Δ[:, index]
    end

    @assert α==0
    # @show size(ϕ(x)), size(x)
    # index = findall(x1->[1 -1] * softmax(ϕ(x1))>ϵ, x)
    mask = abs.(ϕ(x)) .< ϵ
    index = findall(x->x==true, mask[1,:])
    # @show size(mask[1,:]), size(index)
    # @show mask, index
    size(index)[1] == 0 && return 0
    x = x[:, index]
    u = u[:, index]
    if !isnothing(Δ)
        A = A[:,:, index]
        B = B[:,:, index]
        Δ = Δ[:, index]
    end
    if use_adv
        X_lcoal = [Hyperrectangle(x[:, i], radius_hyperrectangle(X) ./ 20) for i=1:size(x)[2]]
        x = pgd_find_x_notce(ϕ, A, x, B, u, X_lcoal; α = α,Δ=Δ)
    end
    use_pgd && (u = pgd_find_u_notce(ϕ, A, x, B, u, U; α = α, lr =lr, num_iter=num_iter,Δ=Δ))
    # @show forward_invariance_func_ce(ϕ, A, x, B, u; α)
    loss = relu(forward_invariance_func(ϕ, A, x, B, u; α,Δ=Δ) .+ 1e-6)
    return sum(loss) / size(loss)[end]
end

function get_min_u_noAB_vertices(ϕ::Chain, x::AbstractArray, u::AbstractArray, U::Hyperrectangle,y_init::AbstractArray, dyn_model; use_pgd=false, α=0,use_adv = false, ϵ=0.1, same_x=false)
    if !same_x
        y_init = y_init[1, :]
        index = findall(x->x==1, y_init)
        # @show index
        size(index)[1] == 0 && return nothing, nothing

        x = x[:, index]
        u = u[:, index]


        @assert α==0
        # @show size(ϕ(x)), size(x)
        # index = findall(x1->[1 -1] * softmax(ϕ(x1))>ϵ, x)
        mask = abs.(ϕ(x)) .< ϵ
        index = findall(x->x==true, mask[1,:])
        # @show size(mask[1,:]), size(index)
        # @show mask, index
        size(index)[1] == 0 && return nothing, nothing
        x = x[:, index]
        u = u[:, index]
    end
    u_cand = vertices_list(U)
    # ẋ_batch = []
    ẋ_batch = zeros(size(x))
    for i in 1:size(x, 2)
        min_ϕ̇ = Inf
        min_ẋ = nothing
        for j in 1:length(u_cand)
            # @show size(x[:, i]), size(u_cand[j])
            cand_ẋ = dyn_model(x[:, i], u_cand[j])
            # @show size(cand_ẋ)
            # @show size(forward_invariance_func_noAB(ϕ,x[:, i:i],cand_ẋ; α))
            if min_ϕ̇ > forward_invariance_func_noAB(ϕ,x[:, i:i],cand_ẋ; α)[1, 1]
                min_ϕ̇ = forward_invariance_func_noAB(ϕ,x[:, i:i],cand_ẋ; α)[1, 1]
                min_ẋ = cand_ẋ
            end
        end
        ẋ_batch[:, i] .= min_ẋ
        # @show size(min_ẋ)
        # push!(ẋ_batch, min_ẋ)
    end
    # @show size(ẋ_batch)
    # ẋ = cat(ẋ_batch..., dims=2)
    # @show size(ẋ)
    # if use_adv
    #     X_lcoal = [Hyperrectangle(x[:, i], radius_hyperrectangle(X) ./ 20) for i=1:size(x)[2]]
    #     x = pgd_find_x_notce_noAB(ϕ,x, X_lcoal, ẋ_batch; α = α)
    # end
    return x,ẋ_batch
end

function loss_naive_fi_noAB(ϕ::Chain, x::AbstractMatrix, ẋ_batch::AbstractMatrix; α=0)
    # @show forward_invariance_func_ce(ϕ, A, x, B, u; α)
    loss = relu(forward_invariance_func_noAB(ϕ,x,ẋ_batch; α) .+ 1e-6)
    return sum(loss) / size(loss)[end]
end

function loss_forward_invariance_ce(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray, y_cbf::AbstractArray; α=0,Δ=nothing)
    # ẋ = affine_dyn_batch(A, x, B, u)
    
    # state_dim, batchsize = size(x)
    # _, ∇ϕ = Zygote.pullback(ϕ, x)
    # ∇ϕ_x = ∇ϕ(x)[1]
    # ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))

    # ẋ = reshape(ẋ, (state_dim, 1, batchsize))

    # ϕ̇ = batched_mul(∇ϕ_x, ẋ)
    # @show y_cbf
    l = forward_invariance_func_ce(ϕ, A, x, B, u; α,Δ=Δ)
    # loss = Flux.Losses.mse(max.(0, [1 -1] * l), 0)
    pred_y = softmax(l)
    # @show size(pred_y)
    # y_cbf = y_cbf[1, :]
    # @show size(y_init)
    y_cbf = ones(size(x[1,:]))
    label_y = Flux.onehotbatch(y_cbf, 0:1)

    # @show size(label_y)
    loss = Flux.crossentropy(pred_y, label_y)
    return loss
    # return Flux.Losses.mse(max.(0, (2 .* y_cbf .- 1) .* l), 0)
end

function loss_forward_invariance_phi(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray, y_init::AbstractArray; use_pgd=false, α=0, lr =1, num_iter=10,ϵ=0.5,Δ=nothing)

    
    y_init = y_init[1, :]
    index = findall(x->x==1, y_init)
    # @show index
    size(index)[1] == 0 && return 0

    x = x[:, index]
    u = u[:, index]
    A = A[:,:, index]
    B = B[:,:, index]
    Δ = Δ[:, index]

    @assert α==0
    # @show size(ϕ(x)), size(x)
    # index = findall(x1->[1 -1] * softmax(ϕ(x1))>ϵ, x)
    mask = [1 -1] * softmax(ϕ(x)) .> -ϵ
    index = findall(x->x==true, mask[1,:])
    # @show size(mask[1,:]), size(index)
    # @show mask, index
    size(index)[1] == 0 && return 0
    x = x[:, index]
    u = u[:, index]
    A = A[:,:, index]
    B = B[:,:, index]
    Δ = Δ[:, index]

    use_pgd && (u = pgd_find_u(ϕ, A, x, B, u, U; α = α, lr =lr, num_iter=num_iter,Δ=Δ))
    # @show forward_invariance_func_ce(ϕ, A, x, B, u; α)
    return [1 -1] * forward_invariance_func_ce(ϕ, A, x, B, u; α,Δ=Δ)
    # loss = Flux.Losses.mse(max.(0, [1 -1] * l), 0)
    # pred_y = softmax(l)
    # # @show size(pred_y)
    # # y_cbf = y_cbf[1, :]
    # # @show size(y_init)
    # y_cbf = ones(size(x[1,:]))
    # label_y = Flux.onehotbatch(y_cbf, 0:1)

    # # @show size(label_y)
    # loss = Flux.crossentropy(pred_y, label_y)
    # return loss
    # return Flux.Losses.mse(max.(0, (2 .* y_cbf .- 1) .* l), 0)
end

function verification_forward(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u_0::AbstractArray, U::Hyperrectangle; α=0, lr = 1,num_iter=10,Δ=nothing)
    # state_dim, batchsize = size(x)

    original_condition = (forward_invariance_func(ϕ, A, x, B, u_0; α,Δ=Δ) .≤ 0)
    # @show forward_invariance_func(ϕ, A, x, B, u_0, α)
    u = pgd_find_u_notce(ϕ, A, x, B, u_0, U; α = α, lr =lr, num_iter=num_iter,Δ=Δ)
    # u̲ = u
    # ẋ̲ = affine_dyn_batch(A, x, B, u̲)
    # ẋ̲ = reshape(ẋ̲, (state_dim, 1, batchsize))
    # ϕ̲̇ = reshape(batched_mul(∇ϕ_x, ẋ̲), size(ϕ(x)))
    # @show forward_invariance_func(ϕ, A, x, B, u, α)
    return original_condition, forward_invariance_func(ϕ, A, x, B, u; α,Δ=Δ) .≤ 0, u, forward_invariance_func(ϕ, A, x, B, u; α,Δ=Δ)
    # ẋ = affine_dyn_batch(A, x, B, u)
    # ẋ = reshape(ẋ, (state_dim, 1, batchsize))
    # _, ∇ϕ = Zygote.pullback(ϕ, x)
    # ∇ϕ_x = reshape(∇ϕ(x)[1], (1, state_dim, batchsize))
    # ϕ̇ = batched_mul(∇ϕ_x, ẋ)
    # # l = ϕ̇ .+ α .* ϕ(x)
    # original_condition = (forward_invariance_func(ϕ, A, x, B, u_0, α) .≤ 0)
    # # @show forward_invariance_func(ϕ, A, x, B, u_0, α)
    # u = u_0
    # for i in 1:100
    #     function l_min_u_function(u::AbstractArray)
    #         return forward_invariance_func(ϕ, A, x, B, u, α)
    #     end

    #     val, ∇l = Zygote.pullback(l_min_u_function, u)
    #     # @show size(∇l(ones(size(u)))[1])
    #     @show size(val), size(u)
    #     ∇l_u = ∇l(ones(size(u)))[1] ./ size(u)[1]
    #     # @assert size(∇l_u) == size(u)
    #     # @show ∇l_u[:,1]
    #     # u[1,1] = u[1,1] + 1
    #     # @show l_min_u_function(u)[1]
    #     # @show (l_min_u_function(u)[1] - val[1]) / lr
    #     u = u - lr .* ∇l_u
    #     u = low(U) .+ relu(u .- low(U))
    #     u = high(U) .- relu(high(U) .- u)
    #     # @show val[1]
    #     # @show l_min_u_function(u)[1]
    # end
    # # u̲ = u
    # # ẋ̲ = affine_dyn_batch(A, x, B, u̲)
    # # ẋ̲ = reshape(ẋ̲, (state_dim, 1, batchsize))
    # # ϕ̲̇ = reshape(batched_mul(∇ϕ_x, ẋ̲), size(ϕ(x)))
    # # @show forward_invariance_func(ϕ, A, x, B, u, α)
    # return original_condition, forward_invariance_func(ϕ, A, x, B, u, α) .≤ 0, u
end

function verification_forward_ce(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u_0::AbstractArray, U::Hyperrectangle; α=0, lr = 1,num_iter=10,Δ=nothing)
    # state_dim, batchsize = size(x)


    # ẋ = affine_dyn_batch(A, x, B, u)
    # ẋ = reshape(ẋ, (state_dim, 1, batchsize))
    # _, ∇ϕ = Zygote.pullback(ϕ, x)
    # ∇ϕ_x = reshape(∇ϕ(x)[1], (1, state_dim, batchsize))
    # ϕ̇ = batched_mul(∇ϕ_x, ẋ)
    # l = ϕ̇ .+ α .* ϕ(x)
    original_condition = (forward_invariance_func_ce(ϕ, A, x, B, u_0; α,Δ=Δ) .≤ 0)
    # @show forward_invariance_func(ϕ, A, x, B, u_0, α)
    u = pgd_find_u(ϕ, A, x, B, u_0, U; α = α, lr =lr, num_iter=num_iter,Δ=Δ)
    # u̲ = u
    # ẋ̲ = affine_dyn_batch(A, x, B, u̲)
    # ẋ̲ = reshape(ẋ̲, (state_dim, 1, batchsize))
    # ϕ̲̇ = reshape(batched_mul(∇ϕ_x, ẋ̲), size(ϕ(x)))
    # @show forward_invariance_func(ϕ, A, x, B, u, α)
    return original_condition, forward_invariance_func_ce(ϕ, A, x, B, u; α,Δ=Δ) .≤ 0, u, forward_invariance_func_ce(ϕ, A, x, B, u; α,Δ=Δ)
end

function pgd_find_u(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u_0::AbstractArray, U::Hyperrectangle; α=0, lr = 1,num_iter=10,Δ=nothing)
    u = u_0
    # u_1
    for i in 1:num_iter
        # l_min_u_function_ce = nothing
        # @show i, u
        function l_min_u_function_ce(u1::AbstractArray)
            # @show size(forward_invariance_func_ce(ϕ, A, x, B, u, α))
            return forward_invariance_func_ce(ϕ, A, x, B, u1; α,Δ=Δ)
        end
        

        val, ∇l = Zygote.pullback(l_min_u_function_ce, u)
        # @show u, val, l_min_u_function_ce
        # @show size(val), size(u)
        # @show size(∇l(ones((2, 128, 2)))[1] ./ size(u)[1])
        ∇l_u_1 = ∇l(vcat(ones(1,size(u')...), zeros(1,size(u')...)))[1] ./ size(u)[1]
        # @show size(∇ϕ_x_1)
        ∇l_u_2 = ∇l(vcat(zeros(1,size(u')...), ones(1,size(u')...)))[1] ./ size(u)[1]
        # ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))
        ∇l_u = vcat(reshape(∇l_u_1, (1, size(∇l_u_1)...)), reshape(∇l_u_2, (1, size(∇l_u_2)...)))

        # @show size(∇l(ones(size(u)))[1])
        # ∇l_u = ∇l(ones(size(u)))[1] ./ size(u)[1]

        ones_batch = ones(1,size(u)...)
        ones_batch[:, 2, :] = (-1) .* ones_batch[:, 2, :]
        # @show ones_batch[:,:,1]
        # @show ∇l_u
        ∇l_u = batched_mul(ones_batch, ∇l_u)[1, :, :]
        # @show size(∇l_u)
        # u[1,1] = u[1,1] + 1
        # @show l_min_u_function(u)[1]
        # @show (l_min_u_function(u)[1] - val[1]) / lr
        u_old = copy(u)
        u = u - lr .* ∇l_u
        u = low(U) .+ relu(u .- low(U))
        u = high(U) .- relu(high(U) .- u)
        if u == u_old
            # @show i 
            break
        end
        # @show "ddd", u
        # @show val[1]
        # @show l_min_u_function(u)[1]
    end
    return u
end

function pgd_find_u_notce(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u_0::AbstractArray, U::Hyperrectangle; α=0, lr = 1,num_iter=10,Δ=nothing)
    u = u_0
    # u_1
    for i in 1:num_iter
        # l_min_u_function_ce = nothing
        # @show i, u
        function l_min_u_function(u1::AbstractArray)
            # @show size(forward_invariance_func_ce(ϕ, A, x, B, u, α))
            return forward_invariance_func(ϕ, A, x, B, u1; α,Δ=Δ)
        end
        

        val, ∇l = Zygote.pullback(l_min_u_function, u)
        state_dim, batchsize = size(u)
        # _, ∇ϕ = Zygote.pullback(ϕ, x)
        # @show ∇ϕ
        # g = ∇ϕ(ones(size(x)))[1]
        # val = ϕ(x)
        # @show g[:, 1]
        # x[1,1] = x[1,1] + 1e-6
        # @show (ϕ(x)[1] - val[1]) / 1e-6
        # @show size(ϕ(x)), size(x), size(ẋ)
        # @show size(∇ϕ(ẋ)[1])
        ∇l_u = ∇l(ones(size(u)))[1] ./ state_dim
        ∇l_u = reshape(∇l_u, (state_dim, batchsize))

        # ∇l_u = vcat(reshape(∇l_u_1, (1, size(∇l_u_1)...)), reshape(∇l_u_2, (1, size(∇l_u_2)...)))

        # # @show size(∇l(ones(size(u)))[1])
        # # ∇l_u = ∇l(ones(size(u)))[1] ./ size(u)[1]

        # ones_batch = ones(1,size(u)...)
        # ones_batch[:, 2, :] = (-1) .* ones_batch[:, 2, :]
        # # @show ones_batch[:,:,1]
        # # @show ∇l_u
        # ∇l_u = batched_mul(ones_batch, ∇l_u)[1, :, :]
        # # @show size(∇l_u)
        # # u[1,1] = u[1,1] + 1
        # # @show l_min_u_function(u)[1]
        # # @show (l_min_u_function(u)[1] - val[1]) / lr
        u_old = copy(u)
        u = u - lr .* ∇l_u
        u = low(U) .+ relu(u .- low(U))
        u = high(U) .- relu(high(U) .- u)
        if u == u_old
            # @show i 
            break
        end
        # @show "ddd", u
        # @show val[1]
        # @show l_min_u_function(u)[1]
    end
    return u
end

function pgd_find_x_notce(ϕ::Chain, A::Union{AbstractMatrix,AbstractArray}, x_0::AbstractArray, B::Union{AbstractMatrix,AbstractArray}, u::AbstractArray, X::Vector; α=0, lr = 0.01,num_iter=10,Δ=nothing)
    x = x_0
    # u_1
    low_X = [low(X[i]) for i = 1:length(X)]
    high_X = [high(X[i]) for i = 1:length(X)]
    # low_X = hcat(low_X...)
    # high_X = hcat(high_X...)
    low_X = reduce(hcat, low_X)
    high_X = reduce(hcat, high_X)
    for i in 1:num_iter
        # l_min_u_function_ce = nothing
        # @show i, u
        function l_max_x_function(x1::AbstractArray)
            # @show size(forward_invariance_func_ce(ϕ, A, x, B, u, α))
            return forward_invariance_func(ϕ, A, x1, B, u; α,Δ=Δ)
        end
        

        val, ∇l = Zygote.pullback(l_max_x_function, x)
        state_dim, batchsize = size(x)
        # _, ∇ϕ = Zygote.pullback(ϕ, x)
        # @show ∇ϕ
        # g = ∇ϕ(ones(size(x)))[1]
        # val = ϕ(x)
        # @show g[:, 1]
        # x[1,1] = x[1,1] + 1e-6
        # @show (ϕ(x)[1] - val[1]) / 1e-6
        # @show size(ϕ(x)), size(x), size(ẋ)
        # @show size(∇ϕ(ẋ)[1])
        ∇l_x = ∇l(ones(size(x)))[1] ./ state_dim
        ∇l_x = reshape(∇l_x, (state_dim, batchsize))

        # ∇l_u = vcat(reshape(∇l_u_1, (1, size(∇l_u_1)...)), reshape(∇l_u_2, (1, size(∇l_u_2)...)))

        # # @show size(∇l(ones(size(u)))[1])
        # # ∇l_u = ∇l(ones(size(u)))[1] ./ size(u)[1]

        # ones_batch = ones(1,size(u)...)
        # ones_batch[:, 2, :] = (-1) .* ones_batch[:, 2, :]
        # # @show ones_batch[:,:,1]
        # # @show ∇l_u
        # ∇l_u = batched_mul(ones_batch, ∇l_u)[1, :, :]
        # # @show size(∇l_u)
        # # u[1,1] = u[1,1] + 1
        # # @show l_min_u_function(u)[1]
        # # @show (l_min_u_function(u)[1] - val[1]) / lr
        x_old = copy(x)
        x = x + lr .* ∇l_x
        x = low_X .+ relu(x .- low_X)
        x = high_X .- relu(high_X .- x)
        if x == x_old
            # @show i 
            break
        end
        # @show "ddd", u
        # @show val[1]
        # @show l_min_u_function(u)[1]
    end
    return x
end

# function pgd_find_u_notce_noAB(ϕ::Chain, x::AbstractArray, u_0::AbstractArray, U::Hyperrectangle, ẋ::AbstractArray; α=0, lr = 1,num_iter=10)
#     u = u_0
    
#     # u_1
#     for i in 1:num_iter
#         # l_min_u_function_ce = nothing
#         # @show i, u
#         function l_min_u_function(u1::AbstractArray)
#             # @show size(forward_invariance_func_ce(ϕ, A, x, B, u, α))
#             return forward_invariance_func_noAB(ϕ, x, u1,ẋ; α)
#         end
        

#         val, ∇l = Zygote.pullback(l_min_u_function, u)
#         state_dim, batchsize = size(u)
#         # _, ∇ϕ = Zygote.pullback(ϕ, x)
#         # @show ∇ϕ
#         # g = ∇ϕ(ones(size(x)))[1]
#         # val = ϕ(x)
#         # @show g[:, 1]
#         # x[1,1] = x[1,1] + 1e-6
#         # @show (ϕ(x)[1] - val[1]) / 1e-6
#         # @show size(ϕ(x)), size(x), size(ẋ)
#         # @show size(∇ϕ(ẋ)[1])
#         ∇l_u = ∇l(ones(size(u)))[1] ./ state_dim
#         ∇l_u = reshape(∇l_u, (state_dim, batchsize))

#         # ∇l_u = vcat(reshape(∇l_u_1, (1, size(∇l_u_1)...)), reshape(∇l_u_2, (1, size(∇l_u_2)...)))

#         # # @show size(∇l(ones(size(u)))[1])
#         # # ∇l_u = ∇l(ones(size(u)))[1] ./ size(u)[1]

#         # ones_batch = ones(1,size(u)...)
#         # ones_batch[:, 2, :] = (-1) .* ones_batch[:, 2, :]
#         # # @show ones_batch[:,:,1]
#         # # @show ∇l_u
#         # ∇l_u = batched_mul(ones_batch, ∇l_u)[1, :, :]
#         # # @show size(∇l_u)
#         # # u[1,1] = u[1,1] + 1
#         # # @show l_min_u_function(u)[1]
#         # # @show (l_min_u_function(u)[1] - val[1]) / lr
#         u_old = copy(u)
#         u = u - lr .* ∇l_u
#         u = low(U) .+ relu(u .- low(U))
#         u = high(U) .- relu(high(U) .- u)
#         if u == u_old
#             # @show i 
#             break
#         end
#         # @show "ddd", u
#         # @show val[1]
#         # @show l_min_u_function(u)[1]
#     end
#     return u
# end

function pgd_find_x_notce_noAB(ϕ::Chain, x_0::AbstractArray, X::Vector,ẋ::AbstractArray; α=0, lr = 0.01,num_iter=10)
    x = x_0
    # u_1
    low_X = [low(X[i]) for i = 1:length(X)]
    high_X = [high(X[i]) for i = 1:length(X)]
    low_X = reduce(hcat, low_X)
    high_X = reduce(hcat, high_X)
    # low_X = hcat(low_X...)
    # high_X = hcat(high_X...)
    for i in 1:num_iter
        # l_min_u_function_ce = nothing
        # @show i, u
        function l_max_x_function(x1::AbstractArray)
            # @show size(forward_invariance_func_ce(ϕ, A, x, B, u, α))
            return forward_invariance_func_noAB(ϕ, x1,ẋ; α)
        end
        

        val, ∇l = Zygote.pullback(l_max_x_function, x)
        state_dim, batchsize = size(x)
        # _, ∇ϕ = Zygote.pullback(ϕ, x)
        # @show ∇ϕ
        # g = ∇ϕ(ones(size(x)))[1]
        # val = ϕ(x)
        # @show g[:, 1]
        # x[1,1] = x[1,1] + 1e-6
        # @show (ϕ(x)[1] - val[1]) / 1e-6
        # @show size(ϕ(x)), size(x), size(ẋ)
        # @show size(∇ϕ(ẋ)[1])
        ∇l_x = ∇l(ones(size(x)))[1] ./ state_dim
        ∇l_x = reshape(∇l_x, (state_dim, batchsize))

        # ∇l_u = vcat(reshape(∇l_u_1, (1, size(∇l_u_1)...)), reshape(∇l_u_2, (1, size(∇l_u_2)...)))

        # # @show size(∇l(ones(size(u)))[1])
        # # ∇l_u = ∇l(ones(size(u)))[1] ./ size(u)[1]

        # ones_batch = ones(1,size(u)...)
        # ones_batch[:, 2, :] = (-1) .* ones_batch[:, 2, :]
        # # @show ones_batch[:,:,1]
        # # @show ∇l_u
        # ∇l_u = batched_mul(ones_batch, ∇l_u)[1, :, :]
        # # @show size(∇l_u)
        # # u[1,1] = u[1,1] + 1
        # # @show l_min_u_function(u)[1]
        # # @show (l_min_u_function(u)[1] - val[1]) / lr
        x_old = copy(x)
        x = x + lr .* ∇l_x
        x = low_X .+ relu(x .- low_X)
        x = high_X .- relu(high_X .- x)
        if x == x_old
            # @show i 
            break
        end
        # @show "ddd", u
        # @show val[1]
        # @show l_min_u_function(u)[1]
    end
    return x
end