{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf555e98-8e13-4f75-8272-4984ca0d448b",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Revise\n",
    "using Burgers#, Plots\n",
    "using DataDeps, MAT, MLUtils\n",
    "using NeuralOperators, Flux\n",
    "using BSON\n",
    "using DataDeps, MAT, MLUtils\n",
    "using NeuralOperators, Flux\n",
    "using CUDA, FluxTraining, BSON\n",
    "import Flux: params\n",
    "using BSON: @save, @load\n",
    "using ProgressBars\n",
    "using Zygote\n",
    "using Optimisers, ParameterSchedulers\n",
    "using LazySets\n",
    "using Burgers\n",
    "using FluxTraining\n",
    "using Plots\n",
    "using NPZ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38d0bc45-89b3-432f-a24d-85cb2efc2558",
   "metadata": {},
   "outputs": [],
   "source": [
    "function my_get_data(file_path; n = 50000, Δsamples = 1, grid_size = div(51, Δsamples), T = Float32)\n",
    "# function my_get_data(file_path; n = 2048, Δsamples = 2^3, grid_size = div(2^13, Δsamples), T = Float32)\n",
    "    # file = matopen(joinpath(datadep\"Burgers\", \"burgers_data_R10.mat\"))\n",
    "    file = matopen(file_path)\n",
    "    \n",
    "    x_data = T.(collect(read(file, \"a\")[1:n, 1:Δsamples:end]'))\n",
    "    y_data = T.(collect(read(file, \"u\")[1:n, 1:Δsamples:end]'))\n",
    "    safe_labels = T.(collect(read(file, \"safe\")[1:n, 1:Δsamples:end]'))\n",
    "    pf_labels = T.(collect(read(file, \"pf\")[1:n, 1:Δsamples:end]'))\n",
    "    close(file)\n",
    "\n",
    "    x_loc_data = Array{T, 3}(undef, 2, grid_size, n)\n",
    "    x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 5, grid_size), n), (grid_size, n))\n",
    "    x_loc_data[2, :, :] .= x_data\n",
    "\n",
    "    return x_loc_data, reshape(y_data, 1, :, n), safe_labels, pf_labels\n",
    "end\n",
    "\n",
    "function my_get_dataloader(; ratio::Float64 = 0.9, batchsize = 1)\n",
    "    𝐱1, 𝐲1, safe1, pf1 = my_get_data(\"data_bcks_hyperbolic_1.mat\") # data_bcks_hyperbolic_1_new.mat _minus\n",
    "    \n",
    "    data_train1, data_test1 = splitobs((𝐱1, 𝐲1, safe1, pf1), at = ratio)\n",
    "    𝐱2, 𝐲2, safe2, pf2 = my_get_data(\"data_ppo_hyperbolic_1.mat\")\n",
    "    \n",
    "    data_train2, data_test2 = splitobs((𝐱2, 𝐲2, safe2, pf2), at = ratio)\n",
    "    𝐱3, 𝐲3, safe3, pf3 = my_get_data(\"data_sac_hyperbolic_1.mat\")\n",
    "    \n",
    "    data_train3, data_test3 = splitobs((𝐱3, 𝐲3, safe3, pf3), at = ratio)\n",
    "\n",
    "    @show size(data_train3[1]), size(data_test3[2])\n",
    "    data_train1_x_pf = data_train1[1][:,:,:]\n",
    "    data_test1_x_pf = data_test1[1][:,:,:]\n",
    "    data_train1_y_pf = data_train1[2][:,:,:]\n",
    "    data_test1_y_pf = data_test1[2][:,:,:]\n",
    "    data_train1_safe_pf = data_train1[3][:,:]\n",
    "    data_test1_safe_pf = data_test1[3][:,:]\n",
    "\n",
    "    data_train2_x_pf = data_train2[1][:,:,:]\n",
    "    data_test2_x_pf = data_test2[1][:,:,:]\n",
    "    data_train2_y_pf = data_train2[2][:,:,:]\n",
    "    data_test2_y_pf = data_test2[2][:,:,:]\n",
    "    data_train2_safe_pf = data_train2[3][:,:]\n",
    "    data_test2_safe_pf = data_test2[3][:,:]\n",
    "\n",
    "    data_train3_x_pf = data_train3[1][:,:,:]\n",
    "    data_test3_x_pf = data_test3[1][:,:,:]\n",
    "    data_train3_y_pf = data_train3[2][:,:,:]\n",
    "    data_test3_y_pf = data_test3[2][:,:,:]\n",
    "    data_train3_safe_pf = data_train3[3][:,:]\n",
    "    data_test3_safe_pf = data_test3[3][:,:]\n",
    "    \n",
    "    # data_train1_x_pf = data_train1[1][:,:,(data_train1[4][1,:].==0)]\n",
    "    # data_test1_x_pf = data_test1[1][:,:,(data_test1[4][1,:].==0)]\n",
    "    # data_train1_y_pf = data_train1[2][:,:,(data_train1[4][1,:].==0)]\n",
    "    # data_test1_y_pf = data_test1[2][:,:,(data_test1[4][1,:].==0)]\n",
    "    # data_train1_safe_pf = data_train1[3][:,(data_train1[4][1,:].==0)]\n",
    "    # data_test1_safe_pf = data_test1[3][:,(data_test1[4][1,:].==0)]\n",
    "\n",
    "    # data_train2_x_pf = data_train2[1][:,:,(data_train2[4][1,:].==0)]\n",
    "    # data_test2_x_pf = data_test2[1][:,:,(data_test2[4][1,:].==0)]\n",
    "    # data_train2_y_pf = data_train2[2][:,:,(data_train2[4][1,:].==0)]\n",
    "    # data_test2_y_pf = data_test2[2][:,:,(data_test2[4][1,:].==0)]\n",
    "    # data_train2_safe_pf = data_train2[3][:,(data_train2[4][1,:].==0)]\n",
    "    # data_test2_safe_pf = data_test2[3][:,(data_test2[4][1,:].==0)]\n",
    "\n",
    "    # data_train3_x_pf = data_train3[1][:,:,(data_train3[4][1,:].==0)]\n",
    "    # data_test3_x_pf = data_test3[1][:,:,(data_test3[4][1,:].==0)]\n",
    "    # data_train3_y_pf = data_train3[2][:,:,(data_train3[4][1,:].==0)]\n",
    "    # data_test3_y_pf = data_test3[2][:,:,(data_test3[4][1,:].==0)]\n",
    "    # data_train3_safe_pf = data_train3[3][:,(data_train3[4][1,:].==0)]\n",
    "    # data_test3_safe_pf = data_test3[3][:,(data_test3[4][1,:].==0)]\n",
    "\n",
    "    @show size(data_train1_x_pf), size(data_train2_x_pf), size(data_train3_x_pf)\n",
    "    @show size(data_test1_x_pf), size(data_test2_x_pf), size(data_test3_x_pf)\n",
    "\n",
    "\n",
    "\n",
    "    data_train = (cat(cat(data_train1_x_pf, data_train2_x_pf, dims=3), data_train3_x_pf, dims=3), \n",
    "                    cat(cat(data_train1_y_pf, data_train2_y_pf, dims=3), data_train3_y_pf, dims=3), \n",
    "                    cat(cat(data_train1_safe_pf, data_train2_safe_pf, dims=2), data_train3_safe_pf, dims=2)) # omit the last pf tumple\n",
    "    data_test = (cat(cat(data_test1_x_pf, data_test2_x_pf, dims=3), data_test3_x_pf, dims=3), \n",
    "                cat(cat(data_test1_y_pf, data_test2_y_pf, dims=3), data_test3_y_pf, dims=3), \n",
    "                cat(cat(data_test1_safe_pf, data_test2_safe_pf, dims=2), data_test3_safe_pf, dims=2)) # # omit the last pf tumple\n",
    "    loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = false)\n",
    "    loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)\n",
    "\n",
    "    return loader_train, loader_test\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a5135bc-4000-4f86-bf28-f9a56a9ded88",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6234f0c-bafc-4a91-b737-61ab953af6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "function find_derivative_1step(vector)\n",
    "    M, N = size(vector)[2], size(vector)[3]\n",
    "\n",
    "    # Assume `vector` is the (2, M, N) array\n",
    "    inputs = vector[1, :, :]  # Shape (M, N)\n",
    "    outputs = vector[2, :, :]  # Shape (M, N)\n",
    "\n",
    "    # Preallocate the derivative array with shape (1, M, N)\n",
    "    derivatives = zeros(Float64, 1, M, N)\n",
    "\n",
    "    # 1-step forward finite difference for all points from 1 to M-1\n",
    "    derivatives[1, 1:M-1, :] = (outputs[2:M, :] .- outputs[1:M-1, :]) ./ (inputs[2:M, :] .- inputs[1:M-1, :])\n",
    "\n",
    "    # 1-step backward finite difference for the last point\n",
    "    derivatives[1, M, :] = (outputs[M, :] .- outputs[M-1, :]) ./ (inputs[M, :] .- inputs[M-1, :])\n",
    "\n",
    "    # `derivatives` now contains the derivative of the output with respect to the input\n",
    "    # with shape (1, M, N)\n",
    "    return derivatives\n",
    "end\n",
    "\n",
    "function get_model(name)\n",
    "    model_path = joinpath(@__DIR__, \"./model/\")\n",
    "    @assert name in readdir(model_path)\n",
    "    model_file = name\n",
    "    return BSON.load(joinpath(model_path, model_file), @__MODULE__)\n",
    "end\n",
    "\n",
    "function reconstruct_traj(U_0, T_traj, U_dot_traj)\n",
    "    # Get the dimensions of the input trajectory\n",
    "    _, M, N = size(T_traj)\n",
    "\n",
    "    # Preallocate the reconstructed trajectory array with shape (1, M, N)\n",
    "    U_traj = zeros(Float64, 1, M, N)\n",
    "\n",
    "    # Set the initial value for all trajectories\n",
    "    U_traj[1, 1, :] .= U_0\n",
    "\n",
    "    # Compute the time differences dt between consecutive time steps\n",
    "    dt = T_traj[1, 2:end, :] .- T_traj[1, 1:end-1, :]\n",
    "    # @show size(dt)\n",
    "\n",
    "    # Calculate the cumulative sum of derivatives multiplied by dt\n",
    "    # This gives the increment to add to the initial value at each step\n",
    "    changes = U_dot_traj[1, 1:end-1, :] .* dt\n",
    "    # @show changes\n",
    "    # Compute the cumulative sum of changes\n",
    "    cumulative_changes = cumsum(changes, dims=1)\n",
    "    # @show cumulative_changes\n",
    "    # Reconstruct the trajectory\n",
    "    U_traj[1, 2:end, :] .= U_0 .+ cumulative_changes\n",
    "    # @show size(U_traj)\n",
    "    # Return the reconstructed trajectory\n",
    "    return U_traj\n",
    "end\n",
    "\n",
    "function reconstruct_traj_central(U_0, T_traj, U_dot_traj)\n",
    "    # Get the dimensions of the input trajectory\n",
    "    _, M, N = size(T_traj)\n",
    "\n",
    "    # Preallocate the reconstructed trajectory array with shape (1, M, N)\n",
    "    U_traj = zeros(Float64, 1, M, N)\n",
    "\n",
    "    # Set the initial value for all trajectories\n",
    "    U_traj[1, 1, :] .= U_0\n",
    "\n",
    "    # Compute the time differences dt for central differences\n",
    "    dt_central = T_traj[1, 3:end, :] .- T_traj[1, 1:end-2, :]  # Time differences for central points\n",
    "\n",
    "    # Calculate the changes using central differences for the interior points (2 to M-1)\n",
    "    changes_central = U_dot_traj[1, 2:M-1, :] .* dt_central\n",
    "\n",
    "    # Forward difference for the first point\n",
    "    dt_forward = T_traj[1, 2, :] .- T_traj[1, 1, :]\n",
    "    changes_forward = U_dot_traj[1, 1, :] .* dt_forward\n",
    "\n",
    "    # Backward difference for the last point\n",
    "    dt_backward = T_traj[1, M, :] .- T_traj[1, M-1, :]\n",
    "    changes_backward = U_dot_traj[1, M, :] .* dt_backward\n",
    "\n",
    "    # Reconstruct the trajectory using both forward, central, and backward differences\n",
    "    U_traj[1, 2, :] .= U_traj[1, 1, :] .+ changes_forward  # Forward step for the first point\n",
    "    for i in eachindex(changes_central)\n",
    "        U_traj[1, 2+i, :] .= U_traj[1, i, :] .+ changes_central[i,:]\n",
    "    end\n",
    "    # even_numbers = filter(iseven, 1:size(changes_central)[1])\n",
    "    # odd_numbers = filter(isodd, 1:size(changes_central)[1])\n",
    "    # U_traj[1, 3:M, :][odd_numbers,:] .= U_traj[1, 1, :] .+ cumsum(changes_central[odd_numbers,:], dims=1) \n",
    "    # U_traj[1, 3:M, :][even_numbers,:] .= U_traj[1, 2, :] .+ cumsum(changes_central[even_numbers,:], dims=1) # Central steps for the interior points\n",
    "    # @show U_traj[1, M, :] .- (U_traj[1, M-1, :] .+ changes_backward)\n",
    "    U_traj[1, M, :] .= U_traj[1, M-1, :] .+ changes_backward  # Backward step for the last point\n",
    "\n",
    "    # Return the reconstructed trajectory\n",
    "    return U_traj\n",
    "end\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06a466a1-3eb5-4c46-a339-9c85f5e87e74",
   "metadata": {},
   "outputs": [],
   "source": [
    "a=\"www\"\n",
    "basename(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ab8103-dfce-41d7-85d4-fa6ed3878002",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using CPLEX\n",
    "\n",
    "# Define a simple QP problem for CPLEX\n",
    "function solve_multistep_qp_with_cplex(U_dot_nominal,G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=false,diff_U_dot_threshold=Inf)\n",
    "    n_steps = size(U_dot_nominal)[1]\n",
    "    # @show phi_t .+ phi_Y .* (G_t .+ G_u .* U_dot_nominal) .+ α * phiY .+ C * phiU_0\n",
    "    # Define the model using CPLEX\n",
    "    model = Model(CPLEX.Optimizer)\n",
    "    set_silent(model)\n",
    "    # @show size(phi_Y)\n",
    "    # @show size(phiY)\n",
    "    # @show size(phiU_0)\n",
    "    if opt_alpha\n",
    "        @variable(model, x[1:n_steps+1])  # Decision variables x1\n",
    "        @objective(model, Min, (x[1:end-1]' .- U_dot_nominal') * (x[1:end-1] .- U_dot_nominal))\n",
    "        @constraint(model, phi_t .+ phi_Y .* (G_t .+ G_u .* x[1:end-1]) .+ x[end] .* phiY .+ (1/T) * phiU_0 .<= 0)\n",
    "        @constraint(model, x[end]  .>= 0)\n",
    "    else\n",
    "        C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "        @variable(model, x[1:n_steps])  # Decision variables x1\n",
    "        @objective(model, Min, (x' .- U_dot_nominal') * (x .- U_dot_nominal))\n",
    "        @constraint(model, phi_t .+ phi_Y .* (G_t .+ G_u .* x) .+ α .* phiY .+ C * phiU_0 .<= 0)\n",
    "    end\n",
    "    # @constraint(model, x .- U_dot_nominal  .<= 20)\n",
    "    # @constraint(model, U_dot_nominal .- x   .<= 20)\n",
    "    \n",
    "    # Solve the QP using CPLEX\n",
    "    # println(\"Solving with CPLEX...\")\n",
    "    optimize!(model)\n",
    "    \n",
    "    # Get the solution\n",
    "    solution_cplex = value.(x)\n",
    "    # println(\"Solution using CPLEX: \", solution_cplex)\n",
    "    # opt_alpha && (@show solution_cplex[end])\n",
    "    if abs(solution_cplex[1] - U_dot_nominal[1]) > diff_U_dot_threshold\n",
    "        return U_dot_nominal[1]\n",
    "    else\n",
    "        return solution_cplex[1]\n",
    "    end\n",
    "    # return solution_cplex[1]\n",
    "end\n",
    "\n",
    "function solve_multistep_qp_with_cplex_cbfnoT(U_dot_nominal,G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=false,diff_U_dot_threshold=Inf)\n",
    "    n_steps = size(U_dot_nominal)[1]\n",
    "    # @show phi_t .+ phi_Y .* (G_t .+ G_u .* U_dot_nominal) .+ α * phiY .+ C * phiU_0\n",
    "    # Define the model using CPLEX\n",
    "    model = Model(CPLEX.Optimizer)\n",
    "    set_silent(model)\n",
    "    # @show size(phi_Y)\n",
    "    # @show size(phiY)\n",
    "    # @show size(phiU_0)\n",
    "    if opt_alpha\n",
    "        @variable(model, x[1:n_steps+1])  # Decision variables x1\n",
    "        @objective(model, Min, (x[1:end-1]' .- U_dot_nominal') * (x[1:end-1] .- U_dot_nominal))\n",
    "        @constraint(model, phi_Y .* (G_t .+ G_u .* x[1:end-1]) .+ x[end] .* phiY .+ (1/T) * phiU_0 .<= 0)\n",
    "        @constraint(model, x[end]  .>= 0)\n",
    "    else\n",
    "        C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "        @variable(model, x[1:n_steps])  # Decision variables x1\n",
    "        @objective(model, Min, (x' .- U_dot_nominal') * (x .- U_dot_nominal))\n",
    "        @constraint(model, phi_t .+ phi_Y .* (G_t .+ G_u .* x) .+ α .* phiY .+ C * phiU_0 .<= 0)\n",
    "    end\n",
    "    # @constraint(model, x .- U_dot_nominal  .<= 20)\n",
    "    # @constraint(model, U_dot_nominal .- x   .<= 20)\n",
    "    \n",
    "    # Solve the QP using CPLEX\n",
    "    # println(\"Solving with CPLEX...\")\n",
    "    optimize!(model)\n",
    "    \n",
    "    # Get the solution\n",
    "    solution_cplex = value.(x)\n",
    "    # println(\"Solution using CPLEX: \", solution_cplex)\n",
    "    # opt_alpha && (@show solution_cplex[end])\n",
    "    if abs(solution_cplex[1] - U_dot_nominal[1]) > diff_U_dot_threshold\n",
    "        return U_dot_nominal[1]\n",
    "    else\n",
    "        return solution_cplex[1]\n",
    "    end\n",
    "    # return solution_cplex[1]\n",
    "end\n",
    "\n",
    "\n",
    "# function solve_multistep_qp_with_cplex_adapt(U_dot_nominal,G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "#     C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "#     n_steps = size(U_dot_nominal)[1]\n",
    "#     # @show phi_t .+ phi_Y .* (G_t .+ G_u .* U_dot_nominal) .+ α * phiY .+ C * phiU_0\n",
    "#     # Define the model using CPLEX\n",
    "#     model = Model(CPLEX.Optimizer)\n",
    "#     set_silent(model)\n",
    "#     @variable(model, x[1:n_steps+1])  # Decision variables x1\n",
    "#     @objective(model, Min, (x[1:end-1]' .- U_dot_nominal') * (x[1:end-1] .- U_dot_nominal))\n",
    "#     @constraint(model, phi_t .+ phi_Y .* (G_t .+ G_u .* x[1:end-1]) .+ x[end] .* phiY .+ (1/T) * phiU_0 .<= 0)\n",
    "#     @constraint(model, x[end]  .>= 0)\n",
    "    \n",
    "#     # @constraint(model, x[1:end-1] .- U_dot_nominal  .<= 20)\n",
    "#     # @constraint(model, U_dot_nominal .- x[1:end-1]   .<= 20)\n",
    "    \n",
    "#     # Solve the QP using CPLEX\n",
    "#     # println(\"Solving with CPLEX...\")\n",
    "#     optimize!(model)\n",
    "    \n",
    "#     # Get the solution\n",
    "#     solution_cplex = value.(x)\n",
    "#     # println(\"Solution using CPLEX: \", solution_cplex)\n",
    "#     if abs(solution_cplex[1] - U_dot_nominal[1]) > 5\n",
    "#         return U_dot_nominal[1]\n",
    "#     else\n",
    "#         return solution_cplex[1]\n",
    "#     end\n",
    "#     # return solution_cplex[1]\n",
    "# end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b24558b-3608-4444-9c8a-4ad13ea2721a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "get_model(\"hyper_FNO_all.bson\")[:model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fb6b3ee-3b42-4914-bc58-0b24836426b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "# pretrained_NO=\"hyper_MNO_all.bson\"\n",
    "# pretrained_NO=\"hyper_NOMAD.bson\" #cannot work with zygote\n",
    "# pretrained_NO=\"hyper_DON.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "\n",
    "# model_NO = NOMAD((51, 51), (102, 51), gelu, gelu)\n",
    "# model_NO = get_model(pretrained_NO)[:m]\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_C0_20.bson\")[:model_CBF] \n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 2 # 2: best for unsafe trajectories\n",
    "\n",
    "# for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 20\n",
    "\n",
    "\n",
    "filter_result_list_ppo = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "for item in test_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 45000)\n",
    "    #     continue\n",
    "    # elseif (j_index > 45100) \n",
    "    #     break\n",
    "    # end\n",
    "    # for test_laoder\n",
    "    if (j_index <= 5000)\n",
    "        # @show j_index\n",
    "        continue\n",
    "    elseif (j_index > 5100) \n",
    "        # @show j_index\n",
    "        break\n",
    "    end\n",
    "    @show j_index, ind\n",
    "    @show ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        # @show Ut_pred_traj\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show Yt_pred_traj\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        @show phiU_0\n",
    "        phiU_0 = zeros(size(phiU_0))\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_ppo, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_ppo, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "\n",
    "U_safe = [filter_result_list_ppo[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_ppo[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_ppo[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_ppo[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "# matwrite(\"hyperbolic_ppo_1000.mat\", Dict(\n",
    "# \t\"U_safe\" => U_safe,\n",
    "# \t\"U_nominal\" => U_nominal,\n",
    "#     \"Y_nominal\" => Y_nominal,\n",
    "#     \"safe_label\" => safe_label\n",
    "# ))\n",
    "npzwrite(\"hyperbolic_ppo_all_nominal_100_2__trainCO_C0.npy\", Dict( # the default cbf is hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dc1c25f-67d0-45a0-8c22-7c5997c865c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ablation ppo\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "# pretrained_NO=\"hyper_NOMAD.bson\" #cannot work with zygote\n",
    "# pretrained_NO=\"hyper_DON.bson\" #cannot work with zygote\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "\n",
    "# model_NO = NOMAD((51, 51), (102, 51), gelu, gelu)\n",
    "# model_NO = get_model(pretrained_NO)[:m]\n",
    "\n",
    "model_CBF = Chain(\n",
    "        Dense(1 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_CBFnoNOfixed_pf52_addend_preNO20_alldata_20.bson\")[:model_CBF] \n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.5 # 2: best for unsafe trajectories\n",
    "\n",
    "# for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 20\n",
    "\n",
    "\n",
    "filter_result_list_ppo = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "for item in test_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 45000)\n",
    "    #     continue\n",
    "    # elseif (j_index > 45100) \n",
    "    #     break\n",
    "    # end\n",
    "    # for test_laoder\n",
    "    if (j_index <= 5000)\n",
    "        # @show j_index\n",
    "        continue\n",
    "    elseif (j_index > 5100) \n",
    "        # @show j_index\n",
    "        break\n",
    "    end\n",
    "    @show j_index, ind\n",
    "    @show ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        # @show Ut_pred_traj\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show Yt_pred_traj\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        Yt = Yt[2:2, :]\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        # @show size(Yt)\n",
    "        \n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y),state_dim\n",
    "        phi_t = nothing\n",
    "        phi_Y = ∇ϕ_Y[1, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_cbfnoT(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_ppo, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_ppo, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "\n",
    "\n",
    "U_safe = [filter_result_list_ppo[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_ppo[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_ppo[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_ppo[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "# matwrite(\"hyperbolic_ppo_1000.mat\", Dict(\n",
    "# \t\"U_safe\" => U_safe,\n",
    "# \t\"U_nominal\" => U_nominal,\n",
    "#     \"Y_nominal\" => Y_nominal,\n",
    "#     \"safe_label\" => safe_label\n",
    "# ))\n",
    "npzwrite(\"hyperbolic_ppo_all_nominal_100_0.5_abl_noT.npy\", Dict( # the default cbf is hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n",
    "\n",
    "\n",
    "# U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "# U_safe = reduce(hcat, U_safe)\n",
    "# @show U_safe[:,1]\n",
    "# U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "# U_nominal = reduce(hcat, U_nominal)\n",
    "# @show U_nominal[:,1]\n",
    "# Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "# Y_nominal = reduce(hcat, Y_nominal)\n",
    "# @show Y_nominal[:,1]\n",
    "# safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "# safe_label = reduce(hcat, safe_label)\n",
    "# @show safe_label[:,1]\n",
    "# # @show typeof(U_safe), typeof(U_nominal)\n",
    "# @show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "# npzwrite(\"hyperbolic_sac_allmixed10_train_nonominal_100_0.5.npy\", Dict(\n",
    "# \t\"U_safe\" => U_safe,\n",
    "# \t\"U_nominal\" => U_nominal,\n",
    "#     \"Y_nominal\" => Y_nominal,\n",
    "#     \"safe_label\" => safe_label\n",
    "# ))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3be06a79-ec8f-47de-8f47-35d21491d48e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ablation sac\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        Yt = Yt[2:2, :]\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        # @show size(Yt)\n",
    "        \n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y),state_dim\n",
    "        phi_t = nothing\n",
    "        phi_Y = ∇ϕ_Y[1, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal_cbfnoT(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20_abl_noT.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4376a8ef-6b0f-48c9-929b-db9d9c5d868d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ablation ppo\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "# pretrained_NO=\"hyper_MNO_all.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "# pretrained_NO=\"hyper_NOMAD.bson\" #cannot work with zygote\n",
    "# pretrained_NO=\"hyper_DON.bson\" #cannot work with zygote\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "\n",
    "# model_NO = NOMAD((51, 51), (102, 51), gelu, gelu)\n",
    "# model_NO = get_model(pretrained_NO)[:m]\n",
    "\n",
    "model_CBF = Chain(\n",
    "        Dense(1 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_CBFnoNOfixed_pf52_addend_preNO20_alldata_20.bson\")[:model_CBF] \n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 2 # 2: best for unsafe trajectories\n",
    "\n",
    "# for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 20\n",
    "\n",
    "\n",
    "filter_result_list_ppo = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "for item in test_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 45000)\n",
    "    #     continue\n",
    "    # elseif (j_index > 45100) \n",
    "    #     break\n",
    "    # end\n",
    "    # for test_laoder\n",
    "    if (j_index <= 5000)\n",
    "        # @show j_index\n",
    "        continue\n",
    "    elseif (j_index > 5100) \n",
    "        # @show j_index\n",
    "        break\n",
    "    end\n",
    "    @show j_index, ind\n",
    "    @show ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        # @show Ut_pred_traj\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show Yt_pred_traj\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        Yt = Yt[2:2, :]\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        # @show size(Yt)\n",
    "        \n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y),state_dim\n",
    "        phi_t = nothing\n",
    "        phi_Y = ∇ϕ_Y[1, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([U_0])[1]\n",
    "        phiU_0 = zeros(size(phiU_0))\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_cbfnoT(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_ppo, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_ppo, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "\n",
    "\n",
    "U_safe = [filter_result_list_ppo[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_ppo[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_ppo[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_ppo[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "# matwrite(\"hyperbolic_ppo_1000.mat\", Dict(\n",
    "# \t\"U_safe\" => U_safe,\n",
    "# \t\"U_nominal\" => U_nominal,\n",
    "#     \"Y_nominal\" => Y_nominal,\n",
    "#     \"safe_label\" => safe_label\n",
    "# ))\n",
    "npzwrite(\"hyperbolic_ppo_all_nominal_100_2_abl_noT_C0.npy\", Dict( # the default cbf is hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n",
    "\n",
    "\n",
    "# U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "# U_safe = reduce(hcat, U_safe)\n",
    "# @show U_safe[:,1]\n",
    "# U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "# U_nominal = reduce(hcat, U_nominal)\n",
    "# @show U_nominal[:,1]\n",
    "# Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "# Y_nominal = reduce(hcat, Y_nominal)\n",
    "# @show Y_nominal[:,1]\n",
    "# safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "# safe_label = reduce(hcat, safe_label)\n",
    "# @show safe_label[:,1]\n",
    "# # @show typeof(U_safe), typeof(U_nominal)\n",
    "# @show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "# npzwrite(\"hyperbolic_sac_allmixed10_train_nonominal_100_0.5.npy\", Dict(\n",
    "# \t\"U_safe\" => U_safe,\n",
    "# \t\"U_nominal\" => U_nominal,\n",
    "#     \"Y_nominal\" => Y_nominal,\n",
    "#     \"safe_label\" => safe_label\n",
    "# ))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aad1f06-6756-4c35-bda2-4b7e21c50b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ablation sac\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "# pretrained_NO=\"hyper_MNO_all.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        Yt = Yt[2:2, :]\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        # @show size(Yt)\n",
    "        \n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y),state_dim\n",
    "        phi_t = nothing\n",
    "        phi_Y = ∇ϕ_Y[1, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([U_0])[1]\n",
    "        phiU_0 = zeros(size(phiU_0))\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal_cbfnoT(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20_abl_noT_C0.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc6970b3-46e0-48ca-8bd9-93d027705281",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "\n",
    "\n",
    "# matwrite(\"hyperbolic_sac_all.mat\", )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8415413-a9c2-4e82-9d23-04a1ae631659",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.5\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90010) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c4ccce-c8d6-4211-9fd3-478385076d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90050) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_allmixed50_train_nonominal_100_0.1.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75e8d62-26c9-4742-b9c2-208043ebdffe",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90030) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_allmixed30_train_nonominal_100_0.1.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a67b4b25-d1b4-4791-a67f-5a6603df207f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 5\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90050) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_allmixed50_train_nominal_100_5_pf52_addend_1reg_abs.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028651b6-58ca-4638-95b7-f5e737bd5fb8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57818ff7-8c46-4478-92c3-34349a4c6296",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a2dff0-ab9b-4f60-a32f-998af53009a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "function solve_multistep_qp_with_cplex_nonominal(U_dot_nominal,G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=false,diff_U_dot_threshold=Inf)\n",
    "    n_steps = size(U_dot_nominal)[1]\n",
    "    # @show phi_t .+ phi_Y .* (G_t .+ G_u .* U_dot_nominal) .+ α * phiY .+ C * phiU_0\n",
    "    # Define the model using CPLEX\n",
    "    model = Model(CPLEX.Optimizer)\n",
    "    set_silent(model)\n",
    "    if opt_alpha\n",
    "        obj_weight = ones(n_steps)\n",
    "        obj_weight[1] *= 1\n",
    "        obj_weight = obj_weight .* phi_Y .* G_u\n",
    "        @variable(model, x[1:n_steps+1])  # Decision variables x1\n",
    "        @objective(model, Min, (x[1:end-1]') * obj_weight )\n",
    "        @constraint(model, (x[1:end-1] .- U_dot_nominal) .<= diff_U_dot_threshold)\n",
    "        @constraint(model, (U_dot_nominal .- x[1:end-1]) .<= diff_U_dot_threshold)\n",
    "        # @objective(model, Min, (x[1:end-1]' .- U_dot_nominal') * (x[1:end-1] .- U_dot_nominal))\n",
    "        # @constraint(model, phi_t .+ phi_Y .* (G_t .+ G_u .* x[1:end-1]) .+ x[end] .* phiY .+ (1/T) * phiU_0 .<= 0)\n",
    "        @constraint(model, x[end]  .>= 0)\n",
    "    else\n",
    "        C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "        obj_weight = ones(n_steps)\n",
    "        obj_weight[1] *= 1\n",
    "        obj_weight = obj_weight .* phi_Y .* G_u\n",
    "        # @show obj_weight\n",
    "        @variable(model, x[1:n_steps])  # Decision variables x1\n",
    "        # @objective(model, Min, phi_t .+ phi_Y .* (G_t .+ G_u .* x) .+ α .* phiY .+ C * phiU_0)\n",
    "        @objective(model, Min, (x') * obj_weight )\n",
    "        @constraint(model, (x .- U_dot_nominal) .<= diff_U_dot_threshold)\n",
    "        @constraint(model, (U_dot_nominal .- x) .<= diff_U_dot_threshold)\n",
    "    end\n",
    "    # @constraint(model, x .- U_dot_nominal  .<= 20)\n",
    "    # @constraint(model, U_dot_nominal .- x   .<= 20)\n",
    "    \n",
    "    # Solve the QP using CPLEX\n",
    "    # println(\"Solving with CPLEX...\")\n",
    "    optimize!(model)\n",
    "    \n",
    "    # Get the solution\n",
    "    solution_cplex = value.(x)\n",
    "    # println(\"Solution using CPLEX: \", solution_cplex)\n",
    "    # opt_alpha && (@show solution_cplex[end])\n",
    "    # if abs(solution_cplex[1] - U_dot_nominal[1]) > diff_U_dot_threshold\n",
    "    #     return U_dot_nominal[1]\n",
    "    # else\n",
    "    #     return solution_cplex[1]\n",
    "    # end\n",
    "    return solution_cplex[1]\n",
    "end\n",
    "\n",
    "\n",
    "\n",
    "function solve_multistep_qp_with_cplex_nonominal_cbfnoT(U_dot_nominal,G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=false,diff_U_dot_threshold=Inf)\n",
    "    n_steps = size(U_dot_nominal)[1]\n",
    "    # @show phi_t .+ phi_Y .* (G_t .+ G_u .* U_dot_nominal) .+ α * phiY .+ C * phiU_0\n",
    "    # Define the model using CPLEX\n",
    "    model = Model(CPLEX.Optimizer)\n",
    "    set_silent(model)\n",
    "    if opt_alpha\n",
    "        obj_weight = ones(n_steps)\n",
    "        obj_weight[1] *= 1\n",
    "        obj_weight = obj_weight .* phi_Y .* G_u\n",
    "        @variable(model, x[1:n_steps+1])  # Decision variables x1\n",
    "        @objective(model, Min, (x[1:end-1]') * obj_weight )\n",
    "        @constraint(model, (x[1:end-1] .- U_dot_nominal) .<= diff_U_dot_threshold)\n",
    "        @constraint(model, (U_dot_nominal .- x[1:end-1]) .<= diff_U_dot_threshold)\n",
    "        # @objective(model, Min, (x[1:end-1]' .- U_dot_nominal') * (x[1:end-1] .- U_dot_nominal))\n",
    "        # @constraint(model, phi_t .+ phi_Y .* (G_t .+ G_u .* x[1:end-1]) .+ x[end] .* phiY .+ (1/T) * phiU_0 .<= 0)\n",
    "        @constraint(model, x[end]  .>= 0)\n",
    "    else\n",
    "        C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "        obj_weight = ones(n_steps)\n",
    "        obj_weight[1] *= 1\n",
    "        obj_weight = obj_weight .* phi_Y .* G_u\n",
    "        # @show obj_weight\n",
    "        @variable(model, x[1:n_steps])  # Decision variables x1\n",
    "        # @objective(model, Min, phi_t .+ phi_Y .* (G_t .+ G_u .* x) .+ α .* phiY .+ C * phiU_0)\n",
    "        @objective(model, Min, (x') * obj_weight )\n",
    "        @constraint(model, (x .- U_dot_nominal) .<= diff_U_dot_threshold)\n",
    "        @constraint(model, (U_dot_nominal .- x) .<= diff_U_dot_threshold)\n",
    "    end\n",
    "    # @constraint(model, x .- U_dot_nominal  .<= 20)\n",
    "    # @constraint(model, U_dot_nominal .- x   .<= 20)\n",
    "    \n",
    "    # Solve the QP using CPLEX\n",
    "    # println(\"Solving with CPLEX...\")\n",
    "    optimize!(model)\n",
    "    \n",
    "    # Get the solution\n",
    "    solution_cplex = value.(x)\n",
    "    # println(\"Solution using CPLEX: \", solution_cplex)\n",
    "    # opt_alpha && (@show solution_cplex[end])\n",
    "    # if abs(solution_cplex[1] - U_dot_nominal[1]) > diff_U_dot_threshold\n",
    "    #     return U_dot_nominal[1]\n",
    "    # else\n",
    "    #     return solution_cplex[1]\n",
    "    # end\n",
    "    return solution_cplex[1]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18505cab-39d4-4178-ab16-5a6fb1824788",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 5\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90050) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_allmixed50_train_nominal_100_5_hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a80a615-18d5-4383-91cd-c9c642087d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# # pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# # pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "# train_loader, test_loader = my_get_dataloader()\n",
    "# # @show size(te)\n",
    "# model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "#                               σ = gelu)\n",
    "# if isnothing(pretrained_NO)\n",
    "#     model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "#                               σ = gelu)\n",
    "# else\n",
    "#     model_NO = get_model(pretrained_NO)[:model]\n",
    "# end\n",
    "# model_CBF = Chain(\n",
    "#         Dense(2 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 64, relu),   # activation function inside layer\n",
    "#         Dense(64 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 1)\n",
    "#     )\n",
    "\n",
    "# # abs_flag=false\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# # use_central_flag = true\n",
    "\n",
    "# # # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# # use_central_flag = false\n",
    "\n",
    "# abs_flag=true\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# # \n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# # model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 20\n",
    "\n",
    "# use_central_flag = true\n",
    "\n",
    "# α = 0.00001\n",
    "\n",
    "# # for ppo\n",
    "# # max_mpc_step = 10\n",
    "# # opt_alpha = true\n",
    "# # diff_U_dot_threshold = 5\n",
    "\n",
    "# # for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 20\n",
    "\n",
    "\n",
    "# filter_result_list_sac = []\n",
    "# ind = 1\n",
    "# j_index = 0\n",
    "# tested_num = 0\n",
    "# for item in train_loader\n",
    "#     j_index += 1\n",
    "#     # if (j_index <= 1893)\n",
    "#     #     continue\n",
    "#     # elseif (j_index > 1993) \n",
    "#     #     break\n",
    "#     # end\n",
    "#     if (j_index <= 90000)\n",
    "#         continue\n",
    "#     elseif (j_index > 90050) \n",
    "#         if item[3][end, 1] == 1\n",
    "#             continue\n",
    "#         end\n",
    "#         if tested_num == 100\n",
    "#             break\n",
    "#         end\n",
    "#     end\n",
    "#     tested_num += 1\n",
    "#     @show j_index, ind\n",
    "#     x_batch = item[1]\n",
    "#     y_batch = item[2]\n",
    "#     safe_batch = item[3]\n",
    "#     # @show size(safe_batch)\n",
    "#     # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "#     t_0 = x_batch[1,1,1]\n",
    "#     U_0 = x_batch[2,1,1]\n",
    "#     Y_0 = y_batch[1,1,1]\n",
    "#     @assert U_0 == Y_0\n",
    "#     T = x_batch[1,end,1]\n",
    "#     if use_central_flag\n",
    "#         U_dot_nominal_traj = zeros(size(y_batch))\n",
    "#         U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "#         U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "#         U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "#     else\n",
    "#         U_dot_nominal_traj = zeros(size(y_batch))\n",
    "#         U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "#         U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "#     end\n",
    "#     # @show min.(U_dot_nominal_traj)\n",
    "#     # @show U_0\n",
    "#     T_traj = x_batch[1:1,:,:]\n",
    "#     U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "#     # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "#     # break\n",
    "#     for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "#         if i + max_mpc_step  > size(T_traj,2)\n",
    "#             mpc_end = size(T_traj,2)\n",
    "#         else\n",
    "#             mpc_end = i + max_mpc_step\n",
    "#         end\n",
    "#         if use_central_flag\n",
    "#             U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "#         else\n",
    "#             U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "#         end\n",
    "#         # @show U_pred_traj[1,1,1]\n",
    "#         Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "#         Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "#         # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "#         G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "#         G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "#         # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "#         Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "#         Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "#         # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "#         state_dim, batchsize = size(Yt) # 2*51000\n",
    "#         _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "#         ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "#         # @show size(∇ϕ_Y)\n",
    "#         phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "#         phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "#         phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "#         # @show size(phiY)\n",
    "#         phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "#         # @show size(phiU_0)\n",
    "#         U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "#         # try to avoid drift\n",
    "#         # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "#         # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "#         # if use_central_flag\n",
    "#         #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "#         # else\n",
    "#         #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "#         # end\n",
    "        \n",
    "#         # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "#         #     U_dot_safe_traj[i] = U_dot_safe\n",
    "#         # end\n",
    "\n",
    "#         # if i > 40\n",
    "#         #     @show i\n",
    "#         #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "#         #     U_dot_safe_traj[i] = U_dot_safe\n",
    "#         # else\n",
    "#         #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "#         # end\n",
    "#         U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "#     end\n",
    "#     if use_central_flag\n",
    "#         push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "#     else\n",
    "#         push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "#     end\n",
    "#     # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "#     # if ind > 0\n",
    "#     #     @assert 1==2\n",
    "#     # end\n",
    "#     ind += 1\n",
    "# end\n",
    "# U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "# U_safe = reduce(hcat, U_safe)\n",
    "# @show U_safe[:,1]\n",
    "# U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "# U_nominal = reduce(hcat, U_nominal)\n",
    "# @show U_nominal[:,1]\n",
    "# Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "# Y_nominal = reduce(hcat, Y_nominal)\n",
    "# @show Y_nominal[:,1]\n",
    "# safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "# safe_label = reduce(hcat, safe_label)\n",
    "# @show safe_label[:,1]\n",
    "# # @show typeof(U_safe), typeof(U_nominal)\n",
    "# @show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "# npzwrite(\"hyperbolic_sac_allmixed50_train_nominal_100_20_hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.npy\", Dict(\n",
    "# \t\"U_safe\" => U_safe,\n",
    "# \t\"U_nominal\" => U_nominal,\n",
    "#     \"Y_nominal\" => Y_nominal,\n",
    "#     \"safe_label\" => safe_label\n",
    "# ))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac91713e-0b86-4c4f-8a2d-21d2f8b4f74b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_C0_20.bson\")[:model_CBF] \n",
    "\n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        phiU_0 = ones(size(phiU_0))\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_trainC0_C0_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b1e8249-9cae-40cc-ad41-9b445b97d671",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed8faa59-364c-4db4-a911-cf46ba625de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nonominal_100_0.1_hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984ec6ea-da3f-4b75-9fa0-ccd669cefc1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "# pretrained_NO=\"hyper_MNO_all.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nonominal_100_1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e34d724e-11be-4ce2-9516-4c36518e2eb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex_nonominal(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "749d01be-369f-4210-8327-6b104e6ce5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "# max_mpc_step = 10\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 5\n",
    "\n",
    "# for sac\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 2\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 90000)\n",
    "        continue\n",
    "    elseif (j_index > 90100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_sac_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37e4f9f0-94b4-4b00-9fe5-8de7e3eef0aa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "956a8e8d-66cd-4703-86c8-f485cc8c18a1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e61a80f-f7cf-477a-b318-5fa60e98ea39",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 2\n",
    "\n",
    "# for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 45000)\n",
    "        continue\n",
    "    elseif (j_index > 45100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_ppo_all_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64cb77f6-241f-40ce-8397-b4acc416c76a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 2\n",
    "\n",
    "# for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 45000)\n",
    "        continue\n",
    "    elseif (j_index > 45100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_ppo_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b82690-0475-4877-8138-7fce4b40a7de",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# pretrained_NO=\"hyper_NO_20.bson\"\n",
    "# pretrained_NO=\"hyper_FNO_all_pf.bson\"\n",
    "pretrained_NO=\"hyper_FNO_all.bson\"\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "# @show size(te)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "\n",
    "# abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF] # too minus\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good # opt_alpha=true,diff_U_dot_threshold=5\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# use_central_flag = true\n",
    "\n",
    "# # model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] #OK\n",
    "# use_central_flag = false\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK,  only the last step falls in\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF]  # not feasible\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_1.bson\")[:model_CBF] # not working\n",
    "\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # OK not good enough\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # not working too minus\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_2.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_10.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_5.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1.bson\")[:model_CBF] # sac max_mpc_step = 1,diff_U_dot_threshold = 50\n",
    "\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.bson\")[:model_CBF] \n",
    "\n",
    "use_central_flag = true\n",
    "\n",
    "α = 0.00001\n",
    "\n",
    "# for ppo\n",
    "max_mpc_step = 1\n",
    "opt_alpha = true\n",
    "diff_U_dot_threshold = 2\n",
    "\n",
    "# for sac\n",
    "# max_mpc_step = 1\n",
    "# opt_alpha = true\n",
    "# diff_U_dot_threshold = 0.1\n",
    "\n",
    "\n",
    "filter_result_list_sac = []\n",
    "ind = 1\n",
    "j_index = 0\n",
    "tested_num = 0\n",
    "for item in train_loader\n",
    "    j_index += 1\n",
    "    # if (j_index <= 1893)\n",
    "    #     continue\n",
    "    # elseif (j_index > 1993) \n",
    "    #     break\n",
    "    # end\n",
    "    if (j_index <= 45000)\n",
    "        continue\n",
    "    elseif (j_index > 45100) \n",
    "        if item[3][end, 1] == 1\n",
    "            continue\n",
    "        end\n",
    "        if tested_num == 100\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    tested_num += 1\n",
    "    @show j_index, ind\n",
    "    x_batch = item[1]\n",
    "    y_batch = item[2]\n",
    "    safe_batch = item[3]\n",
    "    # @show size(safe_batch)\n",
    "    # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "    t_0 = x_batch[1,1,1]\n",
    "    U_0 = x_batch[2,1,1]\n",
    "    Y_0 = y_batch[1,1,1]\n",
    "    @assert U_0 == Y_0\n",
    "    T = x_batch[1,end,1]\n",
    "    if use_central_flag\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 2:end-1, :] = (x_batch[2,3:end, :] .- x_batch[2,1:end-2, :]) ./ (x_batch[1,3:end, :] .- x_batch[1,1:end-2, :])\n",
    "        U_dot_nominal_traj[1, 1, :] = (x_batch[2,2, :] .- x_batch[2,1, :]) ./ (x_batch[1,2, :] .- x_batch[1,1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    else\n",
    "        U_dot_nominal_traj = zeros(size(y_batch))\n",
    "        U_dot_nominal_traj[1, 1:end-1, :] = (x_batch[2,2:end, :] .- x_batch[2,1:end-1, :]) ./ (x_batch[1,2:end, :] .- x_batch[1,1:end-1, :])\n",
    "        U_dot_nominal_traj[1, end, :] = (x_batch[2,end, :] .- x_batch[2,end-1, :]) ./ (x_batch[1, end, :] .- x_batch[1, end-1, :])\n",
    "    end\n",
    "    # @show min.(U_dot_nominal_traj)\n",
    "    # @show U_0\n",
    "    T_traj = x_batch[1:1,:,:]\n",
    "    U_dot_safe_traj = copy(U_dot_nominal_traj)\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj) - x_batch[2:2,:,:]\n",
    "    # break\n",
    "    for i in eachindex(x_batch[1,2:end,1]) # 1,2,3,4...50\n",
    "        if i + max_mpc_step  > size(T_traj,2)\n",
    "            mpc_end = size(T_traj,2)\n",
    "        else\n",
    "            mpc_end = i + max_mpc_step\n",
    "        end\n",
    "        if use_central_flag\n",
    "            U_pred_traj = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj)\n",
    "        else\n",
    "            U_pred_traj = reconstruct_traj(U_0, T_traj, U_dot_safe_traj)\n",
    "        end\n",
    "        # @show U_pred_traj[1,1,1]\n",
    "        Ut_pred_traj = cat(T_traj, U_pred_traj, dims=1)\n",
    "        Yt_pred_traj, ∇ϕ = Zygote.pullback(model_NO, Ut_pred_traj)\n",
    "        # @show size(∇ϕ(ones(size(T_traj)))[1][2:2, :,:]), size(∇ϕ(ones(size(T_traj)))[1][1:1, :,:] )\n",
    "        G_u = ∇ϕ(ones(size(T_traj)))[1][2, i:mpc_end,1]\n",
    "        G_t = ∇ϕ(ones(size(T_traj)))[1][1, i:mpc_end,1]\n",
    "        # @show Yt_pred_traj .- model_NO(Ut_pred_traj)\n",
    "        \n",
    "        Yt = cat(T_traj, Yt_pred_traj, dims=1) # NO\n",
    "        Yt = reshape(Yt, (size(Yt)[1], size(Yt)[2]*size(Yt)[3]))\n",
    "        # extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        state_dim, batchsize = size(Yt) # 2*51000\n",
    "        _, ∇ϕ = Zygote.pullback(model_CBF, Yt)\n",
    "        ∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ state_dim\n",
    "        # @show size(∇ϕ_Y)\n",
    "        phi_t = ∇ϕ_Y[1, i:mpc_end]\n",
    "        phi_Y = ∇ϕ_Y[2, i:mpc_end]\n",
    "        \n",
    "        phiY = Yt_pred_traj[1,i:mpc_end,1]\n",
    "        # @show size(phiY)\n",
    "        phiU_0 = model_CBF([t_0,U_0])[1]\n",
    "        # @show size(phiU_0)\n",
    "        U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0;opt_alpha=opt_alpha,diff_U_dot_threshold=diff_U_dot_threshold)\n",
    "\n",
    "        # try to avoid drift\n",
    "        # U_dot_safe_traj_try = copy(U_dot_safe_traj)\n",
    "        # U_dot_safe_traj_try[i] = U_dot_safe\n",
    "        # if use_central_flag\n",
    "        #     U_pred_traj_try = reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # else\n",
    "        #     U_pred_traj_try = reconstruct_traj(U_0, T_traj, U_dot_safe_traj_try)\n",
    "        # end\n",
    "        \n",
    "        # if abs(U_pred_traj_try[1,i,1] - x_batch[2,i,1]) < 0.001\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # end\n",
    "\n",
    "        # if i > 40\n",
    "        #     @show i\n",
    "        #     U_dot_safe = solve_multistep_qp_with_cplex(U_dot_nominal_traj[1,i:mpc_end,1],G_u,G_t,phi_t,phi_Y,phiY,α,T,phiU_0)\n",
    "        #     U_dot_safe_traj[i] = U_dot_safe\n",
    "        # else\n",
    "        #     @show U_dot_safe_traj - U_dot_nominal_traj\n",
    "        # end\n",
    "        U_dot_safe_traj[i] = U_dot_safe\n",
    "        \n",
    "    end\n",
    "    if use_central_flag\n",
    "        push!(filter_result_list_sac, (reconstruct_traj_central(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    else\n",
    "        push!(filter_result_list_sac, (reconstruct_traj(U_0, T_traj, U_dot_safe_traj), x_batch, y_batch, safe_batch))\n",
    "    end\n",
    "    # @show reconstruct_traj_central(U_0, T_traj, U_dot_nominal_traj) - x_batch[2:2,:,:]\n",
    "    # if ind > 0\n",
    "    #     @assert 1==2\n",
    "    # end\n",
    "    ind += 1\n",
    "end\n",
    "U_safe = [filter_result_list_sac[i][1][1,:,1] for i in 1:100]\n",
    "U_safe = reduce(hcat, U_safe)\n",
    "@show U_safe[:,1]\n",
    "U_nominal = [filter_result_list_sac[i][2][2,:,1] for i in 1:100]\n",
    "U_nominal = reduce(hcat, U_nominal)\n",
    "@show U_nominal[:,1]\n",
    "Y_nominal = [filter_result_list_sac[i][3][1,:,1] for i in 1:100]\n",
    "Y_nominal = reduce(hcat, Y_nominal)\n",
    "@show Y_nominal[:,1]\n",
    "safe_label = [filter_result_list_sac[i][4][:,1] for i in 1:100]\n",
    "safe_label = reduce(hcat, safe_label)\n",
    "@show safe_label[:,1]\n",
    "# @show typeof(U_safe), typeof(U_nominal)\n",
    "@show size(U_safe), size(U_nominal), size(Y_nominal), size(safe_label)\n",
    "npzwrite(\"hyperbolic_ppo_all_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\", Dict(\n",
    "\t\"U_safe\" => U_safe,\n",
    "\t\"U_nominal\" => U_nominal,\n",
    "    \"Y_nominal\" => Y_nominal,\n",
    "    \"safe_label\" => safe_label\n",
    "))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701fd60e-e20e-4897-8205-976df1dc9d17",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0d1d773-1fb3-4e3d-a7d9-0e536e9f3aad",
   "metadata": {},
   "outputs": [],
   "source": [
    "filter_result_list = filter_result_list_ppo\n",
    "item = filter_result_list[1]\n",
    "# @show item[2][1:1,:,:],item[1][1:1,:,:]\n",
    "t = item[2][1,:,1]\n",
    "U = item[2][2,:,1]\n",
    "Y = item[3][1,:,1]\n",
    "U_safe = item[1][1,:,1]\n",
    "Y_NO_safe = model_NO(cat(item[2][1:1,:,:], item[1][1:1,:,:], dims=1))[1,:,1]\n",
    "Y_NO = model_NO(item[2])[1,:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "613db23c-e715-4a03-825d-0ff344dbdd36",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show item[1][1,:,1]#, U, Y\n",
    "@show U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e592ae-5c5c-42df-b2c9-f43d1dd7584c",
   "metadata": {},
   "outputs": [],
   "source": [
    "function plot_env(;abs=false)#, x, y, z;levels=[0.001,0.01, 0.1, 1,10])\n",
    "    # @show size(x), size(y), size(z)\n",
    "    plt1 = plot(Hyperrectangle(low=[0,-1], high=[5,10]))\n",
    "    plot!(plt1, Hyperrectangle(low=[0,1], high=[5,10]), fillcolor=:red)\n",
    "    # contour!(plt1,x, y, z, levels=[-1,-0.1,-0.01, 0.01, 0.1, 1], color=:turbo, clabels=true, cbar=false, lw=1)\n",
    "    if abs\n",
    "        plot!(plt1, Hyperrectangle(low=[0,-2], high=[5,-1]), fillcolor=:red)\n",
    "    end\n",
    "    return plt1\n",
    "end\n",
    "plt1 = plot_env(;abs=abs_flag)\n",
    "plot!(plt1,t,Y_NO, label=\"Y_NO\", xlabel=\"t\", ylabel=\"y\", title=\"Plot of y vs. t\", lw=2)\n",
    "plot!(plt1,t, Y_NO_safe, label=\"Y_NO_safe\", lw=2)   # Add the second line\n",
    "plot!(plt1,t, Y, label=\"Y\", lw=2)   # Add the second line\n",
    "# plot!(plt1,t, U, label=\"U\", lw=2)  # Add the third line\n",
    "# plot!(plt1,t, U_safe, label=\"U_safe\", lw=2)  # Add the third line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c65eb5b-4a00-4af0-a988-6657357dc3c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f53e9ba-fde1-45bb-9692-bcc64649cafd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "487e7c0d-54ed-41ce-99f2-54212c1c522e",
   "metadata": {},
   "outputs": [],
   "source": [
    "b=rand(2,3,4)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "                                         σ = gelu)\n",
    "loss, nabla = Zygote.pullback(model_NO, b)\n",
    "\n",
    "\n",
    "Y=model_NO(b)\n",
    "Y = vcat(Y[1,:,:]...)\n",
    "Y = reshape(Y, (1, size(Y)[1]))\n",
    "# @show find_derivative_1step(cat(b[1:1,:,:], model_NO(b), dims=1))\n",
    "# @show find_derivative(cat(b[1:1,:,:], model_NO(b), dims=1))\n",
    "@show size(Y), size(nabla(ones(size(Y)))[1])\n",
    "Y_known = zeros(size(Y))\n",
    "Y_known\n",
    "@show nabla(Y_known)[1]\n",
    "\n",
    "@show nabla(ones(size(Y)))[1][2:2, :,:] .* find_derivative_1step(b) .+ nabla(ones(size(Y)))[1][1:1, :,:] \n",
    "# batched_mul(reshape(nabla(ones(size(Y)))[1], (1, size(nabla(ones(size(Y)))[1])...)),  reshape(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1), (size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[1], 1, size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[2:end]...)))[1,:,:,:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f7b46c9-7cd0-481c-8a7b-38ca5c87b32e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ϕ = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "Yt = rand(2,3)\n",
    "@show size(ϕ(Yt))\n",
    "# state_dim, batchsize = size(Yt) # 2*51000\n",
    "_, ∇ϕ = Zygote.pullback(ϕ, Yt)\n",
    "∇ϕ_Y = ∇ϕ(ones(size(Yt)))[1] ./ 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53940da2-4b03-4a99-a8a2-dd831b4ba707",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "Y_t1 = copy(Yt)\n",
    "for i in 1:2\n",
    "    for j in 1:3\n",
    "        Y_t1[i,j] += 0.001\n",
    "        @show i, j, (ϕ(Y_t1) - ϕ(Yt)) / 0.001\n",
    "        Y_t1[i,j] -= 0.001\n",
    "    end\n",
    "end\n",
    "\n",
    "# (ϕ(Y_t1) - ϕ(Y_t)) /"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94237f15-eafa-4fb5-adf8-af4751860624",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b754e2-80d3-400a-91e4-b976148ddba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using CPLEX\n",
    "\n",
    "# Define a simple QP problem for CPLEX\n",
    "function solve_qp_with_cplex()\n",
    "    # Objective function coefficients\n",
    "    Q = [2.0 0.0; 0.0 2.0]  # Quadratic coefficients matrix\n",
    "    c = [-1.0, -1.0]        # Linear coefficients vector\n",
    "\n",
    "    # Constraints\n",
    "    A = [1.0 2.0; 2.0 1.0]\n",
    "    b = [1.0, 1.0]\n",
    "\n",
    "    # Define the model using CPLEX\n",
    "    model = Model(CPLEX.Optimizer)\n",
    "    @variable(model, x[1:1])  # Decision variables x1, x2\n",
    "    @objective(model, Min, (x' .-1) * x )\n",
    "    @constraint(model,2 * x .<= 1)\n",
    "    \n",
    "    # Solve the QP using CPLEX\n",
    "    println(\"Solving with CPLEX...\")\n",
    "    optimize!(model)\n",
    "    \n",
    "    # Get the solution\n",
    "    solution_cplex = value.(x)\n",
    "    println(\"Solution using CPLEX: \", solution_cplex)\n",
    "end\n",
    "\n",
    "solve_qp_with_cplex()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30238c8-1284-44e1-ad2d-9be72376c4b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using Gurobi\n",
    "\n",
    "# Define a simple QP problem for Gurobi\n",
    "function solve_qp_with_gurobi()\n",
    "    # Objective function coefficients\n",
    "    Q = [2.0 0.0; 0.0 2.0]  # Quadratic coefficients matrix\n",
    "    c = [-1.0, -1.0]        # Linear coefficients vector\n",
    "\n",
    "    # Constraints\n",
    "    A = [1.0 2.0; 2.0 1.0]\n",
    "    b = [1.0, 1.0]\n",
    "\n",
    "    # Define the model using Gurobi\n",
    "    model = Model(Gurobi.Optimizer)\n",
    "    set_optimizer_attribute(model, \"OutputFlag\", 0)  # Suppress Gurobi output\n",
    "    @variable(model, x[1:2] >= 0)\n",
    "    @objective(model, Min, 0.5 * x' * Q * x + c' * x)\n",
    "    @constraint(model, A * x .<= b)\n",
    "    \n",
    "    # Solve the QP using Gurobi\n",
    "    println(\"Solving with Gurobi...\")\n",
    "    optimize!(model)\n",
    "    \n",
    "    # Get the solution\n",
    "    solution_gurobi = value.(x)\n",
    "    println(\"Solution using Gurobi: \", solution_gurobi)\n",
    "end\n",
    "\n",
    "solve_qp_with_gurobi()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f9bcbf-0e9e-47d2-9ae6-bc9f13060e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "a=[2,5,6]\n",
    "for i in eachindex(a[2:end])\n",
    "    @show i\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a3aef41-5049-48f3-ab87-4776fa0fba2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "α = 0.00001\n",
    "T=51\n",
    "C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fd4ed65-01ba-4ccb-8064-ac71d7ad91f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "function plot_env(;abs=false)#, x, y, z;levels=[0.001,0.01, 0.1, 1,10])\n",
    "    # @show size(x), size(y), size(z)\n",
    "    plt1 = plot(Hyperrectangle(low=[0,-1], high=[5,10]))\n",
    "    plot!(plt1, Hyperrectangle(low=[0,1], high=[5,10]), fillcolor=:red)\n",
    "    # contour!(plt1,x, y, z, levels=[-1,-0.1,-0.01, 0.01, 0.1, 1], color=:turbo, clabels=true, cbar=false, lw=1)\n",
    "    if abs\n",
    "        plot!(plt1, Hyperrectangle(low=[0,-2], high=[5,-1]), fillcolor=:red)\n",
    "    end\n",
    "    return plt1\n",
    "end\n",
    "function Phi(model, t,y)\n",
    "    input = [t, y]\n",
    "    # input = [x,y,zeros(size(x)),zeros(size(x))]\n",
    "    input = reduce(hcat,input)'\n",
    "    return model(input) \n",
    "end\n",
    "\n",
    "function Phi_dot(model, t,y,y_t; α=0.02,T=51)\n",
    "    input = [t, y]\n",
    "    # input = [x,y,zeros(size(x)),zeros(size(x))]\n",
    "    input = reduce(hcat,input)'\n",
    "    # input = reduce(hcat,input)'\n",
    "    # @show size(input)\n",
    "    # @show size(model(input)[1])\n",
    "    # u = [zeros(size(x)), -ones(size(x))]\n",
    "    # u = [ax .* ones(size(x)), ay .* ones(size(x))]\n",
    "    # u = reduce(hcat,u)'\n",
    "\n",
    "    U_0 = copy(input)\n",
    "    U_0[2:2,:] .= input[2:2,1:1]\n",
    "    # U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "    # U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "    ∇Y_t = cat(ones(size(input[1,:])),y_t .* ones(size(input[1,:])),dims=1) \n",
    "    \n",
    "    state_dim, batchsize = size(input) # 2*51000\n",
    "    _, ∇ϕ = Zygote.pullback(model, input)\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(input)))[1] ./ state_dim\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, state_dim, batchsize))\n",
    "\n",
    "    ∇Y_t = reshape(∇Y_t, (state_dim, 1, batchsize))\n",
    "    \n",
    "    ϕ̇ = reshape(batched_mul(∇ϕ_Y, ∇Y_t), size(model(input)))\n",
    "    C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "    l = ϕ̇ .+ α .* model(input) .+ C .* model(U_0)\n",
    "    return l\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0218512-c3bb-4aad-b4d9-9eb113b5619f",
   "metadata": {},
   "outputs": [],
   "source": [
    "abs_flag=false\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pfall_unsafe_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_minus_pfall_preNO20_1.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_6.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_8.bson\")[:model_CBF] # OK\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20.bson\")[:model_CBF] # not working\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_1step_5.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_1step_1.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_10.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_1step_5.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_1step_20.bson\")[:model_CBF] # not working\n",
    "\n",
    "abs_flag=true\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # OK\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_3.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_10.bson\")[:model_CBF] # not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pf52_addend_preNO20_abs1_20.bson\")[:model_CBF] # not working\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_1.bson\")[:model_CBF] # good\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # good\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_1.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_1.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_20.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNO_pf52_addminussafe_preNO20_abs1_10.bson\")[:model_CBF] # explode, not working\n",
    "# \n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_10.bson\")[:model_CBF] # explode, not working\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addminussafe_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_abs1_5pf_20.bson\")[:model_CBF]\n",
    "\n",
    "model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_3.bson\")[:model_CBF] #OK\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20.bson\")[:model_CBF] # good\n",
    "# model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_10.bson\")[:model_CBF] # too time-dependant\n",
    "# \n",
    "# \n",
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_10.bson\")[:model_CBF]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ae6611b-1ffd-4d96-9680-bf141330510d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "t = range(0, 5, length=10)\n",
    "y = range(-2, 10, length=1000)\n",
    "\n",
    "@show Phi(model_CBF,2, 1)[1]\n",
    "# @show Phi(model,1, 1;vx=vx, vy=vy)[1]\n",
    "# @show Phi_dot(model, A, B,1, 1;α=α, vx=vx, vy=vy)[1] \n",
    "\n",
    "Phi_contour(t, y) = Phi(model_CBF,t, y)[1]\n",
    "z1 = @. Phi_contour(t',y)\n",
    "plt1 = plot_env()\n",
    "contour!(plt1,t, y, z1, levels=10, color=:turbo, clabels=true, cbar=false, lw=1)\n",
    "contour!(plt1,t, y, z1, levels=[0], color=:black, clabels=true, cbar=false, lw=1)\n",
    "# plot(plt1)\n",
    "\n",
    "# @show Phi_dot(model, A, B,2, 1;α=α, vx=vx, vy=vy)[1]\n",
    "\n",
    "plt2 = plot_env(;abs=abs_flag)\n",
    "# z .= 1 ./ (z1 .+ 1e-18)\n",
    "heatmap!(plt2,t,y, z1)\n",
    "plot(plt1, plt2, layout = (1, 2), size=(1000,500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0d5d17d-50e6-450b-a1ad-961ba1959131",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = range(0, 5, length=10)\n",
    "y = range(-2, 10, length=1000)\n",
    "y_t = -1\n",
    "# @show size(x)\n",
    "Phi_dot_contour(t, y) = Phi_dot(model_CBF,t, y, y_t; α=0.00001)[1]\n",
    "z1 = @. Phi_dot_contour(t',y)\n",
    "\n",
    "plt1 = plot_env(;abs=abs_flag)\n",
    "contour!(plt1,t, y, z1, levels=10, color=:turbo, clabels=true, cbar=false, lw=1)\n",
    "contour!(plt1,t, y, z1, levels=[0], color=:black, clabels=true, cbar=false, lw=1)\n",
    "# plot(plt1)\n",
    "\n",
    "# @show Phi_dot(model, A, B,2, 1;α=α, vx=vx, vy=vy)[1]\n",
    "\n",
    "plt2 = plot_env()\n",
    "# z .= 1 ./ (z1 .+ 1e-18)\n",
    "heatmap!(plt2,t,y, z1)\n",
    "plot(plt1, plt2, layout = (1, 2), size=(1000,500))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8cd458a-3e94-4416-b30c-00b4c541a9ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# a=[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9; 4.847742 5.3145413 3.9163132 1.4816904 0.041546613 1.2740334 4.6864047 7.2294483 5.85791 2.6027904]\n",
    "model_CBF([3, 1])# t,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7167b328-e807-465e-bcfe-5badac018bfa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b00453-3e2b-465c-beb3-dc9bfef5c5a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "lr_CBF = 0.01\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "model_CBF = Chain(\n",
    "        Dense(2 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 64, relu),   # activation function inside layer\n",
    "        Dense(64 => 16, relu),   # activation function inside layer\n",
    "        Dense(16 => 1)\n",
    "    )\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_time_CBF_preNO20_20.bson\")[:model_CBF]\n",
    "# model_CBF = get_model(\"hyper_1reg_1pf_timenograd_CBF_pfall_addend_preNO20_1.bson\")[:model_CBF]\n",
    "\n",
    "# model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "#                               σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 1\n",
    "λ_reg = 0.1\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(zeros(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        # @show size(extended_∇Y_tt), size(yt)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            # loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            loss_naive_safeset(m, ytt, y_init)  +  λ_reg .* loss_regularization(m, ytt, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init) + loss_naive_safeset_end(m, ytt, y_init)  +  λ_reg .* loss_regularization_end(m, ytt, y_init)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        # @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init) + loss_naive_safeset_end(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init)\n",
    "        @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init), loss_naive_safeset_end(model_CBF, ytt, y_init), loss_regularization_end(model_CBF, ytt, y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(U[1:1,:,:], model_NO(U), dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(zeros(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        U_0tt = reshape(U_0t, (size(U_0t)[1], size(U_0t)[2]*size(U_0t)[3]))\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        extended_∇Y_tt = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],size(extended_∇Y_t)[2], size(extended_∇Y_t)[3]*size(extended_∇Y_t)[4]))\n",
    "        \n",
    "        loss = loss_naive_safeset(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization(model_CBF, ytt, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init) + loss_naive_safeset_end(model_CBF, ytt, y_init)  +  λ_reg .* loss_regularization_end(model_CBF, ytt, y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, ytt, y_init), loss_regularization(model_CBF, ytt, y_init), loss_pf(model_CBF, x_batch, ytt, U_0tt,extended_U̇, extended_∇Y_tt,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_0.1reg_1pf_timenograd_CBF_pfall_addend_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch) ./ 45000) \n",
    "    push!(test_losses, sum(test_loss_epoch) ./ 5000)\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d9d3584-8cd7-40da-80a4-7d86208b20c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show training_losses, test_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6498963e-5f33-4669-9657-8120ec568592",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55a69db5-6650-4687-b8a3-be7a9ec5218a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3525e54b-66eb-4315-8fcd-517cfb794e87",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3668d390-0e86-46d7-a168-032979de81c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c84c0b-4e82-4e67-b826-d0e59b59fc9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_CBF = get_model(\"hyper_0.1reg_1pf_time_CBF_pfall_0addend_preNO20_1.bson\")[:model_CBF]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92b6bdc2-dcf9-4ec5-b627-4dbe00ce4f87",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_CBF([3, 2])# t,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0717b6c2-717e-4395-a4fa-bef9f35eda06",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "b=rand(2,3000,4)\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "                                         σ = gelu)\n",
    "loss, nabla = Zygote.pullback(model_NO, b)\n",
    "Y=model_NO(b)\n",
    "Y = vcat(Y[1,:,:]...)\n",
    "Y = reshape(Y, (1, size(Y)[1]))\n",
    "@show find_derivative(cat(b[1:1,:,:], model_NO(b), dims=1))[1]\n",
    "# @show nabla(ones(size(Y)))[1][2:2, :,:] .* find_derivative(b) .+ nabla(ones(size(Y)))[1][1:1, :,:] \n",
    "@show batched_mul(reshape(nabla(ones(size(Y)))[1], (1, size(nabla(ones(size(Y)))[1])...)),  reshape(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1), (size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[1], 1, size(cat(ones(size(find_derivative(b))),find_derivative(b),dims=1))[2:end]...)))[1,:,:,:][]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239e7a51-a982-4c68-8296-ae086ca6d504",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb254334-6bad-47d6-aa3e-524737e5daf5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d6b8d0a-0274-433d-98d7-cdb3adeca313",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "809c59c2-65a2-4935-ae0b-da3425616c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_CBF([1.6])\n",
    "# model_CBF([0.99])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "906899e1-3114-4743-8985-dacdb533f5bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "function find_phi_dot(ϕ, Y, ∇Y_t)\n",
    "    Y = reshape(Y, (1, size(Y)...))\n",
    "    ∇Y_t = reshape(∇Y_t, (1, size(∇Y_t)...))\n",
    "    state_dim, batchsize = size(Y)\n",
    "\n",
    "    _, ∇ϕ = Zygote.pullback(ϕ, Y)\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(Y)))[1] ./ state_dim\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, state_dim, batchsize))\n",
    "    @show ∇ϕ_Y\n",
    "    ∇Y_t = reshape(∇Y_t, (state_dim, 1, batchsize))\n",
    "    \n",
    "    ϕ̇ = reshape(batched_mul(∇ϕ_Y, ∇Y_t), size(ϕ(Y)))\n",
    "    return ϕ̇\n",
    "end\n",
    "    \n",
    "# ∇Y_t = find_derivative(cat(U[1:1,:,:], model_NO(U), dims=1)) # empirical derivative\n",
    "# "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b81f709-75f6-4ccf-833c-048498d6b09d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hcat([1,1.5,1.5, 0.9, 0.9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086225b2-dbad-45da-9a83-b84c7040e4e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "find_phi_dot(model_CBF, [0.7], [0.5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc4f9875-63e4-4ea0-8dc1-04f80533692f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aef814d5-1e69-4f04-b7be-704f5fdb0025",
   "metadata": {},
   "outputs": [],
   "source": [
    "function loss_naive_safeset_NO(ϕ, x,y_init)\n",
    "    # x = copy(x_)\n",
    "    # y_init = copy(y_init_)\n",
    "    # x = vcat(x[1,:,:]...)\n",
    "    # x = reshape(x, (1, size(x)[1]))\n",
    "    # # @show size(x), size(y_init)\n",
    "    # y_init = vcat(y_init...)\n",
    "    # # y_init = y_init[1, :] # safe: 1; unsafe: 0\n",
    "    # # @show size(x), size(y_init)\n",
    "    # @show size(y_init), size(x)\n",
    "    # @show x[:, 1:10]\n",
    "    # @show y_init[1:10]\n",
    "    # index = findall(x->x>=0, y_init)\n",
    "    # # @show index\n",
    "    # size(index)[1] == 0 && return 0\n",
    "    # x = x[:, index]\n",
    "    # y_init = y_init[index]\n",
    "    pred = ϕ(x) \n",
    "    pred = reshape(pred, (size(y_init)...)) # NO\n",
    "    loss = relu((2 .* y_init .- 1) .* pred .+ 1e-6)\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function loss_regularization_NO(ϕ, x::AbstractArray,y_init::AbstractArray)\n",
    "    # x = copy(x_)\n",
    "    # y_init = copy(y_init_)\n",
    "    # # @show size(x)\n",
    "    # x = vcat(x[1,:,:]...)\n",
    "    # x = reshape(x, (1, size(x)[1]))\n",
    "    # y_init = vcat(y_init...)\n",
    "    # y_init = y_init[1, :] # safe: 1; unsafe: 0\n",
    "    # index = findall(x->x>=0, y_init)\n",
    "    # # @show index\n",
    "    # size(index)[1] == 0 && return 0\n",
    "    # x = x[:, index]\n",
    "    # y_init = y_init[index]\n",
    "\n",
    "    pred = ϕ(x)\n",
    "    pred = reshape(pred, (size(y_init)...)) # NO\n",
    "    loss = sigmoid_fast((2 .* y_init .- 1) .* pred)\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function loss_pf_NO(ϕ, yt, extended_∇Y_t, U_0t,T, α,y_init)\n",
    "    # U = copy(U_)\n",
    "    # Y = copy(Y_)\n",
    "    # Y = vcat(Y[1,:,:]...)\n",
    "    # Y = reshape(Y, (1, size(Y)[1]))\n",
    "    # y_init = vcat(y_init[1,:,:]...)\n",
    "    \n",
    "    # @show size(ϕ(x)), size(x)\n",
    "    # index = findall(x1->[1 -1] * softmax(ϕ(x1))>ϵ, x)\n",
    "    # ∇Y_t = reshape(∇Y_t, size(Y))\n",
    "    pred = ϕ(yt)\n",
    "    _, ∇ϕ = Zygote.pullback(ϕ, yt)\n",
    "    # dG\\du * du\\dt + dG\\dt\n",
    "    # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(pred)))[1]\n",
    "    return sum(∇ϕ_Y)\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, size(∇ϕ_Y)...))\n",
    "    # extended_∇Y_t = reshape(extended_∇Y_t, (NO_input_dim, size(pred)...))\n",
    "    @show size(∇ϕ_Y), size(extended_∇Y_t)\n",
    "    ϕ̇ = batched_mul(∇ϕ_Y,  extended_∇Y_t)\n",
    "    ϕ̇ = reshape(ϕ̇, (size(y_init)...))\n",
    "    @show size(ϕ̇)\n",
    "    # pred = ϕ(yt)\n",
    "    pred = reshape(pred, (size(y_init)...))\n",
    "\n",
    "    pred0 = ϕ(U_0t)\n",
    "    pred0 = reshape(pred0, (size(y_init)...))\n",
    "    \n",
    "    C = (α * ℯ^(-α*T)) / (1-ℯ^(-α*T))\n",
    "    # @show C, C .* ϕ(U_0)\n",
    "    l = ϕ̇ .+ α .* pred .+ C .* pred0 # \n",
    "    loss = relu(l .+ 1e-6)\n",
    "    # @show loss\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b88604e-07ab-4139-91a9-44ee3e3858ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function my_train(; )\n",
    "cuda = true\n",
    "η₀ = 1.0f-3\n",
    "λ = 1.0f-4\n",
    "total_epoch = 20\n",
    "pretrained_NO=\"hyper_NO_20.bson\"\n",
    "if cuda && CUDA.has_cuda()\n",
    "    device = gpu\n",
    "    CUDA.allowscalar(false)\n",
    "    @info \"Training on GPU\"\n",
    "else\n",
    "    device = cpu\n",
    "    @info \"Training on CPU\"\n",
    "end\n",
    "@show 1\n",
    "lr_NO = η₀\n",
    "lr_CBF = 0.001 # NO CBF\n",
    "\n",
    "lr_decay_rate = 0.2\n",
    "lr_decay_epoch =4\n",
    "\n",
    "train_loader, test_loader = my_get_dataloader()\n",
    "model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "if isnothing(pretrained_NO)\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                              σ = gelu)\n",
    "else\n",
    "    model_NO = get_model(pretrained_NO)[:model_NO]\n",
    "end\n",
    "# model_CBF = Chain(\n",
    "#         Dense(2 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 64, relu),   # activation function inside layer\n",
    "#         Dense(64 => 16, relu),   # activation function inside layer\n",
    "#         Dense(16 => 1)\n",
    "#     )\n",
    "model_CBF = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (2,), \n",
    "                              σ = gelu)\n",
    "# optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "# optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "optim_CBF = Flux.setup(Flux.Optimise.AdamW(lr_CBF, (0.9, 0.999), 0), model_CBF) # NO\n",
    "sched_CBF = ParameterSchedulers.Stateful(Step(lr_CBF, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "\n",
    "loss_func = l₂loss\n",
    "α = 0.00001\n",
    "λ_pf = 0.1\n",
    "λ_reg = 1\n",
    "training_losses = []\n",
    "test_losses = []\n",
    "no_training_losses = []\n",
    "no_test_losses = []\n",
    "least_loss = 1000\n",
    "test_loss = 0\n",
    "loss = 0\n",
    "for epoch in ProgressBar(1:total_epoch)\n",
    "    training_loss_epoch = []\n",
    "    test_loss_epoch = []\n",
    "    no_training_loss_epoch = []\n",
    "    no_test_loss_epoch = []\n",
    "    for item in train_loader\n",
    "        # x_batch = reduce(hcat,item[1,:])\n",
    "        # u_batch = reduce(hcat,item[2,:])\n",
    "        # y_init_batch = reduce(hcat,item[3,:])\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "            push!(no_training_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "            # @show l₂loss(model_NO(x_batch), y_batch)\n",
    "        end\n",
    "\n",
    "        \n",
    "        # train CBF\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(x_batch[1:1,:,:], model_NO(x_batch), dims=1)) # empirical derivative\n",
    "\n",
    "        # @show size(x_batch), size(∇Y_t[1,:,:,:]), size(x_batch[1:1,:,:]), size(reshape(U_0, size(y_batch)))\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        ytt = reshape(yt, (size(yt)[1], size(yt)[2]*size(yt)[3]))\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "        extended_∇Y_t = reshape(extended_∇Y_t, (size(extended_∇Y_t)[1],1, size(extended_∇Y_t)[2:end]...))\n",
    "        # @show size(extended_∇Y_t)\n",
    "        CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "            loss_naive_safeset_NO(m, yt, y_init)  +  λ_reg .* loss_regularization_NO(m, yt, y_init) + λ_pf .* loss_pf_NO(m, yt, extended_∇Y_t,U_0t,T, α,y_init) # NO\n",
    "            # loss_naive_safeset(m, x, y_init)  +  λ_reg .* loss_regularization(m, x, y_init) + λ_pf .* loss_pf(m, x_batch, ytt, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "            # sum(m(rand(1,12)))\n",
    "        end\n",
    "        \n",
    "        Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "        @show CBF_training_loss\n",
    "        # loss = CBF_training_loss\n",
    "        # @show loss, loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init)  # CBF_training_loss\n",
    "\n",
    "        # loss = loss_naive_safeset_NO(model_CBF, yt, y_init)  +  λ_reg .* loss_regularization_NO(model_CBF, yt, y_init) + λ_pf .* loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "        # @show loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init), loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(training_loss_epoch, loss)  # logging, outside gradient context\n",
    "        \n",
    "        # @show training_loss\n",
    "    end\n",
    "    for item in test_loader\n",
    "        x_batch = item[1]\n",
    "        y_batch = item[2]\n",
    "        safe_batch = item[3]\n",
    "        # @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "        if isnothing(pretrained_NO)\n",
    "            push!(no_test_loss_epoch, l₂loss(model_NO(x_batch), y_batch))\n",
    "        end\n",
    "\n",
    "        x = copy(y_batch)\n",
    "        y_init = copy(safe_batch)\n",
    "        x = vcat(x[1,:,:]...)\n",
    "        x = reshape(x, (1, size(x)[1]))\n",
    "        y_init = vcat(y_init...)\n",
    "\n",
    "        U_0 = copy(x_batch)\n",
    "        U_0[2:2,:,:] .= x_batch[2:2,1:1,:]\n",
    "        U_0 = vcat(U_0[2:2,:,:][1,:,:]...)\n",
    "        U_0 = reshape(U_0, (1, size(U_0)[1]))\n",
    "        U̇ = find_derivative(x_batch)\n",
    "        extended_U̇ = cat(ones(size(U̇)),U̇,dims=1)\n",
    "        T = x_batch[1,end,1]\n",
    "        _, ∇ϕ = Zygote.pullback(model_NO, x_batch)\n",
    "        # dG\\du * du\\dt + dG\\dt\n",
    "        # ∇Y_t = ∇ϕ(ones(size(Y)))[1][2:2, :,:] .* U̇ .+ ∇ϕ(ones(size(Y)))[1][1:1, :,:] # rewrite it below\n",
    "        ∇Y_t = batched_mul(reshape(∇ϕ(ones(size(y_batch)))[1], (1, size(∇ϕ(ones(size(y_batch)))[1])...)),  reshape(extended_U̇, (size(extended_U̇)[1], 1, size(extended_U̇)[2:end]...)))\n",
    "        # ∇Y_t = find_derivative(cat(U[1:1,:,:], model_NO(U), dims=1)) # empirical derivative\n",
    "\n",
    "        yt = cat(x_batch[1:1,:,:], y_batch, dims=1) # NO\n",
    "        extended_∇Y_t = cat(ones(size(∇Y_t[1,:,:,:])),∇Y_t[1,:,:,:],dims=1) # NO\n",
    "        U_0t = cat(x_batch[1:1,:,:], reshape(U_0, size(y_batch)), dims=1) # NO\n",
    "\n",
    "        loss = loss_naive_safeset_NO(model_CBF, yt, y_init)  +  λ_reg .* loss_regularization_NO(model_CBF, yt, y_init) + λ_pf .* loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "        @show loss_naive_safeset_NO(model_CBF, yt, y_init), loss_regularization_NO(model_CBF, yt, y_init), loss_pf_NO(model_CBF, yt, extended_∇Y_t,U_0t,T, α,y_init)\n",
    "\n",
    "        # loss = loss_naive_safeset(model_CBF, x, y_init)  +  λ_reg .* loss_regularization(model_CBF, x, y_init) + λ_pf .* loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "        # @show loss_naive_safeset(model_CBF, x, y_init)  , loss_regularization(model_CBF, x, y_init) , loss_pf(model_CBF, x_batch, x, U_0,extended_U̇, ∇Y_t,T, α,y_init)\n",
    "#         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "#         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "        push!(test_loss_epoch, loss)  # logging, outside gradient context\n",
    "    end\n",
    "    # nextlr = ParameterSchedulers.next!(sched_CBF) # advance schedule\n",
    "    # Optimisers.adjust!(optim_CBF, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "\n",
    "    # @show epoch, loss, test_loss\n",
    "    # model_state = Flux.state(model)\n",
    "    # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "    if isnothing(pretrained_NO)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "    end\n",
    "    @save \"model/hyper_1reg_1pf_NOCBF_preNO20_$epoch.bson\" model_CBF\n",
    "    push!(training_losses, sum(training_loss_epoch)) \n",
    "    push!(test_losses, sum(test_loss_epoch))\n",
    "\n",
    "end\n",
    "# return training_losses, test_losses\n",
    "\n",
    "\n",
    "# learner = Learner(model, data, optimiser, loss_func,\n",
    "#                   ToDevice(device, device))\n",
    "\n",
    "# fit!(learner, epochs)\n",
    "# model = learner.model |> cpu\n",
    "# @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "# return learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5180ba38-433e-4e8b-be7f-d71a63afad92",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "i = 1\n",
    "p1 = plot(input_data[1, :, 1], ground_truth[1, :, i], label = \"ground_truth\",\n",
    "     title = \"                                              Burgers equation u(x,T_end)\");\n",
    "# p1 = plot!(input_data[1, :, 1], m(view(input_data, :, :, i:i))[1, :, 1], label = \"predict\");\n",
    "\n",
    "p2 = plot(input_data[1, :, 1], ground_truth[1, :, i + 1], label = \"ground_truth\");\n",
    "# p2 = plot!(input_data[1, :, 1], m(view(input_data, :, :, (i + 1):(i + 1)))[1, :, 1],\n",
    "#            label = \"predict\");\n",
    "i = 3\n",
    "\n",
    "p3 = plot(input_data[1, :, 1], ground_truth[1, :, i], label = \"ground_truth\");\n",
    "# p3 = plot!(input_data[1, :, 1], m(view(input_data, :, :, i:i))[1, :, 1], label = \"predict\");\n",
    "\n",
    "p4 = plot(input_data[1, :, 1], ground_truth[1, :, i + 1], label = \"ground_truth\");\n",
    "# p4 = plot!(input_data[1, :, 1], m(view(input_data, :, :, (i + 1):(i + 1)))[1, :, 1],\n",
    "#            label = \"predict\");\n",
    "p = plot(p1, p2, p3, p4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd29d7d7-e1e6-4be3-90a0-36115a91fc69",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Burgers\n",
    "using FluxTraining\n",
    "using Test\n",
    "\n",
    "@testset \"Burgers\" begin\n",
    "    xs, ys = Burgers.get_data(n = 1000)\n",
    "\n",
    "    @test size(xs) == (2, 1024, 1000)\n",
    "    @test size(ys) == (1, 1024, 1000)\n",
    "\n",
    "    learner = Burgers.train(epochs = 100)\n",
    "    loss = learner.cbstate.metricsepoch[ValidationPhase()][:Loss].values[end]\n",
    "    @test loss < 0.1\n",
    "\n",
    "    # include(\"deeponet.jl\")\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "928686e9-3a7c-45ee-803e-37d7e80eea7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998eae74-ea3f-4a18-8daf-6995b7cfe5e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "@testset \"Burger: NOMAD Training Accuracy\" begin\n",
    "    ϵ = Burgers.train_nomad(; cuda = true, epochs = 100)\n",
    "    @test ϵ < 0.4 # epoch=100 returns 0.233\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9217960-00f6-4f0a-95c3-b071e04129fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "Burgers.train_don()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c6dce6-f0bd-4999-914a-ac3c394c0334",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.9.4",
   "language": "julia",
   "name": "julia-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
