using DifferentialEquations
using LazySets
# using ProgressMeter
using ProgressBars
using JLD2
# import Random

# f(u, p, t) = 1.01 * u
# u0 = 1 / 2
# tspan = (0.0, 1.0)
# prob = ODEProblem(f, u0, tspan)
# sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)
# using Plots
# plot(sol, linewidth = 5, title = "Solution to the linear ODE with a thick line",
#     xaxis = "Time (t)", yaxis = "u(t) (in μm)", label = "My Thick Line!") # legend=false
# plot!(sol.t, t -> 0.5 * exp(1.01t), lw = 3, ls = :dash, label = "True Solution!")

function find_ada_input_area(U::Hyperrectangle, random_point)
    if length(random_point) != 4 || dim(U) != 2
        println("only support 2D double-integrator")
        return U
    end
    random_point[3] > 0 ? ux_min = low(U, 1) : ux_min = 0
    random_point[3] > 0 ? ux_max = 0 : ux_max = high(U, 1)
    random_point[4] > 0 ? uy_min = low(U, 2) : uy_min = 0
    random_point[4] > 0 ? uy_max = 0 : uy_max = high(U, 2)
    output = Hyperrectangle(low=[ux_min, uy_min], high=[ux_max, uy_max])
    return output
    
end

function find_ada_non_admissible_area(non_admissible_area::Hyperrectangle, random_point, input_bound)
    if length(random_point) != 4 || dim(non_admissible_area) != 4 || dim(input_bound) != 2
        println("only support 2D double-integrator")
        return non_admissible_area
    end
    x_min = low(non_admissible_area, 1) + max(0, random_point[3]) * random_point[3] / (2 * low(input_bound, 1))
    x_max = high(non_admissible_area, 1) + max(0, -random_point[3]) * (-random_point[3]) / (2 * high(input_bound, 1))
    y_min = low(non_admissible_area, 2) + max(0, random_point[4]) * random_point[4] / (2 * low(input_bound, 2))
    y_max = high(non_admissible_area, 2) + max(0, -random_point[4]) * (-random_point[4]) / (2 * high(input_bound, 2))
    # @show non_admissible_area
    @assert x_min ≤ low(non_admissible_area, 1)
    @assert y_min ≤ low(non_admissible_area, 2)
    @assert x_max ≥ high(non_admissible_area, 1)
    @assert y_max ≥ high(non_admissible_area, 2)
    output = Hyperrectangle(low=[x_min, y_min, low(non_admissible_area, 3),low(non_admissible_area, 3)], high=[x_max, y_max, high(non_admissible_area,3), high(non_admissible_area,4)])
    # @show random_point, output
    return output
end


function random_point_in_hyperrectangle(hyperrectangle::Hyperrectangle, non_admissible_area=nothing, input_bound=nothing)
    dimensions = dim(hyperrectangle)
    # while true
    random_point = Vector{Float32}(undef, dimensions)
    for i in 1:dimensions
        random_point[i] = rand() * (high(hyperrectangle, i)-low(hyperrectangle, i)) + low(hyperrectangle, i)
    end
    isnothing(non_admissible_area) && return random_point, true
    # @show non_admissible_area, random_point, input_bound
    ada_non_admissible_area = find_ada_non_admissible_area(non_admissible_area, random_point, input_bound)
    (random_point ∉ ada_non_admissible_area) && return random_point, true
    return random_point, false
        # continue
    # end
end

function affine_dyn(A::AbstractMatrix, x::AbstractArray, B::AbstractMatrix, u::AbstractArray)
    ẋ = A * x + B * u
    return ẋ
end

