using Revise
using Burgers#, Plots
using DataDeps, MAT, MLUtils
using NeuralOperators, Flux
using BSON
using DataDeps, MAT, MLUtils
using NeuralOperators, Flux
using CUDA, FluxTraining, BSON
import Flux: params
using BSON: @save, @load
using ProgressBars
using Zygote
using Optimisers, ParameterSchedulers

using Burgers
using FluxTraining
# using Test




function my_get_data(file_path; n = 50000, Δsamples = 1, grid_size = div(51, Δsamples), T = Float32)
# function my_get_data(file_path; n = 2048, Δsamples = 2^3, grid_size = div(2^13, Δsamples), T = Float32)
    # file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
    file = matopen(file_path)
    
    x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
    y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
    safe_labels = T.(collect(read(file, "safe")[1:n, 1:Δsamples:end]'))
    pf_labels = T.(collect(read(file, "pf")[1:n, 1:Δsamples:end]'))
    close(file)

    x_loc_data = Array{T, 3}(undef, 2, grid_size, n)
    x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 5, grid_size), n), (grid_size, n))
    x_loc_data[2, :, :] .= x_data

    return x_loc_data, reshape(y_data, 1, :, n), safe_labels, pf_labels
end

function my_get_dataloader(; ratio::Float64 = 0.9, batchsize = 128)
    𝐱1, 𝐲1, safe1, pf1 = my_get_data("data_bcks_hyperbolic_1_minus.mat") # _new_10
    
    data_train1, data_test1 = splitobs((𝐱1, 𝐲1, safe1, pf1), at = ratio)
    𝐱2, 𝐲2, safe2, pf2 = my_get_data("data_ppo_hyperbolic_1_minus.mat")
    
    data_train2, data_test2 = splitobs((𝐱2, 𝐲2, safe2, pf2), at = ratio)
    𝐱3, 𝐲3, safe3, pf3 = my_get_data("data_sac_hyperbolic_1_minus.mat")
    
    data_train3, data_test3 = splitobs((𝐱3, 𝐲3, safe3, pf3), at = ratio)

    # @show size(data_train3[1]), size(data_test3[2])

    # data_train1_x_pf = data_train1[1][:,:,(data_train1[4][1,:].==1)]
    # data_test1_x_pf = data_test1[1][:,:,(data_test1[4][1,:].==1)]
    # data_train1_y_pf = data_train1[2][:,:,(data_train1[4][1,:].==1)]
    # data_test1_y_pf = data_test1[2][:,:,(data_test1[4][1,:].==1)]
    # data_train1_safe_pf = data_train1[3][:,(data_train1[4][1,:].==1)]
    # data_test1_safe_pf = data_test1[3][:,(data_test1[4][1,:].==1)]

    # data_train2_x_pf = data_train2[1][:,:,(data_train2[4][1,:].==1)]
    # data_test2_x_pf = data_test2[1][:,:,(data_test2[4][1,:].==1)]
    # data_train2_y_pf = data_train2[2][:,:,(data_train2[4][1,:].==1)]
    # data_test2_y_pf = data_test2[2][:,:,(data_test2[4][1,:].==1)]
    # data_train2_safe_pf = data_train2[3][:,(data_train2[4][1,:].==1)]
    # data_test2_safe_pf = data_test2[3][:,(data_test2[4][1,:].==1)]

    # data_train3_x_pf = data_train3[1][:,:,(data_train3[4][1,:].==1)]
    # data_test3_x_pf = data_test3[1][:,:,(data_test3[4][1,:].==1)]
    # data_train3_y_pf = data_train3[2][:,:,(data_train3[4][1,:].==1)]
    # data_test3_y_pf = data_test3[2][:,:,(data_test3[4][1,:].==1)]
    # data_train3_safe_pf = data_train3[3][:,(data_train3[4][1,:].==1)]
    # data_test3_safe_pf = data_test3[3][:,(data_test3[4][1,:].==1)]

    data_train1_x_pf = data_train1[1][:,:,:]
    data_test1_x_pf = data_test1[1][:,:,:]
    data_train1_y_pf = data_train1[2][:,:,:]
    data_test1_y_pf = data_test1[2][:,:,:]
    data_train1_safe_pf = data_train1[3][:,:]
    data_test1_safe_pf = data_test1[3][:,:]

    data_train2_x_pf = data_train2[1][:,:,:]
    data_test2_x_pf = data_test2[1][:,:,:]
    data_train2_y_pf = data_train2[2][:,:,:]
    data_test2_y_pf = data_test2[2][:,:,:]
    data_train2_safe_pf = data_train2[3][:,:]
    data_test2_safe_pf = data_test2[3][:,:]

    data_train3_x_pf = data_train3[1][:,:,:]
    data_test3_x_pf = data_test3[1][:,:,:]
    data_train3_y_pf = data_train3[2][:,:,:]
    data_test3_y_pf = data_test3[2][:,:,:]
    data_train3_safe_pf = data_train3[3][:,:]
    data_test3_safe_pf = data_test3[3][:,:]




    data_train = (cat(cat(data_train1_x_pf, data_train2_x_pf, dims=3), data_train3_x_pf, dims=3), 
                    cat(cat(data_train1_y_pf, data_train2_y_pf, dims=3), data_train3_y_pf, dims=3), 
                    cat(cat(data_train1_safe_pf, data_train2_safe_pf, dims=2), data_train3_safe_pf, dims=2)) # omit the last pf tumple
    data_test = (cat(cat(data_test1_x_pf, data_test2_x_pf, dims=3), data_test3_x_pf, dims=3), 
                cat(cat(data_test1_y_pf, data_test2_y_pf, dims=3), data_test3_y_pf, dims=3), 
                cat(cat(data_test1_safe_pf, data_test2_safe_pf, dims=2), data_test3_safe_pf, dims=2)) # # omit the last pf tumple
    loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)
    loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)

    return loader_train, loader_test
