{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "a07f6eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "from sklearn.gaussian_process import GaussianProcessRegressor\n",
    "from sklearn.gaussian_process.kernels import RBF, WhiteKernel\n",
    "import warnings\n",
    "from sklearn.exceptions import ConvergenceWarning\n",
    "import math\n",
    "import random\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "303957d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# ── Hyperparameters ───────────────────────────────────────────────────────────\n",
    "ENV_ID       = \"CartPole-v1\"\n",
    "N_EPISODES   = 160       # total episodes N\n",
    "N_INTRINSIC  = 40    # intrinsic‐exploration episodes n*\n",
    "EPISODE_LEN  = 200       # max steps per episode\n",
    "COST_LIMIT   = 0.3       # allowable average cost per episode\n",
    "RAND_SEARCH  = 100       # candidates per inner solve\n",
    "SEARCH_STD   = 0.5       # neighbor std for random search\n",
    "SEED         = 1\n",
    "num_traj=2\n",
    "Position_safe_lb=-1.9\n",
    "Position_safe_ub=1.9\n",
    "Pole_Angle_safe_ld= -0.15\n",
    "Pole_Angle_safe_ud=   0.15\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "Position_lb=-4.8\n",
    "Position_ub=4.8\n",
    "\n",
    "Velocity_lb=-10000\n",
    "Velocity_ub=10000\n",
    "Pole_Angle_ld= -0.418 \n",
    "Pole_Angle_ud= 0.418 \n",
    "\n",
    "\n",
    "np.random.seed(SEED)\n",
    "def cal_cost(state,action,next_state):\n",
    "    cost1=0\n",
    "    cost2=0\n",
    "    if Position_safe_lb >=next_state[0]:\n",
    "        cost1=Position_safe_lb-next_state[0]\n",
    "    if Position_safe_ub <=next_state[0]:\n",
    "        cost1=next_state[0]-Position_safe_ub\n",
    "    if next_state[2]<=Pole_Angle_safe_ld:\n",
    "        cost2=Pole_Angle_safe_ld-next_state[2]\n",
    "    if next_state[2]>=Pole_Angle_safe_ud:\n",
    "        cost2=next_state[2]-Pole_Angle_safe_ud\n",
    "    return np.maximum(cost1,10*cost2)\n",
    "       \n",
    "# ── Dynamics Model via GP ────────────────────────────────────────────────────\n",
    "# class GPModel:\n",
    "#     def __init__(self):\n",
    "#         # RBF + white noise kernel\n",
    "#         kernel = RBF(length_scale=1.0) + WhiteKernel(noise_level=1e-3)\n",
    "#         # one GP per state‐dimension\n",
    "#         self.gps = [GaussianProcessRegressor(kernel) for _ in range(4)]\n",
    "#         self.X = []   # (s,a) pairs\n",
    "#         self.Y = []   # s' targets\n",
    "\n",
    "#     def update(self, transitions):\n",
    "#         # transitions: list of (s,a,s')\n",
    "#         X_new = [np.hstack((s, [a])) for (s,a,s1,_,_) in transitions]\n",
    "#         Y_new = [s1 for (s,a,s1,_,_) in transitions]\n",
    "#         self.X += X_new\n",
    "#         self.Y += Y_new\n",
    "#         X = np.array(self.X)\n",
    "#         Y = np.array(self.Y)\n",
    "#         with warnings.catch_warnings():\n",
    "#             warnings.simplefilter(\"ignore\", category=ConvergenceWarning)\n",
    "#             for i, gp in enumerate(self.gps):\n",
    "#                 gp.fit(X, Y[:, i])\n",
    "\n",
    "#     def predict(self, s, a):\n",
    "#         x = np.hstack((s, [a]))[None, :]\n",
    "#         mus, sigs = [], []\n",
    "#         for gp in self.gps:\n",
    "#             mu, sigma = gp.predict(x, return_std=True)\n",
    "#             mus.append(mu.item())\n",
    "#             sigs.append(sigma.item())\n",
    "#         return np.array(mus), np.array(sigs)\n",
    "\n",
    "    \n",
    "def gaussian_nll_loss(y_pred_mean, y_pred_log_var, y_true):\n",
    "    # NLL for a Gaussian: log(sigma^2) + (y - mu)^2 / sigma^2\n",
    "    loss = 0.5 * (y_pred_log_var + ((y_true - y_pred_mean) ** 2) / torch.exp(y_pred_log_var))\n",
    "    return loss.mean()\n",
    "\n",
    "# def gaussian_nll(mu, var, y):\n",
    "#     # mu, var, y: all (B, nS)\n",
    "#     return 0.5 * (torch.log(var) + (y - mu)**2 / var).sum(dim=-1)\n",
    "\n",
    "# class GaussianTransitionModel(nn.Module):\n",
    "#     def __init__(self, n_states, n_actions, hidden_size=128):\n",
    "#         super().__init__()\n",
    "#         self.backbone = nn.Sequential(\n",
    "#             nn.Linear(n_states + n_actions, hidden_size),\n",
    "#             nn.ReLU(),\n",
    "#             nn.Linear(hidden_size, hidden_size),\n",
    "#             nn.ReLU(),\n",
    "#         )\n",
    "#         # two heads: one for mean, one for log‐variance\n",
    "#         self.fc_mu      = nn.Linear(hidden_size, n_states)\n",
    "#         self.fc_logvar  = nn.Linear(hidden_size, n_states)\n",
    "        \n",
    "#     def forward(self, s_onehot, a_onehot):\n",
    "#         x    = torch.cat([s_onehot, a_onehot], dim=-1)\n",
    "#         h    = self.backbone(x)\n",
    "#         mu   = self.fc_mu(h)             # (B, n_states)\n",
    "#         logv = self.fc_logvar(h)         # (B, n_states)\n",
    "#         # ensure positivity of variance\n",
    "#         var  = torch.exp(logv)           # (B, n_states)\n",
    "#         return mu, var\n",
    "        \n",
    "#     def update(self,trans):\n",
    "#         optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
    "#         for i in range(len(trans)):\n",
    "#             optimizer.zero_grad()\n",
    "#             (s, a, s1, r, cost)=trans[i]\n",
    "#             data=np.concatenate((s, a.reshape(1, )), axis=0)\n",
    "#             data_tensor=torch.from_numpy(data).float().unsqueeze(0).to(device)\n",
    "#             labels=np.array(s1).reshape(1,)\n",
    "#             labels_tensor=torch.from_numpy(labels).float().unsqueeze(0).to(device)\n",
    "#             mean, log_var = self.forward(data_tensor)\n",
    "#             loss = gaussian_nll(mean, log_var, labels_tensor)\n",
    "#             loss.backward()\n",
    "#             optimizer.step()\n",
    "\n",
    "\n",
    "class ProbabilisticRegressor(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim):\n",
    "        super().__init__()\n",
    "        self.shared = nn.Sequential(\n",
    "            nn.Linear(input_dim, hidden_dim),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.mean_head = nn.Linear(hidden_dim, 4)\n",
    "        self.log_var_head = nn.Linear(hidden_dim, 4)  # log variance for stability\n",
    "\n",
    "    def forward(self, x):\n",
    "        h = self.shared(x)\n",
    "        mean = self.mean_head(h)\n",
    "        log_var = self.log_var_head(h)\n",
    "        return mean, log_var  # we return log variance to keep it numerically stable\n",
    "    \n",
    "    def update(self,trans):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
    "        for i in range(len(trans)):\n",
    "            optimizer.zero_grad()\n",
    "            (s, a, s1, r, cost)=trans[i]\n",
    "            data=np.concatenate((s, np.array(a).reshape(1, )), axis=0)\n",
    "            data_tensor=torch.from_numpy(data).float().unsqueeze(0).to(device)\n",
    "            labels_tensor=torch.from_numpy(s1).float().unsqueeze(0).to(device)\n",
    "            mean, log_var = self.forward(data_tensor)\n",
    "            loss = gaussian_nll_loss(mean, log_var, labels_tensor)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "# ── Policy: theta ∈ R⁴ → action ∈ {0,1} ───────────────────────────────────────────\n",
    "def policy(s, theta):\n",
    "    z = theta.dot(s)\n",
    "    return 1 if np.tanh(z) > 0 else 0\n",
    "\n",
    "# ── Evaluate objectives via GP rollout ────────────────────────────────────────\n",
    "def eval_episode_model(model, theta, intrinsic):\n",
    "    s = np.zeros(4)  # start at origin (approx)\n",
    "    J = 0.0\n",
    "    C = 0.0\n",
    "    for _ in range(EPISODE_LEN):\n",
    "        a = policy(s, theta)\n",
    "        X=np.concatenate((s, np.array(a).reshape(1, )), axis=0)\n",
    "        with torch.no_grad():\n",
    "            X_tensor = torch.from_numpy(X).float().unsqueeze(0).to(device)\n",
    "        mean, log_var = model(X_tensor)\n",
    "        # mu, sigma = model.predict(s, a)\n",
    "        # intrinsic reward = sum of σ\n",
    "        if intrinsic:\n",
    "            J += np.exp(log_var.detach().numpy()[0]).sum()\n",
    "        else:\n",
    "            # extrinsic reward (CartPole): +1 per step\n",
    "            state=mean.detach().numpy()[0]\n",
    "            if s[0]>=Position_lb and s[0]<=Position_ub and  s[2]>=Pole_Angle_ld and s[2]<=Pole_Angle_ud:\n",
    "                J += 1.0\n",
    "        cost=cal_cost(s,a,mean.detach().numpy()[0])\n",
    "        C += cost\n",
    "        s = mean.detach().numpy()[0]  # use GP‐predicted next state\n",
    "    # return J / EPISODE_LEN, C / EPISODE_LEN\n",
    "    return J , C / EPISODE_LEN\n",
    "\n",
    "\n",
    "def objective_over_ensemble(ensemble, theta, intrinsic):\n",
    "    Js = []\n",
    "    Cs = []\n",
    "    for f in ensemble:\n",
    "        for i in range(num_traj):\n",
    "            Jf, Cf = eval_episode_model(f, theta, intrinsic)\n",
    "            Js.append(Jf)\n",
    "            Cs.append(Cf)\n",
    "    # max_{f ∈ Qₙ} J  and  max_{f ∈ Qₙ} C\n",
    "    J_max = np.max(Js)\n",
    "    C_max = np.max(Cs)\n",
    "    return J_max, C_max\n",
    "\n",
    "\n",
    "# ── Simple random‐search planner as a stand‐in for LBSGD ───────────────────────\n",
    "def solve_barrier(ensemble, theta0, intrinsic, lam=10000.0):\n",
    "    \"\"\"\n",
    "    Maximize J - λ * max(0, C - COST_LIMIT) via neighborhood random search.\n",
    "    \"\"\"\n",
    "    best = theta0.copy()\n",
    "    Jb, Cb = objective_over_ensemble(ensemble, best, intrinsic)\n",
    "    score_b = Jb - lam * max(0, Cb - COST_LIMIT)\n",
    "    for _ in range(RAND_SEARCH):\n",
    "        theta_cand = best + np.random.randn(*best.shape) * SEARCH_STD\n",
    "        Jc, Cc = objective_over_ensemble(ensemble, theta_cand, intrinsic)\n",
    "        score_c = Jc - lam * max(0, Cc - COST_LIMIT)\n",
    "        if score_c > score_b:\n",
    "            best, score_b = theta_cand, score_c\n",
    "    return best\n",
    "\n",
    "# ── Roll out in the real environment to collect data ───────────────────────────\n",
    "def rollout_env(env, theta):\n",
    "    \"\"\"\n",
    "    Execute one episode in the *real* CartPole and collect transitions.\n",
    "    \"\"\"\n",
    "    transitions = []\n",
    "    # reset may also return (obs, info) in Gymnasium—just grab the first item\n",
    "    s = env.reset()\n",
    "    if isinstance(s, tuple):\n",
    "        s = s[0]\n",
    "\n",
    "    total_reward = 0.0\n",
    "    total_cost   = 0.0\n",
    "    num_vio=0\n",
    "    for t in range(EPISODE_LEN):\n",
    "        a = policy(s, theta)\n",
    "\n",
    "        # unpack the 5‐tuple: obs, reward, terminated, truncated, info\n",
    "        s1, r, terminated, truncated, info = env.step(a)\n",
    "        done = terminated or truncated\n",
    "\n",
    "        cost = cal_cost(s,a,s1)\n",
    "        transitions.append((s, a, s1, r, cost))\n",
    "\n",
    "        total_reward += r\n",
    "        total_cost   += cost\n",
    "        s = s1\n",
    "        if cost>COST_LIMIT:\n",
    "            num_vio+=1\n",
    "        if done:\n",
    "            break\n",
    "\n",
    "    avg_cost = total_cost / (t + 1)\n",
    "    # return transitions, total_reward, avg_cost\n",
    "    return transitions, total_reward, avg_cost,num_vio\n",
    "\n",
    "# ── Main two‐phase ACTSAFE loop for CartPole ─────────────────────────────────\n",
    "# def run():\n",
    "#     env = gym.make(ENV_ID)\n",
    "#     model =  ProbabilisticRegressor(input_dim=5, hidden_dim=128)\n",
    "#     ensemble = [ model ]\n",
    "#     theta = np.zeros(4)  # initial policy params\n",
    "\n",
    "#     # 1) Intrinsic‐exploration phase\n",
    "#     for ep in range(1, N_INTRINSIC + 1):\n",
    "#         # a) Solve Eq (7): maximize intrinsic bonus under cost limit\n",
    "#         theta = solve_barrier(ensemble, theta, intrinsic=True)\n",
    "#         # b) Collect real transitions\n",
    "#         trans, R, C = rollout_env(env, theta)\n",
    "\n",
    "\n",
    "#         model.update(trans)\n",
    "#         print(f\"[I⋅{ep}] R={R:.2f}, C={C:.2f}\")\n",
    "\n",
    "#     # 2) Extrinsic‐exploitation phase\n",
    "#     for ep in range(N_INTRINSIC + 1, N_EPISODES + 1):\n",
    "#         # a) Solve Eq (8): maximize extrinsic reward under cost limit\n",
    "#         theta = solve_barrier(ensemble, theta, intrinsic=False)\n",
    "#         trans, R, C = rollout_env(env, theta)\n",
    "#         model.update(trans)\n",
    "#         print(f\"[E⋅{ep}] R={R:.2f}, C={C:.2f}\")\n",
    "\n",
    "    # env.close()\n",
    "\n",
    "# if __name__ == \"__main__\":\n",
    "#     main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "b89de9b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# s=np.zeros(4)\n",
    "# a=int(1)\n",
    "# X=np.concatenate((s, np.array(a).reshape(1, )), axis=0)\n",
    "\n",
    "# X_tensor = torch.from_numpy(X).float().unsqueeze(0).to(device)\n",
    "# model =  ProbabilisticRegressor(input_dim=5, hidden_dim=128)\n",
    "# mean, log_var = model(X_tensor)\n",
    "# np.exp(log_var.detach().numpy())[0][0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "d62b774b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[I⋅1] R=9.00, C=0.14\n",
      "[I⋅2] R=38.00, C=0.04\n",
      "[I⋅3] R=102.00, C=0.04\n",
      "[I⋅4] R=88.00, C=0.04\n",
      "[I⋅5] R=78.00, C=0.04\n",
      "[I⋅6] R=200.00, C=0.00\n",
      "[I⋅7] R=200.00, C=0.00\n",
      "[I⋅8] R=200.00, C=0.00\n",
      "[I⋅9] R=90.00, C=0.03\n",
      "[I⋅10] R=200.00, C=0.01\n",
      "[I⋅11] R=49.00, C=0.04\n",
      "[I⋅12] R=110.00, C=0.03\n",
      "[I⋅13] R=200.00, C=0.00\n",
      "[I⋅14] R=23.00, C=0.09\n",
      "[I⋅15] R=9.00, C=0.20\n",
      "[I⋅16] R=10.00, C=0.18\n",
      "[I⋅17] R=9.00, C=0.20\n",
      "[I⋅18] R=9.00, C=0.11\n",
      "[I⋅19] R=9.00, C=0.14\n",
      "[I⋅20] R=44.00, C=0.06\n",
      "[I⋅21] R=200.00, C=0.00\n",
      "[I⋅22] R=200.00, C=0.00\n",
      "[I⋅23] R=68.00, C=0.05\n",
      "[I⋅24] R=61.00, C=0.04\n",
      "[I⋅25] R=58.00, C=0.04\n",
      "[I⋅26] R=65.00, C=0.04\n",
      "[I⋅27] R=200.00, C=0.00\n",
      "[I⋅28] R=177.00, C=0.02\n",
      "[I⋅29] R=37.00, C=0.04\n",
      "[I⋅30] R=10.00, C=0.17\n",
      "[I⋅31] R=10.00, C=0.09\n",
      "[I⋅32] R=80.00, C=0.04\n",
      "[I⋅33] R=120.00, C=0.02\n",
      "[I⋅34] R=200.00, C=0.00\n",
      "[I⋅35] R=200.00, C=0.07\n",
      "[I⋅36] R=200.00, C=0.00\n",
      "[I⋅37] R=200.00, C=0.00\n",
      "[I⋅38] R=107.00, C=0.04\n",
      "[I⋅39] R=95.00, C=0.04\n",
      "[I⋅40] R=200.00, C=0.00\n",
      "[E⋅41] R=100.00, C=0.03\n",
      "[E⋅42] R=200.00, C=0.00\n",
      "[E⋅43] R=74.00, C=0.03\n",
      "[E⋅44] R=200.00, C=0.00\n",
      "[E⋅45] R=95.00, C=0.03\n",
      "[E⋅46] R=200.00, C=0.00\n",
      "[E⋅47] R=200.00, C=0.00\n",
      "[E⋅48] R=54.00, C=0.05\n",
      "[E⋅49] R=45.00, C=0.06\n",
      "[E⋅50] R=197.00, C=0.01\n",
      "[E⋅51] R=137.00, C=0.02\n",
      "[E⋅52] R=200.00, C=0.00\n",
      "[E⋅53] R=51.00, C=0.05\n",
      "[E⋅54] R=52.00, C=0.05\n",
      "[E⋅55] R=38.00, C=0.06\n",
      "[E⋅56] R=57.00, C=0.03\n",
      "[E⋅57] R=29.00, C=0.07\n",
      "[E⋅58] R=46.00, C=0.04\n",
      "[E⋅59] R=82.00, C=0.03\n",
      "[E⋅60] R=48.00, C=0.05\n",
      "[E⋅61] R=56.00, C=0.04\n",
      "[E⋅62] R=53.00, C=0.04\n",
      "[E⋅63] R=34.00, C=0.06\n",
      "[E⋅64] R=35.00, C=0.08\n",
      "[E⋅65] R=64.00, C=0.04\n",
      "[E⋅66] R=143.00, C=0.02\n",
      "[E⋅67] R=47.00, C=0.04\n",
      "[E⋅68] R=152.00, C=0.01\n",
      "[E⋅69] R=50.00, C=0.05\n",
      "[E⋅70] R=44.00, C=0.04\n",
      "[E⋅71] R=51.00, C=0.05\n",
      "[E⋅72] R=160.00, C=0.02\n",
      "[E⋅73] R=53.00, C=0.04\n",
      "[E⋅74] R=85.00, C=0.03\n",
      "[E⋅75] R=50.00, C=0.04\n",
      "[E⋅76] R=200.00, C=0.00\n",
      "[E⋅77] R=144.00, C=0.02\n",
      "[E⋅78] R=43.00, C=0.05\n",
      "[E⋅79] R=54.00, C=0.04\n",
      "[E⋅80] R=60.00, C=0.04\n",
      "[E⋅81] R=56.00, C=0.04\n",
      "[E⋅82] R=42.00, C=0.05\n",
      "[E⋅83] R=200.00, C=0.00\n",
      "[E⋅84] R=200.00, C=0.00\n",
      "[E⋅85] R=200.00, C=0.00\n",
      "[E⋅86] R=200.00, C=0.00\n",
      "[E⋅87] R=200.00, C=0.00\n",
      "[E⋅88] R=200.00, C=0.00\n",
      "[E⋅89] R=121.00, C=0.03\n",
      "[E⋅90] R=32.00, C=0.05\n",
      "[E⋅91] R=66.00, C=0.04\n",
      "[E⋅92] R=37.00, C=0.07\n",
      "[E⋅93] R=39.00, C=0.05\n",
      "[E⋅94] R=37.00, C=0.05\n",
      "[E⋅95] R=10.00, C=0.13\n",
      "[E⋅96] R=24.00, C=0.08\n",
      "[E⋅97] R=8.00, C=0.10\n",
      "[E⋅98] R=12.00, C=0.10\n",
      "[E⋅99] R=21.00, C=0.07\n",
      "[E⋅100] R=24.00, C=0.08\n",
      "[E⋅101] R=27.00, C=0.06\n",
      "[E⋅102] R=25.00, C=0.04\n",
      "[E⋅103] R=54.00, C=0.02\n",
      "[E⋅104] R=26.00, C=0.05\n",
      "[E⋅105] R=66.00, C=0.03\n",
      "[E⋅106] R=50.00, C=0.05\n",
      "[E⋅107] R=32.00, C=0.06\n",
      "[E⋅108] R=78.00, C=0.03\n",
      "[E⋅109] R=63.00, C=0.04\n",
      "[E⋅110] R=50.00, C=0.05\n",
      "[E⋅111] R=70.00, C=0.02\n",
      "[E⋅112] R=200.00, C=0.00\n",
      "[E⋅113] R=200.00, C=0.00\n",
      "[E⋅114] R=107.00, C=0.02\n",
      "[E⋅115] R=43.00, C=0.06\n",
      "[E⋅116] R=67.00, C=0.04\n",
      "[E⋅117] R=63.00, C=0.03\n",
      "[E⋅118] R=44.00, C=0.05\n",
      "[E⋅119] R=200.00, C=0.00\n",
      "[E⋅120] R=200.00, C=0.00\n",
      "[E⋅121] R=77.00, C=0.04\n",
      "[E⋅122] R=200.00, C=0.00\n",
      "[E⋅123] R=97.00, C=0.03\n",
      "[E⋅124] R=64.00, C=0.05\n",
      "[E⋅125] R=58.00, C=0.05\n",
      "[E⋅126] R=54.00, C=0.04\n",
      "[E⋅127] R=66.00, C=0.04\n",
      "[E⋅128] R=200.00, C=0.00\n",
      "[E⋅129] R=52.00, C=0.05\n",
      "[E⋅130] R=73.00, C=0.03\n",
      "[E⋅131] R=81.00, C=0.03\n",
      "[E⋅132] R=195.00, C=0.01\n",
      "[E⋅133] R=93.00, C=0.03\n",
      "[E⋅134] R=105.00, C=0.02\n",
      "[E⋅135] R=200.00, C=0.00\n",
      "[E⋅136] R=200.00, C=0.00\n",
      "[E⋅137] R=163.00, C=0.01\n",
      "[E⋅138] R=200.00, C=0.00\n",
      "[E⋅139] R=200.00, C=0.00\n",
      "[E⋅140] R=200.00, C=0.00\n",
      "[E⋅141] R=200.00, C=0.00\n",
      "[E⋅142] R=200.00, C=0.00\n",
      "[E⋅143] R=200.00, C=0.00\n",
      "[E⋅144] R=58.00, C=0.04\n",
      "[E⋅145] R=84.00, C=0.03\n",
      "[E⋅146] R=89.00, C=0.03\n",
      "[E⋅147] R=200.00, C=0.01\n",
      "[E⋅148] R=105.00, C=0.03\n",
      "[E⋅149] R=139.00, C=0.02\n",
      "[E⋅150] R=166.00, C=0.02\n",
      "[E⋅151] R=161.00, C=0.02\n",
      "[E⋅152] R=200.00, C=0.00\n",
      "[E⋅153] R=200.00, C=0.00\n",
      "[E⋅154] R=200.00, C=0.00\n",
      "[E⋅155] R=200.00, C=0.00\n",
      "[E⋅156] R=200.00, C=0.00\n",
      "[E⋅157] R=200.00, C=0.00\n",
      "[E⋅158] R=200.00, C=0.00\n",
      "[E⋅159] R=200.00, C=0.00\n"
     ]
    }
   ],
   "source": [
    "env = gym.make(ENV_ID)\n",
    "model =  ProbabilisticRegressor(input_dim=5, hidden_dim=128)\n",
    "model1 =  ProbabilisticRegressor(input_dim=5, hidden_dim=128)\n",
    "model2 =  ProbabilisticRegressor(input_dim=5, hidden_dim=128)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "ensemble = [ model ]\n",
    "# ensemble = [ model1,model2 ]\n",
    "theta = np.zeros(4)  # initial policy params\n",
    "train_data=np.zeros((N_EPISODES,3))\n",
    "# 1) Intrinsic‐exploration phase\n",
    "for ep in range(1, N_INTRINSIC + 1):\n",
    "    # a) Solve Eq (7): maximize intrinsic bonus under cost limit\n",
    "    theta = solve_barrier(ensemble, theta, intrinsic=True)\n",
    "    # b) Collect real transitions\n",
    "    trans, R, C,vio = rollout_env(env, theta)\n",
    "    train_data[ep,0]=R\n",
    "    train_data[ep,1]=C\n",
    "    train_data[ep,2]=vio\n",
    "    model.update(trans)\n",
    "    print(f\"[I⋅{ep}] R={R:.2f}, C={C:.2f}\")\n",
    "\n",
    "    # 2) Extrinsic‐exploitation phase\n",
    "for ep in range(N_INTRINSIC + 1, N_EPISODES ):\n",
    "    # a) Solve Eq (8): maximize extrinsic reward under cost limit\n",
    "    theta = solve_barrier(ensemble, theta, intrinsic=False)\n",
    "    trans, R, C ,vio= rollout_env(env, theta)\n",
    "    train_data[ep,0]=R\n",
    "    train_data[ep,1]=C\n",
    "    train_data[ep,2]=vio\n",
    "    model.update(trans)\n",
    "    print(f\"[E⋅{ep}] R={R:.2f}, C={C:.2f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "2ab76a3f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(17022.0, 439.0, 0.025790153918458465)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(train_data[:,0]),np.sum(train_data[:,2]),np.sum(train_data[:,2])/np.sum(train_data[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "aa10c519",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average reward over 100 test episodes: 200.00\n"
     ]
    }
   ],
   "source": [
    "# Evaluate the agent's performance\n",
    "np.random.seed(2025)\n",
    "test_episodes = 100\n",
    "max_steps_per_test_episode=200\n",
    "episode_rewards = []\n",
    "test_average_cost=np.zeros(test_episodes)\n",
    "test_epi_length=np.zeros(test_episodes)\n",
    "test_cost_epi=np.zeros(test_episodes)\n",
    "test_cost_per_transition=np.zeros((test_episodes,max_steps_per_test_episode))\n",
    "for episode in range(test_episodes):\n",
    "    state = env.reset()[0]\n",
    "    episode_reward = 0\n",
    "    done = False\n",
    "    for step in range(max_steps_per_test_episode):\n",
    "   \n",
    "        action=policy(state,theta)       # action = new_agent.act(state,eps=0)\n",
    "        next_state, reward, done, truncated, _ = env.step(action)\n",
    "        cost=cal_cost(state,action,next_state)\n",
    "        episode_reward += reward\n",
    "        state = next_state\n",
    "        test_epi_length[episode]+=1\n",
    "        test_cost_epi[episode]+=cost\n",
    "        test_cost_per_transition[episode,step]=cost\n",
    "        if done:\n",
    "            break\n",
    "\n",
    "    test_average_cost[episode]=test_cost_epi[episode]/test_epi_length[episode]\n",
    "    episode_rewards.append(episode_reward)\n",
    "\n",
    "average_reward = sum(episode_rewards) / test_episodes\n",
    "print(f\"Average reward over {test_episodes} test episodes: {average_reward:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "9e75ec93",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x22d8be2b3a0>]"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGdCAYAAADuR1K7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgSElEQVR4nO3df2zU9eHH8de1pVcQ7joK9Ci0go6sIAja2nJowmIvK0qmnbhhU6ViI9EBgmUov8nmWN2MCgyUsESJAQbDKVPGMKw41HAWKKDyq7LoAMG7gqw9fpbae3//MJ7fmwUL9Fru7fORfIL9fN6fu/fnncg98+nd4TDGGAEAAFgiob0nAAAA0JqIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWSWrvCbSHcDiso0ePqkuXLnI4HO09HQAA0ALGGJ08eVIZGRlKSLjw/ZnvZdwcPXpUmZmZ7T0NAABwGQ4fPqzevXtf8Pj3Mm66dOki6avFcblc7TwbAADQEqFQSJmZmZHX8Qv5XsbN17+KcrlcxA0AAHHmu95SwhuKAQCAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFilTeJm8eLF6tOnj1JSUpSfn6+tW7dedPyaNWuUnZ2tlJQUDRo0SOvXr7/g2EceeUQOh0Pz589v5VkDAIB4FPO4Wb16tcrLyzV37lzt2LFDgwcPVmFhoWpra5sdv2XLFhUXF6usrEw7d+5UUVGRioqKtHv37m+Nff311/X+++8rIyMj1pcBAADiRMzj5rnnntPDDz+ssWPHasCAAVqyZIk6deqkl156qdnxCxYs0IgRIzR16lT1799fTz31lG6++WYtWrQoatyRI0c0ceJErVixQh06dIj1ZQAAgDgR07g5f/68qqur5fP5vnnChAT5fD75/f5mz/H7/VHjJamwsDBqfDgc1gMPPKCpU6fqhhtu+M55NDQ0KBQKRW0AAMBOMY2b48ePq6mpSenp6VH709PTFQgEmj0nEAh85/jf//73SkpK0mOPPdaieVRUVMjtdke2zMzMS7wSAAAQL+Lu01LV1dVasGCBli1bJofD0aJzpk+frvr6+sh2+PDhGM8SAAC0l5jGTbdu3ZSYmKhgMBi1PxgMyuPxNHuOx+O56Ph3331XtbW1ysrKUlJSkpKSknTw4EFNmTJFffr0afYxnU6nXC5X1AYAAOwU07hJTk5WTk6OKisrI/vC4bAqKyvl9XqbPcfr9UaNl6SNGzdGxj/wwAP68MMPtWvXrsiWkZGhqVOn6q233ordxQAAgLiQFOsnKC8vV2lpqXJzc5WXl6f58+fr9OnTGjt2rCRpzJgx6tWrlyoqKiRJkyZN0vDhw/Xss89q5MiRWrVqlbZv366lS5dKktLS0pSWlhb1HB06dJDH49GPfvSjWF8OAAC4ysU8bkaPHq1jx45pzpw5CgQCGjJkiDZs2BB50/ChQ4eUkPDNDaRhw4Zp5cqVmjVrlmbMmKF+/fpp7dq1GjhwYKynCgAALOAwxpj2nkRbC4VCcrvdqq+v5/03AADEiZa+fsfdp6UAAAAuhrgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYJU2iZvFixerT58+SklJUX5+vrZu3XrR8WvWrFF2drZSUlI0aNAgrV+/PnKssbFRTz75pAYNGqRrrrlGGRkZGjNmjI4ePRrrywAAAHEg5nGzevVqlZeXa+7cudqxY4cGDx6swsJC1dbWNjt+y5YtKi4uVllZmXbu3KmioiIVFRVp9+7dkqQzZ85ox44dmj17tnbs2KHXXntNNTU1uuuuu2J9KQAAIA44jDEmlk+Qn5+vW265RYsWLZIkhcNhZWZmauLEiZo2bdq3xo8ePVqnT5/WunXrIvuGDh2qIUOGaMmSJc0+x7Zt25SXl6eDBw8qKyvrO+cUCoXkdrtVX18vl8t1mVcGAADaUktfv2N65+b8+fOqrq6Wz+f75gkTEuTz+eT3+5s9x+/3R42XpMLCwguOl6T6+no5HA6lpqY2e7yhoUGhUChqAwAAdopp3Bw/flxNTU1KT0+P2p+enq5AINDsOYFA4JLGnzt3Tk8++aSKi4svWHEVFRVyu92RLTMz8zKuBgAAxIO4/rRUY2OjfvGLX8gYoxdffPGC46ZPn676+vrIdvjw4TacJQAAaEtJsXzwbt26KTExUcFgMGp/MBiUx+Np9hyPx9Oi8V+HzcGDB7Vp06aL/u7N6XTK6XRe5lUAAIB4EtM7N8nJycrJyVFlZWVkXzgcVmVlpbxeb7PneL3eqPGStHHjxqjxX4fNgQMH9M9//lNpaWmxuQAAABB3YnrnRpLKy8tVWlqq3Nxc5eXlaf78+Tp9+rTGjh0rSRozZox69eqliooKSdKkSZM0fPhwPfvssxo5cqRWrVql7du3a+nSpZK+Cpt7771XO3bs0Lp169TU1BR5P07Xrl2VnJwc60sCAABXsZjHzejRo3Xs2DHNmTNHgUBAQ4YM0YYNGyJvGj506JASEr65gTRs2DCtXLlSs2bN0owZM9SvXz+tXbtWAwcOlCQdOXJEb7zxhiRpyJAhUc/19ttv68c//nGsLwkAAFzFYv49N1cjvucGAID4c1V8zw0AAEBbI24AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWKVN4mbx4sXq06ePUlJSlJ+fr61bt150/Jo1a5Sdna2UlBQNGjRI69evjzpujNGcOXPUs2dPdezYUT6fTwcOHIjlJQAAgDgR87hZvXq1ysvLNXfuXO3YsUODBw9WYWGhamtrmx2/ZcsWFRcXq6ysTDt37lRRUZGKioq0e/fuyJg//OEPWrhwoZYsWaKqqipdc801Kiws1Llz52J9OQAA4CrnMMaYWD5Bfn6+brnlFi1atEiSFA6HlZmZqYkTJ2ratGnfGj969GidPn1a69ati+wbOnSohgwZoiVLlsgYo4yMDE2ZMkW/+tWvJEn19fVKT0/XsmXLdN99933nnEKhkNxut+rr6+VyuVrpSr+6o3S2sanVHg8AgHjVsUOiHA5Hqz5mS1+/k1r1Wf/H+fPnVV1drenTp0f2JSQkyOfzye/3N3uO3+9XeXl51L7CwkKtXbtWkvTpp58qEAjI5/NFjrvdbuXn58vv9zcbNw0NDWpoaIj8HAqFruSyLuhsY5MGzHkrJo8NAEA82fubQnVKjmlmXFBMfy11/PhxNTU1KT09PWp/enq6AoFAs+cEAoGLjv/6z0t5zIqKCrnd7siWmZl5WdcDAACufu2TVG1s+vTpUXeDQqFQTAKnY4dE7f1NYas/LgAA8aZjh8R2e+6Yxk23bt2UmJioYDAYtT8YDMrj8TR7jsfjuej4r/8MBoPq2bNn1JghQ4Y0+5hOp1NOp/NyL6PFHA5Hu92CAwAAX4npr6WSk5OVk5OjysrKyL5wOKzKykp5vd5mz/F6vVHjJWnjxo2R8X379pXH44kaEwqFVFVVdcHHBAAA3x8xv81QXl6u0tJS5ebmKi8vT/Pnz9fp06c1duxYSdKYMWPUq1cvVVRUSJImTZqk4cOH69lnn9XIkSO1atUqbd++XUuXLpX01d2RyZMn67e//a369eunvn37avbs2crIyFBRUVGsLwcAAFzlYh43o0eP1rFjxzRnzhwFAgENGTJEGzZsiLwh+NChQ0pI+OYG0rBhw7Ry5UrNmjVLM2bMUL9+/bR27VoNHDgwMuaJJ57Q6dOnNW7cONXV1em2227Thg0blJKSEuvLAQAAV7mYf8/N1ShW33MDAABip6Wv3/zbUgAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrxCxuTpw4oZKSErlcLqWmpqqsrEynTp266Dnnzp3T+PHjlZaWps6dO2vUqFEKBoOR4x988IGKi4uVmZmpjh07qn///lqwYEGsLgEAAMShmMVNSUmJ9uzZo40bN2rdunV65513NG7cuIue8/jjj+vNN9/UmjVrtHnzZh09elT33HNP5Hh1dbV69Oih5cuXa8+ePZo5c6amT5+uRYsWxeoyAABAnHEYY0xrP+i+ffs0YMAAbdu2Tbm5uZKkDRs26M4779Rnn32mjIyMb51TX1+v7t27a+XKlbr33nslSfv371f//v3l9/s1dOjQZp9r/Pjx2rdvnzZt2tTi+YVCIbndbtXX18vlcl3GFQIAgLbW0tfvmNy58fv9Sk1NjYSNJPl8PiUkJKiqqqrZc6qrq9XY2CifzxfZl52draysLPn9/gs+V319vbp27dp6kwcAAHEtKRYPGggE1KNHj+gnSkpS165dFQgELnhOcnKyUlNTo/anp6df8JwtW7Zo9erV+vvf/37R+TQ0NKihoSHycygUasFVAACAeHRJd26mTZsmh8Nx0W3//v2xmmuU3bt36+6779bcuXP1k5/85KJjKyoq5Ha7I1tmZmabzBEAALS9S7pzM2XKFD344IMXHXPdddfJ4/GotrY2av+XX36pEydOyOPxNHuex+PR+fPnVVdXF3X3JhgMfuucvXv3qqCgQOPGjdOsWbO+c97Tp09XeXl55OdQKETgAABgqUuKm+7du6t79+7fOc7r9aqurk7V1dXKycmRJG3atEnhcFj5+fnNnpOTk6MOHTqosrJSo0aNkiTV1NTo0KFD8nq9kXF79uzR7bffrtLSUs2bN69F83Y6nXI6nS0aCwAA4ltMPi0lSXfccYeCwaCWLFmixsZGjR07Vrm5uVq5cqUk6ciRIyooKNArr7yivLw8SdKjjz6q9evXa9myZXK5XJo4caKkr95bI331q6jbb79dhYWFeuaZZyLPlZiY2KLo+hqflgIAIP609PU7Jm8olqQVK1ZowoQJKigoUEJCgkaNGqWFCxdGjjc2NqqmpkZnzpyJ7Hv++ecjYxsaGlRYWKgXXnghcvzVV1/VsWPHtHz5ci1fvjyy/9prr9V//vOfWF0KAACIIzG7c3M1484NAADxp12/5wYAAKC9EDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAq8Qsbk6cOKGSkhK5XC6lpqaqrKxMp06duug5586d0/jx45WWlqbOnTtr1KhRCgaDzY794osv1Lt3bzkcDtXV1cXgCgAAQDyKWdyUlJRoz5492rhxo9atW6d33nlH48aNu+g5jz/+uN58802tWbNGmzdv1tGjR3XPPfc0O7asrEw33nhjLKYOAADimMMYY1r7Qfft26cBAwZo27Ztys3NlSRt2LBBd955pz777DNlZGR865z6+np1795dK1eu1L333itJ2r9/v/r37y+/36+hQ4dGxr744otavXq15syZo4KCAv33v/9Vampqi+cXCoXkdrtVX18vl8t1ZRcLAADaREtfv2Ny58bv9ys1NTUSNpLk8/mUkJCgqqqqZs+prq5WY2OjfD5fZF92draysrLk9/sj+/bu3avf/OY3euWVV5SQ0LLpNzQ0KBQKRW0AAMBOMYmbQCCgHj16RO1LSkpS165dFQgELnhOcnLyt+7ApKenR85paGhQcXGxnnnmGWVlZbV4PhUVFXK73ZEtMzPz0i4IAADEjUuKm2nTpsnhcFx0279/f6zmqunTp6t///66//77L/m8+vr6yHb48OEYzRAAALS3pEsZPGXKFD344IMXHXPdddfJ4/GotrY2av+XX36pEydOyOPxNHuex+PR+fPnVVdXF3X3JhgMRs7ZtGmTPvroI7366quSpK/fLtStWzfNnDlTv/71r5t9bKfTKafT2ZJLBAAAce6S4qZ79+7q3r37d47zer2qq6tTdXW1cnJyJH0VJuFwWPn5+c2ek5OTow4dOqiyslKjRo2SJNXU1OjQoUPyer2SpL/+9a86e/Zs5Jxt27bpoYce0rvvvqvrr7/+Ui4FAABY6pLipqX69++vESNG6OGHH9aSJUvU2NioCRMm6L777ot8UurIkSMqKCjQK6+8ory8PLndbpWVlam8vFxdu3aVy+XSxIkT5fV6I5+U+t+AOX78eOT5LuXTUgAAwF4xiRtJWrFihSZMmKCCggIlJCRo1KhRWrhwYeR4Y2OjampqdObMmci+559/PjK2oaFBhYWFeuGFF2I1RQAAYKGYfM/N1Y7vuQEAIP606/fcAAAAtBfiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYJam9J9AejDGSpFAo1M4zAQAALfX16/bXr+MX8r2Mm5MnT0qSMjMz23kmAADgUp08eVJut/uCxx3mu/LHQuFwWEePHlWXLl3kcDha9bFDoZAyMzN1+PBhuVyuVn1sRGOt2w5r3XZY67bDWred1lprY4xOnjypjIwMJSRc+J0138s7NwkJCerdu3dMn8PlcvE/SxthrdsOa912WOu2w1q3ndZY64vdsfkabygGAABWIW4AAIBViJtW5nQ6NXfuXDmdzvaeivVY67bDWrcd1rrtsNZtp63X+nv5hmIAAGAv7twAAACrEDcAAMAqxA0AALAKcQMAAKxC3LSixYsXq0+fPkpJSVF+fr62bt3a3lOKexUVFbrlllvUpUsX9ejRQ0VFRaqpqYkac+7cOY0fP15paWnq3LmzRo0apWAw2E4ztsfTTz8th8OhyZMnR/ax1q3nyJEjuv/++5WWlqaOHTtq0KBB2r59e+S4MUZz5sxRz5491bFjR/l8Ph04cKAdZxyfmpqaNHv2bPXt21cdO3bU9ddfr6eeeirq3yZirS/PO++8o5/+9KfKyMiQw+HQ2rVro463ZF1PnDihkpISuVwupaamqqysTKdOnbryyRm0ilWrVpnk5GTz0ksvmT179piHH37YpKammmAw2N5Ti2uFhYXm5ZdfNrt37za7du0yd955p8nKyjKnTp2KjHnkkUdMZmamqaysNNu3bzdDhw41w4YNa8dZx7+tW7eaPn36mBtvvNFMmjQpsp+1bh0nTpww1157rXnwwQdNVVWV+eSTT8xbb71l/v3vf0fGPP3008btdpu1a9eaDz74wNx1112mb9++5uzZs+048/gzb948k5aWZtatW2c+/fRTs2bNGtO5c2ezYMGCyBjW+vKsX7/ezJw507z22mtGknn99dejjrdkXUeMGGEGDx5s3n//ffPuu++aH/7wh6a4uPiK50bctJK8vDwzfvz4yM9NTU0mIyPDVFRUtOOs7FNbW2skmc2bNxtjjKmrqzMdOnQwa9asiYzZt2+fkWT8fn97TTOunTx50vTr189s3LjRDB8+PBI3rHXrefLJJ81tt912wePhcNh4PB7zzDPPRPbV1dUZp9Np/vznP7fFFK0xcuRI89BDD0Xtu+eee0xJSYkxhrVuLf8bNy1Z17179xpJZtu2bZEx//jHP4zD4TBHjhy5ovnwa6lWcP78eVVXV8vn80X2JSQkyOfzye/3t+PM7FNfXy9J6tq1qySpurpajY2NUWufnZ2trKws1v4yjR8/XiNHjoxaU4m1bk1vvPGGcnNz9fOf/1w9evTQTTfdpD/96U+R459++qkCgUDUWrvdbuXn57PWl2jYsGGqrKzUxx9/LEn64IMP9N577+mOO+6QxFrHSkvW1e/3KzU1Vbm5uZExPp9PCQkJqqqquqLn/17+w5mt7fjx42pqalJ6enrU/vT0dO3fv7+dZmWfcDisyZMn69Zbb9XAgQMlSYFAQMnJyUpNTY0am56erkAg0A6zjG+rVq3Sjh07tG3btm8dY61bzyeffKIXX3xR5eXlmjFjhrZt26bHHntMycnJKi0tjaxnc3+nsNaXZtq0aQqFQsrOzlZiYqKampo0b948lZSUSBJrHSMtWddAIKAePXpEHU9KSlLXrl2veO2JG8SN8ePHa/fu3XrvvffaeypWOnz4sCZNmqSNGzcqJSWlvadjtXA4rNzcXP3ud7+TJN10003avXu3lixZotLS0naenV3+8pe/aMWKFVq5cqVuuOEG7dq1S5MnT1ZGRgZrbTF+LdUKunXrpsTExG99aiQYDMrj8bTTrOwyYcIErVu3Tm+//bZ69+4d2e/xeHT+/HnV1dVFjWftL111dbVqa2t18803KykpSUlJSdq8ebMWLlyopKQkpaens9atpGfPnhowYEDUvv79++vQoUOSFFlP/k65clOnTtW0adN03333adCgQXrggQf0+OOPq6KiQhJrHSstWVePx6Pa2tqo419++aVOnDhxxWtP3LSC5ORk5eTkqLKyMrIvHA6rsrJSXq+3HWcW/4wxmjBhgl5//XVt2rRJffv2jTqek5OjDh06RK19TU2NDh06xNpfooKCAn300UfatWtXZMvNzVVJSUnkv1nr1nHrrbd+6ysNPv74Y1177bWSpL59+8rj8UStdSgUUlVVFWt9ic6cOaOEhOiXusTERIXDYUmsday0ZF29Xq/q6upUXV0dGbNp0yaFw2Hl5+df2QSu6O3IiFi1apVxOp1m2bJlZu/evWbcuHEmNTXVBAKB9p5aXHv00UeN2+02//rXv8znn38e2c6cORMZ88gjj5isrCyzadMms337duP1eo3X623HWdvj/39ayhjWurVs3brVJCUlmXnz5pkDBw6YFStWmE6dOpnly5dHxjz99NMmNTXV/O1vfzMffvihufvuu/l48mUoLS01vXr1inwU/LXXXjPdunUzTzzxRGQMa315Tp48aXbu3Gl27txpJJnnnnvO7Ny50xw8eNAY07J1HTFihLnppptMVVWVee+990y/fv34KPjV5o9//KPJysoyycnJJi8vz7z//vvtPaW4J6nZ7eWXX46MOXv2rPnlL39pfvCDH5hOnTqZn/3sZ+bzzz9vv0lb5H/jhrVuPW+++aYZOHCgcTqdJjs72yxdujTqeDgcNrNnzzbp6enG6XSagoICU1NT006zjV+hUMhMmjTJZGVlmZSUFHPdddeZmTNnmoaGhsgY1vryvP32283+/VxaWmqMadm6fvHFF6a4uNh07tzZuFwuM3bsWHPy5MkrnpvDmP/3NY0AAABxjvfcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArPJ/s6un4nY2lk4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(test_cost_epi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "f8910883",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x22d8be93ee0>]"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGdCAYAAADuR1K7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgSElEQVR4nO3df2zU9eHH8de1pVcQ7joK9Ci0go6sIAja2nJowmIvK0qmnbhhU6ViI9EBgmUov8nmWN2MCgyUsESJAQbDKVPGMKw41HAWKKDyq7LoAMG7gqw9fpbae3//MJ7fmwUL9Fru7fORfIL9fN6fu/fnncg98+nd4TDGGAEAAFgiob0nAAAA0JqIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWSWrvCbSHcDiso0ePqkuXLnI4HO09HQAA0ALGGJ08eVIZGRlKSLjw/ZnvZdwcPXpUmZmZ7T0NAABwGQ4fPqzevXtf8Pj3Mm66dOki6avFcblc7TwbAADQEqFQSJmZmZHX8Qv5XsbN17+KcrlcxA0AAHHmu95SwhuKAQCAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFilTeJm8eLF6tOnj1JSUpSfn6+tW7dedPyaNWuUnZ2tlJQUDRo0SOvXr7/g2EceeUQOh0Pz589v5VkDAIB4FPO4Wb16tcrLyzV37lzt2LFDgwcPVmFhoWpra5sdv2XLFhUXF6usrEw7d+5UUVGRioqKtHv37m+Nff311/X+++8rIyMj1pcBAADiRMzj5rnnntPDDz+ssWPHasCAAVqyZIk6deqkl156qdnxCxYs0IgRIzR16lT1799fTz31lG6++WYtWrQoatyRI0c0ceJErVixQh06dIj1ZQAAgDgR07g5f/68qqur5fP5vnnChAT5fD75/f5mz/H7/VHjJamwsDBqfDgc1gMPPKCpU6fqhhtu+M55NDQ0KBQKRW0AAMBOMY2b48ePq6mpSenp6VH709PTFQgEmj0nEAh85/jf//73SkpK0mOPPdaieVRUVMjtdke2zMzMS7wSAAAQL+Lu01LV1dVasGCBli1bJofD0aJzpk+frvr6+sh2+PDhGM8SAAC0l5jGTbdu3ZSYmKhgMBi1PxgMyuPxNHuOx+O56Ph3331XtbW1ysrKUlJSkpKSknTw4EFNmTJFffr0afYxnU6nXC5X1AYAAOwU07hJTk5WTk6OKisrI/vC4bAqKyvl9XqbPcfr9UaNl6SNGzdGxj/wwAP68MMPtWvXrsiWkZGhqVOn6q233ordxQAAgLiQFOsnKC8vV2lpqXJzc5WXl6f58+fr9OnTGjt2rCRpzJgx6tWrlyoqKiRJkyZN0vDhw/Xss89q5MiRWrVqlbZv366lS5dKktLS0pSWlhb1HB06dJDH49GPfvSjWF8OAAC4ysU8bkaPHq1jx45pzpw5CgQCGjJkiDZs2BB50/ChQ4eUkPDNDaRhw4Zp5cqVmjVrlmbMmKF+/fpp7dq1GjhwYKynCgAALOAwxpj2nkRbC4VCcrvdqq+v5/03AADEiZa+fsfdp6UAAAAuhrgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYJU2iZvFixerT58+SklJUX5+vrZu3XrR8WvWrFF2drZSUlI0aNAgrV+/PnKssbFRTz75pAYNGqRrrrlGGRkZGjNmjI4ePRrrywAAAHEg5nGzevVqlZeXa+7cudqxY4cGDx6swsJC1dbWNjt+y5YtKi4uVllZmXbu3KmioiIVFRVp9+7dkqQzZ85ox44dmj17tnbs2KHXXntNNTU1uuuuu2J9KQAAIA44jDEmlk+Qn5+vW265RYsWLZIkhcNhZWZmauLEiZo2bdq3xo8ePVqnT5/WunXrIvuGDh2qIUOGaMmSJc0+x7Zt25SXl6eDBw8qKyvrO+cUCoXkdrtVX18vl8t1mVcGAADaUktfv2N65+b8+fOqrq6Wz+f75gkTEuTz+eT3+5s9x+/3R42XpMLCwguOl6T6+no5HA6lpqY2e7yhoUGhUChqAwAAdopp3Bw/flxNTU1KT0+P2p+enq5AINDsOYFA4JLGnzt3Tk8++aSKi4svWHEVFRVyu92RLTMz8zKuBgAAxIO4/rRUY2OjfvGLX8gYoxdffPGC46ZPn676+vrIdvjw4TacJQAAaEtJsXzwbt26KTExUcFgMGp/MBiUx+Np9hyPx9Oi8V+HzcGDB7Vp06aL/u7N6XTK6XRe5lUAAIB4EtM7N8nJycrJyVFlZWVkXzgcVmVlpbxeb7PneL3eqPGStHHjxqjxX4fNgQMH9M9//lNpaWmxuQAAABB3YnrnRpLKy8tVWlqq3Nxc5eXlaf78+Tp9+rTGjh0rSRozZox69eqliooKSdKkSZM0fPhwPfvssxo5cqRWrVql7du3a+nSpZK+Cpt7771XO3bs0Lp169TU1BR5P07Xrl2VnJwc60sCAABXsZjHzejRo3Xs2DHNmTNHgUBAQ4YM0YYNGyJvGj506JASEr65gTRs2DCtXLlSs2bN0owZM9SvXz+tXbtWAwcOlCQdOXJEb7zxhiRpyJAhUc/19ttv68c//nGsLwkAAFzFYv49N1cjvucGAID4c1V8zw0AAEBbI24AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWKVN4mbx4sXq06ePUlJSlJ+fr61bt150/Jo1a5Sdna2UlBQNGjRI69evjzpujNGcOXPUs2dPdezYUT6fTwcOHIjlJQAAgDgR87hZvXq1ysvLNXfuXO3YsUODBw9WYWGhamtrmx2/ZcsWFRcXq6ysTDt37lRRUZGKioq0e/fuyJg//OEPWrhwoZYsWaKqqipdc801Kiws1Llz52J9OQAA4CrnMMaYWD5Bfn6+brnlFi1atEiSFA6HlZmZqYkTJ2ratGnfGj969GidPn1a69ati+wbOnSohgwZoiVLlsgYo4yMDE2ZMkW/+tWvJEn19fVKT0/XsmXLdN99933nnEKhkNxut+rr6+VyuVrpSr+6o3S2sanVHg8AgHjVsUOiHA5Hqz5mS1+/k1r1Wf/H+fPnVV1drenTp0f2JSQkyOfzye/3N3uO3+9XeXl51L7CwkKtXbtWkvTpp58qEAjI5/NFjrvdbuXn58vv9zcbNw0NDWpoaIj8HAqFruSyLuhsY5MGzHkrJo8NAEA82fubQnVKjmlmXFBMfy11/PhxNTU1KT09PWp/enq6AoFAs+cEAoGLjv/6z0t5zIqKCrnd7siWmZl5WdcDAACufu2TVG1s+vTpUXeDQqFQTAKnY4dE7f1NYas/LgAA8aZjh8R2e+6Yxk23bt2UmJioYDAYtT8YDMrj8TR7jsfjuej4r/8MBoPq2bNn1JghQ4Y0+5hOp1NOp/NyL6PFHA5Hu92CAwAAX4npr6WSk5OVk5OjysrKyL5wOKzKykp5vd5mz/F6vVHjJWnjxo2R8X379pXH44kaEwqFVFVVdcHHBAAA3x8xv81QXl6u0tJS5ebmKi8vT/Pnz9fp06c1duxYSdKYMWPUq1cvVVRUSJImTZqk4cOH69lnn9XIkSO1atUqbd++XUuXLpX01d2RyZMn67e//a369eunvn37avbs2crIyFBRUVGsLwcAAFzlYh43o0eP1rFjxzRnzhwFAgENGTJEGzZsiLwh+NChQ0pI+OYG0rBhw7Ry5UrNmjVLM2bMUL9+/bR27VoNHDgwMuaJJ57Q6dOnNW7cONXV1em2227Thg0blJKSEuvLAQAAV7mYf8/N1ShW33MDAABip6Wv3/zbUgAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrxCxuTpw4oZKSErlcLqWmpqqsrEynTp266Dnnzp3T+PHjlZaWps6dO2vUqFEKBoOR4x988IGKi4uVmZmpjh07qn///lqwYEGsLgEAAMShmMVNSUmJ9uzZo40bN2rdunV65513NG7cuIue8/jjj+vNN9/UmjVrtHnzZh09elT33HNP5Hh1dbV69Oih5cuXa8+ePZo5c6amT5+uRYsWxeoyAABAnHEYY0xrP+i+ffs0YMAAbdu2Tbm5uZKkDRs26M4779Rnn32mjIyMb51TX1+v7t27a+XKlbr33nslSfv371f//v3l9/s1dOjQZp9r/Pjx2rdvnzZt2tTi+YVCIbndbtXX18vlcl3GFQIAgLbW0tfvmNy58fv9Sk1NjYSNJPl8PiUkJKiqqqrZc6qrq9XY2CifzxfZl52draysLPn9/gs+V319vbp27dp6kwcAAHEtKRYPGggE1KNHj+gnSkpS165dFQgELnhOcnKyUlNTo/anp6df8JwtW7Zo9erV+vvf/37R+TQ0NKihoSHycygUasFVAACAeHRJd26mTZsmh8Nx0W3//v2xmmuU3bt36+6779bcuXP1k5/85KJjKyoq5Ha7I1tmZmabzBEAALS9S7pzM2XKFD344IMXHXPdddfJ4/GotrY2av+XX36pEydOyOPxNHuex+PR+fPnVVdXF3X3JhgMfuucvXv3qqCgQOPGjdOsWbO+c97Tp09XeXl55OdQKETgAABgqUuKm+7du6t79+7fOc7r9aqurk7V1dXKycmRJG3atEnhcFj5+fnNnpOTk6MOHTqosrJSo0aNkiTV1NTo0KFD8nq9kXF79uzR7bffrtLSUs2bN69F83Y6nXI6nS0aCwAA4ltMPi0lSXfccYeCwaCWLFmixsZGjR07Vrm5uVq5cqUk6ciRIyooKNArr7yivLw8SdKjjz6q9evXa9myZXK5XJo4caKkr95bI331q6jbb79dhYWFeuaZZyLPlZiY2KLo+hqflgIAIP609PU7Jm8olqQVK1ZowoQJKigoUEJCgkaNGqWFCxdGjjc2NqqmpkZnzpyJ7Hv++ecjYxsaGlRYWKgXXnghcvzVV1/VsWPHtHz5ci1fvjyy/9prr9V//vOfWF0KAACIIzG7c3M1484NAADxp12/5wYAAKC9EDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArELcAAAAq8Qsbk6cOKGSkhK5XC6lpqaqrKxMp06duug5586d0/jx45WWlqbOnTtr1KhRCgaDzY794osv1Lt3bzkcDtXV1cXgCgAAQDyKWdyUlJRoz5492rhxo9atW6d33nlH48aNu+g5jz/+uN58802tWbNGmzdv1tGjR3XPPfc0O7asrEw33nhjLKYOAADimMMYY1r7Qfft26cBAwZo27Ztys3NlSRt2LBBd955pz777DNlZGR865z6+np1795dK1eu1L333itJ2r9/v/r37y+/36+hQ4dGxr744otavXq15syZo4KCAv33v/9Vampqi+cXCoXkdrtVX18vl8t1ZRcLAADaREtfv2Ny58bv9ys1NTUSNpLk8/mUkJCgqqqqZs+prq5WY2OjfD5fZF92draysrLk9/sj+/bu3avf/OY3euWVV5SQ0LLpNzQ0KBQKRW0AAMBOMYmbQCCgHj16RO1LSkpS165dFQgELnhOcnLyt+7ApKenR85paGhQcXGxnnnmGWVlZbV4PhUVFXK73ZEtMzPz0i4IAADEjUuKm2nTpsnhcFx0279/f6zmqunTp6t///66//77L/m8+vr6yHb48OEYzRAAALS3pEsZPGXKFD344IMXHXPdddfJ4/GotrY2av+XX36pEydOyOPxNHuex+PR+fPnVVdXF3X3JhgMRs7ZtGmTPvroI7366quSpK/fLtStWzfNnDlTv/71r5t9bKfTKafT2ZJLBAAAce6S4qZ79+7q3r37d47zer2qq6tTdXW1cnJyJH0VJuFwWPn5+c2ek5OTow4dOqiyslKjRo2SJNXU1OjQoUPyer2SpL/+9a86e/Zs5Jxt27bpoYce0rvvvqvrr7/+Ui4FAABY6pLipqX69++vESNG6OGHH9aSJUvU2NioCRMm6L777ot8UurIkSMqKCjQK6+8ory8PLndbpWVlam8vFxdu3aVy+XSxIkT5fV6I5+U+t+AOX78eOT5LuXTUgAAwF4xiRtJWrFihSZMmKCCggIlJCRo1KhRWrhwYeR4Y2OjampqdObMmci+559/PjK2oaFBhYWFeuGFF2I1RQAAYKGYfM/N1Y7vuQEAIP606/fcAAAAtBfiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYhbgBAABWIW4AAIBViBsAAGAV4gYAAFiFuAEAAFYhbgAAgFWIGwAAYBXiBgAAWIW4AQAAViFuAACAVYgbAABgFeIGAABYJam9J9AejDGSpFAo1M4zAQAALfX16/bXr+MX8r2Mm5MnT0qSMjMz23kmAADgUp08eVJut/uCxx3mu/LHQuFwWEePHlWXLl3kcDha9bFDoZAyMzN1+PBhuVyuVn1sRGOt2w5r3XZY67bDWred1lprY4xOnjypjIwMJSRc+J0138s7NwkJCerdu3dMn8PlcvE/SxthrdsOa912WOu2w1q3ndZY64vdsfkabygGAABWIW4AAIBViJtW5nQ6NXfuXDmdzvaeivVY67bDWrcd1rrtsNZtp63X+nv5hmIAAGAv7twAAACrEDcAAMAqxA0AALAKcQMAAKxC3LSixYsXq0+fPkpJSVF+fr62bt3a3lOKexUVFbrlllvUpUsX9ejRQ0VFRaqpqYkac+7cOY0fP15paWnq3LmzRo0apWAw2E4ztsfTTz8th8OhyZMnR/ax1q3nyJEjuv/++5WWlqaOHTtq0KBB2r59e+S4MUZz5sxRz5491bFjR/l8Ph04cKAdZxyfmpqaNHv2bPXt21cdO3bU9ddfr6eeeirq3yZirS/PO++8o5/+9KfKyMiQw+HQ2rVro463ZF1PnDihkpISuVwupaamqqysTKdOnbryyRm0ilWrVpnk5GTz0ksvmT179piHH37YpKammmAw2N5Ti2uFhYXm5ZdfNrt37za7du0yd955p8nKyjKnTp2KjHnkkUdMZmamqaysNNu3bzdDhw41w4YNa8dZx7+tW7eaPn36mBtvvNFMmjQpsp+1bh0nTpww1157rXnwwQdNVVWV+eSTT8xbb71l/v3vf0fGPP3008btdpu1a9eaDz74wNx1112mb9++5uzZs+048/gzb948k5aWZtatW2c+/fRTs2bNGtO5c2ezYMGCyBjW+vKsX7/ezJw507z22mtGknn99dejjrdkXUeMGGEGDx5s3n//ffPuu++aH/7wh6a4uPiK50bctJK8vDwzfvz4yM9NTU0mIyPDVFRUtOOs7FNbW2skmc2bNxtjjKmrqzMdOnQwa9asiYzZt2+fkWT8fn97TTOunTx50vTr189s3LjRDB8+PBI3rHXrefLJJ81tt912wePhcNh4PB7zzDPPRPbV1dUZp9Np/vznP7fFFK0xcuRI89BDD0Xtu+eee0xJSYkxhrVuLf8bNy1Z17179xpJZtu2bZEx//jHP4zD4TBHjhy5ovnwa6lWcP78eVVXV8vn80X2JSQkyOfzye/3t+PM7FNfXy9J6tq1qySpurpajY2NUWufnZ2trKws1v4yjR8/XiNHjoxaU4m1bk1vvPGGcnNz9fOf/1w9evTQTTfdpD/96U+R459++qkCgUDUWrvdbuXn57PWl2jYsGGqrKzUxx9/LEn64IMP9N577+mOO+6QxFrHSkvW1e/3KzU1Vbm5uZExPp9PCQkJqqqquqLn/17+w5mt7fjx42pqalJ6enrU/vT0dO3fv7+dZmWfcDisyZMn69Zbb9XAgQMlSYFAQMnJyUpNTY0am56erkAg0A6zjG+rVq3Sjh07tG3btm8dY61bzyeffKIXX3xR5eXlmjFjhrZt26bHHntMycnJKi0tjaxnc3+nsNaXZtq0aQqFQsrOzlZiYqKampo0b948lZSUSBJrHSMtWddAIKAePXpEHU9KSlLXrl2veO2JG8SN8ePHa/fu3XrvvffaeypWOnz4sCZNmqSNGzcqJSWlvadjtXA4rNzcXP3ud7+TJN10003avXu3lixZotLS0naenV3+8pe/aMWKFVq5cqVuuOEG7dq1S5MnT1ZGRgZrbTF+LdUKunXrpsTExG99aiQYDMrj8bTTrOwyYcIErVu3Tm+//bZ69+4d2e/xeHT+/HnV1dVFjWftL111dbVqa2t18803KykpSUlJSdq8ebMWLlyopKQkpaens9atpGfPnhowYEDUvv79++vQoUOSFFlP/k65clOnTtW0adN03333adCgQXrggQf0+OOPq6KiQhJrHSstWVePx6Pa2tqo419++aVOnDhxxWtP3LSC5ORk5eTkqLKyMrIvHA6rsrJSXq+3HWcW/4wxmjBhgl5//XVt2rRJffv2jTqek5OjDh06RK19TU2NDh06xNpfooKCAn300UfatWtXZMvNzVVJSUnkv1nr1nHrrbd+6ysNPv74Y1177bWSpL59+8rj8UStdSgUUlVVFWt9ic6cOaOEhOiXusTERIXDYUmsday0ZF29Xq/q6upUXV0dGbNp0yaFw2Hl5+df2QSu6O3IiFi1apVxOp1m2bJlZu/evWbcuHEmNTXVBAKB9p5aXHv00UeN2+02//rXv8znn38e2c6cORMZ88gjj5isrCyzadMms337duP1eo3X623HWdvj/39ayhjWurVs3brVJCUlmXnz5pkDBw6YFStWmE6dOpnly5dHxjz99NMmNTXV/O1vfzMffvihufvuu/l48mUoLS01vXr1inwU/LXXXjPdunUzTzzxRGQMa315Tp48aXbu3Gl27txpJJnnnnvO7Ny50xw8eNAY07J1HTFihLnppptMVVWVee+990y/fv34KPjV5o9//KPJysoyycnJJi8vz7z//vvtPaW4J6nZ7eWXX46MOXv2rPnlL39pfvCDH5hOnTqZn/3sZ+bzzz9vv0lb5H/jhrVuPW+++aYZOHCgcTqdJjs72yxdujTqeDgcNrNnzzbp6enG6XSagoICU1NT006zjV+hUMhMmjTJZGVlmZSUFHPdddeZmTNnmoaGhsgY1vryvP32283+/VxaWmqMadm6fvHFF6a4uNh07tzZuFwuM3bsWHPy5MkrnpvDmP/3NY0AAABxjvfcAAAAqxA3AADAKsQNAACwCnEDAACsQtwAAACrEDcAAMAqxA0AALAKcQMAAKxC3AAAAKsQNwAAwCrEDQAAsApxAwAArPJ/s6un4nY2lk4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(test_average_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "92718187",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.argmax(test_cost_per_transition)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
