{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6da8ab6d-b602-496c-968e-948d3d5dbcdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import pde_control_gym\n",
    "import numpy as np\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import stable_baselines3\n",
    "import time\n",
    "from utils import set_size\n",
    "from utils import linestyle_tuple\n",
    "from utils import load_csv\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3 import SAC\n",
    "from stable_baselines3.common.env_checker import check_env\n",
    "from stable_baselines3.common.callbacks import CheckpointCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "307a8313-126e-479f-b94c-69dbf2188e70",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Print Versioning\n",
    "print(\"Gym version\", gym.__version__)\n",
    "print(\"Numpy version\", np.__version__)\n",
    "print(\"Stable Baselines3 version\", stable_baselines3.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec9efd4f-b0e3-4c70-9c15-6e677066c310",
   "metadata": {},
   "source": [
    "This Jupyter-notebook has an adjoining tutorial at https://pdecontrolgym.readthedocs.io/en/latest/guide/tutorials.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32177259-8979-4957-baf5-e8f15c9c563e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NO NOISE\n",
    "def noiseFunc(state):\n",
    "    return state\n",
    "\n",
    "# Chebyshev Polynomial Beta Functions\n",
    "def solveBetaFunction(x, gamma):\n",
    "    beta = np.zeros(len(x), dtype=np.float32)\n",
    "    for idx, val in enumerate(x):\n",
    "        beta[idx] = 5*math.cos(gamma*math.acos(val))\n",
    "    return beta\n",
    "\n",
    "# Kernel function solver for backstepping\n",
    "def solveKernelFunction(theta):\n",
    "    kappa = np.zeros(len(theta))\n",
    "    for i in range(0, len(theta)):\n",
    "        kernelIntegral = 0\n",
    "        for j in range(0, i):\n",
    "            kernelIntegral += (kappa[i-j]*theta[j])*dx\n",
    "        kappa[i] = kernelIntegral  - theta[i]\n",
    "    return np.flip(kappa)\n",
    "\n",
    "# Control convolution solver\n",
    "def solveControl(kernel, u):\n",
    "    res = 0\n",
    "    for i in range(len(u)):\n",
    "        res += kernel[i]*u[i]\n",
    "    return res*1e-2\n",
    "\n",
    "# Set initial condition function here\n",
    "def getInitialCondition(nx):\n",
    "    return np.ones(nx)*np.random.uniform(1, 10)\n",
    "\n",
    "# Returns beta functions passed into PDE environment. Currently gamma is always\n",
    "# set to 7.35, but this can be modified for further problesms\n",
    "def getBetaFunction(nx):\n",
    "    return solveBetaFunction(np.linspace(0, 1, nx), 7.35)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da1f129a-e88d-48d1-96ba-da0c85abda38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Timestep and spatial step for PDE Solver\n",
    "T = 5\n",
    "dt = 1e-4 \n",
    "dx = 1e-2\n",
    "X = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac9fc8b-2b8f-4e1b-a46b-64222b995b65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Backstepping does not need to normalize actions to be between -1 and 1, so normalize is set to False. Otherwise, \n",
    "# parameters are same as RL algorithms\n",
    "from pde_control_gym.src import TunedReward1D,NormReward\n",
    "reward_class =  TunedReward1D(int(round(T/dt)), -1e3, 3e2) # with penalize\n",
    "# reward_class =  TunedReward1D(int(round(T/dt)), -1e-4, 1e2) # no penalize\n",
    "hyperbolicParameters = {\n",
    "        \"T\": T, \n",
    "        \"dt\": dt, \n",
    "        \"X\": X,\n",
    "        \"dx\": dx, \n",
    "        \"reward_class\": reward_class,\n",
    "        \"normalize\":None, \n",
    "        \"sensing_loc\": \"full\", \n",
    "        \"control_type\": \"Dirchilet\", \n",
    "        \"sensing_type\": None,\n",
    "        \"sensing_noise_func\": lambda state: state,\n",
    "        \"limit_pde_state_size\": True,\n",
    "        \"max_state_value\": 1e10,\n",
    "        \"max_control_value\": 20,\n",
    "        \"reset_init_condition_func\": getInitialCondition,\n",
    "        \"reset_recirculation_func\": getBetaFunction,\n",
    "        \"control_sample_rate\": 0.1\n",
    "}\n",
    "\n",
    "hyperbolicParametersBackstepping = hyperbolicParameters.copy()\n",
    "hyperbolicParametersBackstepping[\"normalize\"] = False\n",
    "\n",
    "hyperbolicParametersRL = hyperbolicParameters.copy()\n",
    "hyperbolicParametersRL[\"normalize\"] = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8f53c44-2173-4da4-bbad-387f37f640e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make environments\n",
    "envRL = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersRL)\n",
    "envBcks = gym.make(\"PDEControlGym-TransportPDE1D\",**hyperbolicParametersBackstepping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf23d021-501d-4ca7-a507-8c1c2d93aa1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save a checkpoint every 10000 steps\n",
    "# checkpoint_callback = CheckpointCallback(\n",
    "#   save_freq=10000,\n",
    "#   save_path=\"./logsPPO\",\n",
    "#   name_prefix=\"rl_model\",\n",
    "#   save_replay_buffer=True,\n",
    "#   save_vecnormalize=True,\n",
    "# )\n",
    "\n",
    "checkpoint_callbackPPO = CheckpointCallback(\n",
    "    save_freq=10000,\n",
    "    save_path=\"./logsPPO_high_reso\", # dt = 1e-5, \"control_sample_rate\": 0.001, \n",
    "    name_prefix=\"rl_model\",\n",
    "    save_replay_buffer=False,\n",
    "    save_vecnormalize=False,\n",
    " )\n",
    "\n",
    "# Save a checkpoint every 10000 steps\n",
    "checkpoint_callbackSAC = CheckpointCallback(\n",
    "    save_freq=10000,\n",
    "    save_path=\"./logsSAC_high_reso\",# # dt = 1e-5, \"control_sample_rate\": 0.001, \n",
    "    name_prefix=\"rl_model\",\n",
    "    save_replay_buffer=False,\n",
    "    save_vecnormalize=False,\n",
    " )\n",
    "\n",
    "\n",
    "# TRAINING. SKIP IF WANT TO USE PRELOADED MODELS\n",
    "# Train PPO\n",
    "model = PPO(\"MlpPolicy\",envRL, verbose=1, tensorboard_log=\"./tb/\")\n",
    "# Train for 500k timesteps\n",
    "model.learn(total_timesteps=1e6, callback=checkpoint_callbackPPO)\n",
    "\n",
    "# Train SAC\n",
    "model = SAC(\"MlpPolicy\",envRL, verbose=1, tensorboard_log=\"./tb/\")\n",
    "# Train for 500k timesteps\n",
    "model.learn(total_timesteps=5e5, callback=checkpoint_callbackSAC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cad8d344-d209-49c7-a629-dcb72a0d1815",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize Rewards\n",
    "\n",
    "# In TensorBoard, save the avg rewards plot as a csv and then put their paths here\n",
    "# Set your tensorboard avg_rew paths. WILL NEED UPDATING FOR USE\n",
    "filenamesPPO = [\"PPOData/test1.csv\", \"PPOData/test2.csv\", \"PPOData/test3.csv\", \"PPOData/test4.csv\", \"PPOData/test5.csv\"]\n",
    "filenamesSAC = [\"SACData/SAC_18.csv\", \"SACData/SAC_19.csv\", \"SACData/SAC_20.csv\", \"SACData/SAC_21.csv\", \"SACData/SAC_23.csv\"]\n",
    "\n",
    "timePPOArr = []\n",
    "rewardPPOArr = []\n",
    "for name in filenamesPPO:\n",
    "    times, rewards = load_csv(name)\n",
    "    timePPOArr.append(times)\n",
    "    rewardPPOArr.append(rewards)\n",
    "\n",
    "timeSACArr = []\n",
    "rewardSACArr = []\n",
    "for name in filenamesSAC:\n",
    "    times, rewards = load_csv(name)\n",
    "    timeSACArr.append(times)\n",
    "    rewardSACArr.append(rewards)\n",
    "\n",
    "# takes max amount of timesteps all data has\n",
    "maxTimestep = np.inf\n",
    "for data in timePPOArr:\n",
    "    maxTimestep = min(maxTimestep, data[-1])\n",
    "for data in timeSACArr:\n",
    "    maxTimestep = min(maxTimestep, data[-1])\n",
    "print(maxTimestep)\n",
    "\n",
    "# remove data after minTimestep\n",
    "maxDataSeqPPO = []\n",
    "for data in timePPOArr:\n",
    "    for i in range(len(data)):\n",
    "        if data[i] >= maxTimestep:\n",
    "            maxDataSeqPPO.append(i)\n",
    "            break\n",
    "maxDataSeqSAC = []\n",
    "for data in timeSACArr:\n",
    "    for i in range(len(data)):\n",
    "        if data[i] >= maxTimestep:\n",
    "            maxDataSeqSAC.append(i)\n",
    "            break\n",
    "\n",
    "# Get mean and std of each value at time step \n",
    "rewardArrCleanPPO = []\n",
    "for i, data in enumerate(rewardPPOArr):\n",
    "    rewardArrCleanPPO.append(data[:min(maxDataSeqPPO)])\n",
    "rewardArrPPO = np.array(rewardArrCleanPPO)\n",
    "meanArrPPO = rewardArrPPO.mean(axis=0)\n",
    "stdArrPPO = rewardArrPPO.std(axis=0)\n",
    "\n",
    "rewardArrCleanSAC = []\n",
    "for i, data in enumerate(rewardSACArr):\n",
    "    rewardArrCleanSAC.append(data[:min(maxDataSeqSAC)])\n",
    "rewardArrSAC = np.array(rewardArrCleanSAC)\n",
    "meanArrSAC = rewardArrSAC.mean(axis=0)\n",
    "stdArrSAC = rewardArrSAC.std(axis=0)\n",
    "\n",
    "# Set size according to latex textwidth\n",
    "fig = plt.figure(figsize=set_size(432, 0.99, (1, 1), height_add=0))\n",
    "ax = fig.subplots(ncols=1)\n",
    "t = timePPOArr[0]\n",
    "x = t[:maxDataSeqPPO[0]]\n",
    "mean = meanArrPPO\n",
    "std = stdArrPPO\n",
    "# 95 confidence interval\n",
    "cis = (mean - 2*std, mean + 2*std)\n",
    "ax.plot(x, mean, label=\"PPO\")\n",
    "ax.fill_between(x, cis[0], cis[1], alpha=0.2)\n",
    "\n",
    "t = timeSACArr[0]\n",
    "x = t[:min(maxDataSeqSAC)]\n",
    "mean = meanArrSAC\n",
    "std = stdArrSAC\n",
    "# 95 confidence interval\n",
    "cis = (mean - 2*std, mean + 2*std)\n",
    "ax.plot(x, mean, label=\"SAC\")\n",
    "ax.fill_between(x, cis[0], cis[1], alpha=0.2)\n",
    "\n",
    "plt.legend()\n",
    "plt.title(\"Training Reward for Hyperbolic PDE\")\n",
    "plt.xlabel(\"Episode Number\")\n",
    "plt.ylabel(\"Average Reward\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "648bd8f0-7f37-404d-bb6c-f857f0fd3802",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ppoModel = PPO.load(ppoModelPath)\n",
    "sacModel = SAC.load(sacModelPath)\n",
    "\n",
    "# For backstepping controller\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "beta = solveBetaFunction(spatial, 7.35)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c93c80f-723e-4fc0-98a5-deeb0db023ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Runs a single epsiode calculation\n",
    "# Parameter varies. For SAC and PPO it is the model itself\n",
    "# For backstepping it is the beta function\n",
    "def runSingleEpisode(model, env, parameter):\n",
    "    terminate = False\n",
    "    truncate = False\n",
    "\n",
    "    # Holds the resulting states\n",
    "    uStorage = []\n",
    "\n",
    "    # Reset Environment\n",
    "    obs,__ = env.reset()\n",
    "    uStorage.append(obs)\n",
    "\n",
    "    i = 0\n",
    "    rew = 0\n",
    "    while not truncate and not terminate:\n",
    "        # use backstepping controller\n",
    "        action = model(obs, parameter)\n",
    "        \n",
    "        obs, rewards, terminate, truncate, info = env.step(action)\n",
    "        # print(action, obs)\n",
    "        uStorage.append(obs)\n",
    "        rew += rewards \n",
    "    u = np.array(uStorage)\n",
    "    return rew, u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49278e04-7771-4a5c-9d75-0ca82a512b78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Controllers\n",
    "def bcksController(obs, beta):\n",
    "    kernel = solveKernelFunction(beta)\n",
    "    return solveControl(kernel, obs)\n",
    "\n",
    "def RLController(obs, model):\n",
    "    action, _state = model.predict(obs)\n",
    "    return action\n",
    "\n",
    "def openLoopController(_, _a):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47c16b5c-42bb-409b-ace7-46bf5e4731f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run comparisons\n",
    "\n",
    "\n",
    "hyperbolicParametersRL[\"reward_class\"] = TunedReward1D(int(round(T/dt)), -1e3, 3e2)\n",
    "hyperbolicParametersBackstepping[\"reward_class\"] = TunedReward1D(int(round(T/dt)), -1e3, 3e2)\n",
    "# hyperbolicParametersRL[\"reward_class\"] = TunedReward1D(int(round(T/dt)), -1e-4, 3e2) # test, no penalize\n",
    "# hyperbolicParametersBackstepping[\"reward_class\"] = TunedReward1D(int(round(T/dt)), -1e-4, 3e2) # test, no penalize\n",
    "\n",
    "envRL = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersRL)\n",
    "envBcks = gym.make(\"PDEControlGym-TransportPDE1D\",**hyperbolicParametersBackstepping)\n",
    "\n",
    "num_instances = 50\n",
    "# Backstepping. Controller is slow so this will take some time.\n",
    "total_bcks_reward = 0\n",
    "for i in range(num_instances):\n",
    "    rew, _ = runSingleEpisode(bcksController, envBcks, beta)\n",
    "    total_bcks_reward += rew\n",
    "print(\"Backstepping Reward Average:\", total_bcks_reward/num_instances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a6fba2-974b-4867-9f35-e51ff86f33c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PPO\n",
    "total_ppo_reward = 0\n",
    "for i in range(num_instances):\n",
    "    rew, _ = runSingleEpisode(RLController, envRL, ppoModel)\n",
    "    total_ppo_reward += rew\n",
    "print(\"PPO Reward Average:\", total_ppo_reward/num_instances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9512075b-3361-4a79-8096-da1a1fa8d153",
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAC\n",
    "total_sac_reward = 0\n",
    "for i in range(num_instances):\n",
    "    rew, _ = runSingleEpisode(RLController, envRL, sacModel)\n",
    "    print(rew)\n",
    "    total_sac_reward += rew\n",
    "print(\"SAC Reward Average:\", total_sac_reward/num_instances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b8970f-92cc-4b57-8e29-c7e30b790c95",
   "metadata": {},
   "outputs": [],
   "source": [
    "1 + np.random.rand() * 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02a9473d-b1c4-49f8-a805-e697760e0e7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# collect dataset\n",
    "import scipy\n",
    "# from tqdm import tqdm\n",
    "from tqdm import trange, tqdm\n",
    "def getInitialConditionRandom(nx):\n",
    "    return np.ones(nx) * (1 + np.random.rand() * 9)\n",
    "\n",
    "\n",
    "\n",
    "hyperbolicParametersBacksteppingRandom = hyperbolicParametersBackstepping.copy()\n",
    "hyperbolicParametersBacksteppingRandom[\"reset_init_condition_func\"] = getInitialConditionRandom\n",
    "\n",
    "\n",
    "\n",
    "hyperbolicParametersRLRandom = hyperbolicParametersRL.copy()\n",
    "hyperbolicParametersRLRandom[\"reset_init_condition_func\"] = getInitialConditionRandom\n",
    "\n",
    "\n",
    "# Make environments\n",
    "envBcksRandom = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingRandom)\n",
    "\n",
    "\n",
    "envRLRandom = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersRLRandom)\n",
    "\n",
    "xs_bcks = []\n",
    "ys_bcks = []\n",
    "xs_ppo = []\n",
    "ys_ppo = []\n",
    "xs_sac = []\n",
    "ys_sac = []\n",
    "for i in range(50000):\n",
    "    rewBcksRandom, uBcksRandom = runSingleEpisode(bcksController, envBcksRandom, beta)\n",
    "    xs_bcks.append((uBcksRandom.transpose())[-1])\n",
    "    ys_bcks.append((uBcksRandom.transpose())[0])\n",
    "\n",
    "    rewPPORandom, uPPORandom = runSingleEpisode(RLController, envRLRandom, ppoModel)\n",
    "    xs_ppo.append((uPPORandom.transpose())[-1])\n",
    "    ys_ppo.append((uPPORandom.transpose())[0])\n",
    "    \n",
    "\n",
    "    rewSACRandom, uSACRandom = runSingleEpisode(RLController, envRLRandom, sacModel)\n",
    "    xs_sac.append((uSACRandom.transpose())[-1])\n",
    "    ys_sac.append((uSACRandom.transpose())[0])\n",
    "    print(rewPPORandom,rewSACRandom)\n",
    "    print((uSACRandom)[-1])\n",
    "    if i % 1000 == 0: print(i)\n",
    "    break\n",
    "    \n",
    "# data_bcks = {\"a\": np.stack(xs_bcks), \"u\": np.stack(ys_bcks)}\n",
    "# scipy.io.savemat(\"data_bcks_hyperbolic.mat\", data_bcks)\n",
    "\n",
    "# data_ppo = {\"a\": np.stack(xs_ppo), \"u\": np.stack(ys_ppo)}\n",
    "# scipy.io.savemat(\"data_ppo_hyperbolic.mat\", data_ppo)\n",
    "\n",
    "# data_sac = {\"a\": np.stack(xs_sac), \"u\": np.stack(ys_sac)}\n",
    "# scipy.io.savemat(\"data_sac_hyperbolic.mat\", data_sac)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ee5be3-bbf2-48d8-a559-605e11cd91f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(uPPORandom.shape) #input is (uPPORandom.transpose())[-1] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12b0b47e-dae4-4af7-8895-2654b7e8632c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.stack([(uPPORandom.transpose())[-1],(uPPORandom.transpose())[-1]]).shape\n",
    "np.stack(xs_bcks).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91612c05-8e83-4a5a-8d6a-ba3189f516d2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24841a76-17b3-45e5-9df9-65cf572ff4d9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ca9f7b8-75f8-4011-9c2f-210daffc1b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "RL_1000 = np.load(\"../../../verify-pde-control/hyperbolic_sac_unsafe_nonominal_100_0.5.npy\")\n",
    "RL_reward_beforeQP = []\n",
    "RL_reward_afterQP = []\n",
    "uBcks_beforeQP,uBcks_afterQP = 0,0\n",
    "for i in range(RL_1000[\"safe_label\"].transpose().shape[0]):\n",
    "    # if i < 1:continue\n",
    "    # if RL_1000[\"Y_nominal\"][-1, i] < 0.5: continue\n",
    "    U_list = RL_1000[\"U_nominal\"][:, i]\n",
    "    print(U_list[0])\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "    RL_reward_beforeQP.append(reward_beforeQP)\n",
    "\n",
    "    U_safe_list = RL_1000[\"U_safe\"][:, i]\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)\n",
    "    RL_reward_afterQP.append(reward_afterQP)\n",
    "    # print(uBcks_beforeQP,uBcks_afterQP)\n",
    "    # break\n",
    "\n",
    "result = np.array([RL_reward_beforeQP, RL_reward_afterQP])\n",
    "# print(result)\n",
    "print(np.mean(result, axis=1))\n",
    "print(np.std(result, axis=1))\n",
    "\n",
    "# hyperbolic_ppo_all_nonominal_100_0.5.npy\n",
    "# [157.89960567 182.75855243]\n",
    "# [37.46392797 57.35514341]\n",
    "\n",
    "# hyperbolic_sac_all_nonominal_100_2.npy\n",
    "# [107.87008005  -3.27868209]\n",
    "# [96.64765878  3.04253213]\n",
    "\n",
    "# hyperbolic_sac_all_nonominal_100_0.5.npy\n",
    "# 107.87008005  -3.25188291]\n",
    "# [96.64765878 25.38824463]\n",
    "\n",
    "# hyperbolic_sac_all_nonominal_100_0.1.npy\n",
    "# [107.87008005 100.87794536]\n",
    "# [96.64765878 94.12486543]\n",
    "\n",
    "# unsafe only\n",
    "# # ppo with nominal obj 0.5\n",
    "# [149.11256748 149.90658671]\n",
    "# [43.71928074 44.00617739]\n",
    "\n",
    "# # ppo nonominal obj 0.5\n",
    "# [149.11256748 183.00950318]\n",
    "# [43.71928074 41.67842278]\n",
    "\n",
    "# # ppo with nominal obj 2\n",
    "# [149.11256748 159.30184598]\n",
    "# [43.71928074 44.87083828]\n",
    "\n",
    "# # ppo nonominal obj 2\n",
    "# [149.11256748  -2.25891204]\n",
    "# [43.71928074  1.80812215]\n",
    "\n",
    "\n",
    "# # sac with nominal obj 2\n",
    "# [-10.0837676   -9.41883304]\n",
    "# [19.90388617 18.72091389]\n",
    "\n",
    "# sac nonominal 2\n",
    "# [-10.0837676   -5.79452949]\n",
    "# [19.90388617  3.11867209]\n",
    "\n",
    "# hyperbolic_sac_unsafe_nonominal_100_0.5.npy\n",
    "# [-10.0837676   -8.10151297]\n",
    "# [19.90388617 16.06321464]\n",
    "\n",
    "\n",
    "# sac with nominal obj 5\n",
    "# [-10.0837676   -8.82507658]\n",
    "# [19.90388617 17.92041767]\n",
    "\n",
    "# sac nonominal 5\n",
    "# [-10.0837676   -9.92833236]\n",
    "# [19.90388617  4.9081061 ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1faca857-ada9-4c66-ad89-4f9f6c641107",
   "metadata": {},
   "outputs": [],
   "source": [
    "qp_works = 0\n",
    "print(result.shape)\n",
    "for i in range(result.shape[1]):\n",
    "    if (result[0, i] - result[1, i]) < 0.01:\n",
    "        qp_works += 1\n",
    "        print(\"working\", result[:, i])\n",
    "    else:\n",
    "        # pass\n",
    "        print(\"not working\",result[:, i])\n",
    "print(qp_works)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e37fcec6-60d2-428e-b137-c67d841e752b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74cdaba5-9c29-4ae2-8a7f-8c5783a271d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# condition = result[:, :, :, 0] < 1  # This will give a boolean array of shape (10, 8, 16)\n",
    "# print(condition)\n",
    "# Step 2: Find the earliest index in the third dimension where all subsequent values are True\n",
    "def find_earliest_true(condition):\n",
    "    # Iterate over the first two dimensions (10 and 8) and check for each slice\n",
    "    earliest_indices = np.full(condition.shape[:2], 0)  # Initialize with -1 (indicating no valid index)\n",
    "\n",
    "    for i in range(condition.shape[0]):  # Iterate over first dimension\n",
    "        for j in range(condition.shape[1]):  # Iterate over second dimension\n",
    "            # For each slice (i, j), find the earliest index where the condition is True\n",
    "            # and all subsequent values are also True\n",
    "            for k in range(condition.shape[2]):\n",
    "                if not condition[i, j, condition.shape[2]-k-1]: \n",
    "                    # print(k)\n",
    "                    if k == 0:\n",
    "                        earliest_indices[i,j] = -1\n",
    "                    else:\n",
    "                        earliest_indices[i,j] = condition.shape[2]-k\n",
    "                    break\n",
    "            # valid_indices = np.where(np.cumprod(condition[i, j, :]) == 1)[0]\n",
    "            # if len(valid_indices) > 0:\n",
    "            #     earliest_indices[i, j] = valid_indices[0]  # Store the first valid index\n",
    "    return earliest_indices\n",
    "reward_class_no_penalty =  TunedReward1D(int(round(T/dt)), -1e-4, 3e2) # no penalize\n",
    "# Apply the function\n",
    "\n",
    "RL_1000 = np.load(\"../../../verify-pde-control/hyperbolic_ppo_all_nominal_100_2__trainCO_C0.npy\")\n",
    "RL_reward_beforeQP = []\n",
    "RL_reward_afterQP = []\n",
    "uBcks_beforeQP_list = []\n",
    "uBcks_afterQP_list = []\n",
    "# uBcks_beforeQP,uBcks_afterQP = 0,0\n",
    "for i in range(RL_1000[\"safe_label\"].transpose().shape[0]):\n",
    "    # if i < 1:continue\n",
    "    # if RL_1000[\"Y_nominal\"][-1, i] < 0.5: continue\n",
    "    U_list = RL_1000[\"U_nominal\"][:, i]\n",
    "    # print(U_list[0])\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "    uBcks_beforeQP_list.append(uBcks_beforeQP)\n",
    "    RL_reward_beforeQP.append(reward_beforeQP)\n",
    "\n",
    "    U_safe_list = RL_1000[\"U_safe\"][:, i]\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)\n",
    "    # print(uBcks_afterQP.shape) (51, 100)\n",
    "    uBcks_afterQP_list.append(uBcks_afterQP)\n",
    "    RL_reward_afterQP.append(reward_afterQP)\n",
    "    # print(uBcks_beforeQP,uBcks_afterQP)\n",
    "    # break\n",
    "\n",
    "result = np.array([uBcks_beforeQP_list, uBcks_afterQP_list]) #(2,100,51, 100) # first 100 is num of samples, second 100 is num of 100 spatial steps\n",
    "\n",
    "condition = result[:, :,:, 0] < 1\n",
    "earliest_index = find_earliest_true(condition)\n",
    "valid_earliest_index_beforeQP = earliest_index[0,earliest_index[0,:]>=0]\n",
    "valid_earliest_index_afterQP = earliest_index[1,earliest_index[1,:]>=0]\n",
    "# result.shape[2] - earliest_index\n",
    "print(f\"beforeQP PF steps among {valid_earliest_index_beforeQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_beforeQP), np.std(result.shape[2] - valid_earliest_index_beforeQP))\n",
    "print(f\"afterQP PF steps among {valid_earliest_index_afterQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_afterQP), np.std(result.shape[2] - valid_earliest_index_afterQP))\n",
    "\n",
    "\n",
    "# output_boundary_endtime_diff = result[:, :,-1, 0] - 1\n",
    "# output_boundary_endtime_diff = np.where(output_boundary_endtime_diff < 0, 0, output_boundary_endtime_diff)\n",
    "# # print(\"< 1 distance: beforeQP and afterQP\")\n",
    "# # print(np.mean(output_boundary_endtime_diff, axis=1))\n",
    "# # print(np.std(output_boundary_endtime_diff, axis=1))\n",
    "# print(\"times less than 1: beforeQP and afterQP\",np.sum(output_boundary_endtime_diff == 0, axis=1))\n",
    "\n",
    "reward_result = np.array([RL_reward_beforeQP,RL_reward_afterQP])\n",
    "print(\"reward: beforeQP and afterQP\")\n",
    "print(np.mean(reward_result, axis=1))\n",
    "print(np.std(reward_result, axis=1))\n",
    "\n",
    "# < 1, data is test data\n",
    "# hyperbolic_ppo_all_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 34 PF trajectories 30.441176470588236 6.47252644767999\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 182.75855243]\n",
    "# [37.46392797 57.35514341]\n",
    "\n",
    "# hyperbolic_ppo_all_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 25 PF trajectories 38.88 3.1153811965793206\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567  -2.42227559]\n",
    "# [37.46392797  1.73217119]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_0.5.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 68 PF trajectories 7.1911764705882355 8.169630549896324\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 158.59629352]\n",
    "# [37.46392797 37.75723808]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_2.npy\n",
    "# use this \n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 71 PF trajectories 9.80281690140845 8.759091402988757\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 165.04356112]\n",
    "# [37.46392797 43.73399149]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_5.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 73 PF trajectories 12.602739726027398 9.332135855054242\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 127.17568828]\n",
    "# [37.46392797 82.66816544]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_10.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 57 PF trajectories 13.736842105263158 10.694360402271178\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567  28.61075924]\n",
    "# [37.46392797 64.03013391]\n",
    "\n",
    "# train loader\n",
    "# hyperbolic_ppo_all_train_nominal_100_2.npy\n",
    "\n",
    "# beforeQP PF steps among 61 PF trajectories 9.01639344262295 10.424579723067016\n",
    "# afterQP PF steps among 65 PF trajectories 10.923076923076923 10.822068183268557\n",
    "# reward: beforeQP and afterQP\n",
    "# [161.40440344 167.97460973]\n",
    "# [33.12636148 40.60239363]\n",
    "\n",
    "# result is < 1, train loader\n",
    "# hyperbolic_ppo_all_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 61 PF trajectories 9.01639344262295 10.424579723067016\n",
    "# afterQP PF steps among 50 PF trajectories 9.84 11.614404849151763\n",
    "# reward: beforeQP and afterQP\n",
    "# [161.40440344 130.4425012 ]\n",
    "# [33.12636148 67.97618437]\n",
    "\n",
    "# hyperbolic_ppo_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 61 PF trajectories 9.01639344262295 10.424579723067016\n",
    "# afterQP PF steps among 57 PF trajectories 9.789473684210526 11.279132793131613\n",
    "# reward: beforeQP and afterQP\n",
    "# [161.40440344 154.958119  ]\n",
    "# [33.12636148 52.51863401]\n",
    "\n",
    "# hyperbolic_ppo_all_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 61 PF trajectories 9.01639344262295 10.424579723067016\n",
    "# afterQP PF steps among 63 PF trajectories 9.444444444444445 10.955578093283107\n",
    "# reward: beforeQP and afterQP\n",
    "# [161.40440344 159.99623399]\n",
    "# [33.12636148 42.52320282]\n",
    "\n",
    "# hyperbolic_ppo_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 61 PF trajectories 9.01639344262295 10.424579723067016\n",
    "# afterQP PF steps among 54 PF trajectories 10.722222222222221 10.492354653405474\n",
    "# reward: beforeQP and afterQP\n",
    "# [161.40440344 152.72139094]\n",
    "# [33.12636148 53.27888331]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_0.5_abl_noT.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 65 PF trajectories 7.492307692307692 8.212791293366852\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 158.45019986]\n",
    "# [37.46392797 37.81559006]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_2_abl_noT.npy todo\n",
    "# used for ablation\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 63 PF trajectories 8.492063492063492 8.983005146179435\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 162.26152911]\n",
    "# [37.46392797 44.527945  ]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_5_abl_noT.npy todo\n",
    "# used for ablation\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 67 PF trajectories 11.014925373134329 10.994561121216352\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 114.39988034]\n",
    "# [37.46392797 83.24960517]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_10_abl_noT.npy todo\n",
    "# used for ablation\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 57 PF trajectories 11.298245614035087 10.526432747888245\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567  27.276813  ]\n",
    "# [37.46392797 57.61620396]\n",
    "\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_2_MNO.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 78 PF trajectories 8.974358974358974 9.149715364739649\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 163.81594818]\n",
    "# [37.46392797 47.23151087]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_2_abl_noT_MNO.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 68 PF trajectories 8.691176470588236 8.494388031381758\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 162.92331239]\n",
    "# [37.46392797 45.22144822]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_2_C0.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 72 PF trajectories 10.01388888888889 8.669059053700304\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 167.80181346]\n",
    "# [37.46392797 38.90309391]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_5_C0.npy\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 72 PF trajectories 12.555555555555555 9.624749475655758\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 135.1272597 ]\n",
    "# [37.46392797 80.25424976]\n",
    "\n",
    "# hyperbolic_ppo_all_nominal_100_2__trainCO_C0\n",
    "# beforeQP PF steps among 63 PF trajectories 7.555555555555555 8.303332239448183\n",
    "# afterQP PF steps among 71 PF trajectories 8.056338028169014 8.206639915497595\n",
    "# reward: beforeQP and afterQP\n",
    "# [157.89960567 163.80849354]\n",
    "# [37.46392797 40.58306125]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d2a811c-2eb0-4e53-afb0-ea6d6f3740c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "(0.00001)/(math.exp(50*0.00001)-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3d2e0cf-da74-4e20-b94f-8bba1d7bef62",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "RL_1000 = np.load(\"../../../verify-pde-control/hyperbolic_sac_all_train_nonominal_100_2.npy\")\n",
    "RL_reward_beforeQP = []\n",
    "RL_reward_afterQP = []\n",
    "uBcks_beforeQP_list = []\n",
    "uBcks_afterQP_list = []\n",
    "# uBcks_beforeQP,uBcks_afterQP = 0,0\n",
    "for i in range(RL_1000[\"safe_label\"].transpose().shape[0]):\n",
    "    # if i < 1:continue\n",
    "    # if RL_1000[\"Y_nominal\"][-1, i] < 0.5: continue\n",
    "    U_list = RL_1000[\"U_nominal\"][:, i]\n",
    "    # print(U_list[0])\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "    \n",
    "    uBcks_beforeQP_list.append(uBcks_beforeQP)\n",
    "    RL_reward_beforeQP.append(reward_beforeQP)\n",
    "\n",
    "    U_safe_list = RL_1000[\"U_safe\"][:, i]\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)\n",
    "    # print(uBcks_afterQP.shape) (51, 100)\n",
    "    uBcks_afterQP_list.append(uBcks_afterQP)\n",
    "    RL_reward_afterQP.append(reward_afterQP)\n",
    "    # print(uBcks_beforeQP,uBcks_afterQP)\n",
    "    # break\n",
    "    # if reward_beforeQP < 50: break\n",
    "\n",
    "result = np.array([uBcks_beforeQP_list, uBcks_afterQP_list]) #(2,100,51, 100) # first 100 is num of samples, second 100 is num of 100 spatial steps\n",
    "\n",
    "condition = result[:, :,:, 0] < 1\n",
    "earliest_index = find_earliest_true(condition)\n",
    "valid_earliest_index_beforeQP = earliest_index[0,earliest_index[0,:]>=0]\n",
    "valid_earliest_index_afterQP = earliest_index[1,earliest_index[1,:]>=0]\n",
    "# result.shape[2] - earliest_index\n",
    "print(f\"beforeQP PF steps among {valid_earliest_index_beforeQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_beforeQP), np.std(result.shape[2] - valid_earliest_index_beforeQP))\n",
    "print(f\"afterQP PF steps among {valid_earliest_index_afterQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_afterQP), np.std(result.shape[2] - valid_earliest_index_afterQP))\n",
    "\n",
    "\n",
    "# output_boundary_endtime_diff = result[:, :,-1, 0] - 1\n",
    "# output_boundary_endtime_diff = np.where(output_boundary_endtime_diff < 0, 0, output_boundary_endtime_diff)\n",
    "# # print(\"< 1 distance: beforeQP and afterQP\")\n",
    "# # print(np.mean(output_boundary_endtime_diff, axis=1))\n",
    "# # print(np.std(output_boundary_endtime_diff, axis=1))\n",
    "# print(\"times less than 1: beforeQP and afterQP\",np.sum(output_boundary_endtime_diff == 0, axis=1))\n",
    "\n",
    "reward_result = np.array([RL_reward_beforeQP,RL_reward_afterQP])\n",
    "print(\"reward: beforeQP and afterQP\")\n",
    "print(np.mean(reward_result, axis=1))\n",
    "print(np.std(reward_result, axis=1))\n",
    "\n",
    "# hyperbolic_sac_all_nominal_100_2.npy\n",
    "# beforeQP PF steps among 88 PF trajectories 27.25 13.329947486768281\n",
    "# afterQP PF steps among 86 PF trajectories 26.53488372093023 13.408343180532945\n",
    "# reward: beforeQP and afterQP\n",
    "# [107.87008005 102.95409101]\n",
    "# [96.64765878 95.17185426]\n",
    "\n",
    "# hyperbolic_sac_all_nominal_100_5.npy\n",
    "# beforeQP PF steps among 88 PF trajectories 27.25 13.329947486768281\n",
    "# afterQP PF steps among 81 PF trajectories 23.320987654320987 14.200353986627468\n",
    "# reward: beforeQP and afterQP\n",
    "# [107.87008005  71.02155631]\n",
    "# [96.64765878 87.07771603]\n",
    "\n",
    "# hyperbolic_sac_all_nominal_100_10.npy\n",
    "# beforeQP PF steps among 88 PF trajectories 27.25 13.329947486768281\n",
    "# afterQP PF steps among 63 PF trajectories 17.19047619047619 14.685979862955941\n",
    "# reward: beforeQP and afterQP\n",
    "# [107.87008005  17.02524672]\n",
    "# [96.64765878 60.38958789]\n",
    "\n",
    "# hyperbolic_sac_all_nonominal_100_0.1.npy\n",
    "# beforeQP PF steps among 88 PF trajectories 27.25 13.329947486768281\n",
    "# afterQP PF steps among 88 PF trajectories 27.397727272727273 13.382094357019206\n",
    "# reward: beforeQP and afterQP\n",
    "# [107.87008005 100.87794536]\n",
    "# [96.64765878 94.12486543]\n",
    "# beforeQP PF steps among 88 PF trajectories 25.03409090909091 13.704220371540131\n",
    "# afterQP PF steps among 86 PF trajectories 26.186046511627907 12.981661150653826\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  99.81501708]\n",
    "# [98.67610304 95.30119322]\n",
    "\n",
    "# hyperbolic_sac_all_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 88 PF trajectories 27.25 13.329947486768281\n",
    "# afterQP PF steps among 87 PF trajectories 28.057471264367816 12.929833916170343\n",
    "# reward: beforeQP and afterQP\n",
    "# [107.87008005  -3.25188291]\n",
    "# [96.64765878 25.38824463]\n",
    "\n",
    "# hyperbolic_sac_all_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 88 PF trajectories 27.25 13.329947486768281\n",
    "# afterQP PF steps among 65 PF trajectories 27.107692307692307 13.62364807906636\n",
    "# reward: beforeQP and afterQP\n",
    "# [107.87008005  -3.27868209]\n",
    "# [96.64765878  3.04253213]\n",
    "\n",
    "# hyperbolic_sac_all_train_nominal_100_2.npy\n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 89 PF trajectories 25.60674157303371 14.34300360119881\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 101.37464236]\n",
    "# [98.67610304 97.23563785]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1.npy\n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 91 PF trajectories 26.604395604395606 14.04222761896551\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  99.81501708]\n",
    "# [98.67610304 95.30119322]\n",
    "# <0 usethiswrong\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 78 PF trajectories 17.26923076923077 11.788383109288606\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  99.81501708]\n",
    "# [98.67610304 95.30119322]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 87 PF trajectories 28.482758620689655 13.152388943410529\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403   4.8252939 ]\n",
    "# [98.67610304 31.54242982]\n",
    "# < 0.5\n",
    "# # beforeQP PF steps among 88 PF trajectories 25.03409090909091 13.704220371540131\n",
    "# afterQP PF steps among 81 PF trajectories 27.17283950617284 13.001225130739321\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403   4.8252939 ]\n",
    "# [98.67610304 31.54242982]\n",
    "# <0\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 70 PF trajectories 22.5 11.973482605920708\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403   4.8252939 ]\n",
    "# [98.67610304 31.54242982]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 62 PF trajectories 28.258064516129032 13.494246663971122\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  -2.89114265]\n",
    "# [98.67610304  2.58230291]\n",
    "# <0.5\n",
    "# beforeQP PF steps among 88 PF trajectories 25.03409090909091 13.704220371540131\n",
    "# afterQP PF steps among 53 PF trajectories 24.169811320754718 14.39960147385717\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  -2.89114265]\n",
    "# [98.67610304  2.58230291]\n",
    "# <0\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 50 PF trajectories 19.28 13.815990735376163\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  -2.89114265]\n",
    "# [98.67610304  2.58230291]\n",
    "\n",
    "# hyperbolic_sac_unsafe_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 0 PF trajectories nan nan\n",
    "# afterQP PF steps among 13 PF trajectories 6.769230769230769 6.191232587033503 # mixed to be tested\n",
    "# reward: beforeQP and afterQP\n",
    "# [-10.0837676   -5.79452949]\n",
    "# [19.90388617  3.11867209] \n",
    "# < 2\n",
    "# beforeQP PF steps among 26 PF trajectories 3.0384615384615383 2.8213576823220468\n",
    "# afterQP PF steps among 19 PF trajectories 6.7894736842105265 5.680798196032343\n",
    "# reward: beforeQP and afterQP\n",
    "# [-10.0837676   -5.79452949]\n",
    "# [19.90388617  3.11867209]\n",
    "\n",
    "\n",
    "# hyperbolic_sac_unsafe_nonominal_100_5.npy\n",
    "# beforeQP PF steps among 0 PF trajectories nan nan\n",
    "# afterQP PF steps among 14 PF trajectories 8.071428571428571 6.2044256752082765\n",
    "# reward: beforeQP and afterQP\n",
    "# [-10.0837676   -9.92833236]\n",
    "# [19.90388617  4.9081061 ]\n",
    "\n",
    "# hyperbolic_sac_unsafe_nominal_100_2.npy\n",
    "# beforeQP PF steps among 0 PF trajectories nan nan\n",
    "# afterQP PF steps among 3 PF trajectories 1.3333333333333333 0.4714045207910317\n",
    "# reward: beforeQP and afterQP\n",
    "# [-10.0837676   -9.41883304]\n",
    "# [19.90388617 18.72091389]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 40 PF trajectories 24.525 14.956582998800227\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 -4.57934999]\n",
    "# [95.86539638  3.11861912]\n",
    "\n",
    "# hyperbolic_sac_allmixed30_train_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 30 PF trajectories 29.133333333333333 13.674144295794974\n",
    "# afterQP PF steps among 33 PF trajectories 21.848484848484848 15.6478698562344\n",
    "# reward: beforeQP and afterQP\n",
    "# [37.13475063 -4.79385091]\n",
    "# [84.80599846  3.30738471]\n",
    "\n",
    "# hyperbolic_sac_allmixed10_train_nonominal_100_2.npy\n",
    "# beforeQP PF steps among 10 PF trajectories 28.6 13.200000000000001\n",
    "# afterQP PF steps among 21 PF trajectories 12.80952380952381 12.57593714246926\n",
    "# reward: beforeQP and afterQP\n",
    "# [ 5.18676352 -5.26608422]\n",
    "# [49.98382512  3.26901378]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 48 PF trajectories 28.125 13.67726976897558\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 -0.4700795 ]\n",
    "# [95.86539638 24.43963561]\n",
    "\n",
    "# hyperbolic_sac_allmixed30_train_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 30 PF trajectories 29.133333333333333 13.674144295794974\n",
    "# afterQP PF steps among 32 PF trajectories 28.25 14.1178079034955\n",
    "# reward: beforeQP and afterQP\n",
    "# [37.13475063 -0.93608058]\n",
    "# [84.80599846 24.48733125]\n",
    "\n",
    "# hyperbolic_sac_allmixed10_train_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 10 PF trajectories 28.6 13.200000000000001\n",
    "# afterQP PF steps among 15 PF trajectories 20.866666666666667 15.978596795574871\n",
    "# reward: beforeQP and afterQP\n",
    "# [ 5.18676352 -5.82718389]\n",
    "# [49.98382512 10.41191905]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nonominal_100_0.5.npy\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 49 PF trajectories 26.73469387755102 14.467099211665296\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 52.98968431]\n",
    "# [95.86539638 91.92103097]\n",
    "\n",
    "# hyperbolic_sac_allmixed30_train_nonominal_100_0.1.npy\n",
    "# beforeQP PF steps among 30 PF trajectories 29.133333333333333 13.674144295794974\n",
    "# afterQP PF steps among 32 PF trajectories 27.75 14.517231140957975\n",
    "# reward: beforeQP and afterQP\n",
    "# [37.13475063 34.88808109]\n",
    "# [84.80599846 81.29617585]\n",
    "\n",
    "# hyperbolic_sac_allmixed10_train_nonominal_100_0.1.npy\n",
    "# beforeQP PF steps among 10 PF trajectories 28.6 13.200000000000001\n",
    "# afterQP PF steps among 12 PF trajectories 24.416666666666668 15.386456887651411\n",
    "# reward: beforeQP and afterQP\n",
    "# [5.18676352 4.39070588]\n",
    "# [49.98382512 47.70829633]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2.npy\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 47 PF trajectories 24.78723404255319 15.088302035249898\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 52.50388052]\n",
    "# [95.86539638 92.86883621]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2_pfallend.npy\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.93067345]\n",
    "# [95.86539638 95.46002411]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2_pf52_addend_1reg_abs.npy\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 49 PF trajectories 25.775510204081634 15.079740844038579\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.35356744]\n",
    "# [95.86539638 94.73617795]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_5hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 40 PF trajectories 22.325 14.687388297447574\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 37.17831453]\n",
    "# [95.86539638 79.26894931]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_5hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 46 PF trajectories 22.17391304347826 15.385346427301881\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 42.02858353]\n",
    "# [95.86539638 83.66237159]\n",
    "\n",
    "\n",
    "#hyperbolic_sac_allmixed50_train_nominal_100_5hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 46 PF trajectories 23.58695652173913 15.309241436006413\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 38.78003051]\n",
    "# [95.86539638 80.49237041]\n",
    "\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_5hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 45 PF trajectories 23.733333333333334 14.803002698401729\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 43.58442153]\n",
    "# [95.86539638 84.81109471]\n",
    "\n",
    "#hyperbolic_sac_allmixed50_train_nominal_100_5_hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 46 PF trajectories 24.065217391304348 15.04399659825783\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 49.00764289]\n",
    "# [95.86539638 87.94930882]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_5_pf52_addend_1reg_abs\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 46 PF trajectories 24.065217391304348 15.04399659825783\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 49.00764289]\n",
    "# [95.86539638 87.94930882]\n",
    "\n",
    "#hyperbolic_sac_allmixed50_train_nominal_100_2hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_5pf_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 48 PF trajectories 27.0 14.26826315054966\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.84979727]\n",
    "# [95.86539638 95.18905366]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_5pf_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 47 PF trajectories 26.78723404255319 14.380743877739391\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.84661233]\n",
    "# [95.86539638 94.43460522]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2hyper_1reg_1pf_time_CBFnoNOfixed_pfall_addend_preNO20_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 48 PF trajectories 26.895833333333332 14.556269530305109\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.22448778]\n",
    "# [95.86539638 94.6270296 ]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 50 PF trajectories 26.26 14.50077239322099\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.32415433]\n",
    "# [95.86539638 94.61859919]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_20_hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 44 PF trajectories 7.0227272727272725 11.850371371939893\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 -4.09719812]\n",
    "# [95.86539638 31.94383816\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOfixed_pf52_addend_bcks_1\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 48 PF trajectories 27.041666666666668 14.193540686601855\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.64651084]\n",
    "# [95.86539638 95.07150836]\n",
    "\n",
    "# hyperbolic_sac_allmixed50_train_nominal_100_2_pf52_addend_1reg_abs\n",
    "# beforeQP PF steps among 47 PF trajectories 27.48936170212766 14.014995525956325\n",
    "# afterQP PF steps among 49 PF trajectories 25.775510204081634 15.079740844038579\n",
    "# reward: beforeQP and afterQP\n",
    "# [56.34765177 55.35356744]\n",
    "# [95.86539638 94.73617795]\n",
    "\n",
    "# hyperbolic_sac_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.npy\n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 88 PF trajectories 26.363636363636363 13.970083147229149\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 102.29131553]\n",
    "# [98.67610304 97.68625281]\n",
    "\n",
    "# hyperbolic_sac_unsafe_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOfixed_pf52_addend_preNO20_20.npy\n",
    "# beforeQP PF steps among 0 PF trajectories nan nan\n",
    "# afterQP PF steps among 4 PF trajectories 5.25 4.380353866983808\n",
    "# reward: beforeQP and afterQP\n",
    "# [-8.58994355 -8.60569665]\n",
    "# [15.71175902 15.79449384]\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1497a3e-0580-4655-bcb7-7e5716d5638f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reward_class_no_penalty =  TunedReward1D(int(round(T/dt)), 1e-4, 3e2) # no penalize\n",
    "# hyperbolicParametersBackstepping[\"normalize\"] = False\n",
    "RL_1000 = np.load(\"../../../verify-pde-control/hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_trainC0_C0_20.npy\")\n",
    "RL_reward_beforeQP = []\n",
    "RL_reward_afterQP = []\n",
    "uBcks_beforeQP_list = []\n",
    "uBcks_afterQP_list = []\n",
    "# uBcks_beforeQP,uBcks_afterQP = 0,0\n",
    "for i in range(RL_1000[\"safe_label\"].transpose().shape[0]):\n",
    "    # if i < 1:continue\n",
    "    # if RL_1000[\"Y_nominal\"][-1, i] < 0.5: continue\n",
    "    U_list = RL_1000[\"U_nominal\"][:, i]\n",
    "    # print(U_list[0])\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "    \n",
    "    uBcks_beforeQP_list.append(uBcks_beforeQP)\n",
    "    RL_reward_beforeQP.append(reward_beforeQP)\n",
    "\n",
    "    U_safe_list = RL_1000[\"U_safe\"][:, i]\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)\n",
    "    # print(uBcks_afterQP.shape) (51, 100)\n",
    "    uBcks_afterQP_list.append(uBcks_afterQP)\n",
    "    RL_reward_afterQP.append(reward_afterQP)\n",
    "    # print(uBcks_beforeQP,uBcks_afterQP)\n",
    "    # break\n",
    "    # if reward_beforeQP < 50: break\n",
    "\n",
    "result = np.array([uBcks_beforeQP_list, uBcks_afterQP_list]) #(2,100,51, 100) # first 100 is num of samples, second 100 is num of 100 spatial steps\n",
    "\n",
    "condition = result[:, :,:, 0] < 0\n",
    "earliest_index = find_earliest_true(condition)\n",
    "valid_earliest_index_beforeQP = earliest_index[0,earliest_index[0,:]>=0]\n",
    "valid_earliest_index_afterQP = earliest_index[1,earliest_index[1,:]>=0]\n",
    "# result.shape[2] - earliest_index\n",
    "print(f\"beforeQP PF steps among {valid_earliest_index_beforeQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_beforeQP), np.std(result.shape[2] - valid_earliest_index_beforeQP))\n",
    "print(f\"afterQP PF steps among {valid_earliest_index_afterQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_afterQP), np.std(result.shape[2] - valid_earliest_index_afterQP))\n",
    "\n",
    "\n",
    "# output_boundary_endtime_diff = result[:, :,-1, 0] - 1\n",
    "# output_boundary_endtime_diff = np.where(output_boundary_endtime_diff < 0, 0, output_boundary_endtime_diff)\n",
    "# # print(\"< 1 distance: beforeQP and afterQP\")\n",
    "# # print(np.mean(output_boundary_endtime_diff, axis=1))\n",
    "# # print(np.std(output_boundary_endtime_diff, axis=1))\n",
    "# print(\"times less than 1: beforeQP and afterQP\",np.sum(output_boundary_endtime_diff == 0, axis=1))\n",
    "\n",
    "reward_result = np.array([RL_reward_beforeQP,RL_reward_afterQP])\n",
    "print(\"reward: beforeQP and afterQP\")\n",
    "print(np.mean(reward_result, axis=1))\n",
    "print(np.std(reward_result, axis=1))\n",
    "\n",
    "# result is < 0\n",
    "# hyperbolic_sac_all_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 75 PF trajectories 11.773333333333333 10.116090593153508\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 102.99656446]\n",
    "# [98.67610304 97.87447163]\n",
    "\n",
    "# hyperbolic_sac_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 75 PF trajectories 11.613333333333333 10.675696178183832\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 105.00601158]\n",
    "# [98.67610304 97.82105918]\n",
    "\n",
    "# hyperbolic_sac_all_train_nominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 69 PF trajectories 11.173913043478262 10.472817155896706\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 104.61862378]\n",
    "# [98.67610304 97.77209434]\n",
    "\n",
    "# hyperbolic_sac_all_train_nominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 75 PF trajectories 11.32 10.254312913761378\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 104.5205977 ]\n",
    "# [98.67610304 97.95271223]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 92 PF trajectories 11.945652173913043 16.46033713165416\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  -1.469558  ]\n",
    "# [98.67610304 25.69986247]\n",
    "# < 1: \n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 94 PF trajectories 12.936170212765957 17.15995369645981\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  -1.469558  ]\n",
    "# [98.67610304 25.69986247]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_2_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 73 PF trajectories 8.58904109589041 12.813558333370738\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403  -4.50263998]\n",
    "# [98.67610304  2.60080025]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_2_hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 20 PF trajectories 8.4 11.425410277097274\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403   0.63630712]\n",
    "# [98.67610304 23.87476031]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 83 PF trajectories 10.891566265060241 11.271044090775769\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 105.31074014]\n",
    "# [98.67610304 97.71464102]\n",
    "\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\n",
    "# use this\n",
    "# <0\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 85 PF trajectories 13.941176470588236 12.039496707456593\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.3820142 ]\n",
    "# [98.67610304 96.37758918]\n",
    "# <1:\n",
    "# beforeQP PF steps among 90 PF trajectories 26.68888888888889 13.99098475165303\n",
    "# afterQP PF steps among 91 PF trajectories 26.483516483516482 14.14552516309798\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.3820142 ]\n",
    "# [98.67610304 96.37758918]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 72 PF trajectories 15.541666666666666 11.685624478154535\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 102.9337309 ]\n",
    "# [98.67610304 95.57772234]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pfall_addsafe__le1safe_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 62 PF trajectories 18.14516129032258 12.917663293877318\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 101.02695975]\n",
    "# [98.67610304 94.61860176]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20_abl_noT.npy todo\n",
    "# used for ablation\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 57 PF trajectories 15.68421052631579 11.179803053120205\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.29112165]\n",
    "# [98.67610304 98.37415634]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20_MNO.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 84 PF trajectories 14.666666666666666 12.3030516281563\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.3005425 ]\n",
    "# [98.67610304 96.36969509]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20_abl_noT_MNO.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 59 PF trajectories 15.423728813559322 11.169272907507906\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.29410782]\n",
    "# [98.67610304 98.36459644]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20_C0.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 85 PF trajectories 13.941176470588236 12.039496707456593\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.3820142 ]\n",
    "# [98.67610304 96.37758918]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoT_pf12_addsafe__le1safe_20_abl_noT_C0.npy \n",
    "# used for ablation\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 57 PF trajectories 15.68421052631579 11.179803053120205\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 103.29112165]\n",
    "# [98.67610304 98.37415634]\n",
    "\n",
    "# hyperbolic_sac_all_train_nonominal_100_0.1_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_trainC0_C0_20.npy\n",
    "# beforeQP PF steps among 78 PF trajectories 12.35897435897436 10.475826090642867\n",
    "# afterQP PF steps among 56 PF trajectories 14.714285714285714 10.86724154399405\n",
    "# reward: beforeQP and afterQP\n",
    "# [106.16656403 104.58665957]\n",
    "# [98.67610304 98.61703421]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894e35c3-0868-4f85-93d4-6141b1f6f9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "T/dt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c95538e-09c9-4211-be3f-68a16781824a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69926c7b-983e-458c-a2f0-b738a0c94131",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82f72d61-1f56-4866-9b26-3879b3a9871f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Example result array\n",
    "result = np.random.randn(2, 3, 4, 2)  # Replace with your actual array\n",
    "print(result[:, :, :, 0])\n",
    "# Step 1: Check the condition for the first element along the last dimension (result[:,:,:,0] < 1)\n",
    "condition = result[:, :, :, 0] < 1  # This will give a boolean array of shape (10, 8, 16)\n",
    "print(condition)\n",
    "# Step 2: Find the earliest index in the third dimension where all subsequent values are True\n",
    "def find_earliest_true(condition):\n",
    "    # Iterate over the first two dimensions (10 and 8) and check for each slice\n",
    "    earliest_indices = np.full(condition.shape[:2], 0)  # Initialize with -1 (indicating no valid index)\n",
    "\n",
    "    for i in range(condition.shape[0]):  # Iterate over first dimension\n",
    "        for j in range(condition.shape[1]):  # Iterate over second dimension\n",
    "            # For each slice (i, j), find the earliest index where the condition is True\n",
    "            # and all subsequent values are also True\n",
    "            for k in range(condition.shape[2]):\n",
    "                if not condition[i, j, condition.shape[2]-k-1]: \n",
    "                    print(k)\n",
    "                    if k == 0:\n",
    "                        earliest_indices[i,j] = -1\n",
    "                    else:\n",
    "                        earliest_indices[i,j] = condition.shape[2]-k\n",
    "                    break\n",
    "            # valid_indices = np.where(np.cumprod(condition[i, j, :]) == 1)[0]\n",
    "            # if len(valid_indices) > 0:\n",
    "            #     earliest_indices[i, j] = valid_indices[0]  # Store the first valid index\n",
    "    return earliest_indices\n",
    "\n",
    "# Apply the function\n",
    "earliest_index = find_earliest_true(condition)\n",
    "\n",
    "print(earliest_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1e611ab-a9f4-493c-a46c-fcbb16e43d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_earliest_index_beforeQP = earliest_index[0,earliest_index[0,:]>=0]\n",
    "valid_earliest_index_afterQP = earliest_index[1,earliest_index[1,:]>=0]\n",
    "# result.shape[2] - earliest_index\n",
    "print(f\"beforeQP PF steps among {valid_earliest_index_beforeQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_beforeQP), np.std(result.shape[2] - valid_earliest_index_beforeQP))\n",
    "print(f\"afterQP PF steps among {valid_earliest_index_afterQP.shape[0]} PF trajectories\", np.mean(result.shape[2] - valid_earliest_index_afterQP), np.std(result.shape[2] - valid_earliest_index_afterQP))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38f19f76-a877-4458-97a2-051e603236c6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f95ee3da-16fe-41e3-9393-693bc712ca2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def runSingleEpisodeQP(model, env, parameter):\n",
    "    terminate = False\n",
    "    truncate = False\n",
    "\n",
    "    # Holds the resulting states\n",
    "    uStorage = []\n",
    "\n",
    "    # Reset Environment\n",
    "    obs,__ = env.reset()\n",
    "    uStorage.append(obs)\n",
    "\n",
    "    i = 0\n",
    "    rew = 0\n",
    "    while not truncate and not terminate:\n",
    "        # use backstepping controller\n",
    "        action = model(obs, parameter,i)\n",
    "        \n",
    "        obs, rewards, terminate, truncate, info = env.step(action)\n",
    "        # print(action, obs)\n",
    "        uStorage.append(obs)\n",
    "        rew += rewards \n",
    "        i += 1\n",
    "    u = np.array(uStorage)\n",
    "    return rew, u\n",
    "def QP_filter_Controller(obs, parameter,index):\n",
    "    # print(obs)\n",
    "    # print(parameter)\n",
    "    return parameter[index+1]\n",
    "\n",
    "reward_class_no_penalty =  TunedReward1D(int(round(T/dt)), -1e-4, 3e2) # no penalize\n",
    "# reward_class =  TunedReward1D(int(round(T/dt)), -1e3, 3e2) # with penalize\n",
    "# reward_class = NormReward(int(round(T/dt)),\"2\", \"temporal\", -1e-4, 3e2)\n",
    "\n",
    "def getInitialConditionFixed(nx):\n",
    "    return np.ones(nx) * 4.617858\n",
    "# ppo\n",
    "# U_list = [3.4483314, 7.519741, 3.210167, -1.233036, -2.05723, -2.7463036, 1.0309582, 4.518032, 8.707178, -1.8257523, 0.47780037, 2.5584679, 6.1201687, 1.3186092, 3.391739, 0.39335442, 1.716732, 3.2754917, 2.5711555, 0.39237595, 1.4837856, 1.3000488, 0.107624054, 2.7288265, 1.0334282, 2.1155167, 1.4288158, 0.37491608, 1.246748, 1.9155388, 1.4339046, 3.2164078, 0.46427155, 2.5861073, 0.28009224, 2.3356724, 1.6124992, 1.0692406, 0.8970852, 1.512373, 0.779047, 3.195963, 1.9200859, 1.9334564, 1.8800068, 1.9017906, 1.9094696, 2.035246, 0.1969738, 1.558815, 1.9237404]\n",
    "# U_list = [3.3722398, 7.4964485, 3.1820717, -1.0928841, -1.0235615, -0.97184944, 1.3208351, 5.8789654, 6.6949234, 0.3586483, 1.4759865, -0.0071411133, 3.3761787, 3.4062023, 1.6935387, -0.6621933, -0.28075218, 1.1501484, 4.305546, 3.4884377, 1.4927292, 0.9863682, 1.3623695, 1.9568596, 2.2611618, 1.6984978, 0.84568405, 1.8136616, 1.2548599, 4.0963764, 3.3163376, 1.3278103, 2.462883, 0.5718384, 1.6981659, 1.9384155, 1.647728, 1.2026386, 1.8850422, 2.4998322, 2.936737, 1.4727974, 2.9161453, -0.012018204, 1.4811478, 1.9514446, 0.60920906, 1.9053936, 1.1167717, 1.1801071, 2.643156]\n",
    "U_list = [4.617858, 6.524477, 4.2477837, -1.6751938, -6.231449, -2.5804176, 1.6056061, 10.603195, 6.774357, -1.0105801, 0.41062355, -1.5105667, 4.3493633, 3.0550537, 1.3159599, 0.32970047, 1.2317257, 2.541359, 1.2924461, 2.2199593, -1.2209549, 0.060626984, 2.0285816, 0.79698753, 2.8629513, 3.132225, 1.542284, 2.9536533, 3.5463924, 1.8504238, 2.3651009, 1.1760731, 0.82167625, 1.4998417, 2.5814934, 1.8735085, 1.7619438, 1.9039497, 1.1589928, 1.3399811, 1.5291691, 1.2556839, 1.6269836, 2.0162888, 3.2156734, 1.2802029, 2.2987404, -0.23257828, 2.0404243, 0.3293705, 1.2808723]\n",
    "\n",
    "# sac\n",
    "# U_list = [5.414479, 13.6943245, 6.737919, -1.4116383, -15.878218, 2.0415401, 11.883478, -12.460169, 18.540043, 15.989639, 10.549917, 7.4516907, 15.486492, 12.833923, 12.351456, -14.793518, 2.3114376, -12.606421, 9.233444, -16.279337, 10.991688, 2.8182907, 11.371559, 7.0601273, -0.18647385, 7.8772087, 3.0575066, 4.0379906, 0.07736397, 2.9936218, 11.991655, -7.8768234, -8.862894, 2.7116585, 3.4345264, 13.078705, 9.8074, 13.228119, 13.207127, 10.606209, 2.3662605, -0.11846161, 1.1179047, 1.814764, -3.8936005, 8.674536, -1.7833538, -7.6276445, 10.394348, 2.2616348, 1.0796776]\n",
    "# U_list = [9.9052515, 15.933079, 16.819916, -15.480906, -12.982727, -11.169411, -1.3884354, 12.963764, 4.570591, 1.6753864, -9.258332, -1.1977978, 10.823498, -4.6682215, -8.100311, -0.76138306, 11.713013, -1.1016998, -0.3859768, 4.751175, 1.9874783, 11.862745, -6.792221, 0.42554474, 2.7659988, 6.0947113, 5.412056, 4.699177, 8.645184, 5.9337215, 2.0537071, 14.013359, 11.627483, 0.76327324, 7.638443, 13.869091, -8.363653, -4.835432, -2.5752602, -11.574808, 11.517239, 0.5907364, 6.616297, 10.301319, 1.0391026, 3.1400795, -2.3702354, -5.737729, -5.26517, -0.05091858, 2.6315594]\n",
    "# U_list = [4.375922, 15.102211, 8.371391, -12.569988, -1.5775452, -7.982729, 3.2732048, -10.349854, 11.546528, -0.99699783, -5.4987326, 12.436901, -10.3670845, -1.860075, 7.447201, 11.71246, -6.266899, 12.430904, 11.550989, -11.672353, -10.813607, -0.11060524, -14.260563, -9.212543, -5.192293, 3.0454369, 10.099113, -0.16654778, 3.718258, 5.310917, -0.33210754, -5.513237, 3.4668922, -5.8947906, 2.6177197, 8.546547, -14.535984, -4.0717907, -4.958953, 14.49345, 10.750582, 12.249825, 14.439644, 15.610119, 13.553192, -11.512323, 8.821808, 3.2025948, -17.850409, 10.7373905, -1.5921249]\n",
    "\n",
    "hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "\n",
    "U_safe_list = [4.617857933044434, 6.474477032670663, 4.147783653884019, -1.825194119533947, -6.431449291844771, -2.8304180014209095, 1.3056058432611763, 10.253194805160732, 6.374356685549245, -1.4605801314094577, -0.089376602277909, -1.860566779965822, 3.949363374304383, 2.605053571245506, 0.8159599538815456, -0.22029977113326993, 0.6317256901671797, 1.8913586379038207, 0.592446171782079, 1.4699589806781432, -2.020954910590696, -0.7893732496514758, 1.1285815916599162, -0.15301273873944865, 1.8629512354362063, 2.0822247250297323, 0.44228408366895877, 1.8036530052469857, 2.3463923943941367, 0.6004234817719976, 1.0651007948125866, -0.17392713816269634, -0.5783238218637976, 0.04984145938496454, 1.081493315810711, 0.3235081925704151, 0.16194386927590898, 0.2539494542055465, -0.5410072050741874, -0.4100192194930017, -0.2708309207189714, -0.5943163013952599, -0.2730162633309092, 0.06628840527897495, 1.2156733815124525, -0.7697974043460718, 0.39874024260208074, -2.1825786307115678, 0.2404243337184866, -1.720630004267754, -0.7691281117037434]\n",
    "# U_safe_list = [3.448331356048584, 7.519741023650909, 3.210166147651017, -1.2330360968926914, -2.0572308600469027, -2.746303632166083, 1.0309573967104741, 4.518032330109398, 8.707177642297065, -1.825752007198397, 0.4777996222686163, 2.558468018529701, 6.120167758285174, 1.3186093571276252, 3.391737985769518, 0.39335450544682715, 1.716731101325325, 3.275491728246256, 1.573214637560762, 0.3923760372222107, 0.485844745590692, 1.3000488807953632, -0.8903168717268151, 2.728826534674842, 0.035487296450777084, 2.115516671272066, -0.029515764355225904, 0.3749160518128649, -0.21158363117361184, 1.9155387488377658, -0.2900844490380859, 3.216407736044961, -1.2597175918987158, 2.586107206970202, -1.4438968941570305, 2.3356723432991844, -0.11148992738327013, 1.0692405189889955, -0.8269039544575729, 1.0365827945672488, -0.9449421362484383, 2.7201726501509533, 0.1960967192661418, 1.4576661711193815, -0.035679224052888175, 0.6929926975287598, -0.006216410997698235, 0.8264479778752415, -1.718712185018834, 0.3500170709640228, 0.7149424650930681]\n",
    "# U_safe_list = [3.372239828109741, 7.496448339882619, 3.1820713407990673, -1.0928842458665322, -1.0235618886618592, -0.9718496388674968, 1.320834721395173, 5.878965535067465, 6.694923069897769, 0.358648400481246, 1.4759860776769171, -0.00714100419210667, 3.376178409990426, 3.406202349075358, 1.6935383615412927, -0.6621930911283971, -0.2807525755241198, 1.1501486495611926, 4.305545299057868, 3.488437837127226, 1.4927287707213113, 0.9863683038346807, 1.362369125745741, 1.9568597537847126, 1.2677068459334944, 1.6984979473057962, -0.147770891990044, 1.638106513786373, -0.2199820256250003, 3.9208213295732763, 1.8414955547976608, 1.152255168006353, 0.9880409522093889, 0.3962832699341776, 0.22332385857075154, 1.7628604580866964, 0.17288593174234324, 1.0270835721536047, 0.1707110228629105, 2.324277122547959, 1.2224058574992611, 1.2972423888097397, 1.017028368023169, -0.18757323468369247, -0.4179691452645571, 1.7758896297488933, -1.2899078774694104, 1.7298386035725075, -0.7823452293228925, 1.0045521313204704, 2.4676011014949113]\n",
    "# U_safe_list = [5.414478778839111, 13.694325093624684, 5.692215931220867, -1.4116378658402482, -16.92392057521942, -0.15775705626695768, 10.837776242806974, -14.65946670215156, 17.494340940038917, 13.790342393123199, 9.50421537499438, 5.252393923875207, 14.440790343469196, 10.634626572402386, 11.305754314320584, -16.99281531722586, 1.2657365060472596, -14.805718772221962, 8.187743307887585, -18.4786342525078, 8.014235835996242, 0.6189938053410522, 7.741018946722455, 4.860830220115805, -3.8170137641394577, 5.67791167560431, -0.5730333983705096, 1.838693383554574, -3.5531759393360263, -0.6264642029796845, 8.361115231329794, -11.496909768722242, -12.493433808997167, -0.9084276299741649, -1.4898662767457491, 9.458618909104564, 4.883007055392875, 8.4206661419543, 8.28273397068707, 5.798755993832103, -2.558132414025101, -4.925914106684857, -3.806488233386778, -2.9926883812106215, -8.81799328142032, 3.8670836717654193, -6.707746537501668, -12.435098046872747, 5.469955666751696, -2.545818458270526, -3.7277757203493938]\n",
    "# U_safe_list = [9.905251502990723, 15.9330786649559, 16.819916065255768, -15.48090424621185, -12.982727201103842, -11.169409294456756, -1.3884352404211882, 12.96376492627357, 4.5705911225327895, 1.6753874558027917, -9.258332617642225, -1.1977966703796312, 10.823495900706405, -4.668220494502825, -8.100313250096214, -0.7613820996612821, 11.7130108707399, -1.1016988828188823, -0.38597848518410416, 4.751175687237319, 1.9874765560943715, 11.862746172931528, -6.792222955361503, 0.42554556096572504, 2.7659969443447707, 4.839684031144672, 2.2922194537828426, 3.4441494691448193, 1.7116968739354688, 2.7958411749133005, -4.879779612145396, 10.875479007892944, 4.693996626719016, -2.374606548633812, 0.7049560638414532, 10.731211700547727, -15.297140491761748, -7.973310886398314, -9.508747616950696, -14.712687035426745, 4.5837514705011735, -2.547142301956148, -0.31719028198369514, 7.163440589368372, -5.894384463215484, 0.0022009283867081564, -9.303722642478188, -8.875607377528388, -12.198657241778275, -3.188796950622338, -0.5063190782320017]\n",
    "# U_safe_list = [4.375922203063965, 14.902210964654216, 7.971391159228517, -13.169986979765952, -2.377545457384997, -8.982727555273812, 2.0732045285557206, -11.749853092244564, 9.946527881290535, -2.7969964176244844, -7.498733224886905, 10.236902563921888, -12.767085082155859, -3.660074277653564, 5.447199785741873, 10.312460181587142, -8.666899648656795, 11.430905081729009, 8.750988104152498, -12.272351822182088, -13.213609112594781, -1.1106042575390234, -17.060564756953113, -10.612541337429946, -8.39229486801139, 1.245438021714392, 6.499112439211444, -2.366546710909631, -0.2817432118939678, 3.5109182218810133, -3.9321083857975054, -7.7132356552203305, 0.2668916387593754, -7.694789205357267, -0.9822810604947563, 6.346548676049167, -18.535984412736184, -6.6717893887012565, -8.558953378021329, 11.493450843107382, 6.750580911703651, 8.849825625314622, 10.039643326764121, 11.810119242977123, 8.753191005067151, -14.912322313508014, 4.421806280820988, 0.20259552274586667, -21.85040842585329, 8.137391572032755, -4.192123952328586]\n",
    "\n",
    "hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb5f467b-1dce-4bab-af70-9c324e90df10",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_class_no_penalty =  TunedReward1D(int(round(T/dt)), -1e-4, 3e2) # no penalize\n",
    "# Apply the function\n",
    "\n",
    "RL_1000 = np.load(\"../../../verify-pde-control/hyperbolic_ppo_all_nominal_100_10.npy\")\n",
    "RL_reward_beforeQP = []\n",
    "RL_reward_afterQP = []\n",
    "uBcks_beforeQP_list = []\n",
    "uBcks_afterQP_list = []\n",
    "# uBcks_beforeQP,uBcks_afterQP = 0,0\n",
    "for i in range(RL_1000[\"safe_label\"].transpose().shape[0]):\n",
    "    # if i < 1:continue\n",
    "    # if RL_1000[\"Y_nominal\"][-1, i] < 0.5: continue\n",
    "    U_list = RL_1000[\"U_nominal\"][:, i]\n",
    "    # print(U_list[0])\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "    uBcks_beforeQP_list.append(uBcks_beforeQP)\n",
    "    RL_reward_beforeQP.append(reward_beforeQP)\n",
    "\n",
    "    U_safe_list = RL_1000[\"U_safe\"][:, i]\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)\n",
    "    # print(uBcks_afterQP.shape) (51, 100)\n",
    "    uBcks_afterQP_list.append(uBcks_afterQP)\n",
    "    RL_reward_afterQP.append(reward_afterQP)\n",
    "    # print(uBcks_beforeQP,uBcks_afterQP)\n",
    "    # if i > 2: break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcdda126-0cdd-448e-9b25-9a26b70a5701",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88354a90-8e20-4b03-b459-b15b951c498e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "\n",
    "# %matplotlib inline\n",
    "# PLOT EACH EXAMPLE. PLOTS ARE NOT SCALED THE SAME ON Z SO MAY HAVE TO ADJUST\n",
    "fig = plt.figure(figsize=set_size(433, 0.99, (2, 3), height_add=1))\n",
    "subfigs = fig.subfigures(nrows=2, ncols=1, hspace=-0.05)\n",
    "\n",
    "subfig = subfigs[0]\n",
    "subfig.suptitle(r\"PPO controller before and aftering filtering for hyperbolic equation\", y=1.1)\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uBcks_afterQP))\n",
    "# print(uBcks_afterQP[0].shape)\n",
    "print(spatial.shape, temporal.shape)\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xlim(1, 0)\n",
    "ax[0].set_xticks([1, 0.5, 0])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "# ax[0].xaxis.set_rotate_label(True)\n",
    "\n",
    "ind0 = 6\n",
    "ind1 = 2\n",
    "ind2 = 11\n",
    "# ind0 = 12\n",
    "# ind0 = 4\n",
    "# ind0 = 8\n",
    "#\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_afterQP_list[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "# test = np.zeros(len(temporal))\n",
    "# vals = (uBcks_afterQP_list[ind0].transpose())[0] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=0.1, antialiased=False, rasterized=False, alpha=0.8,label=r'$Y(t)_{safe}$')\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_beforeQP_list[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "# test = np.zeros(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[ind0].transpose())[0] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False, alpha=0.8,label=r'$Y(t)_{nominal}$')\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_afterQP_list[ind0].transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", alpha=1,lw=1, antialiased=False, rasterized=False,linestyle='--',label=r'$U_{safe}(t)$')\n",
    "\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_beforeQP_list[ind0].transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\",alpha=1, lw=1, antialiased=False, rasterized=False,linestyle='--', label=r'$U_{nominal}(t)$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-11, 10, len(vals)))\n",
    "ax[0].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "\n",
    "#\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_afterQP_list[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_afterQP_list[ind1].transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False,linestyle='--',label=r'$U(t)_{safe}$')\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True,label='Before filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_beforeQP_list[ind1].transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False,linestyle='--', label=r'$U_{nominal}(t)$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-6, 10, len(vals)))\n",
    "ax[1].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list[1], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[1].view_init(10, 35)\n",
    "# ax[1].zaxis.set_rotate_label(False)\n",
    "# ax[1].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[1].transpose())[-1] \n",
    "# ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "# ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list[2], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[2].view_init(10, 35)\n",
    "# ax[2].zaxis.set_rotate_label(False)\n",
    "# ax[2].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[2].transpose())[-1] \n",
    "# ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "#\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_afterQP_list[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_afterQP_list[ind2].transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False, linestyle='--',label=r'$U_{safe}(t)$')\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_beforeQP_list[ind2].transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False, linestyle='--',label=r'$U_{nominal}(t)$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-11, 10, len(vals)))\n",
    "ax[2].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "# print()\n",
    "# handles[1].set_alpha(1)\n",
    "# print(labels)\n",
    "subfig.legend(handles, labels, loc='upper center', ncol=4, framealpha=0.1)\n",
    "# plt.show()\n",
    "#plt.savefig(\"hyperbolicExamples.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de41295a-53d0-4bdf-96c9-9097d722a700",
   "metadata": {},
   "outputs": [],
   "source": [
    "subfig = subfigs[1]\n",
    "# subfig.suptitle(r\"Example trajectories for $u(0, x)=1$ with backstepping, PPO, and SAC\", y=1.1)\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uBcks_afterQP))\n",
    "# print(uBcks_afterQP[0].shape)\n",
    "print(spatial.shape, temporal.shape)\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xlim(1, 0)\n",
    "ax[0].set_xticks([1, 0.5, 0])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "# ax[0].xaxis.set_rotate_label(True)\n",
    "\n",
    "ind0 = 6\n",
    "ind1 = 2\n",
    "ind2 = 11\n",
    "# ind0 = 12\n",
    "# ind0 = 4\n",
    "# ind0 = 8\n",
    "#\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_afterQP_list[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_afterQP_list[ind0].transpose())[0] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False,label=r'$Y_{safe}(t)$')\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_afterQP_list[ind0].transpose())[-1] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"green\", lw=1, antialiased=False, rasterized=False,alpha=0.8,linestyle='--', label=r'$U(t)_{safe}$')\n",
    "# \n",
    "\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_beforeQP_list[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_beforeQP_list[ind0].transpose())[0] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False,label=r'$Y_{nominal}(t)$')\n",
    "# vals[1:]\n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"yellow\", lw=1, antialiased=False, rasterized=False,label=r'$\\phi(t,Y)=0$')\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[ind0].transpose())[-1] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"orange\", lw=0.1, antialiased=False, rasterized=False,alpha=0.8,linestyle='--', label=r'$U(t)_{nominal}$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-10, 10, len(vals)))\n",
    "ax[0].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-10, 1, len(vals)))\n",
    "ax[0].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety constraint')\n",
    "#\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_afterQP_list[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_afterQP_list[ind1].transpose())[0] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False)\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_beforeQP_list[ind1].transpose())[0] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-5, 10, len(vals)))\n",
    "ax[1].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-5, 1, len(vals)))\n",
    "ax[1].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety constraint')\n",
    " \n",
    "# ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list[1], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[1].view_init(10, 35)\n",
    "# ax[1].zaxis.set_rotate_label(False)\n",
    "# ax[1].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[1].transpose())[-1] \n",
    "# ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "# ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list[2], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[2].view_init(10, 35)\n",
    "# ax[2].zaxis.set_rotate_label(False)\n",
    "# ax[2].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[2].transpose())[-1] \n",
    "# ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "#\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_afterQP_list[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_afterQP_list[ind2].transpose())[0] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False)\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_beforeQP_list[ind2].transpose())[0] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-10, 10, len(vals)))\n",
    "ax[2].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-10, 1, len(vals)))\n",
    "ax[2].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety constraint')\n",
    "\n",
    "\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "subfig.legend(handles, labels, loc='upper center', ncol=4, framealpha=0.1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c29baf7b-666f-444e-b5e2-3739d8e4e9be",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.savefig(\"hyperbolicExamples.png\", dpi=400, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6030090-64be-49e4-9abd-05b5c8ba829f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.savefig(\"hyperbolicExamples_highreso.png\", dpi=600, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4826321e-c6a7-42df-838d-da61b61998f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_class_no_penalty =  TunedReward1D(int(round(T/dt)), -1e-4, 3e2) # no penalize\n",
    "# Apply the function\n",
    "\n",
    "RL_1000 = np.load(\"../../../verify-pde-control/hyperbolic_sac_all_train_nonominal_100_0.5_hyper_0.1reg_1pf_time_CBFnoNOnotpf_pf12_addsafe__le1safe_20.npy\")\n",
    "RL_reward_beforeQP = []\n",
    "RL_reward_afterQP = []\n",
    "uBcks_beforeQP_list_sac = []\n",
    "uBcks_afterQP_list_sac = []\n",
    "# uBcks_beforeQP,uBcks_afterQP = 0,0\n",
    "for i in range(RL_1000[\"safe_label\"].transpose().shape[0]):\n",
    "    # if i < 1:continue\n",
    "    # if RL_1000[\"Y_nominal\"][-1, i] < 0.5: continue\n",
    "    U_list = RL_1000[\"U_nominal\"][:, i]\n",
    "    # print(U_list[0])\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_beforeQP, uBcks_beforeQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_list)\n",
    "    uBcks_beforeQP_list_sac.append(uBcks_beforeQP)\n",
    "    RL_reward_beforeQP.append(reward_beforeQP)\n",
    "\n",
    "    U_safe_list = RL_1000[\"U_safe\"][:, i]\n",
    "    def getInitialConditionFixed(nx):\n",
    "        return np.ones(nx) * U_list[0]\n",
    "    hyperbolicParametersBacksteppingFixed = hyperbolicParametersBackstepping.copy()\n",
    "    hyperbolicParametersBacksteppingFixed[\"reset_init_condition_func\"] = getInitialConditionFixed\n",
    "    hyperbolicParametersBacksteppingFixed[\"reward_class\"] = reward_class_no_penalty\n",
    "    envBcksFixed = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingFixed)\n",
    "    reward_afterQP, uBcks_afterQP = runSingleEpisodeQP(QP_filter_Controller, envBcksFixed, U_safe_list)\n",
    "    # print(uBcks_afterQP.shape) (51, 100)\n",
    "    uBcks_afterQP_list_sac.append(uBcks_afterQP)\n",
    "    RL_reward_afterQP.append(reward_afterQP)\n",
    "    # print(uBcks_beforeQP,uBcks_afterQP)\n",
    "    # if i > 2: break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49662a54-42e7-4301-b9df-7cbe70469bd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "\n",
    "# %matplotlib inline\n",
    "# PLOT EACH EXAMPLE. PLOTS ARE NOT SCALED THE SAME ON Z SO MAY HAVE TO ADJUST\n",
    "fig = plt.figure(figsize=set_size(433, 0.99, (2, 3), height_add=1))\n",
    "subfigs = fig.subfigures(nrows=2, ncols=1, hspace=-0.05)\n",
    "\n",
    "subfig = subfigs[0]\n",
    "subfig.suptitle(r\"SAC controller before and aftering filtering for hyperbolic equation\", y=1.1)\n",
    "subfig.subplots_adjust(left=0.05, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uBcks_afterQP))\n",
    "# print(uBcks_afterQP[0].shape)\n",
    "print(spatial.shape, temporal.shape)\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xlim(1, 0)\n",
    "ax[0].set_xticks([1, 0.5, 0])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "# ax[0].xaxis.set_rotate_label(True)\n",
    "\n",
    "ind0 = 13\n",
    "ind1 = 6\n",
    "ind2 = 24\n",
    "\n",
    "# ind0 = 24\n",
    "# ind1 = 6\n",
    "# ind2 = 13\n",
    "# ind0 = 12\n",
    "# ind0 = 4\n",
    "# ind0 = 8\n",
    "#\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_afterQP_list_sac[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "# test = np.zeros(len(temporal))\n",
    "# vals = (uBcks_afterQP_list[ind0].transpose())[0] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=0.1, antialiased=False, rasterized=False, alpha=0.8,label=r'$Y(t)_{safe}$')\n",
    "\n",
    "\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_beforeQP_list_sac[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "# test = np.zeros(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[ind0].transpose())[0] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False, alpha=0.8,label=r'$Y(t)_{nominal}$')\n",
    "\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_afterQP_list_sac[ind0].transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", alpha=1,lw=1, antialiased=False, rasterized=False,linestyle='--',label=r'$U_{safe}(t)$')\n",
    "# \n",
    "\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_beforeQP_list_sac[ind0].transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\",alpha=1, lw=1, antialiased=False, rasterized=False, linestyle='--', label=r'$U_{nominal}(t)$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-8, 5, len(vals)))\n",
    "ax[0].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "\n",
    "#\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_afterQP_list_sac[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_afterQP_list_sac[ind1].transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False,linestyle='--',label=r'$U(t)_{safe}$')\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list_sac[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_beforeQP_list_sac[ind1].transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False, linestyle='--',label=r'$U(t)_{nominal}$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-15, 10, len(vals)))\n",
    "ax[1].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[1].plot_surface(meshx, mesht, np.zeros_like(uBcks_beforeQP_list_sac[ind0]), edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "#                         alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety boundary')\n",
    " \n",
    "# ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list[1], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[1].view_init(10, 35)\n",
    "# ax[1].zaxis.set_rotate_label(False)\n",
    "# ax[1].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[1].transpose())[-1] \n",
    "# ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "# ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list[2], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[2].view_init(10, 35)\n",
    "# ax[2].zaxis.set_rotate_label(False)\n",
    "# ax[2].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[2].transpose())[-1] \n",
    "# ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "#\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_afterQP_list_sac[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_afterQP_list_sac[ind2].transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False, linestyle='--',label=r'$U(t)_{safe}$')\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list_sac[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcks_beforeQP_list_sac[ind2].transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False, linestyle='--',label=r'$U(t)_{nominal}$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-15, 10, len(vals)))\n",
    "ax[2].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[2].plot_surface(meshx, mesht, np.zeros_like(uBcks_beforeQP_list_sac[ind0]), edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "#                         alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety boundary')\n",
    "\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "# print()\n",
    "# handles[1].set_alpha(1)\n",
    "# print(labels)\n",
    "subfig.legend(handles, labels, loc='upper center', ncol=4, framealpha=0.1)\n",
    "# plt.show()\n",
    "#plt.savefig(\"hyperbolicExamples.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad0ec6ea-f831-4f30-a963-e384cf5aa248",
   "metadata": {},
   "outputs": [],
   "source": [
    "subfig = subfigs[1]\n",
    "# subfig.suptitle(r\"Example trajectories for $u(0, x)=1$ with backstepping, PPO, and SAC\", y=1.1)\n",
    "subfig.subplots_adjust(left=0.05, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uBcks_afterQP))\n",
    "# print(uBcks_afterQP[0].shape)\n",
    "print(spatial.shape, temporal.shape)\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xlim(1, 0)\n",
    "ax[0].set_xticks([1, 0.5, 0])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "# ax[0].xaxis.set_rotate_label(True)\n",
    "\n",
    "\n",
    "ind0 = 13\n",
    "ind1 = 6\n",
    "ind2 = 24\n",
    "\n",
    "#\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_afterQP_list_sac[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_afterQP_list_sac[ind0].transpose())[0] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False,label=r'$Y_{safe}(t)$')\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_afterQP_list[ind0].transpose())[-1] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"green\", lw=1, antialiased=False, rasterized=False,alpha=0.8,linestyle='--', label=r'$U(t)_{safe}$')\n",
    "# \n",
    "\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcks_beforeQP_list_sac[ind0], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_beforeQP_list_sac[ind0].transpose())[0] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False,label=r'$Y_{nominal}(t)$')\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-8, 5, len(vals)))\n",
    "ax[0].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-8, -0, len(vals)))\n",
    "ax[0].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety constraint')\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[ind0].transpose())[-1] \n",
    "# ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"orange\", lw=0.1, antialiased=False, rasterized=False,alpha=0.8,linestyle='--', label=r'$U(t)_{nominal}$')\n",
    "\n",
    "#\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_afterQP_list_sac[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_afterQP_list_sac[ind1].transpose())[0] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False)\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list_sac[ind1], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xlim(1, 0)\n",
    "ax[1].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_beforeQP_list_sac[ind1].transpose())[0] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False)\n",
    "# ax[1].plot(test[1:], temporal[1:], np.linspace(-5, 0, len(vals[1:])), color=\"green\", lw=1, antialiased=False, rasterized=False, label='Safety boundary')\n",
    "# print(temporal.shape,np.linspace(-5, 0, len(vals)).shape)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-15, 10, len(vals)))\n",
    "ax[1].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-15, -0, len(vals)))\n",
    "ax[1].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety constraint')\n",
    " \n",
    "# ax[1].plot_surface(meshx, mesht, uBcks_beforeQP_list[1], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[1].view_init(10, 35)\n",
    "# ax[1].zaxis.set_rotate_label(False)\n",
    "# ax[1].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[1].transpose())[-1] \n",
    "# ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "# ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list[2], edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "#                         alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "# ax[2].view_init(10, 35)\n",
    "# ax[2].zaxis.set_rotate_label(False)\n",
    "# ax[2].set_xticks([0, 0.5, 1])\n",
    "# test = np.ones(len(temporal))\n",
    "# vals = (uBcks_beforeQP_list[2].transpose())[-1] \n",
    "# ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "#\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_afterQP_list_sac[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"blue\", shade=False, rasterized=True, antialiased=True, label='After filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_afterQP_list_sac[ind2].transpose())[0] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"blue\", lw=1, antialiased=False, rasterized=False)\n",
    "\n",
    "# \n",
    "\n",
    "\n",
    "ax[2].plot_surface(meshx, mesht, uBcks_beforeQP_list_sac[ind2], edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"red\", shade=False, rasterized=True, antialiased=True, label='Before filtering')\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xlim(1, 0)\n",
    "ax[2].set_xticks([1, 0.5, 0])\n",
    "test = np.zeros(len(temporal))\n",
    "vals = (uBcks_beforeQP_list_sac[ind2].transpose())[0] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=1, antialiased=False, rasterized=False)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-15, 10, len(vals)))\n",
    "ax[2].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.1, color=\"black\", shade=False, rasterized=True, antialiased=True)\n",
    "mesht_new, mesh_safe = np.meshgrid(temporal, np.linspace(-15, 0, len(vals)))\n",
    "ax[2].plot_surface(test, mesht_new, mesh_safe, edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "                        alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety constraint')\n",
    "# ax[2].plot(test[1:], temporal[1:], np.zeros_like(vals[1:]), color=\"green\", lw=1, antialiased=False, rasterized=False, label='Safety boundary')\n",
    "# ax[2].plot_surface(meshx, mesht, np.zeros_like(uBcks_beforeQP_list_sac[ind0]), edgecolor=\"black\",lw=0.0, rstride=50, cstride=1, \n",
    "#                         alpha=0.2, color=\"green\", shade=False, rasterized=True, antialiased=True, label='Safety boundary')\n",
    "\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "subfig.legend(handles, labels, loc='upper center', ncol=3, framealpha=0.1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78ff3e25-1a28-4aab-905e-dcf7348463cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.savefig(\"hyperbolicExamples_sac.png\", dpi=400, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be93e43-099e-43fe-aa6c-6995bb756841",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e79e26d-a1d8-4e16-bd26-6608c3c850a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "x = np.linspace(0, 5, 51)  # Generate 100 points from 0 to 10\n",
    "y1 = (uBcks_beforeQP.transpose())[0]  # First line: sine function\n",
    "y2 = (uBcks_afterQP.transpose())[0]\n",
    "# Create the plot\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(x, y1, label='before QP', linestyle='-', marker='o')  # Plot the first line with markers\n",
    "plt.plot(x, y2, label='after QP', linestyle='--', marker='x')  # Plot the second line with a different style\n",
    "\n",
    "# Add titles and labels\n",
    "plt.title('Plot of Two Lines')\n",
    "plt.xlabel('x-axis')\n",
    "plt.ylabel('y-axis')\n",
    "\n",
    "# Add a legend to differentiate the lines\n",
    "plt.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2efb7302-fe2c-4686-b5f9-1168c5a6815d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PLOT EXAMPLE PROBLEMS.\n",
    "\n",
    "# First Build Same Initial Condition Environments\n",
    "# Set initial condition function here\n",
    "def getInitialConditionTen(nx):\n",
    "    return np.ones(nx)*10\n",
    "\n",
    "def getInitialConditionOne(nx):\n",
    "    return np.ones(nx)*1\n",
    "\n",
    "hyperbolicParametersBacksteppingTen = hyperbolicParametersBackstepping.copy()\n",
    "hyperbolicParametersBacksteppingTen[\"reset_init_condition_func\"] = getInitialConditionTen\n",
    "\n",
    "hyperbolicParametersBacksteppingOne = hyperbolicParametersBackstepping.copy()\n",
    "hyperbolicParametersBacksteppingOne[\"reset_init_condition_func\"] = getInitialConditionOne\n",
    "\n",
    "hyperbolicParametersRLTen = hyperbolicParametersRL.copy()\n",
    "hyperbolicParametersRLTen[\"reset_init_condition_func\"] = getInitialConditionTen\n",
    "\n",
    "hyperbolicParametersRLOne = hyperbolicParametersRL.copy()\n",
    "hyperbolicParametersRLOne[\"reset_init_condition_func\"] = getInitialConditionOne\n",
    "\n",
    "# Make environments\n",
    "envBcksTen = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingTen)\n",
    "envBcksOne = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersBacksteppingOne)\n",
    "\n",
    "envRLTen = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersRLTen)\n",
    "envRLOne = gym.make(\"PDEControlGym-TransportPDE1D\", **hyperbolicParametersRLOne)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b4b856-1914-4264-8578-ee418a76e83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewBcksTen, uBcksTen = runSingleEpisode(bcksController, envBcksTen, beta)\n",
    "rewBcksOne, uBcksOne = runSingleEpisode(bcksController, envBcksOne, beta)\n",
    "\n",
    "rewPPOTen, uPPOTen = runSingleEpisode(RLController, envRLTen, ppoModel)\n",
    "rewPPOOne, uPPOOne = runSingleEpisode(RLController, envRLOne, ppoModel)\n",
    "\n",
    "rewSACTen, uSACTen = runSingleEpisode(RLController, envRLTen, sacModel)\n",
    "rewSACOne, uSACOne = runSingleEpisode(RLController, envRLOne, sacModel)\n",
    "\n",
    "rewOpenTen,uOpenTen = runSingleEpisode(openLoopController, envBcksTen, _)\n",
    "rewOpenTen,uOpenOne = runSingleEpisode(openLoopController, envBcksOne, _)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "796cc833-e0cc-4e18-aed6-aa7a13e39f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "# %matplotlib inline\n",
    "\n",
    "# PLOT OPENLOOOP EXAMPLE. PLOTS ARE NOT SCALED THE SAME ON Z SO MAY HAVE TO ADJUST\n",
    "fig = plt.figure(figsize=set_size(433, 0.99, (1, 2), height_add=1))\n",
    "subfigs = fig.subfigures(nrows=1, ncols=1, hspace=0)\n",
    "\n",
    "subfig = subfigs\n",
    "subfig.suptitle(r\"Open-loop (U(t)=0) instability of transport PDE for u(x, 0)=1, 10\")\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uPPOOne))\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=2, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xticks([0, 0.5, 1])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uOpenOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=2, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uOpenOne.transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[1].plot_surface(meshx, mesht, uOpenTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=2, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uOpenTen.transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "\n",
    "plt.show()\n",
    "# plt.savefig(\"hyperbolicOpenloop.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8f8aaf-5e05-409a-b3c7-364fcabd4f18",
   "metadata": {},
   "outputs": [],
   "source": [
    "(uOpenOne.transpose()).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ef157d-8135-4148-a42c-8f025f99744f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PLOT EACH EXAMPLE. PLOTS ARE NOT SCALED THE SAME ON Z SO MAY HAVE TO ADJUST\n",
    "fig = plt.figure(figsize=set_size(433, 0.99, (2, 3), height_add=1))\n",
    "subfigs = fig.subfigures(nrows=2, ncols=1, hspace=0)\n",
    "\n",
    "subfig = subfigs[0]\n",
    "subfig.suptitle(r\"Example trajectories for $u(0, x)=1$ with backstepping, PPO, and SAC\")\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uPPOOne))\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xticks([0, 0.5, 1])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcksOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcksOne.transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[1].plot_surface(meshx, mesht, uPPOOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uPPOOne.transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[2].plot_surface(meshx, mesht, uSACOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uSACOne.transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "\n",
    "subfig = subfigs[1]\n",
    "subfig.suptitle(r\"Example trajectories for $u(0, x)=10$ with backstepping, PPO, and SAC\")\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uPPOOne))\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xticks([0, 0.5, 1])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcksTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcksTen.transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[1].plot_surface(meshx, mesht, uPPOTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uPPOTen.transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[2].plot_surface(meshx, mesht, uSACTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uSACTen.transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "\n",
    "plt.show()\n",
    "#plt.savefig(\"hyperbolicExamples.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e780575c-a669-4483-b920-be735d882c2b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c923dec-dd18-44a4-974a-29871ad60773",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5078b80e-47c6-45e7-a1e2-78ec06a35256",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PLOT EACH EXAMPLE. PLOTS ARE NOT SCALED THE SAME ON Z SO MAY HAVE TO ADJUST\n",
    "fig = plt.figure(figsize=set_size(433, 0.99, (2, 3), height_add=1))\n",
    "subfigs = fig.subfigures(nrows=2, ncols=1, hspace=0)\n",
    "\n",
    "subfig = subfigs[0]\n",
    "subfig.suptitle(r\"Example trajectories for $u(0, x)=1$ with backstepping, PPO, and SAC\")\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uPPOOne))\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xticks([0, 0.5, 1])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcksOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcksOne.transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[1].plot_surface(meshx, mesht, uPPOOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uPPOOne.transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[2].plot_surface(meshx, mesht, uSACOne, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uSACOne.transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "\n",
    "subfig = subfigs[1]\n",
    "subfig.suptitle(r\"Example trajectories for $u(0, x)=10$ with backstepping, PPO, and SAC\")\n",
    "subfig.subplots_adjust(left=0.03, bottom=0.05, right=1, top=0.95, wspace=0, hspace=0)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 5\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uPPOOne))\n",
    "meshx, mesht = np.meshgrid(spatial, temporal)\n",
    "\n",
    "ax = subfig.subplots(nrows=1, ncols=3, subplot_kw={\"projection\": \"3d\", \"computed_zorder\": False})\n",
    "\n",
    "for axes in ax:\n",
    "    for axis in [axes.xaxis, axes.yaxis, axes.zaxis]:\n",
    "        axis._axinfo['axisline']['linewidth'] = 1\n",
    "        axis._axinfo['axisline']['color'] = \"b\"\n",
    "        axis._axinfo['grid']['linewidth'] = 0.2\n",
    "        axis._axinfo['grid']['linestyle'] = \"--\"\n",
    "        axis._axinfo['grid']['color'] = \"#d1d1d1\"\n",
    "        axis.set_pane_color((1,1,1))\n",
    "\n",
    "ax[0].view_init(10, 35)\n",
    "ax[0].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[1].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[2].set_xlabel(\"x\", labelpad=-3)\n",
    "ax[0].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[2].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[1].set_ylabel(\"Time\", labelpad=-3)\n",
    "ax[0].set_zlabel(r\"$u(x, t)$\", rotation=90, labelpad=-7)\n",
    "\n",
    "ax[0].zaxis.set_rotate_label(False)\n",
    "ax[0].set_xticks([0, 0.5, 1])\n",
    "ax[0].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='x', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[1].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[2].tick_params(axis='y', which='major', pad=-3)\n",
    "ax[0].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[1].tick_params(axis='z', which='major', pad=-1)\n",
    "ax[2].tick_params(axis='z', which='major', pad=-1)\n",
    "\n",
    "ax[0].plot_surface(meshx, mesht, uBcksTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uBcksTen.transpose())[-1] \n",
    "ax[0].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[1].plot_surface(meshx, mesht, uPPOTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[1].view_init(10, 35)\n",
    "ax[1].zaxis.set_rotate_label(False)\n",
    "ax[1].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uPPOTen.transpose())[-1] \n",
    "ax[1].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    " \n",
    "ax[2].plot_surface(meshx, mesht, uSACTen, edgecolor=\"black\",lw=0.2, rstride=50, cstride=1, \n",
    "                        alpha=1, color=\"white\", shade=False, rasterized=True, antialiased=True)\n",
    "ax[2].view_init(10, 35)\n",
    "ax[2].zaxis.set_rotate_label(False)\n",
    "ax[2].set_xticks([0, 0.5, 1])\n",
    "test = np.ones(len(temporal))\n",
    "vals = (uSACTen.transpose())[-1] \n",
    "ax[2].plot(test[1:], temporal[1:], vals[1:], color=\"red\", lw=0.1, antialiased=False, rasterized=False)\n",
    "\n",
    "plt.show()\n",
    "#plt.savefig(\"hyperbolicExamples.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d8c7741-6ac9-4656-8af5-6adb032819a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Revise\n",
    "using Burgers#, Plots\n",
    "using DataDeps, MAT, MLUtils\n",
    "using NeuralOperators, Flux\n",
    "using BSON\n",
    "using DataDeps, MAT, MLUtils\n",
    "using NeuralOperators, Flux\n",
    "using CUDA, FluxTraining, BSON\n",
    "import Flux: params\n",
    "using BSON: @save, @load\n",
    "using ProgressBars\n",
    "\n",
    "\n",
    "\n",
    "using Burgers\n",
    "using FluxTraining\n",
    "# using Test\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "function my_get_data(file_path; n = 50000, Δsamples = 1, grid_size = div(51, Δsamples), T = Float32)\n",
    "# function my_get_data(file_path; n = 2048, Δsamples = 2^3, grid_size = div(2^13, Δsamples), T = Float32)\n",
    "    # file = matopen(joinpath(datadep\"Burgers\", \"burgers_data_R10.mat\"))\n",
    "    file = matopen(file_path)\n",
    "    \n",
    "    x_data = T.(collect(read(file, \"a\")[1:n, 1:Δsamples:end]'))\n",
    "    y_data = T.(collect(read(file, \"u\")[1:n, 1:Δsamples:end]'))\n",
    "    safe_labels = T.(collect(read(file, \"safe\")[1:n, 1:Δsamples:end]'))\n",
    "    pf_labels = T.(collect(read(file, \"pf\")[1:n, 1:Δsamples:end]'))\n",
    "    close(file)\n",
    "\n",
    "    x_loc_data = Array{T, 3}(undef, 2, grid_size, n)\n",
    "    x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 5, grid_size), n), (grid_size, n))\n",
    "    x_loc_data[2, :, :] .= x_data\n",
    "\n",
    "    return x_loc_data, reshape(y_data, 1, :, n), safe_labels, pf_labels\n",
    "end\n",
    "\n",
    "function my_get_dataloader(; ratio::Float64 = 0.9, batchsize = 100)\n",
    "    𝐱1, 𝐲1, safe1, pf1 = my_get_data(\"data_bcks_hyperbolic_1.mat\")\n",
    "    \n",
    "    data_train1, data_test1 = splitobs((𝐱1, 𝐲1, safe1, pf1), at = ratio)\n",
    "    𝐱2, 𝐲2, safe2, pf2 = my_get_data(\"data_ppo_hyperbolic_1.mat\")\n",
    "    \n",
    "    data_train2, data_test2 = splitobs((𝐱2, 𝐲2, safe2, pf2), at = ratio)\n",
    "    𝐱3, 𝐲3, safe3, pf3 = my_get_data(\"data_sac_hyperbolic_1.mat\")\n",
    "    \n",
    "    data_train3, data_test3 = splitobs((𝐱3, 𝐲3, safe3, pf3), at = ratio)\n",
    "\n",
    "    # @show size(data_train3[1]), size(data_test3[2])\n",
    "\n",
    "    data_train1_x_pf = data_train1[1][:,:,(data_train1[4][1,:].==1)]\n",
    "    data_test1_x_pf = data_test1[1][:,:,(data_test1[4][1,:].==1)]\n",
    "    data_train1_y_pf = data_train1[2][:,:,(data_train1[4][1,:].==1)]\n",
    "    data_test1_y_pf = data_test1[2][:,:,(data_test1[4][1,:].==1)]\n",
    "    data_train1_safe_pf = data_train1[3][:,(data_train1[4][1,:].==1)]\n",
    "    data_test1_safe_pf = data_test1[3][:,(data_test1[4][1,:].==1)]\n",
    "\n",
    "    data_train2_x_pf = data_train2[1][:,:,(data_train2[4][1,:].==1)]\n",
    "    data_test2_x_pf = data_test2[1][:,:,(data_test2[4][1,:].==1)]\n",
    "    data_train2_y_pf = data_train2[2][:,:,(data_train2[4][1,:].==1)]\n",
    "    data_test2_y_pf = data_test2[2][:,:,(data_test2[4][1,:].==1)]\n",
    "    data_train2_safe_pf = data_train2[3][:,(data_train2[4][1,:].==1)]\n",
    "    data_test2_safe_pf = data_test2[3][:,(data_test2[4][1,:].==1)]\n",
    "\n",
    "    data_train3_x_pf = data_train3[1][:,:,(data_train3[4][1,:].==1)]\n",
    "    data_test3_x_pf = data_test3[1][:,:,(data_test3[4][1,:].==1)]\n",
    "    data_train3_y_pf = data_train3[2][:,:,(data_train3[4][1,:].==1)]\n",
    "    data_test3_y_pf = data_test3[2][:,:,(data_test3[4][1,:].==1)]\n",
    "    data_train3_safe_pf = data_train3[3][:,(data_train3[4][1,:].==1)]\n",
    "    data_test3_safe_pf = data_test3[3][:,(data_test3[4][1,:].==1)]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    data_train = (cat(cat(data_train1_x_pf, data_train2_x_pf, dims=3), data_train3_x_pf, dims=3), \n",
    "                    cat(cat(data_train1_y_pf, data_train2_y_pf, dims=3), data_train3_y_pf, dims=3), \n",
    "                    cat(cat(data_train1_safe_pf, data_train2_safe_pf, dims=2), data_train3_safe_pf, dims=2)) # omit the last pf tumple\n",
    "    data_test = (cat(cat(data_test1_x_pf, data_test2_x_pf, dims=3), data_test3_x_pf, dims=3), \n",
    "                cat(cat(data_test1_y_pf, data_test2_y_pf, dims=3), data_test3_y_pf, dims=3), \n",
    "                cat(cat(data_test1_safe_pf, data_test2_safe_pf, dims=2), data_test3_safe_pf, dims=2)) # # omit the last pf tumple\n",
    "    loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)\n",
    "    loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)\n",
    "\n",
    "    return loader_train, loader_test\n",
    "end\n",
    "\n",
    "function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 500)\n",
    "    if cuda && CUDA.has_cuda()\n",
    "        device = gpu\n",
    "        CUDA.allowscalar(false)\n",
    "        @info \"Training on GPU\"\n",
    "    else\n",
    "        device = cpu\n",
    "        @info \"Training on CPU\"\n",
    "    end\n",
    "    @show 1\n",
    "    model = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                                  σ = gelu)\n",
    "    data = my_get_dataloader()\n",
    "    optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "    loss_func = l₂loss\n",
    "\n",
    "    learner = Learner(model, data, optimiser, loss_func,\n",
    "                      ToDevice(device, device))\n",
    "\n",
    "    fit!(learner, epochs)\n",
    "    model = learner.model |> cpu\n",
    "    @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "    return learner\n",
    "end\n",
    "\n",
    "using Flux, CUDA, BSON\n",
    "using Logging\n",
    "\n",
    "function loss_naive_safeset(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)\n",
    "    x = vcat(x[1,:,:]...)\n",
    "    x = reshape(x, (1, size(x)[1]))\n",
    "    y_init = vcat(y_init[1,:,:]...)\n",
    "    # y_init = y_init[1, :] # safe: 1; unsafe: 0\n",
    "    loss = relu((2 .* y_init .- 1) .* ϕ(x)[1, :] .+ 1e-6)\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function loss_regularization(ϕ::Chain, x::AbstractArray,y_init::AbstractArray)\n",
    "    x = vcat(x[1,:,:]...)\n",
    "    x = reshape(x, (1, size(x)[1]))\n",
    "    y_init = vcat(y_init[1,:,:]...)\n",
    "    # y_init = y_init[1, :] # safe: 1; unsafe: 0\n",
    "    loss = sigmoid_fast((2 .* y_init .- 1) .* ϕ(x)[1, :])\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function find_derivative(vector)\n",
    "    M, N = size(vector)[2], size(vector)[3]\n",
    "\n",
    "    # Assume `vector` is the (2, M, N) array\n",
    "    inputs = vector[1, :, :]  # Shape (M, N)\n",
    "    outputs = vector[2, :, :]  # Shape (M, N)\n",
    "\n",
    "    # Preallocate the derivative array with shape (1, M, N)\n",
    "    derivatives = zeros(Float64, 1, M, N)\n",
    "\n",
    "    # Central differences for the interior points (2 to M-1)\n",
    "    derivatives[1, 2:M-1, :] = (outputs[3:M, :] .- outputs[1:M-2, :]) ./ (inputs[3:M, :] .- inputs[1:M-2, :])\n",
    "\n",
    "    # Forward difference for the first point\n",
    "    derivatives[1, 1, :] = (outputs[2, :] .- outputs[1, :]) ./ (inputs[2, :] .- inputs[1, :])\n",
    "\n",
    "    # Backward difference for the last point\n",
    "    derivatives[1, M, :] = (outputs[M, :] .- outputs[M-1, :]) ./ (inputs[M, :] .- inputs[M-1, :])\n",
    "\n",
    "    # `derivatives` now contains the derivative of the output with respect to the input\n",
    "    # with shape (1, M, N)\n",
    "    return derivatives\n",
    "\n",
    "\n",
    "function loss_pf(ϕ::Chain, U::AbstractArray, Y::AbstractArray, model_NO::Chain, U̇::AbstractArray, α::Float32)\n",
    "    Y = vcat(x[1,:,:]...)\n",
    "    Y = reshape(Y, (1, size(Y)[1]))\n",
    "    y_init = vcat(y_init[1,:,:]...)\n",
    "    \n",
    "    state_dim, batchsize = size(Y) # 1*51000\n",
    "    # ẋ = dyn_model(x, u) # if support batchsize\n",
    "    U̇ = find_derivative(U)\n",
    "    U̇ = reshape(U̇, (state_dim, 1, batchsize))\n",
    "    # gradient(x -> sum(layer_output), x)[1]\n",
    "    _, ∇ϕ = Zygote.pullback(ϕ, Y)\n",
    "    ∇ϕ_Y = ∇ϕ(ones(size(Y)))[1] ./ state_dim\n",
    "    ∇ϕ_Y = reshape(∇ϕ_Y, (1, state_dim, batchsize))\n",
    "\n",
    "    _, ∇ϕ = Zygote.pullback(model_NO, U)\n",
    "    ∇ϕ_U = ∇ϕ(ones(size(Y)))[1] ./ state_dim\n",
    "    ∇ϕ_U = reshape(∇ϕ_U, (1, state_dim, batchsize))\n",
    "\n",
    "    ϕ̇ = reshape(batched_mul(∇ϕ_Y, batched_mul(∇ϕ_U, U̇)), size(ϕ(Y)))\n",
    "    l = ϕ̇ .+ α .* ϕ(Y) .+ \n",
    "    loss = relu(l .+ 1e-6)\n",
    "    return sum(loss) / size(loss)[end]\n",
    "end\n",
    "\n",
    "function my_train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 500)\n",
    "    if cuda && CUDA.has_cuda()\n",
    "        device = gpu\n",
    "        CUDA.allowscalar(false)\n",
    "        @info \"Training on GPU\"\n",
    "    else\n",
    "        device = cpu\n",
    "        @info \"Training on CPU\"\n",
    "    end\n",
    "    @show 1\n",
    "    lr_NO = η₀\n",
    "    lr_CBF = 0.001\n",
    "\n",
    "    loader_train, loader_test = my_get_dataloader()\n",
    "    model_NO = FourierNeuralOperator(ch = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,), \n",
    "                                  σ = gelu)\n",
    "    model_CBF = Chain(\n",
    "            Dense(1 => 16, relu),   # activation function inside layer\n",
    "            Dense(16 => 64, relu),   # activation function inside layer\n",
    "            Dense(64 => 16, relu),   # activation function inside layer\n",
    "            Dense(16 => 1)\n",
    "        )\n",
    "    # optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))\n",
    "    optim_NO = Flux.setup(Flux.Optimise.AdamW(η₀, (0.9, 0.999), λ), model_NO)\n",
    "    optim_CBF = Flux.setup(Flux.Optimise.NADAM(lr_CBF, (0.9, 0.999), 0.1), model_CBF)\n",
    "    # sched = ParameterSchedulers.Stateful(Step(ini_lr, lr_decay_rate, lr_decay_epoch)) # setup schedule of your choice\n",
    "\n",
    "    \n",
    "    loss_func = l₂loss\n",
    "\n",
    "    training_losses = []\n",
    "    test_losses = []\n",
    "    least_loss = 1000\n",
    "    test_loss = 0\n",
    "    loss = 0\n",
    "    for epoch in ProgressBar(1:total_epoch)\n",
    "        training_loss_epcoh = []\n",
    "        test_loss_epcoh = []\n",
    "        ∇l_lambda = 0\n",
    "        ∇l_alpha = 0\n",
    "        for item in train_loader\n",
    "            ∇l_alpha = 0\n",
    "            ∇l_lambda = 0\n",
    "            # x_batch = reduce(hcat,item[1,:])\n",
    "            # u_batch = reduce(hcat,item[2,:])\n",
    "            # y_init_batch = reduce(hcat,item[3,:])\n",
    "            x_batch = item[1]\n",
    "            y_batch = item[2]\n",
    "            safe_batch = item[3]\n",
    "            @show size(x_batch), size(y_batch), size(safe_batch) # (2,51, bs), (1, 51, bs), (51, bs)\n",
    "\n",
    "            # train NO\n",
    "            NO_training_loss, NO_grads = Flux.withgradient(model_NO) do m \n",
    "                l₂loss(m(x_batch), y_batch)\n",
    "            end\n",
    "            Flux.update!(optim_NO, model_NO, NO_grads[1])\n",
    "\n",
    "            # train CBF\n",
    "            CBF_training_loss, CBF_grads = Flux.withgradient(model_CBF) do m \n",
    "                loss_naive_safeset(m, y_batch, safe_batch) + λ .* loss_pf(m, x_batch, y_batch, model_NO, α) + μ .* loss_regularization(m, y_batch, safe_batch)\n",
    "            end\n",
    "            Flux.update!(optim_CBF, model_CBF, CBF_grads[1])\n",
    "\n",
    "            loss = loss_naive_safeset(model, x_batch, y_init_batch) + λ .* loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ) + μ .* loss_regularization(model, x_batch, y_init_batch)\n",
    "    #         @show size((2 .* y_init_batch[1, :] .- 1)), size(model(x_batch)), size(((2 .* y_init_batch[1, :] .- 1) .* model(x_batch)))\n",
    "    #         @show loss,loss_naive_safeset(model, x_batch, y_init_batch), loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ), loss_regularization(model, x_batch, y_init_batch)\n",
    "            push!(training_loss_epcoh, loss)  # logging, outside gradient context\n",
    "            \n",
    "            # @show training_loss\n",
    "        end\n",
    "        for item in test_loader\n",
    "            x_batch = reduce(hcat,item[1,:])\n",
    "            u_batch = reduce(hcat,item[2,:])\n",
    "            y_init_batch = reduce(hcat,item[3,:])\n",
    "            # y_cbf_batch = reduce(hcat,item[4,:])\n",
    "            A = []\n",
    "            B = []\n",
    "            Δ = []\n",
    "            for i in 1:size(x_batch, 2)\n",
    "                z = RD.KnotPoint(x_batch[:, i],u_batch[:, i],0.0,1e-3 ) \n",
    "                ∇f = zeros(n, n + m)\n",
    "                RD.jacobian!(RD.StaticReturn(), RD.ForwardAD(), dyn_model, ∇f, zeros(n), z)\n",
    "                A_ = ∇f[:, 1:n]\n",
    "                B_ = ∇f[:, n+1:end]\n",
    "                Δ_ = RobotDynamics.dynamics(dyn_model, x_batch[:, i] .- eps, u_batch[:, i].-eps) - A_ * (x_batch[:, i].-eps) - B_ * (u_batch[:, i] .- eps)\n",
    "                push!(A, A_)\n",
    "                push!(B, B_)\n",
    "                push!(Δ, Δ_)\n",
    "            end\n",
    "            A = cat(A..., dims=3)\n",
    "            B = cat(B..., dims=3)\n",
    "            Δ = cat(Δ..., dims=2)\n",
    "\n",
    "            test_loss =  loss_naive_safeset(model, x_batch, y_init_batch) + λ .* loss_naive_fi(model, A, x_batch, B, u_batch,y_init_batch;use_pgd=use_pgd, α=α,Δ=Δ) + μ .* loss_regularization(model, x_batch, y_init_batch)\n",
    "            # @show loss_safe_set_ce(model, x_batch, y_init_batch), loss_forward_invariance_ce(model, A, x_batch, B, u_batch,y_cbf_batch; α=0.01)\n",
    "            push!(test_loss_epcoh, test_loss)\n",
    "        end\n",
    "        # nextlr = ParameterSchedulers.next!(sched) # advance schedule\n",
    "        # Optimisers.adjust!(optim, nextlr) # update optimizer state, by default this changes the learning rate `eta`\n",
    "        # lr_lambda = lr_lambda * lr_decay_rate^(floor(epoch / lr_decay_epoch))\n",
    "        @show epoch, loss, test_loss\n",
    "        model_state = Flux.state(model)\n",
    "        # jldsave(\"car_wd0.0001_naive_model_1_0_0.1_pgd_relu_$epoch.jld2\"; model_state)\n",
    "        @save \"model/hyper_NO_$epoch.bson\" model_NO\n",
    "        @save \"model/hyper_CBF_$epoch.bson\" model_CBF\n",
    "        push!(training_losses, sum(training_loss_epcoh)) \n",
    "        push!(test_losses, sum(test_loss_epcoh))\n",
    "\n",
    "    end\n",
    "\n",
    "\n",
    "\n",
    "    # learner = Learner(model, data, optimiser, loss_func,\n",
    "    #                   ToDevice(device, device))\n",
    "\n",
    "    # fit!(learner, epochs)\n",
    "    model = learner.model |> cpu\n",
    "    @save \"model/hyper_FNO_all_pf.bson\" model\n",
    "\n",
    "    return learner\n",
    "end\n",
    "\n",
    "function train_nomad(; n = 50000, cuda = true, learning_rate = 0.001, epochs = 400)\n",
    "    if cuda && has_cuda()\n",
    "        @info \"Training on GPU\"\n",
    "        device = gpu\n",
    "    else\n",
    "        @info \"Training on CPU\"\n",
    "        device = cpu\n",
    "    end\n",
    "\n",
    "    x, y = get_data_don()\n",
    "\n",
    "    # 50000, 0.9\n",
    "    xtrain = x[1:45000, :]'\n",
    "    ytrain = y[1:45000, :]\n",
    "\n",
    "    xval = x[45001:end, :]' |> device\n",
    "    yval = y[45001:end, :] |> device\n",
    "\n",
    "    # grid = collect(range(0, 1, length=1024)') |> device\n",
    "    grid = rand(collect(0:0.02:1), (45000, 51)) |> device\n",
    "    gridval = rand(collect(0:0.02:1), (5000, 51)) |> device\n",
    "\n",
    "    opt = Adam(learning_rate)\n",
    "\n",
    "    m = NOMAD((51, 51), (102, 51), gelu, gelu) |> device\n",
    "\n",
    "    loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)\n",
    "    evalcb() = @show(loss(xval, yval, gridval))\n",
    "\n",
    "    data = [(xtrain, ytrain, grid)] |> device\n",
    "    Flux.@epochs epochs Flux.train!(loss, params(m), data, opt, cb = evalcb)\n",
    "    ỹ = m(xval |> device, gridval |> device)\n",
    "    @save \"model/hyper_NOMAD_bcks.bson\" m\n",
    "    diffvec = vec(abs.(cpu(yval) .- cpu(ỹ)))\n",
    "    mean_diff = sum(diffvec) / length(diffvec)\n",
    "    return mean_diff\n",
    "end\n",
    "\n",
    "function get_data_don(; n = 50000, Δsamples = 1, grid_size = div(51, Δsamples), T = Float32)\n",
    "# function get_data_don(; n = 2048, Δsamples = 2^3, grid_size = div(2^13, Δsamples))\n",
    "    # file = matopen(joinpath(datadep\"Burgers\", \"burgers_data_R10.mat\"))\n",
    "    # file = matopen(\"/Users/james/Hanjiang/burgers_data_R10.mat\")\n",
    "    file = matopen(\"/Users/james/Hanjiang/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "    x_data = collect(read(file, \"a\")[1:n, 1:Δsamples:end])\n",
    "    y_data = collect(read(file, \"u\")[1:n, 1:Δsamples:end])\n",
    "    close(file)\n",
    "\n",
    "    return x_data, y_data\n",
    "end\n",
    "\n",
    "function train_don(; n = 50000, cuda = true, learning_rate = 0.001, epochs = 400)\n",
    "    if cuda && has_cuda()\n",
    "        @info \"Training on GPU\"\n",
    "        device = gpu\n",
    "    else\n",
    "        @info \"Training on CPU\"\n",
    "        device = cpu\n",
    "    end\n",
    "\n",
    "    x, y = get_data_don()\n",
    "\n",
    "    # xtrain = x[1:280, :]'\n",
    "    # ytrain = y[1:280, :]\n",
    "\n",
    "    # xval = x[(end - 19):end, :]' |> device\n",
    "    # yval = y[(end - 19):end, :] |> device\n",
    "\n",
    "    # grid = collect(range(0, 1, length = 1024)') |> device\n",
    "\n",
    "    xtrain = x[1:45000, :]'\n",
    "    ytrain = y[1:45000, :]\n",
    "\n",
    "    xval = x[45001:end, :]' |> device\n",
    "    yval = y[45001:end, :] |> device\n",
    "\n",
    "    grid = collect(range(0, 1, length=51)') |> device\n",
    "    # grid = rand(collect(0:0.02:1), (45000, 51)) |> device\n",
    "    # gridval = rand(collect(0:0.02:1), (5000, 51)) |> device\n",
    "\n",
    "    opt = Adam(learning_rate)\n",
    "\n",
    "    # m = DeepONet((1024, 1024, 1024), (1, 1024, 1024), gelu, gelu) |> device\n",
    "    m = DeepONet((51, 51, 51), (1, 51, 51), gelu, gelu) |> device\n",
    "\n",
    "    loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)\n",
    "    evalcb() = @show(loss(xval, yval, grid))\n",
    "\n",
    "    data = [(xtrain, ytrain, grid)] |> device\n",
    "    Flux.@epochs epochs Flux.train!(loss, params(m), data, opt, cb = evalcb)\n",
    "    ỹ = m(xval |> device, grid |> device)\n",
    "\n",
    "    diffvec = vec(abs.(cpu(yval) .- cpu(ỹ)))\n",
    "    mean_diff = sum(diffvec) / length(diffvec)\n",
    "    @save \"model/hyper_DON_bcks.bson\" m\n",
    "    return mean_diff\n",
    "end\n",
    "\n",
    "train(epochs=100) # 0.005\n",
    "# train_nomad(epochs=500) # 1.76\n",
    "# train_don(epochs=500) # 2.56\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "976968bf-9615-4751-93bc-5b41c5583ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "using BSON\n",
    "using DataDeps, MAT, MLUtils\n",
    "\n",
    "\n",
    "\n",
    "T = Float32\n",
    "file = matopen(\"/Users/james/Hanjiang/data_sac_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if y_data[i,j] > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if y_data[i,end] > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_sac_hyperbolic_1.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11b1700d-f3ee-4633-8e3f-3e45301d87f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# BUILD CONTROL SIGNAL PLOTS \n",
    "fig = plt.figure(figsize=set_size(433, 0.99, (1, 2), height_add=1))\n",
    "subfigs = fig.subfigures(nrows=1, ncols=1, hspace=0)\n",
    "\n",
    "subfig = subfigs\n",
    "subfig.suptitle(r\"Control Signals for $u(0, x)=1$ and $u(0, x)=10$\")\n",
    "subfig.subplots_adjust(left=0.1, bottom=0.2, right=.98, top=0.86, wspace=0.25, hspace=0.1)\n",
    "X = 1\n",
    "dx = 1e-2\n",
    "T = 10\n",
    "spatial = np.linspace(dx, X, int(round(X/dx)))\n",
    "temporal = np.linspace(0, T, len(uPPOOne))\n",
    "ax = subfig.subplots(nrows=1, ncols=2)\n",
    "l2, = ax[0].plot(temporal, uSACOne.transpose()[-1], label=\"SAC\", linestyle=linestyle_tuple[2][1], color=\"green\")\n",
    "l1, = ax[0].plot(temporal, uPPOOne.transpose()[-1], label=\"PPO\", linestyle=linestyle_tuple[2][1], color=\"orange\")\n",
    "l3, = ax[0].plot(temporal, uBcksOne.transpose()[-1], label=\"Backstepping\", color=\"#0096FF\")\n",
    "ax[0].set_xlabel(\"Time\")\n",
    "ax[0].set_ylabel(R\"$U(t)$\", labelpad=-2)\n",
    "\n",
    "l2, = ax[1].plot(temporal, uSACTen.transpose()[-1], label=\"SAC\", linestyle=linestyle_tuple[2][1], color=\"green\")\n",
    "l1, = ax[1].plot(temporal, uPPOTen.transpose()[-1], label=\"PPO\", linestyle=linestyle_tuple[2][1], color=\"orange\")\n",
    "l3, = ax[1].plot(temporal, uBcksTen.transpose()[-1], label=\"Backstepping\", color=\"#0096FF\")\n",
    "ax[1].set_xlabel(\"Time\")\n",
    "ax[1].set_ylabel(r\"$U(t)$\", labelpad=-2)\n",
    "plt.legend([l1, l2, l3], [\"PPO\", \"SAC\", \"Backstepping\"], loc=\"lower left\", bbox_to_anchor=[.56,.86], reverse=True)\n",
    "plt.legend(handletextpad=0.1)\n",
    "\n",
    "plt.show()\n",
    "#plt.savefig(\"hyperbolicControlSignals.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03599419-4ba8-4837-9bfd-b043aed40ff6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PDE L2 Error\n",
    "def getPDEl2(u, uhat):\n",
    "    nt = len(u)\n",
    "    nx = len(u[0])\n",
    "    pdeError = np.zeros(nt-1)\n",
    "    for i in range(1, nt):\n",
    "        error = 0\n",
    "        for j in range(nx):\n",
    "            error += (u[i][j] - uhat[i][j])**2\n",
    "        error = np.sqrt(error)\n",
    "        pdeError[i-1] = error\n",
    "    return pdeError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00159f4a-0521-4965-86f8-a315accbe57c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Share Rewards and L2 Norms for each problem\n",
    "print((\"InitialCondition\\tModel Trained\\tHyperbolic1D Rewards\\tHyperbolic1DTotalL2Norm\").expandtabs(30))\n",
    "print((\"u(x, 0)=1\\tBackstepping\\t\" + str(rewBcksOne) +\"\\t\" + str(sum(getPDEl2(uBcksOne, np.zeros(uBcksOne.shape))))).expandtabs(30))\n",
    "print((\"u(x, 0)=1\\tPPO\\t\" + str(rewPPOOne) +\"\\t\" + str(sum(getPDEl2(uPPOOne, np.zeros(uBcksOne.shape))))).expandtabs(30))\n",
    "print((\"u(x, 0)=1\\tSAC\\t\" + str(rewSACOne) +\"\\t\" + str(sum(getPDEl2(uSACOne, np.zeros(uBcksOne.shape))))).expandtabs(30))\n",
    "print((\"u(x, 0)=10\\tBackstepping\\t\" + str(rewBcksTen) +\"\\t\" + str(sum(getPDEl2(uBcksTen, np.zeros(uBcksOne.shape))))).expandtabs(30))\n",
    "print((\"u(x, 0)=10\\tPPO\\t\" + str(rewPPOTen) +\"\\t\" + str(sum(getPDEl2(uPPOTen, np.zeros(uBcksOne.shape))))).expandtabs(30))\n",
    "print((\"u(x, 0)=10\\tSAC\\t\" + str(rewSACTen) +\"\\t\" + str(sum(getPDEl2(uSACTen, np.zeros(uBcksOne.shape))))).expandtabs(30))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pdecontrol",
   "language": "python",
   "name": "pdecontrol"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