end

function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 500)
    if cuda && CUDA.has_cuda()
        device = gpu
        CUDA.allowscalar(false)
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end
    @show 1
    model = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), 
                                  σ = gelu)
    data = my_get_dataloader()
    optimiser = Flux.Optimiser(Flux.Optimise.WeightDecay(λ), Flux.Adam(η₀))
    loss_func = l₂loss

    learner = Learner(model, data, optimiser, loss_func,
                      ToDevice(device, device))

    fit!(learner, epochs)
    model = learner.model |> cpu
    @save "model/hyper_FNO_all.bson" model

    return learner
end

function train_MNO(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 500)
    if cuda && CUDA.has_cuda()
        device = gpu
        CUDA.allowscalar(false)
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end
    @show 1
    # model = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), 
    #                               σ = gelu)
    model = MarkovNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 1), modes = (16,), 
                                  σ = gelu)
    data = my_get_dataloader()
    optimiser = Flux.Optimiser(Flux.Optimise.WeightDecay(λ), Flux.Adam(η₀))
    loss_func = l₂loss

    learner = Learner(model, data, optimiser, loss_func,
                      ToDevice(device, device))

    fit!(learner, epochs)
    model = learner.model |> cpu
    @save "model/hyper_MNO_all.bson" model

    return learner
end


function delete_with_probability!(list, p = 0.5)
    mask = rand(length(list)) .< p  
    index = findall(x->x==1, mask)
    return list[index] 
end

using Flux, CUDA, BSON
using Logging

