{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf555e98-8e13-4f75-8272-4984ca0d448b",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Revise\n",
    "using Burgers#, Plots\n",
    "using DataDeps, MAT, MLUtils\n",
    "using NeuralOperators, Flux\n",
    "using BSON\n",
    "using DataDeps, MAT, MLUtils\n",
    "using NeuralOperators, Flux\n",
    "using CUDA, FluxTraining, BSON\n",
    "import Flux: params\n",
    "using BSON: @save, @load\n",
    "using ProgressBars\n",
    "using Zygote\n",
    "using Optimisers, ParameterSchedulers\n",
    "\n",
    "using Burgers\n",
    "using FluxTraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38d0bc45-89b3-432f-a24d-85cb2efc2558",
   "metadata": {},
   "outputs": [],
   "source": [
    "function my_get_data(file_path; n = 50000, Δsamples = 1, grid_size = div(51, Δsamples), T = Float32)\n",
    "# function my_get_data(file_path; n = 2048, Δsamples = 2^3, grid_size = div(2^13, Δsamples), T = Float32)\n",
    "    # file = matopen(joinpath(datadep\"Burgers\", \"burgers_data_R10.mat\"))\n",
    "    file = matopen(file_path)\n",
    "    \n",
    "    x_data = T.(collect(read(file, \"a\")[1:n, 1:Δsamples:end]'))\n",
    "    y_data = T.(collect(read(file, \"u\")[1:n, 1:Δsamples:end]'))\n",
    "    safe_labels = T.(collect(read(file, \"safe\")[1:n, 1:Δsamples:end]'))\n",
    "    pf_labels = T.(collect(read(file, \"pf\")[1:n, 1:Δsamples:end]'))\n",
    "    close(file)\n",
    "\n",
    "    x_loc_data = Array{T, 3}(undef, 2, grid_size, n)\n",
    "    x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 5, grid_size), n), (grid_size, n))\n",
    "    x_loc_data[2, :, :] .= x_data\n",
    "\n",
    "    return x_loc_data, reshape(y_data, 1, :, n), safe_labels, pf_labels\n",
    "end\n",
    "\n",
    "function my_get_dataloader(; ratio::Float64 = 0.9, batchsize = 128)\n",
    "    𝐱1, 𝐲1, safe1, pf1 = my_get_data(\"data_bcks_hyperbolic_1_minus.mat\") # data_bcks_hyperbolic_1_new.mat _minus _pf5\n",
    "    \n",
    "    data_train1, data_test1 = splitobs((𝐱1, 𝐲1, safe1, pf1), at = ratio)\n",
    "    𝐱2, 𝐲2, safe2, pf2 = my_get_data(\"data_ppo_hyperbolic_1_minus.mat\")\n",
    "    \n",
    "    data_train2, data_test2 = splitobs((𝐱2, 𝐲2, safe2, pf2), at = ratio)\n",
    "    𝐱3, 𝐲3, safe3, pf3 = my_get_data(\"data_sac_hyperbolic_1_minus.mat\")\n",
    "    \n",
    "    data_train3, data_test3 = splitobs((𝐱3, 𝐲3, safe3, pf3), at = ratio)\n",
    "\n",
    "    @show size(data_train3[1]), size(data_test3[2])\n",
    "\n",
    "    data_train1_x_pf = data_train1[1][:,:,(data_train1[4][1,:].==1)]\n",
    "    data_test1_x_pf = data_test1[1][:,:,(data_test1[4][1,:].==1)]\n",
    "    data_train1_y_pf = data_train1[2][:,:,(data_train1[4][1,:].==1)]\n",
    "    data_test1_y_pf = data_test1[2][:,:,(data_test1[4][1,:].==1)]\n",
    "    data_train1_safe_pf = data_train1[3][:,(data_train1[4][1,:].==1)]\n",
    "    data_test1_safe_pf = data_test1[3][:,(data_test1[4][1,:].==1)]\n",
    "\n",
    "    data_train2_x_pf = data_train2[1][:,:,(data_train2[4][1,:].==1)]\n",
    "    data_test2_x_pf = data_test2[1][:,:,(data_test2[4][1,:].==1)]\n",
    "    data_train2_y_pf = data_train2[2][:,:,(data_train2[4][1,:].==1)]\n",
    "    data_test2_y_pf = data_test2[2][:,:,(data_test2[4][1,:].==1)]\n",
    "    data_train2_safe_pf = data_train2[3][:,(data_train2[4][1,:].==1)]\n",
    "    data_test2_safe_pf = data_test2[3][:,(data_test2[4][1,:].==1)]\n",
    "\n",
    "    data_train3_x_pf = data_train3[1][:,:,(data_train3[4][1,:].==1)]\n",
    "    data_test3_x_pf = data_test3[1][:,:,(data_test3[4][1,:].==1)]\n",
    "    data_train3_y_pf = data_train3[2][:,:,(data_train3[4][1,:].==1)]\n",
    "    data_test3_y_pf = data_test3[2][:,:,(data_test3[4][1,:].==1)]\n",
    "    data_train3_safe_pf = data_train3[3][:,(data_train3[4][1,:].==1)]\n",
    "    data_test3_safe_pf = data_test3[3][:,(data_test3[4][1,:].==1)]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    data_train = (cat(cat(data_train1_x_pf, data_train2_x_pf, dims=3), data_train3_x_pf, dims=3), \n",
    "                    cat(cat(data_train1_y_pf, data_train2_y_pf, dims=3), data_train3_y_pf, dims=3), \n",
    "                    cat(cat(data_train1_safe_pf, data_train2_safe_pf, dims=2), data_train3_safe_pf, dims=2)) # omit the last pf tumple\n",
    "    data_test = (cat(cat(data_test1_x_pf, data_test2_x_pf, dims=3), data_test3_x_pf, dims=3), \n",
    "                cat(cat(data_test1_y_pf, data_test2_y_pf, dims=3), data_test3_y_pf, dims=3), \n",
    "                cat(cat(data_test1_safe_pf, data_test2_safe_pf, dims=2), data_test3_safe_pf, dims=2)) # # omit the last pf tumple\n",
    "    loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)\n",
    "    loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)\n",
    "\n",
    "    return loader_train, loader_test\n",
    "end\n",
    "function delete_with_probability!(list, p = 0.5)\n",
    "    mask = rand(length(list)) .< p  \n",
    "    index = findall(x->x==1, mask)\n",
    "    return list[index] \n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a5135bc-4000-4f86-bf28-f9a56a9ded88",
   "metadata": {},
   "outputs": [],
   "source": [
    "function loss_naive_safeset(ϕ, x,y_init)\n",
    "    @show x[:, 1:10], ϕ(x)[1, 1:10], y_init[1:10]\n",
    "    @show x[:, end-10:end], ϕ(x)[1, end-10:end], y_init[end-10:end]\n",
    "    index = findall(x->x==0, y_init)\n",
    "    # @show index\n",
    "    size(index)[1] == 0 && return 0\n",
    "    # last_loss = relu((2 .* y_init[end] .- 1) .* ϕ(x)[1, end] .+ 1e-6)\n",
    "    x = x[:, index]\n",
    "    y_init = y_init[index]\n",
    "    \n",
    "    loss = relu((2 .* y_init .- 1) .* ϕ(x)[1, :] .+ 1e-6)\n",
    "    # @show loss\n",
    "    return (sum(loss)) / (size(loss)[end])\n",
    "end\n",
    "\n",
    "function loss_regularization(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)\n",
    "     # safe: 1; unsafe: 0\n",
    "    index = findall(x->x==0, y_init)\n",
    "    # @show index\n",
    "    size(index)[1] == 0 && return 0\n",
    "    # @show y_init[end], ϕ(x)[1, end]\n",
    "    # last_loss = sigmoid_fast((2 .* y_init[end] .- 1) .* ϕ(x)[1, end])\n",
    "    x = x[:, index]\n",
    "    y_init = y_init[index]\n",
    "    loss = sigmoid_fast((2 .* y_init .- 1) .* ϕ(x)[1, :])\n",
    "    # @show size(loss)[end]\n",
    "    return sum(loss) / (size(loss)[end])\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6234f0c-bafc-4a91-b737-61ab953af6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "function loss_naive_safeset_end(ϕ, x,y_init;minus_safe=false)\n",
    "    if minus_safe\n",
    "        index = findall(x->x==1, y_init)\n",
    "        # @show index\n",
    "        size(index)[1] == 0 && return 0\n",
    "        # last_loss = relu((2 .* y_init[end] .- 1) .* ϕ(x)[1, end] .+ 1e-6)\n",
    "        x = x[:, index]\n",
    "        y_init = y_init[index]\n",
    "        \n",
    "        loss = relu((2 .* y_init .- 1) .* ϕ(x)[1, :] .+ 1e-6)\n",
    "        # @show loss\n",
    "        return (sum(loss)) / (size(loss)[end])\n",
    "    else\n",
    "        return relu((2 .* y_init[end] .- 1) .* ϕ(x)[1, end] .+ 1e-6)\n",
    "    end\n",
    "end\n",
    "\n",
    "function loss_regularization_end(ϕ::Chain, x::AbstractArray,y_init::AbstractArray;minus_safe=false)\n",
    "    if minus_safe\n",
    "        index = findall(x->x==1, y_init)\n",
    "        # @show index\n",
    "        size(index)[1] == 0 && return 0\n",
    "        # @show y_init[end], ϕ(x)[1, end]\n",
    "        # last_loss = sigmoid_fast((2 .* y_init[end] .- 1) .* ϕ(x)[1, end])\n",
    "        x = x[:, index]\n",
    "        y_init = y_init[index]\n",
    "        loss = sigmoid_fast((2 .* y_init .- 1) .* ϕ(x)[1, :])\n",
    "        # @show size(loss)[end]\n",
    "        return sum(loss) / (size(loss)[end])\n",
    "    else\n",
    "        return sigmoid_fast((2 .* y_init[end] .- 1) .* ϕ(x)[1, end])\n",
    "    end\n",
    "end\n",
    "\n",
    "function find_derivative(vector)\n",
    "    M, N = size(vector)[2], size(vector)[3]\n",
    "\n",
    "    # Assume `vector` is the (2, M, N) array\n",
    "    inputs = vector[1, :, :]  # Shape (M, N)\n",
    "    outputs = vector[2, :, :]  # Shape (M, N)\n",
    "\n",
    "    # Preallocate the derivative array with shape (1, M, N)\n",
    "    derivatives = zeros(Float64, 1, M, N)\n",
    "\n",
    "    # Central differences for the interior points (2 to M-1)\n",
    "    derivatives[1, 2:M-1, :] = (outputs[3:M, :] .- outputs[1:M-2, :]) ./ (inputs[3:M, :] .- inputs[1:M-2, :])\n",
    "\n",
    "    # Forward difference for the first point\n",
    "    derivatives[1, 1, :] = (outputs[2, :] .- outputs[1, :]) ./ (inputs[2, :] .- inputs[1, :])\n",
    "\n",
    "    # Backward difference for the last point\n",
    "    derivatives[1, M, :] = (outputs[M, :] .- outputs[M-1, :]) ./ (inputs[M, :] .- inputs[M-1, :])\n",
    "\n",
    "    # `derivatives` now contains the derivative of the output with respect to the input\n",
    "    # with shape (1, M, N)\n",
    "    return derivatives\n",
    "end\n",
    "\n",
    "function find_derivative_1step(vector)\n",
    "    M, N = size(vector)[2], size(vector)[3]\n",
    "\n",
    "    # Assume `vector` is the (2, M, N) array\n",
    "    inputs = vector[1, :, :]  # Shape (M, N)\n",
    "    outputs = vector[2, :, :]  # Shape (M, N)\n",
    "\n",
    "    # Preallocate the derivative array with shape (1, M, N)\n",
    "    derivatives = zeros(Float64, 1, M, N)\n",
    "\n",
    "    # 1-step forward finite difference for all points from 1 to M-1\n",
    "    derivatives[1, 1:M-1, :] = (outputs[2:M, :] .- outputs[1:M-1, :]) ./ (inputs[2:M, :] .- inputs[1:M-1, :])\n",
    "\n",
    "    # 1-step backward finite difference for the last point\n",
    "    derivatives[1, M, :] = (outputs[M, :] .- outputs[M-1, :]) ./ (inputs[M, :] .- inputs[M-1, :])\n",
    "\n",
    "    # `derivatives` now contains the derivative of the output with respect to the input\n",
    "    # with shape (1, M, N)\n",
    "    return derivatives\n",
    "end\n",
    "\n",
    "\n",
    "\n",
    "function loss_pf(ϕ::Chain, U::AbstractArray, Yt::AbstractArray, U_0,extended_U̇, ∇Y_t,T, α,y_init; all=false)\n",
    "    # U = copy(U_)\n",
    "    # Y = copy(Y_)\n",
    "    # Y = vcat(Y[1,:,:]...)\n",
    "    # Y = reshape(Y, (1, size(Y)[1]))\n",
    "    # y_init = vcat(y_init[1,:,:]...)\n",
    "    \n",
    "    # @show size(ϕ(x)), size(x)\n",
    "    # index = findall(x1->[1 -1] * softmax(ϕ(x1))>ϵ, x)\n",
    "    ∇Y_t = reshape(∇Y_t, size(Yt))\n",
    "    # ϵ = 0.1\n",
    "    # mask = abs.(ϕ(Y)) .< ϵ\n",
    "    # index = findall(x->x==true, mask[1,:])\n",
    "    # # @show size(mask[1,:]), size(index)\n",
    "    # # @show mask, index\n",
    "    # size(index)[1] == 0 && return 0\n",
    "    # # @show size(Y), size(∇Y_t), size(U_0), size(mask)\n",
    "    # Y = Y[:, index]\n",
    "    \n",
    "    # ∇Y_t = ∇Y_t[:, index]\n",
    "    # U_0 = U_0[:, index]\n",
    "\n",
    "    if !all\n",
    "        ϵ = 0.5\n",
    "        mask = abs.(Yt[2,:]) .< ϵ\n",
    "        index = findall(x->x==true, mask)\n",
    "        # index = findall(x->x>=0, y_init)\n",
    "        # @show index\n",
    "        # @show size(index), size(Yt)\n",
    "        index = delete_with_probability!(index, 0.2) \n",
    "        size(index)[1] == 0 && return 0\n",
    "        Yt = Yt[:, index]\n",
    "    \n",
    "        ∇Y_t = ∇Y_t[:, index]\n",
    "        U_0 = U_0[:, index]\n",
    "    end\n",
    "\n",
    "    \n",
    "    state_dim, batchsize = size(Yt) # 2*51000\n",
    "    # ẋ = dyn_model(x, u) # if support batchsize\n",
    "    # U̇ = find_derivative(U)\n",
    "    # U̇ = reshape(U̇, (state_dim, 1, batchsize))\n",
    "    # gradient(x -> sum(layer_output), x)[1]\n",
    "    _, ∇ϕ = Zygote.pullback(ϕ, Yt)\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, state_dim, batchsize))\n",
    "\n",
    "\n",
    "    # test example\n",
    "    # model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "    #                                          σ = gelu)\n",
    "    # loss, nabla = Zygote.pullback(model_NO, b)\n",
    "    \n",
    "    # b=rand(2,3,4)\n",
    "    # Y=model_NO(b)\n",
    "    # Y = vcat(Y[1,:,:]...)\n",
    "    # Y = reshape(Y, (1, size(Y)[1]))\n",
    "    # find_derivative(cat(b[1:1,:,:], model_NO(b), dims=1))\n",
    "    # nabla(ones(size(Y)))[1][2:2, :,:] .* find_derivative(b) .+ nabla(ones(size(Y)))[1][1:1, :,:] \n",
    "    # batched_mul(reshape(nabla(ones(size(Y)))[1], (1, size(nabla(ones(size(Y)))[1])...)),  reshape(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1), (size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[1], 1, size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[2:end]...)))[1,:,:,:]\n",
    "\n",
    "    # @show size(∇Y_t)\n",
    "    ∇Y_t = reshape(∇Y_t, (state_dim, 1, batchsize))\n",
    "    \n",
    "    ϕ̇ = reshape(batched_mul(∇ϕ_Y, ∇Y_t), size(ϕ(Yt)))\n",
    "    \n",
    "    C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "    C=0\n",
    "    # @show C, C .* ϕ(U_0)\n",
    "    l = ϕ̇ .+ α .* ϕ(Yt) .+ C .* ϕ(U_0)\n",
    "    loss = relu(l .+ 1e-6)\n",
    "    # @show loss\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function get_model(name)\n",
    "    model_path = joinpath(@__DIR__, \"./model/\")\n",
    "    @assert name in readdir(model_path)\n",
    "    model_file = name\n",
    "    return BSON.load(joinpath(model_path, model_file), @__MODULE__)\n",
    "end\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a2dff0-ab9b-4f60-a32f-998af53009a3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a80a615-18d5-4383-91cd-c9c642087d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 1\n",
    "all_flag = false\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_C0_$epoch.bson\" model_CBF\n",
    "    \n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ae6611b-1ffd-4d96-9680-bf141330510d",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses, test_losses\n",
    "# (loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt, extended_U̇, extended_∇Y_tt, T, α, y_init), loss_naive_safeset_end(model_CBF, ytt, y_init), loss_regularization_end(model_CBF, ytt, y_init)) = (0.061774470473719485, 0.18670177f0, 0.11567616420201923, 0.0, 0.090803966f\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0d5d17d-50e6-450b-a1ad-961ba1959131",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_1.bson\")[:model_CBF]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8cd458a-3e94-4416-b30c-00b4c541a9ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_CBF([3, 0.1])# t,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b00453-3e2b-465c-beb3-dc9bfef5c5a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 0.1\n",
    "all_flag = false\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_$epoch.bson\" model_CBF\n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d9d3584-8cd7-40da-80a4-7d86208b20c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49ebead1-eb0c-4824-ac18-82d69c3d1d6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6498963e-5f33-4669-9657-8120ec568592",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 1\n",
    "all_flag = true\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_bcks_$epoch.bson\" model_CBF\n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55a69db5-6650-4687-b8a3-be7a9ec5218a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3525e54b-66eb-4315-8fcd-517cfb794e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 0.1\n",
    "all_flag = true\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_bcks_$epoch.bson\" model_CBF\n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3668d390-0e86-46d7-a168-032979de81c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89b0b541-a95c-4294-b723-4c21a85e357c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # function my_train(; )\n",
    "# cuda = true\n",
    "# η₀ = 1.0f-3\n",
    "# λ = 1.0f-4\n",
    "# total_epoch = 20\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# # pretrained_NO=nothing\n",
    "# if cuda && CUDA.has_cuda()\n",
    "#     device = gpu\n",
    "#     CUDA.allowscalar(false)\n",
    "#     @info \"Training on GPU\"\n",
    "# else\n",
    "#     device = cpu\n",
    "#     @info \"Training on CPU\"\n",
    "# end\n",
    "# @show 1\n",
    "# lr_NO = η₀\n",
    "# lr_CBF = 0.001 # NO CBF\n",
    "# lr_CBF = 0.01\n",
    "\n",
    "# lr_decay_rate = 0.2\n",
    "# lr_decay_epoch =4\n",
    "\n",
    "# train_loader, test_loader = my_get_dataloader()\n",
    "# model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "#                               σ = gelu)\n",
    "# if isnothing(pretrained_NO)\n",
    "#     model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "#                               σ = gelu)\n",
    "# else\n",
    "#     model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "# end\n",
    "# model_CBF = Chain(\n",
    "#         Dense(2 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 64, relu),   # activation function inside layer\n",
    "#         Dense(64 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 1)\n",
    "#     )\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# # model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "# #                               σ = gelu)\n",
    "# # optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "# optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# # optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "# sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "# loss_func = l₂loss\n",
    "# α = 0.00001\n",
    "# λ_pf = 1\n",
    "# λ_reg = 1\n",
    "# all_flag = false\n",
    "# minus_safe_flag = false\n",
    "\n",
    "# training_losses = []\n",
    "# test_losses = []\n",
    "# no_training_losses = []\n",
    "# no_test_losses = []\n",
    "# least_loss = 1000\n",
    "# test_loss = 0\n",
    "# loss = 0\n",
    "# for epoch in ProgressBar(1:total_epoch)\n",
    "#     training_loss_epoch = []\n",
    "#     test_loss_epoch = []\n",
    "#     no_training_loss_epoch = []\n",
    "#     no_test_loss_epoch = []\n",
    "#     for item in train_loader\n",
    "#         # x_batch = reduce(hcat,item[1,:])\n",
    "#         # u_batch = reduce(hcat,item[2,:])\n",
    "#         # y_init_batch = reduce(hcat,item[3,:])\n",
    "#         x_batch = item[1]\n",
    "#         y_batch = item[2]\n",
    "#         safe_batch = item[3]\n",
    "#         # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "#         if isnothing(pretrained_NO)\n",
    "#             # train NO\n",
    "#             NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "#                 l₂loss(m(x_batch), y_batch)\n",
    "#             end\n",
    "#             Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "#             push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "#             # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "#         end\n",
    "\n",
    "        \n",
    "#         # train CBF\n",
    "#         x = copy(y_batch)\n",
    "#         y_init = copy(safe_batch)\n",
    "#         x = vcat(x[1,:,:]...)\n",
    "#         x = reshape(x, (1, size(x)[1]))\n",
    "#         y_init = vcat(y_init...)\n",
    "\n",
    "#         U_0 = copy(x_batch)\n",
    "#         U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "#         U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "#         U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "#         U̇ = find_derivative(x_batch)\n",
    "#         extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "#         T = x_batch[1,end,1]\n",
    "#         _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "#         # dG\\du * du\\dt + dG\\dt\n",
    "#         # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "#         # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "#         # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "#         ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "#         # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "#         yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "#         ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "#         extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "#         U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "#         U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "#         extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "#         extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "#         # @show size(extended_∇Y_tt), size(yt)\n",
    "#         # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "#         # @show model_CBF(ytt)\n",
    "#         CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "#             # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "#             loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#             # sum(m(rand(1,12)))\n",
    "#         end\n",
    "        \n",
    "#         Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "#         # @show CBF_training_loss\n",
    "#         # loss = CBF_training_loss\n",
    "#         # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "#         loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "# #         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "# #         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "#         push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "#         # @show training_loss\n",
    "#     end\n",
    "#     for item in test_loader\n",
    "#         x_batch = item[1]\n",
    "#         y_batch = item[2]\n",
    "#         safe_batch = item[3]\n",
    "#         # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "#         if isnothing(pretrained_NO)\n",
    "#             push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "#         end\n",
    "\n",
    "#         x = copy(y_batch)\n",
    "#         y_init = copy(safe_batch)\n",
    "#         x = vcat(x[1,:,:]...)\n",
    "#         x = reshape(x, (1, size(x)[1]))\n",
    "#         y_init = vcat(y_init...)\n",
    "\n",
    "#         U_0 = copy(x_batch)\n",
    "#         U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "#         U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "#         U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "#         U̇ = find_derivative(x_batch)\n",
    "#         extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "#         T = x_batch[1,end,1]\n",
    "#         _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "#         # dG\\du * du\\dt + dG\\dt\n",
    "#         # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "#         # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "#         # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "#         ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "#         yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "#         ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "#         extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "#         U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "#         U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "#         extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "#         extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "#         loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "#         # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "# #         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "# #         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "#         push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "#     end\n",
    "#     nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "#     Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "#     # @show epoch, loss, test_loss\n",
    "#     # model_state = Flux.state(model)\n",
    "#     # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "#     if isnothing(pretrained_NO)\n",
    "#         @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "#     end\n",
    "#     @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_$epoch.bson\" model_CBF\n",
    "#     # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "#     push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "#     push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "# end\n",
    "# # return training_losses, test_losses\n",
    "\n",
    "\n",
    "# # learner = Learner(model, data, optimiser, loss_func,\n",
    "# #                   ToDevice(device, device))\n",
    "\n",
    "# # fit!(learner, epochs)\n",
    "# # model = learner.model |> cpu\n",
    "# # @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# # return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53f2fb6b-8953-48d8-8ea6-65c030246ed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8244380-00af-4c93-962b-a3979296d232",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 0.1\n",
    "all_flag = false\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_$epoch.bson\" model_CBF\n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b558ee-6fc0-4014-b63b-73753eb9d975",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71da76c0-ab8d-4f7e-941b-287f329f30a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 1\n",
    "all_flag = true\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_$epoch.bson\" model_CBF\n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8901499-4066-4374-a3f9-e287e234bfae",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2dda94d-6a23-498c-89ae-b7b7d50c5df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=nothing\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 0.1\n",
    "all_flag = true\n",
    "minus_safe_flag = false\n",
    "\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative, not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        # @show ytt[:, 1:10], model_CBF(ytt)[1, 1:10],model_CBF(ytt[:, 1:10])\n",
    "        # @show model_CBF(ytt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(m, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(m, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag), loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag), loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        # ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative not working... sigmoid 0.5\n",
    "        ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], y_batch, dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init;all=all_flag) + loss_naive_safeset_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init;minus_safe=minus_safe_flag)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_$epoch.bson\" model_CBF\n",
    "    # @save \"model/hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40ecab45-68d4-4042-b1b1-78f41209453f",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses./45000, test_losses./5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "297de0e2-af66-48de-8bc6-76d6706cf093",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c84c0b-4e82-4e67-b826-d0e59b59fc9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pfall_0addend_preNO20_1.bson\")[:model_CBF]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92b6bdc2-dcf9-4ec5-b627-4dbe00ce4f87",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_CBF([3, 2])# t,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0717b6c2-717e-4395-a4fa-bef9f35eda06",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "b=rand(2,3000,4)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "                                         σ = gelu)\n",
    "loss, nabla = Zygote.pullback(model_NO, b)\n",
    "Y=model_NO(b)\n",
    "Y = vcat(Y[1,:,:]...)\n",
    "Y = reshape(Y, (1, size(Y)[1]))\n",
    "@show find_derivative(cat(b[1:1,:,:], model_NO(b), dims=1))[1]\n",
    "# @show nabla(ones(size(Y)))[1][2:2, :,:] .* find_derivative(b) .+ nabla(ones(size(Y)))[1][1:1, :,:] \n",
    "@show batched_mul(reshape(nabla(ones(size(Y)))[1], (1, size(nabla(ones(size(Y)))[1])...)),  reshape(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1), (size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[1], 1, size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[2:end]...)))[1,:,:,:][]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239e7a51-a982-4c68-8296-ae086ca6d504",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb254334-6bad-47d6-aa3e-524737e5daf5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d6b8d0a-0274-433d-98d7-cdb3adeca313",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "809c59c2-65a2-4935-ae0b-da3425616c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_CBF([1.6])\n",
    "# model_CBF([0.99])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "906899e1-3114-4743-8985-dacdb533f5bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "function find_phi_dot(ϕ, Y, ∇Y_t)\n",
    "    Y = reshape(Y, (1, size(Y)...))\n",
    "    ∇Y_t = reshape(∇Y_t, (1, size(∇Y_t)...))\n",
    "    state_dim, batchsize = size(Y)\n",
    "\n",
    "    _, ∇ϕ = Zygote.pullback(ϕ, Y)\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(Y)))[1] ./ state_dim\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, state_dim, batchsize))\n",
    "    @show ∇ϕ_Y\n",
    "    ∇Y_t = reshape(∇Y_t, (state_dim, 1, batchsize))\n",
    "    \n",
    "    ϕ̇ = reshape(batched_mul(∇ϕ_Y, ∇Y_t), size(ϕ(Y)))\n",
    "    return ϕ̇\n",
    "end\n",
    "    \n",
    "# ∇Y_t = find_derivative(cat(U[1:1,:,:], model_NO(U), dims=1)) # empirical derivative\n",
    "# "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b81f709-75f6-4ccf-833c-048498d6b09d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hcat([1,1.5,1.5, 0.9, 0.9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086225b2-dbad-45da-9a83-b84c7040e4e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "find_phi_dot(model_CBF, [0.7], [0.5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc4f9875-63e4-4ea0-8dc1-04f80533692f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aef814d5-1e69-4f04-b7be-704f5fdb0025",
   "metadata": {},
   "outputs": [],
   "source": [
    "function loss_naive_safeset_NO(ϕ, x,y_init)\n",
    "    # x = copy(x_)\n",
    "    # y_init = copy(y_init_)\n",
    "    # x = vcat(x[1,:,:]...)\n",
    "    # x = reshape(x, (1, size(x)[1]))\n",
    "    # # @show size(x), size(y_init)\n",
    "    # y_init = vcat(y_init...)\n",
    "    # # y_init = y_init[1, :] # safe: 1; unsafe: 0\n",
    "    # # @show size(x), size(y_init)\n",
    "    # @show size(y_init), size(x)\n",
    "    # @show x[:, 1:10]\n",
    "    # @show y_init[1:10]\n",
    "    # index = findall(x->x>=0, y_init)\n",
    "    # # @show index\n",
    "    # size(index)[1] == 0 && return 0\n",
    "    # x = x[:, index]\n",
    "    # y_init = y_init[index]\n",
    "    pred = ϕ(x) \n",
    "    pred = reshape(pred, (size(y_init)...)) # NO\n",
    "    loss = relu((2 .* y_init .- 1) .* pred .+ 1e-6)\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function loss_regularization_NO(ϕ, x::AbstractArray,y_init::AbstractArray)\n",
    "    # x = copy(x_)\n",
    "    # y_init = copy(y_init_)\n",
    "    # # @show size(x)\n",
    "    # x = vcat(x[1,:,:]...)\n",
    "    # x = reshape(x, (1, size(x)[1]))\n",
    "    # y_init = vcat(y_init...)\n",
    "    # y_init = y_init[1, :] # safe: 1; unsafe: 0\n",
    "    # index = findall(x->x>=0, y_init)\n",
    "    # # @show index\n",
    "    # size(index)[1] == 0 && return 0\n",
    "    # x = x[:, index]\n",
    "    # y_init = y_init[index]\n",
    "\n",
    "    pred = ϕ(x)\n",
    "    pred = reshape(pred, (size(y_init)...)) # NO\n",
    "    loss = sigmoid_fast((2 .* y_init .- 1) .* pred)\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function loss_pf_NO(ϕ, yt, extended_∇Y_t, U_0t,T, α,y_init)\n",
    "    # U = copy(U_)\n",
    "    # Y = copy(Y_)\n",
    "    # Y = vcat(Y[1,:,:]...)\n",
    "    # Y = reshape(Y, (1, size(Y)[1]))\n",
    "    # y_init = vcat(y_init[1,:,:]...)\n",
    "    \n",
    "    # @show size(ϕ(x)), size(x)\n",
    "    # index = findall(x1->[1 -1] * softmax(ϕ(x1))>ϵ, x)\n",
    "    # ∇Y_t = reshape(∇Y_t, size(Y))\n",
    "    pred = ϕ(yt)\n",
    "    _, ∇ϕ = Zygote.pullback(ϕ, yt)\n",
    "    # dG\\du * du\\dt + dG\\dt\n",
    "    # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(pred)))[1]\n",
    "    return sum(∇ϕ_Y)\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, size(∇ϕ_Y)...))\n",
    "    # extended_∇Y_t = reshape(extended_∇Y_t, (NO_input_dim, size(pred)...))\n",
    "    @show size(∇ϕ_Y), size(extended_∇Y_t)\n",
    "    ϕ̇ = batched_mul(∇ϕ_Y,  extended_∇Y_t)\n",
    "    ϕ̇ = reshape(ϕ̇, (size(y_init)...))\n",
    "    @show size(ϕ̇)\n",
    "    # pred = ϕ(yt)\n",
    "    pred = reshape(pred, (size(y_init)...))\n",
    "\n",
    "    pred0 = ϕ(U_0t)\n",
    "    pred0 = reshape(pred0, (size(y_init)...))\n",
    "    \n",
    "    C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "    # @show C, C .* ϕ(U_0)\n",
    "    l = ϕ̇ .+ α .* pred .+ C .* pred0 # \n",
    "    loss = relu(l .+ 1e-6)\n",
    "    # @show loss\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b88604e-07ab-4139-91a9-44ee3e3858ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "# model_CBF = Chain(\n",
    "#         Dense(2 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 64, relu),   # activation function inside layer\n",
    "#         Dense(64 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 1)\n",
    "#     )\n",
    "model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "                              σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 0.1\n",
    "λ_reg = 1\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        # @show size(extended_∇Y_t)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            # loss_naive_safeset(m, x, y_init)  +  λ_reg .* loss_regularization(m, x, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        # loss = loss_naive_safeset_NO(model_CBF, yt, y_init)  +  λ_reg .* loss_regularization_NO(model_CBF, yt, y_init) + λ_pf .* loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "        # @show loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init), loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(U[1:1,:,:], model_NO(U), dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "\n",
    "        loss = loss_naive_safeset_NO(model_CBF, yt, y_init)  +  λ_reg .* loss_regularization_NO(model_CBF, yt, y_init) + λ_pf .* loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "        @show loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init), loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    # nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    # Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_1reg_1pf_NOCBF_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch)) \n",
    "    push!(test_losses, sum(test_loss_epoch))\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5180ba38-433e-4e8b-be7f-d71a63afad92",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "i = 1\n",
    "p1 = plot(input_data[1, :, 1], ground_truth[1, :, i], label = \"ground_truth\",\n",
    "     title = \"                                              Burgers equation u(x,T_end)\");\n",
    "# p1 = plot!(input_data[1, :, 1], m(view(input_data, :, :, i:i))[1, :, 1], label = \"predict\");\n",
    "\n",
    "p2 = plot(input_data[1, :, 1], ground_truth[1, :, i + 1], label = \"ground_truth\");\n",
    "# p2 = plot!(input_data[1, :, 1], m(view(input_data, :, :, (i + 1):(i + 1)))[1, :, 1],\n",
    "#            label = \"predict\");\n",
    "i = 3\n",
    "\n",
    "p3 = plot(input_data[1, :, 1], ground_truth[1, :, i], label = \"ground_truth\");\n",
    "# p3 = plot!(input_data[1, :, 1], m(view(input_data, :, :, i:i))[1, :, 1], label = \"predict\");\n",
    "\n",
    "p4 = plot(input_data[1, :, 1], ground_truth[1, :, i + 1], label = \"ground_truth\");\n",
    "# p4 = plot!(input_data[1, :, 1], m(view(input_data, :, :, (i + 1):(i + 1)))[1, :, 1],\n",
    "#            label = \"predict\");\n",
    "p = plot(p1, p2, p3, p4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd29d7d7-e1e6-4be3-90a0-36115a91fc69",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Burgers\n",
    "using FluxTraining\n",
    "using Test\n",
    "\n",
    "@testset \"Burgers\" begin\n",
    "    xs, ys = Burgers.get_data(n = 1000)\n",
    "\n",
    "    @test size(xs) == (2, 1024, 1000)\n",
    "    @test size(ys) == (1, 1024, 1000)\n",
    "\n",
    "    learner = Burgers.train(epochs = 100)\n",
    "    loss = learner.cbstate.metricsepoch[ValidationPhase()][:Loss].values[end]\n",
    "    @test loss < 0.1\n",
    "\n",
    "    # include(\"deeponet.jl\")\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "928686e9-3a7c-45ee-803e-37d7e80eea7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998eae74-ea3f-4a18-8daf-6995b7cfe5e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "@testset \"Burger: NOMAD Training Accuracy\" begin\n",
    "    ϵ = Burgers.train_nomad(; cuda = true, epochs = 100)\n",
    "    @test ϵ < 0.4 # epoch=100 returns 0.233\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9217960-00f6-4f0a-95c3-b071e04129fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "Burgers.train_don()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c6dce6-f0bd-4999-914a-ac3c394c0334",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.9.4",
   "language": "julia",
   "name": "julia-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