function generate_data(A::AbstractMatrix, X::Hyperrectangle, B::AbstractMatrix, U::Hyperrectangle, X_unsafe::Hyperrectangle, num_x::Int, num_u::Int, violated_unknown_list::AbstractVector)
    data = [] # x,u,ini_safe,cbf_condition
    # unsafe_data = []
    tspan = (0.0, 0.1)
    if isnothing(violated_unknown_list)
        for i in ProgressBar(1:num_x)
            # println(i)
            if rand() > 0.2
                random_x0, safe_flag = random_point_in_hyperrectangle(X, X_unsafe, U)
            else
                random_x0, safe_flag = random_point_in_hyperrectangle(X_unsafe, X_unsafe, U)
            end
        
            # # # for j in 1:num_u
            worstcase_U = find_ada_input_area(U, random_x0)

            random_u0, _ = random_point_in_hyperrectangle(worstcase_U)
            # if safe_flag
            #     f(x, p, t) = affine_dyn(A, x, B, random_u0)
            #     prob = ODEProblem(f, random_x0, tspan)
            #     sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)
            #     if (sol[end] ∉ X_unsafe)
            #         # @show ( rand(Float32, size(random_x0)) .- 0.5)
            #         # @show random_x0, random_x0 + 0.00001 .* (rand(Float32, size(random_x0)) .- 0.5)
            #         push!(data, [random_x0, random_u0, [true]])
            #     end
                
            # # else
            # #     push!(data, [random_x0, zeros(size(random_u0)),[false]])
            
            # # (sol[end] ∈ X_unsafe) && continue
            # # push!(data, [random_x0, random_u0])
            # # break
            # end
            # # end
            push!(data, [random_x0, zeros(size(random_u0)),[safe_flag]])
        end
        # @show data
        # @show size(data)
        data = reduce(hcat,data)
        # @show size(reduce(hcat,data))
        # @show size(data),size(data[1]),size(data[2]),size(data[3]),size(data[4])

        training_data = data[:, 1:end-10000]
        test_data = data[:, end-10000:end]
        save_object("new_ada_training_data.jld2", training_data)
        save_object("new_ada_test_data.jld2", test_data)
    else
        for vio_X in violated_unknown_list
            # println(i)
            for i in 1:num_x
                if rand() > 0.2
                    random_x0, safe_flag = random_point_in_hyperrectangle(vio_X, X_unsafe, U)
                else
                    random_x0, safe_flag = random_point_in_hyperrectangle(X_unsafe, X_unsafe, U)
                end
            
                # # # for j in 1:num_u
                worstcase_U = find_ada_input_area(U, random_x0)

                random_u0, _ = random_point_in_hyperrectangle(worstcase_U)
                # if safe_flag
                #     f(x, p, t) = affine_dyn(A, x, B, random_u0)
                #     prob = ODEProblem(f, random_x0, tspan)
                #     sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)
                #     if (sol[end] ∉ X_unsafe)
                #         # @show ( rand(Float32, size(random_x0)) .- 0.5)
                #         # @show random_x0, random_x0 + 0.00001 .* (rand(Float32, size(random_x0)) .- 0.5)
                #         push!(data, [random_x0, random_u0, [true]])
                #     end
                    
                # # else
                # #     push!(data, [random_x0, zeros(size(random_u0)),[false]])
                
                # # (sol[end] ∈ X_unsafe) && continue
                # # push!(data, [random_x0, random_u0])
                # # break
                # end
                # # end
                push!(data, [random_x0, zeros(size(random_u0)),[safe_flag]])
            end
        end
        # @show data
        # @show size(data)
        data = reduce(hcat,data)
        # @show size(reduce(hcat,data))
        # @show size(data),size(data[1]),size(data[2]),size(data[3]),size(data[4])

        training_data = data[:, 1:end-10000]
        test_data = data[:, end-10000:end]
        save_object("new_ada_training_data$length(violated_unknown_list).jld2", training_data)
        save_object("new_ada_test_data$length(violated_unknown_list).jld2", test_data)
    end
    return training_data, test_data
end

# # ∇ϕ = Flux.jacobian(ϕ, x)[1]
#     # g = x->ReverseDiff.jacobian(ϕ,x)
#     # ∇ϕ = g(x)
#     # @show size(g(x))
#     _, ∇ϕ = Zygote.pullback(ϕ, x)
#     ∇ϕ_x = ∇ϕ(x)[1]
#     # @show size(b(x)[1])
#     # ∇ϕ = Zygote.jacobian(ϕ, Zygote.Params([x]))
#     # @show ϕ
#     # ∇ϕ = reshape(∇ϕ, (size(∇ϕ)[1], state_dim, batchsize))
#     # @show ∇ϕ[:, 1]
#     # return Flux.Losses.mse(max.(0, ∇ϕ), 0)
#     # @show size(a)
#     # @show a[1, :, 1]
#     # x_ = Matrix{Bool}(I, batchsize, batchsize)
#     # index_ = reshape(repeat(x_, inner=(state_dim, 1)), (state_dim, batchsize, batchsize))
#     # index_ = permutedims(index_, [3, 1, 2])
#     # ∇ϕ = reshape(∇ϕ[index_], (1, state_dim, batchsize))
#     ∇ϕ_x = reshape(∇ϕ_x, (1, state_dim, batchsize))
#     # @show ∇ϕ[:, 1]
#     # @show sum(a), sum(a[index_])
    
#     # @show size(ϕ(x))
#     ẋ = reshape(ẋ, (state_dim, 1, batchsize))
#     # @show size(∇ϕ)
#     # @show size(ẋ)
#     ϕ̇ = batched_mul(∇ϕ_x, ẋ)
#     # @show size(ϕ̇)
#     # @show size(ϕ(x))