function loss_naive_safeset(ϕ, x,y_init)
    # x = copy(x_)
    # y_init = copy(y_init_)
    # x = vcat(x[1,:,:]...)
    # x = reshape(x, (1, size(x)[1]))
    # # @show size(x), size(y_init)
    # y_init = vcat(y_init...)
    # # y_init = y_init[1, :] # safe: 1; unsafe: 0
    # # @show size(x), size(y_init)
    # @show size(y_init), size(x)
    index = findall(x->x==0, y_init)
    # @show index
    size(index)[1] == 0 && return 0
    x = x[:, index]
    y_init = y_init[index]
    loss = relu((2 .* y_init .- 1) .* ϕ(x)[1, :] .+ 1e-6)
    return sum(loss) / size(loss)[end]
end

function loss_regularization(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)
    # x = copy(x_)
    # y_init = copy(y_init_)
    # # @show size(x)
    # x = vcat(x[1,:,:]...)
    # x = reshape(x, (1, size(x)[1]))
    # y_init = vcat(y_init...)
    # y_init = y_init[1, :] # safe: 1; unsafe: 0
    index = findall(x->x==0, y_init)
    # @show index
    size(index)[1] == 0 && return 0
    x = x[:, index]
    y_init = y_init[index]
    loss = sigmoid_fast((2 .* y_init .- 1) .* ϕ(x)[1, :])
    return sum(loss) / size(loss)[end]
end

# function loss_naive_safeset_end(ϕ, x,y_init)
#     return relu((2 .* y_init[end] .- 1) .* ϕ(x)[1, end] .+ 1e-6)
# end

# function loss_regularization_end(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)
#     return sigmoid_fast((2 .* y_init[end] .- 1) .* ϕ(x)[1, end])
# end

function loss_naive_safeset_end(ϕ, x,y_init;minus_safe=false)
    if minus_safe
        index = findall(x->x==1, y_init)
        # @show index
        size(index)[1] == 0 && return 0
        # last_loss = relu((2 .* y_init[end] .- 1) .* ϕ(x)[1, end] .+ 1e-6)
        x = x[:, index]
        y_init = y_init[index]
        
        loss = relu((2 .* y_init .- 1) .* ϕ(x)[1, :] .+ 1e-6)
        # @show loss
        return (sum(loss)) / (size(loss)[end])
    else
        return relu((2 .* y_init[end] .- 1) .* ϕ(x)[1, end] .+ 1e-6)
    end
end

function loss_regularization_end(ϕ::Chain, x::AbstractArray,y_init::AbstractArray;minus_safe=false)
    if minus_safe
        index = findall(x->x==1, y_init)
        # @show index
        size(index)[1] == 0 && return 0
        # @show y_init[end], ϕ(x)[1, end]
        # last_loss = sigmoid_fast((2 .* y_init[end] .- 1) .* ϕ(x)[1, end])
        x = x[:, index]
        y_init = y_init[index]
        loss = sigmoid_fast((2 .* y_init .- 1) .* ϕ(x)[1, :])
        # @show size(loss)[end]
        return sum(loss) / (size(loss)[end])
    else
        return sigmoid_fast((2 .* y_init[end] .- 1) .* ϕ(x)[1, end])
    end
end

function find_derivative(vector)
    M, N = size(vector)[2], size(vector)[3]

    # Assume `vector` is the (2, M, N) array
    inputs = vector[1, :, :]  # Shape (M, N)
    outputs = vector[2, :, :]  # Shape (M, N)

    # Preallocate the derivative array with shape (1, M, N)
    derivatives = zeros(Float64, 1, M, N)

    # Central differences for the interior points (2 to M-1)
    derivatives[1, 2:M-1, :] = (outputs[3:M, :] .- outputs[1:M-2, :]) ./ (inputs[3:M, :] .- inputs[1:M-2, :])

    # Forward difference for the first point
    derivatives[1, 1, :] = (outputs[2, :] .- outputs[1, :]) ./ (inputs[2, :] .- inputs[1, :])

    # Backward difference for the last point
    derivatives[1, M, :] = (outputs[M, :] .- outputs[M-1, :]) ./ (inputs[M, :] .- inputs[M-1, :])

    # `derivatives` now contains the derivative of the output with respect to the input
    # with shape (1, M, N)
    return derivatives
end


function loss_pf(ϕ::Chain, U::AbstractArray, Y::AbstractArray, U_0,extended_U̇, ∇Y_t,T, α;λ_pf_batch=nothing)
    # U = copy(U_)
    # Y = copy(Y_)
    # Y = vcat(Y[1,:,:]...)
    # Y = reshape(Y, (1, size(Y)[1]))
    # y_init = vcat(y_init[1,:,:]...)
    ∇Y_t = reshape(∇Y_t, size(Y))

    # for ppo
    # ϵ = 0.5
    # mask = abs.(Y[1,:]) .< ϵ
    # index = findall(x->x==true, mask)
    # # index = findall(x->x>=0, y_init)
    # # @show index
    # # @show size(index), size(Yt)
    # index = delete_with_probability!(index, 0.2) 
    # size(index)[1] == 0 && return 0
    # Y = Y[:, index]

    # ∇Y_t = ∇Y_t[:, index]
    # U_0 = U_0[:, index]

    # for sac
    isnothing(λ_pf_batch) || (λ_pf_batch = reshape(λ_pf_batch, size(U_0[1:1,:])))
    ϵ = 0.1
    mask = abs.(Y[1,:]) .< ϵ
    index = findall(x->x==true, mask)
    # index = findall(x->x>=0, y_init)
    # @show index
    # @show size(index), size(Yt)
    index = delete_with_probability!(index, 0.2) 
    size(index)[1] == 0 && return 0
    Y = Y[:, index]

    ∇Y_t = ∇Y_t[:, index]
    U_0 = U_0[:, index]
    isnothing(λ_pf_batch) || (λ_pf_batch = λ_pf_batch[:, index])
    
    state_dim, batchsize = size(Y) # 1*51000
    # ẋ = dyn_model(x, u) # if support batchsize
    # U̇ = find_derivative(U)
    # U̇ = reshape(U̇, (state_dim, 1, batchsize))
    # gradient(x -> sum(layer_output), x)[1]
    _, ∇ϕ = Zygote.pullback(ϕ, Y)
    ∇ϕ_Y = ∇ϕ(ones(size(Y)))[1] ./ state_dim
    ∇ϕ_Y = reshape(∇ϕ_Y, (1, state_dim, batchsize))


    # test example
    # model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), 
    #                                          σ = gelu)
    # loss, nabla = Zygote.pullback(model_NO, b)
    
    # b=rand(2,3,4)
    # Y=model_NO(b)
    # Y = vcat(Y[1,:,:]...)
    # Y = reshape(Y, (1, size(Y)[1]))
    # find_derivative(cat(b[1:1,:,:], model_NO(b), dims=1))
    # nabla(ones(size(Y)))[1][2:2, :,:] .* find_derivative(b) .+ nabla(ones(size(Y)))[1][1:1, :,:] 
    # batched_mul(reshape(nabla(ones(size(Y)))[1], (1, size(nabla(ones(size(Y)))[1])...)),  reshape(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1), (size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[1], 1, size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[2:end]...)))[1,:,:,:]

    # @show size(∇Y_t)
    ∇Y_t = reshape(∇Y_t, (state_dim, 1, batchsize))
    
    ϕ̇ = reshape(batched_mul(∇ϕ_Y, ∇Y_t), size(ϕ(Y)))
    
    C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))
    l = ϕ̇ .+ α .* ϕ(Y) .+ C .* ϕ(U_0)
    isnothing(λ_pf_batch) || (l = l .* λ_pf_batch)
    loss = relu(l .+ 1e-6)
    return sum(loss) / size(loss)[end]
end

function my_train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, total_epoch = 500, pretrained_NO=nothing)
    if cuda && CUDA.has_cuda()
        device = gpu
        CUDA.allowscalar(false)
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end
    @show 1
    lr_NO = η₀
    lr_CBF = 0.01

    lr_decay_rate = 0.2
    lr_decay_epoch =4

    train_loader, test_loader = my_get_dataloader()
    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), 
                                  σ = gelu)
    if isnothing(pretrained_NO)
        model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), 
                                  σ = gelu)
    else
        model_NO = get_model(pretrained_NO)[:model_NO]
    end
    model_CBF = Chain(
            Dense(1 => 16, relu),   # activation function inside layer
            Dense(16 => 64, relu),   # activation function inside layer
            Dense(64 => 16, relu),   # activation function inside layer
            Dense(16 => 1)
        )
    # optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))
    optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)
    optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)
    sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice

    
    loss_func = l₂loss
    α = 0.00001
    λ_pf = 1
    λ_reg = 0.1
    minus_safe_flag = true # cannot be false if the end can be not pf, for sac
    # minus_safe_flag = false # false for ppo
    training_losses = []
    test_losses = []
    no_training_losses = []
    no_test_losses = []
    least_loss = 1000
    test_loss = 0
    loss = 0
    for epoch in ProgressBar(1:total_epoch)
        training_loss_epoch = []
        test_loss_epoch = []
        no_training_loss_epoch = []
        no_test_loss_epoch = []
        for item in train_loader
            # x_batch = reduce(hcat,item[1,:])
            # u_batch = reduce(hcat,item[2,:])
            # y_init_batch = reduce(hcat,item[3,:])
            x_batch = item[1]
            y_batch = item[2]
            safe_batch = item[3]

            λ_pf_batch = zeros(size(safe_batch)) 
            pf_index = findall(x->x==1, safe_batch[end, :])
            # @show pf_index
            size(pf_index)[1] != 0 && (λ_pf_batch[end,pf_index] .= λ_pf)
            λ_pf_batch[:,:] .= λ_pf_batch[end:end,:]
            # @show size(λ_pf_batch), λ_pf_batch[1, :]
            # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)

            if isnothing(pretrained_NO)
                # train NO
                NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m 
                    l₂loss(m(x_batch), y_batch)
                end
                Flux.update!(optim_NO, model_NO, NO_grads[1])
                push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))
                # @show l₂loss(model_NO(x_batch), y_batch)
            end

            
            # train CBF
            x = copy(y_batch)
            y_init = copy(safe_batch)
            x = vcat(x[1,:,:]...)
            x = reshape(x, (1, size(x)[1]))
            y_init = vcat(y_init...)

            U_0 = copy(x_batch)
            U_0[2:2,:,:] .= x_batch[2:2,1:1,:]
            U_0 = vcat(U_0[2:2,:,:][1,:,:]...)
            U_0 = reshape(U_0, (1, size(U_0)[1]))
            U̇ = find_derivative(x_batch)
            extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)
            T = x_batch[1,end,1]
            _, ∇ϕ = Zygote.pullback(model_NO, x_batch)
            # dG\du * du\dt + dG\dt
            # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below
            # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))
            # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative
            ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative

            CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m 
                loss_naive_safeset(m, x, y_init)  +  λ_reg .* loss_regularization(m, x, y_init) + λ_pf .* loss_pf(m, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α;λ_pf_batch=λ_pf_batch) + loss_naive_safeset_end(m, x, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, x, y_init;minus_safe=minus_safe_flag)
                # sum(m(rand(1,12)))
            end
            Flux.update!(optim_CBF, model_CBF, CBF_grads[1])

            loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α;λ_pf_batch=λ_pf_batch) + loss_naive_safeset_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)
            @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α;λ_pf_batch=λ_pf_batch), loss_naive_safeset_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)  , loss_regularization_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)
    #         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))
    #         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)
            push!(training_loss_epoch, loss)  # logging, outside gradient context
            
            # @show training_loss
        end
        for item in test_loader
            x_batch = item[1]
            y_batch = item[2]
            safe_batch = item[3]

            λ_pf_batch = zeros(size(safe_batch)) 
            pf_index = findall(x->x==1, safe_batch[end, :])
            # @show pf_index
            size(pf_index)[1] != 0 && (λ_pf_batch[end,pf_index] .= λ_pf)
            λ_pf_batch[:,:] .= λ_pf_batch[end:end,:]
            # @show size(λ_pf_batch), λ_pf_batch[1, :]
            # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)

            if isnothing(pretrained_NO)
                push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))
            end

            x = copy(y_batch)
            y_init = copy(safe_batch)
            x = vcat(x[1,:,:]...)
            x = reshape(x, (1, size(x)[1]))
            y_init = vcat(y_init...)

            U_0 = copy(x_batch)
            U_0[2:2,:,:] .= x_batch[2:2,1:1,:]
            U_0 = vcat(U_0[2:2,:,:][1,:,:]...)
            U_0 = reshape(U_0, (1, size(U_0)[1]))
            U̇ = find_derivative(x_batch)
            extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)
            T = x_batch[1,end,1]
            _, ∇ϕ = Zygote.pullback(model_NO, x_batch)
            # dG\du * du\dt + dG\dt
            # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below
            # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))
            # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative
            ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative

            loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α;λ_pf_batch=λ_pf_batch) + loss_naive_safeset_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)
            @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α;λ_pf_batch=λ_pf_batch), loss_naive_safeset_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)  , loss_regularization_end(model_CBF, x, y_init;minus_safe=minus_safe_flag)
            # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α)
            # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α)
    #         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))
    #         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)
            push!(test_loss_epoch, loss)  # logging, outside gradient context
        end
        nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule
        Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`

        # @show epoch, loss, test_loss
        # model_state = Flux.state(model)
        # jldsave("car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2"; model_state)
        if isnothing(pretrained_NO)
            @save "model/new_hyper_NO_$epoch.bson" model_NO
        end
        # @save "model/hyper_1reg_1pf_CBFnoNOfixed_pf52_addend_preNO20_alldata_$epoch.bson" model_CBF
        @save "model/hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_$epoch.bson" model_CBF
        push!(training_losses, sum(training_loss_epoch)) 
        push!(test_losses, sum(test_loss_epoch))
    
    end
    return training_losses, test_losses


    # learner = Learner(model, data, optimiser, loss_func,
    #                   ToDevice(device, device))

    # fit!(learner, epochs)
    # model = learner.model |> cpu
    # @save "model/hyper_FNO_all_pf.bson" model

    # return learner
end

function get_model(name)
    model_path = joinpath(@__DIR__, "./model/")
    @assert name in readdir(model_path)
    model_file = name
    return BSON.load(joinpath(model_path, model_file), @__MODULE__)
end



function train_nomad(; n = 50000, cuda = true, learning_rate = 0.001, epochs = 400)
    if cuda && has_cuda()
        @info "Training on GPU"
        device = gpu
    else
        @info "Training on CPU"
        device = cpu
    end
    lr = learning_rate
    for i = 1:10
        x, y = get_data_don()

        # 50000, 0.9
        # xtrain = x[1:45000, :]'
        # ytrain = y[1:45000, :]

        # xval = x[45001:end, :]' |> device
        # yval = y[45001:end, :] |> device

        xtrain = x[1:149000, :]'
        ytrain = y[1:149000, :]

        xval = x[149001:end, :]' |> device
        yval = y[149001:end, :] |> device

        # grid = collect(range(0, 1, length=1024)') |> device
        grid = rand(collect(0:0.1:5), (149000, 51)) |> device
        gridval = rand(collect(0:0.1:5), (1000, 51)) |> device

        lr = lr * 0.2
        opt = Flux.Adam(lr)

        # opt = Flux.Adam(learning_rate)

        m = NOMAD((51, 51), (102, 51), gelu, gelu) |> device

        loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)
        evalcb() = @show(loss(xval, yval, gridval))

        data = [(xtrain, ytrain, grid)] |> device
        Flux.@epochs epochs Flux.train!(loss, params(m), data, opt, cb = evalcb)
        ỹ = m(xval |> device, gridval |> device)
        @save "model/hyper_NOMAD_$i.bson" m
        diffvec = vec(abs.(cpu(yval) .- cpu(ỹ)))
        mean_diff = sum(diffvec) / length(diffvec)
    end
    # return mean_diff
end

function get_data_don(; n = 50000, Δsamples = 1, grid_size = div(51, Δsamples), T = Float32)
# function get_data_don(; n = 2048, Δsamples = 2^3, grid_size = div(2^13, Δsamples))
    # file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
    # file = matopen("/burgers_data_R10.mat")
    file = matopen("/data_bcks_hyperbolic.mat")
    
    x_data = collect(read(file, "a")[1:n, 1:Δsamples:end])
    y_data = collect(read(file, "u")[1:n, 1:Δsamples:end])
    close(file)

    file = matopen("/data_ppo_hyperbolic.mat")
    
    x_data = cat(x_data, T.(collect(read(file, "a")[1:n, 1:Δsamples:end])),dims=1)
    y_data = cat(y_data, T.(collect(read(file, "u")[1:n, 1:Δsamples:end])),dims=1)
    # y_data = T.(collect(read(file, "u")))
    close(file)

    file = matopen("/data_sac_hyperbolic.mat")
    
    x_data = cat(x_data, T.(collect(read(file, "a")[1:n, 1:Δsamples:end])),dims=1)
    y_data = cat(y_data, T.(collect(read(file, "u")[1:n, 1:Δsamples:end])),dims=1)
    # y_data = T.(collect(read(file, "u")))
    close(file)

    return x_data, y_data
end

function train_don(; n = 50000, cuda = true, learning_rate = 0.001, epochs = 400)
    if cuda && has_cuda()
        @info "Training on GPU"
        device = gpu
    else
        @info "Training on CPU"
        device = cpu
    end
    lr = learning_rate
    for i = 1:10

        x, y = get_data_don()

        # xtrain = x[1:280, :]'
        # ytrain = y[1:280, :]

        # xval = x[(end - 19):end, :]' |> device
        # yval = y[(end - 19):end, :] |> device

        # grid = collect(range(0, 1, length = 1024)') |> device

        xtrain = x[1:149000, :]'
        ytrain = y[1:149000, :]

        xval = x[149001:end, :]' |> device
        yval = y[149001:end, :] |> device

        grid = collect(range(0, 5, length=51)') |> device
        # grid = rand(collect(0:0.02:1), (45000, 51)) |> device
        # gridval = rand(collect(0:0.02:1), (5000, 51)) |> device
        lr = lr * 0.2
        opt = Flux.Adam(lr)

        # m = DeepONet((1024, 1024, 1024), (1, 1024, 1024), gelu, gelu) |> device
        m = DeepONet((51, 51, 51), (1, 51, 51), gelu, gelu) |> device

        loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)
        evalcb() = @show(loss(xval, yval, grid))

        data = [(xtrain, ytrain, grid)] |> device
        Flux.@epochs epochs Flux.train!(loss, params(m), data, opt, cb = evalcb)
        ỹ = m(xval |> device, grid |> device)

        diffvec = vec(abs.(cpu(yval) .- cpu(ỹ)))
        mean_diff = sum(diffvec) / length(diffvec)
        @save "model/hyper_DON_$i.bson" m
        # return mean_diff
    end
end

# train(epochs=100) 
# training_losses, test_losses = my_train(total_epoch=20,pretrained_NO="hyper_NO_20.bson") # 0.005 # pf all , 0.012, T=1
train(epochs=100) 
train_MNO(epochs=100) 
# train_nomad(epochs=500) # 1.76 # 10-3 4.7
# train_don(epochs=500) # 2.56 10-3 5.536155f0
# model/hyper_DON_500-4.bson 7.2


