{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4bcfc9fd-c1e1-4d83-9fd2-0949976723b7",
   "metadata": {
    "papermill": {
     "duration": 0.014672,
     "end_time": "2025-09-21T11:34:05.972376",
     "exception": false,
     "start_time": "2025-09-21T11:34:05.957704",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Conformal Robustness Control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ecc9092-f73d-414f-814a-a77e4cd33bbc",
   "metadata": {},
   "source": [
    "### CRC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027da74f-f82a-4655-bfbf-94fcbaa6fb97",
   "metadata": {
    "papermill": {
     "duration": 4.37205,
     "end_time": "2025-09-21T11:34:10.398546",
     "exception": false,
     "start_time": "2025-09-21T11:34:06.026496",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import io\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pytz\n",
    "import math\n",
    "import cvxpy as cp\n",
    "import rsome as rso\n",
    "\n",
    "from rsome import ro\n",
    "from rsome import msk_solver as SOLVER\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.distributions import MultivariateNormal\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "from typing import Protocol\n",
    "from typing import NamedTuple\n",
    "from torch import distributions as tdist, nn, Tensor\n",
    "from collections.abc import Mapping\n",
    "from torch import Tensor\n",
    "from collections.abc import Sequence\n",
    "from torch.utils.data import TensorDataset\n",
    "from collections.abc import Mapping, Sequence\n",
    "from pandas.tseries.holiday import USFederalHolidayCalendar\n",
    "from datetime import datetime, timedelta\n",
    "from tqdm.auto import tqdm\n",
    "from typing import Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6277c9ff-e259-4c38-8adb-c89d220c6a88",
   "metadata": {
    "papermill": {
     "duration": 0.031196,
     "end_time": "2025-09-21T11:34:10.452187",
     "exception": false,
     "start_time": "2025-09-21T11:34:10.420991",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "seed = random.randint(0, 99)\n",
    "print(\"Device:\", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f8ca6a-17af-49fc-b3ea-7d20e26605dc",
   "metadata": {
    "papermill": {
     "duration": 0.026759,
     "end_time": "2025-09-21T11:34:10.499888",
     "exception": false,
     "start_time": "2025-09-21T11:34:10.473129",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tz = pytz.timezone('America/New_York')\n",
    "df = pd.read_csv('data.csv', parse_dates=[0])\n",
    "display(df)\n",
    "df['log_da_price'] = np.log(df['da_price'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6427c656-414d-40da-918f-a97b50dd1390",
   "metadata": {
    "papermill": {
     "duration": 0.1398,
     "end_time": "2025-09-21T11:34:10.664722",
     "exception": false,
     "start_time": "2025-09-21T11:34:10.524922",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def validate_data(df: pd.DataFrame) -> None:\n",
    "    expected_dts = pd.date_range(datetime(2011, 1, 3, 0, 0, 0), datetime(2016, 12, 31, 23, 0, 0), freq=timedelta(hours=1))\n",
    "    actual_dts = set(df['datetime'])\n",
    "    assert len(expected_dts.difference(actual_dts)) == 0\n",
    "    for col in ('datetime', 'da_price', 'load_forecast'):\n",
    "        assert df[col].isna().sum() == 0\n",
    "\n",
    "validate_data(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e0a796-39bc-4e5e-9247-33c904640934",
   "metadata": {
    "papermill": {
     "duration": 0.024836,
     "end_time": "2025-09-21T11:34:10.711035",
     "exception": false,
     "start_time": "2025-09-21T11:34:10.686199",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "df['date'] = df['datetime'].dt.date\n",
    "df['hour'] = df['datetime'].dt.hour\n",
    "df_prices = df.pivot(index='date', columns='hour', values='da_price')\n",
    "df_logprices = df.pivot(index='date', columns='hour', values='log_da_price')\n",
    "df_load = df.pivot(index='date', columns='hour', values='load_forecast')\n",
    "df_temp = df.pivot(index='date', columns='hour', values='temp_dca')\n",
    "df_temp = df_temp.transpose().bfill().ffill().transpose()\n",
    "assert df_logprices.index.equals(df_load.index)\n",
    "assert df_logprices.index.equals(df_temp.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "223d9327-9351-4a2c-bf56-a882fa574b4c",
   "metadata": {
    "papermill": {
     "duration": 0.214594,
     "end_time": "2025-09-21T11:34:10.944320",
     "exception": false,
     "start_time": "2025-09-21T11:34:10.729726",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_dates = df_logprices.index.to_series()\n",
    "\n",
    "holidays = USFederalHolidayCalendar().holidays(\n",
    "    start='2011-01-01', end='2017-01-01').date\n",
    "\n",
    "df_feat = pd.DataFrame({\n",
    "    'weekend': df_dates.map(lambda x: x.isoweekday() >= 6),\n",
    "    'holiday': df_dates.isin(holidays),\n",
    "    'dst': df_dates.map(\n",
    "        lambda x: tz.localize(datetime.combine(x, datetime.min.time())).dst().seconds > 0\n",
    "    ),\n",
    "    \"cos_doy\": df_dates.map(lambda x: np.cos(x.timetuple().tm_yday/365*2*np.pi)),\n",
    "    \"sin_doy\": df_dates.map(lambda x: np.sin(x.timetuple().tm_yday/365*2*np.pi))\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45f8b903-c339-4d5d-9b9f-2bd01b6ed6a3",
   "metadata": {
    "papermill": {
     "duration": 0.045006,
     "end_time": "2025-09-21T11:34:11.013776",
     "exception": false,
     "start_time": "2025-09-21T11:34:10.968770",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "X = np.hstack([df_logprices.iloc[:-1].values, \n",
    "               df_load.iloc[1:].values,        \n",
    "               df_temp.iloc[:-1].values,       \n",
    "               df_temp.iloc[1:].values,        \n",
    "               df_feat.iloc[1:].values]).astype(np.float64)\n",
    "X_mean = np.mean(X, axis=0)\n",
    "X_std = np.std(X, axis=0)\n",
    "X = (X - X_mean) / X_std\n",
    "Y = df_prices.iloc[1:].values\n",
    "\n",
    "dates = np.array(df_prices.iloc[1:].index).astype('datetime64[D]')\n",
    "np.savez_compressed(f'data.npz', X=X, Y=Y, X_mean=X_mean, X_std=X_std, dates=dates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f59bba5-d078-4388-9c05-a1aa16f218a3",
   "metadata": {
    "papermill": {
     "duration": 0.062589,
     "end_time": "2025-09-21T11:34:11.095554",
     "exception": false,
     "start_time": "2025-09-21T11:34:11.032965",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "TRAIN_CALIB_FRAC = 0.8 \n",
    "TRAIN_FRAC = 0.3        \n",
    "\n",
    "def load_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n",
    "    with np.load(f'data.npz') as npz:\n",
    "        X = npz['X']\n",
    "        Y = npz['Y']\n",
    "        dates = npz['dates']\n",
    "    assert X.shape == (2189, 101)\n",
    "    assert Y.shape == (2189, 24)\n",
    "    return X.astype(np.float32), Y.astype(np.float32), dates.astype('datetime64[D]')\n",
    "\n",
    "def get_traincalib_test_split(\n",
    "    X: np.ndarray, Y: np.ndarray, dates: np.ndarray, shuffle: bool = False\n",
    ") -> dict[str, Tensor | np.ndarray]:\n",
    "   \n",
    "    if shuffle:\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        inds = rng.permutation(len(X))\n",
    "        X = X[inds]\n",
    "        Y = Y[inds]\n",
    "        dates = dates[inds]\n",
    "\n",
    "    n = int(X.shape[0] * TRAIN_CALIB_FRAC)\n",
    "    tensors: dict[str, Tensor | np.ndarray] = {\n",
    "        'X_traincalib': torch.from_numpy(X[:n]),\n",
    "        'Y_traincalib': torch.from_numpy(Y[:n]),\n",
    "        'X_test': torch.from_numpy(X[n:]),\n",
    "        'Y_test': torch.from_numpy(Y[n:]),\n",
    "        'date_traincalib': dates[:n],\n",
    "        'date_test': dates[n:]\n",
    "    }\n",
    "    return tensors\n",
    "\n",
    "\n",
    "def get_train_calib_split(\n",
    "    tensors: Mapping[str, Tensor | np.ndarray], seed: int\n",
    ") -> tuple[dict[str, Tensor], dict[str, np.ndarray]]:\n",
    "  \n",
    "    X_traincalib: Tensor = tensors['X_traincalib']  \n",
    "    Y_traincalib: Tensor = tensors['Y_traincalib']  \n",
    "    dates_traincalib: np.ndarray = tensors['date_traincalib']  \n",
    "\n",
    "    rng = np.random.default_rng(seed=seed)\n",
    "    inds = rng.permutation(len(X_traincalib))\n",
    "\n",
    "    n_train = int(len(X_traincalib) * TRAIN_FRAC)\n",
    "    train_inds = inds[:n_train]\n",
    "    calib_inds = inds[n_train:]\n",
    "\n",
    "    new_tensors: dict[str, Tensor] = {\n",
    "        'X_train': X_traincalib[train_inds],\n",
    "        'Y_train': Y_traincalib[train_inds],\n",
    "        'X_calib': X_traincalib[calib_inds],\n",
    "        'Y_calib': Y_traincalib[calib_inds],\n",
    "        'X_test': tensors['X_test'],  \n",
    "        'Y_test': tensors['Y_test'],  \n",
    "    }\n",
    "    new_dates: dict[str, np.ndarray] = {\n",
    "        'train': dates_traincalib[train_inds],\n",
    "        'calib': dates_traincalib[calib_inds],\n",
    "        'test': tensors['date_test'],  \n",
    "    }\n",
    "    return new_tensors, new_dates\n",
    "\n",
    "def get_loaders(tensors: Mapping[str, Tensor], batch_size: int) -> dict[str, DataLoader]:\n",
    "   \n",
    "    shuffle = True\n",
    "    if batch_size == -1:\n",
    "        batch_size = len(tensors['X_train'])\n",
    "        shuffle = False\n",
    "    train_loader = DataLoader(\n",
    "        TensorDataset(tensors['X_train'], tensors['Y_train']),\n",
    "        shuffle=shuffle, batch_size=batch_size, pin_memory=True)\n",
    "\n",
    "    num_test = len(tensors['X_test'])\n",
    "    test_loader = DataLoader(\n",
    "        TensorDataset(tensors['X_test'], tensors['Y_test']),\n",
    "        shuffle=False, batch_size=num_test, pin_memory=True)\n",
    "\n",
    "    num_calib = len(tensors['X_calib'])\n",
    "    calib_loader = DataLoader(\n",
    "        TensorDataset(tensors['X_calib'], tensors['Y_calib']),\n",
    "        shuffle=False, batch_size=num_calib, pin_memory=True)\n",
    "    return {'train': train_loader, 'test': test_loader, 'calib': calib_loader}\n",
    "\n",
    "def get_tensors(\n",
    "    shuffle: bool, log_prices: bool\n",
    ") -> tuple[dict[str, Tensor | np.ndarray], str | tuple[np.ndarray, np.ndarray]]:\n",
    "   \n",
    "    X, Y, dates = load_data()\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray]\n",
    "    if log_prices:\n",
    "        Y = np.log(Y)\n",
    "        y_info = 'log'\n",
    "    else:\n",
    "        Y_mean = Y.mean(axis=0)\n",
    "        Y_std = Y.std(axis=0)\n",
    "        Y_1 = (Y - Y_mean) / Y_std\n",
    "        y_info = (Y_mean, Y_std)\n",
    "    tensors = get_traincalib_test_split(X, Y, dates, shuffle=shuffle)\n",
    "    return tensors, y_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da945b63-5b05-4ffd-a167-f0930e0d7e2d",
   "metadata": {
    "papermill": {
     "duration": 0.03975,
     "end_time": "2025-09-21T11:34:11.158961",
     "exception": false,
     "start_time": "2025-09-21T11:34:11.119211",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tensors_tc, y_info = get_tensors(shuffle=True, log_prices=False)\n",
    "tensors, dates = get_train_calib_split(tensors_tc, seed=seed)\n",
    "\n",
    "X_train, Y_train = tensors['X_train'], tensors['Y_train']\n",
    "X_cal,   Y_cal   = tensors['X_calib'], tensors['Y_calib']   \n",
    "X_test,  Y_test  = tensors['X_test'],  tensors['Y_test']\n",
    "loaders = get_loaders(tensors, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f4f3b8-2996-4058-80b6-da0727156fc8",
   "metadata": {
    "papermill": {
     "duration": 0.029719,
     "end_time": "2025-09-21T11:34:11.782863",
     "exception": false,
     "start_time": "2025-09-21T11:34:11.753144",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "\n",
    "def create_storage_box_layer(\n",
    "    T: int,\n",
    "    const,                         \n",
    "    simplex: bool = False      \n",
    ") -> CvxpyLayer:\n",
    "\n",
    "    lam, eps, eff, c_in, c_out, B = const\n",
    "    z_state = cp.Variable(T + 1, name='z_state', nonneg=True)\n",
    "    z_in    = cp.Variable(T,     name='z_in',    nonneg=True)\n",
    "    z_out   = cp.Variable(T,     name='z_out',   nonneg=True)\n",
    "    v       = cp.Variable(T,     name='v',       nonneg=True)  \n",
    "\n",
    "    y_lo = cp.Parameter(T, name='y_lo')   \n",
    "    y_hi = cp.Parameter(T, name='y_hi')  \n",
    "\n",
    "    cons = [\n",
    "        z_state[0] == B/2,\n",
    "        z_state[1:] == z_state[:-1] + eff * z_in - z_out,  \n",
    "        z_state <= B,                                      \n",
    "        z_in  <= c_in,\n",
    "        z_out <= c_out,\n",
    "    ]\n",
    "\n",
    "    c_vec = z_in - z_out                                  \n",
    "    cons += [v - c_vec >= 0]\n",
    "\n",
    "    box_dual_obj = (y_hi - y_lo) @ v + y_lo @ c_vec\n",
    "    reg_obj = (\n",
    "        lam * cp.sum_squares(z_state - B/2) +\n",
    "        eps * cp.sum_squares(z_in) +\n",
    "        eps * cp.sum_squares(z_out)\n",
    "    )\n",
    "    obj = cp.Minimize(box_dual_obj + reg_obj)\n",
    "    prob = cp.Problem(obj, cons)\n",
    "    assert prob.is_dpp('dcp')\n",
    "\n",
    "    return CvxpyLayer(prob, parameters=[y_lo, y_hi],\n",
    "                      variables=[z_in, z_out, z_state, v])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2405853-3b8c-42a2-a342-c70bd02d0663",
   "metadata": {
    "papermill": {
     "duration": 0.032388,
     "end_time": "2025-09-21T11:34:11.834237",
     "exception": false,
     "start_time": "2025-09-21T11:34:11.801849",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BoxQuantileModel(nn.Module):\n",
    "  \n",
    "    def __init__(self, input_dim, hidden_dims, n=24, dropout_rate=0.5, enable_implicit=True):\n",
    "        super().__init__()\n",
    "        self.n = n                \n",
    "        self.enable_implicit = enable_implicit\n",
    "\n",
    "        if isinstance(hidden_dims, int):\n",
    "            hidden_dims = [hidden_dims] * 3\n",
    "\n",
    "        layers, in_dim = [], input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [\n",
    "                nn.Linear(in_dim, h),\n",
    "                nn.BatchNorm1d(h),\n",
    "                nn.ReLU(),\n",
    "                nn.Dropout(dropout_rate),\n",
    "            ]\n",
    "            in_dim = h\n",
    "        self.backbone   = nn.Sequential(*layers)\n",
    "        self.head_lo    = nn.Linear(in_dim, n)\n",
    "        self.head_delta = nn.Linear(in_dim, n)\n",
    "\n",
    "    def forward(self, x, box_layer=None, q_hat=None):\n",
    "    \n",
    "        B = x.size(0)\n",
    "        h = self.backbone(x)\n",
    "\n",
    "        y_lo  = self.head_lo(h)                          \n",
    "        delta = F.softplus(self.head_delta(h), beta=50)  \n",
    "        y_hi  = y_lo + delta\n",
    "\n",
    "        if q_hat is not None:\n",
    "            if not torch.is_tensor(q_hat):\n",
    "                q_hat = torch.tensor(q_hat, dtype=y_lo.dtype, device=x.device)\n",
    "            assert q_hat.dim() == 0, \"q Error\"\n",
    "            y_lo_adj = y_lo - q_hat\n",
    "            y_hi_adj = y_hi + q_hat\n",
    "        else:\n",
    "            y_lo_adj, y_hi_adj = y_lo, y_hi\n",
    "\n",
    "        y_hi_adj = torch.maximum(y_hi_adj, y_lo_adj + 1e-8)\n",
    "\n",
    "        if (not self.enable_implicit) or (box_layer is None):\n",
    "            return y_lo_adj, y_hi_adj, None\n",
    "\n",
    "        z_in, z_out, z_state, v = box_layer(y_lo_adj.detach().cpu(), y_hi_adj.detach().cpu())\n",
    "\n",
    "        def _squeeze_last(t):\n",
    "            return t.squeeze(-1) if (t.dim() == 3 and t.size(-1) == 1) else t\n",
    "\n",
    "        z_in    = _squeeze_last(z_in   ).to(x.device, dtype=y_lo_adj.dtype)\n",
    "        z_out   = _squeeze_last(z_out  ).to(x.device, dtype=y_lo_adj.dtype)\n",
    "        z_state = _squeeze_last(z_state).to(x.device, dtype=y_lo_adj.dtype)\n",
    "        v       = _squeeze_last(v      ).to(x.device, dtype=y_lo_adj.dtype)\n",
    "\n",
    "        return y_lo_adj, y_hi_adj, (z_in, z_out, z_state, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aedf628-1d5b-46de-8583-fbab96c88e5c",
   "metadata": {
    "papermill": {
     "duration": 0.027393,
     "end_time": "2025-09-21T11:34:11.889062",
     "exception": false,
     "start_time": "2025-09-21T11:34:11.861669",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class PinballLoss(nn.Module):\n",
    "    def __init__(self, q: float):\n",
    "        super().__init__()\n",
    "        self.q = q\n",
    "    def forward(self, pred, target):\n",
    "        e = target - pred               \n",
    "        loss = torch.maximum(self.q*e, (self.q-1)*e)  \n",
    "        return loss.sum(dim=-1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f97d7a30-60e4-47c5-b8e9-b3a7b2158e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BoxQuantileModel(input_dim=101, hidden_dims=[20,20], n=24, dropout_rate = 0.3, enable_implicit=False)\n",
    "device = torch.device(\"cpu\")\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-3) #1e-3\n",
    "\n",
    "alpha  = 0.1\n",
    "\n",
    "q_lo   = alpha/2.0\n",
    "q_hi   = 1.0 - alpha/2.0\n",
    "loss_lo = PinballLoss(q_lo)\n",
    "loss_hi = PinballLoss(q_hi)\n",
    "\n",
    "for epoch in range(1, 300+1):\n",
    "    model.train()\n",
    "    total = 0.0\n",
    "    for x_cpu, y_cpu in loaders['train']:\n",
    "        x = x_cpu.to(device, non_blocking=True)\n",
    "        y = y_cpu.to(device, non_blocking=True)\n",
    "\n",
    "        y_lo_pred, y_hi_pred, _ = model(x, box_layer=None, q_hat=None) \n",
    "        y_hi_pred = torch.maximum(y_hi_pred, y_lo_pred + 1e-8)\n",
    "\n",
    "        loss = loss_lo(y_lo_pred, y) + loss_hi(y_hi_pred, y)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total += loss.item() * x.size(0)\n",
    "\n",
    "    avg = total / len(loaders['train'].dataset)\n",
    "    print(f\"Epoch {epoch:>3d} | Quantile Loss: {avg:.6f}\")\n",
    "\n",
    "os.makedirs('outputs/model/baseline', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "torch.save(model.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e65366-7e58-42bf-b769-0ccd3c5258bf",
   "metadata": {
    "papermill": {
     "duration": 0.046449,
     "end_time": "2025-09-21T11:34:30.022331",
     "exception": false,
     "start_time": "2025-09-21T11:34:29.975882",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def update_theta(\n",
    "    model,\n",
    "    storage_box_layer,         \n",
    "    opt_h,\n",
    "    x, y,\n",
    "    lambda_param,\n",
    "    alpha,\n",
    "    sigma,\n",
    "    batt_consts,              \n",
    "    q_hat=None                 \n",
    "):\n",
    "    lam_batt, eps_batt, B = batt_consts\n",
    "\n",
    "    y_lo, y_hi, sol = model(x, box_layer=storage_box_layer, q_hat=q_hat)\n",
    "    y_hi = torch.maximum(y_hi, y_lo + 1e-8)\n",
    "\n",
    "    assert isinstance(sol, (tuple, list)) and len(sol) == 4, \"expect (z_in,z_out,z_state,v)\"\n",
    "    z_in, z_out, z_state, v = sol\n",
    "    c_vec = z_in - z_out  \n",
    "\n",
    "    box_dual = (v * (y_hi - y_lo) + y_lo * c_vec).sum(dim=1)  # [B]\n",
    "    reg_term = (\n",
    "        lam_batt * ((z_state - B/2.0) ** 2).sum(dim=1) +\n",
    "        eps_batt * (z_in ** 2).sum(dim=1) +\n",
    "        eps_batt * (z_out ** 2).sum(dim=1))  \n",
    "    r = box_dual + reg_term  \n",
    "\n",
    "    s = (y * c_vec).sum(dim=1) + reg_term  # [B]\n",
    "\n",
    "    approx_ind = 0.5 * (1.0 + torch.erf((r - s) / (sigma * math.sqrt(2.0))))\n",
    "\n",
    "    lam = lambda_param.detach()\n",
    "    g = lam * ((1.0 - alpha) - approx_ind.mean())\n",
    "\n",
    "    loss = r.mean() + g\n",
    "    opt_h.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_h.step()\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "def update_lambda(\n",
    "    model,\n",
    "    storage_box_layer,\n",
    "    opt_l,\n",
    "    x, y,\n",
    "    lambda_param,\n",
    "    alpha,\n",
    "    batt_consts,\n",
    "    q_hat=None\n",
    "):\n",
    "    lam_batt, eps_batt, B = batt_consts\n",
    "    device = x.device\n",
    "\n",
    "    with torch.no_grad():\n",
    "        y_lo_det, y_hi_det, _ = model(x, box_layer=None, q_hat=q_hat)\n",
    "        y_hi_det = torch.maximum(y_hi_det, y_lo_det + 1e-8)\n",
    "\n",
    "    z_in, z_out, z_state, v = storage_box_layer(y_lo_det.detach().cpu(), y_hi_det.detach().cpu())\n",
    "    \n",
    "    def _sq(t):  \n",
    "        return t.squeeze(-1) if (t.dim() == 3 and t.size(-1) == 1) else t\n",
    "        \n",
    "    z_in, z_out, z_state, v = map(_sq, (z_in, z_out, z_state, v))\n",
    "    z_in, z_out, z_state, v = (t.to(device, dtype=y_lo_det.dtype) for t in (z_in, z_out, z_state, v))\n",
    "\n",
    "    c_vec = z_in - z_out  \n",
    "\n",
    "    box_dual = (v * (y_hi_det - y_lo_det) + y_lo_det * c_vec).sum(dim=1)  # [B]\n",
    "    reg_term = (\n",
    "        lam_batt * ((z_state - B/2.0) ** 2).sum(dim=1) +\n",
    "        eps_batt * (z_in ** 2).sum(dim=1) +\n",
    "        eps_batt * (z_out ** 2).sum(dim=1)\n",
    "    )\n",
    "    r = box_dual + reg_term  # [B]\n",
    "\n",
    "    s = (y * c_vec).sum(dim=1) + reg_term # [B]\n",
    "    indicator = (s <= r).float()\n",
    "\n",
    "    obj = r.mean() + lambda_param * ((1.0 - alpha) - indicator.mean())\n",
    "    loss = -obj\n",
    "    opt_l.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_l.step()\n",
    "    with torch.no_grad():\n",
    "        lambda_param.clamp_(min=0.0)\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "def total_loss(\n",
    "    model,\n",
    "    storage_box_layer,\n",
    "    x, y,\n",
    "    lambda_param,\n",
    "    alpha,\n",
    "    batt_consts,\n",
    "    q_hat=None,\n",
    "    use_hard_indicator=True,\n",
    "    sigma: float = 0.1\n",
    "):\n",
    "    lam_batt, eps_batt, B = batt_consts\n",
    "\n",
    "    with torch.no_grad():\n",
    "        y_lo, y_hi, sol = model(x, box_layer=storage_box_layer, q_hat=q_hat)\n",
    "        y_hi = torch.maximum(y_hi, y_lo + 1e-8)\n",
    "\n",
    "        assert isinstance(sol, (tuple, list)) and len(sol) == 4\n",
    "        z_in, z_out, z_state, v = sol\n",
    "        c_vec = z_in - z_out\n",
    "\n",
    "        box_dual = (v * (y_hi - y_lo) + y_lo * c_vec).sum(dim=1)\n",
    "        reg_term = (\n",
    "            lam_batt * ((z_state - B/2.0) ** 2).sum(dim=1) +\n",
    "            eps_batt * (z_in ** 2).sum(dim=1) +\n",
    "            eps_batt * (z_out ** 2).sum(dim=1)\n",
    "        )\n",
    "        r = box_dual + reg_term\n",
    "\n",
    "        s = (y * c_vec).sum(dim=1) + reg_term\n",
    "\n",
    "        if use_hard_indicator:\n",
    "            indicator = (s <= r).float()\n",
    "        else:\n",
    "            indicator = 0.5 * (1.0 + torch.erf((r - s) / (sigma * math.sqrt(2.0))))\n",
    "\n",
    "        f = r.mean()\n",
    "        g = lambda_param * ((1.0 - alpha) - indicator.mean())\n",
    "        return (f + g).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61316057-e795-423d-80aa-77441fd09bc4",
   "metadata": {
    "papermill": {
     "duration": 0.058295,
     "end_time": "2025-09-21T11:34:30.116866",
     "exception": false,
     "start_time": "2025-09-21T11:34:30.058571",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def compute_q_hat_box_full(\n",
    "    model,                 \n",
    "    X_cal: torch.Tensor,   \n",
    "    Y_cal: torch.Tensor,   \n",
    "    alpha: float,\n",
    "    device: torch.device\n",
    ") -> torch.Tensor:\n",
    "   \n",
    "    model.eval()\n",
    "    x = X_cal.to(device, non_blocking=True)\n",
    "    y = Y_cal.to(device, non_blocking=True)\n",
    "\n",
    "    y_lo, y_hi, _ = model(x, box_layer=None, q_hat=None)  \n",
    "    y_hi = torch.maximum(y_hi, y_lo + 1e-8)\n",
    "\n",
    "    s_left  = (y_lo - y).amax(dim=1)   # [N]\n",
    "    s_right = (y - y_hi).amax(dim=1)   # [N]\n",
    "    scores  = torch.maximum(s_left, s_right)   # [N]\n",
    "\n",
    "    q_hat = torch.quantile(scores, 1.0 - alpha)\n",
    "    q_hat = torch.clamp(q_hat, min=0.0)\n",
    "    return q_hat.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a40b2560-b297-4226-b21b-e02a9e45fb8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 24  \n",
    "n = T\n",
    "\n",
    "lam_batt = 0.1\n",
    "eps_batt = 0.05\n",
    "eff      = 0.9\n",
    "c_in     = 0.5\n",
    "c_out    = 0.2\n",
    "B        = 1.0\n",
    "batt_consts = (lam_batt, eps_batt, B)\n",
    "\n",
    "cvxpylayer = create_storage_box_layer(\n",
    "    T=T,\n",
    "    const=(lam_batt, eps_batt, eff, c_in, c_out, B),   \n",
    "    simplex=False)\n",
    "\n",
    "model = BoxQuantileModel(\n",
    "    input_dim=101, hidden_dims=[20,20], n=n,\n",
    "    enable_implicit=True,\n",
    "    dropout_rate = 0.2,\n",
    ").to(device)\n",
    "\n",
    "model_path = os.path.join('outputs', 'model', 'baseline', f'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "\n",
    "opt_h = torch.optim.Adam(model.parameters(), lr=5e-4)\n",
    "lambda_param = nn.Parameter(torch.tensor(7.0, device=device), requires_grad=True)\n",
    "opt_l = torch.optim.Adam([lambda_param], lr=1e-3)\n",
    "\n",
    "N = X_cal.shape[0]\n",
    "n = int(N/2)\n",
    "X_cal_1 = X_cal[:n];    Y_cal_1 = Y_cal[:n]\n",
    "X_cal_2 = X_cal[n:];    Y_cal_2 = Y_cal[n:]\n",
    "X_cal_1 = X_cal_1.to(device); Y_cal_1 = Y_cal_1.to(device)\n",
    "X_cal_2 = X_cal_2.to(device); Y_cal_2 = Y_cal_2.to(device)\n",
    "X_cal   = X_cal.to(device);   y_cal   = Y_cal.to(device)\n",
    "\n",
    "alpha  = 0.1\n",
    "sigma  = 0.1\n",
    "epochs = 300\n",
    "\n",
    "total_hist = []\n",
    "prev_loss = None\n",
    "consecutive_count = 0\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    model.train()\n",
    "    l_t = update_theta(\n",
    "        model=model,\n",
    "        storage_box_layer=cvxpylayer,  \n",
    "        opt_h=opt_h,\n",
    "        x=X_cal_1, y=Y_cal_1,\n",
    "        lambda_param=lambda_param,\n",
    "        alpha=alpha,\n",
    "        sigma=sigma,\n",
    "        batt_consts=batt_consts,       \n",
    "        q_hat=None                     \n",
    "    )\n",
    "\n",
    "    l_l = update_lambda(\n",
    "        model=model,\n",
    "        storage_box_layer=cvxpylayer,\n",
    "        opt_l=opt_l,\n",
    "        x=X_cal_2, y=Y_cal_2,\n",
    "        lambda_param=lambda_param,\n",
    "        alpha=alpha,\n",
    "        batt_consts=batt_consts,\n",
    "        q_hat=None\n",
    "    )\n",
    "\n",
    "    tot = total_loss(\n",
    "        model=model,\n",
    "        storage_box_layer=cvxpylayer,\n",
    "        x=X_cal, y=y_cal,\n",
    "        lambda_param=lambda_param,\n",
    "        alpha=alpha,\n",
    "        batt_consts=batt_consts,\n",
    "        q_hat=None,\n",
    "        use_hard_indicator=True,\n",
    "        sigma=sigma\n",
    "    )\n",
    "    total_hist.append(tot)\n",
    "\n",
    "    print(f\"Epoch {epoch:02d} | loss_θ={l_t:.6f} | loss_λ={l_l:.6f} | total={tot:.6f} | λ={lambda_param.item():.4f}\")\n",
    "\n",
    "    if epoch >= 100:\n",
    "        if prev_loss is not None and abs(prev_loss - tot) < 5e-4:\n",
    "            consecutive_count += 1\n",
    "        else:\n",
    "            consecutive_count = 0\n",
    "        if consecutive_count >= 10:\n",
    "            print(f\"Early stopping at epoch {epoch} due to small loss change for 10 consecutive epochs.\")\n",
    "            break\n",
    "    prev_loss = tot\n",
    "\n",
    "os.makedirs('outputs/model/final', exist_ok=True)\n",
    "torch.save(model.state_dict(),\n",
    "           os.path.join('outputs', 'model', 'final', 'final_model.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656c0046-9a75-46ad-bb36-473203897413",
   "metadata": {
    "papermill": {
     "duration": 2.785828,
     "end_time": "2025-09-21T12:40:01.782358",
     "exception": false,
     "start_time": "2025-09-21T12:39:58.996530",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "lam = 0.1\n",
    "eps = 0.05\n",
    "B   = 1.0\n",
    "\n",
    "model.eval()\n",
    "model.enable_implicit = True\n",
    "\n",
    "X_test_cpu = X_test.to(device, non_blocking=True)\n",
    "Y_test_cpu = Y_test.to(device, non_blocking=True)\n",
    "\n",
    "with torch.no_grad():\n",
    "    y_lo, y_hi, sol = model(X_test_cpu, box_layer=cvxpylayer, q_hat=None)\n",
    "\n",
    "y_hi = torch.maximum(y_hi, y_lo + 1e-8)\n",
    "\n",
    "assert isinstance(sol, (tuple, list)) and len(sol) == 4, \"expect (z_in, z_out, z_state, v)\"\n",
    "z_in, z_out, z_state, v = sol  \n",
    "c_vec = z_in - z_out  \n",
    "\n",
    "inside_joint = ((Y_test_cpu >= y_lo) & (Y_test_cpu <= y_hi)).all(dim=1)  # [B]\n",
    "pct_in = inside_joint.float().mean().item() * 100.0\n",
    "\n",
    "box_dual = (v * (y_hi - y_lo) + y_lo * c_vec).sum(dim=1)  # [B]\n",
    "reg_term = (\n",
    "    lam * ((z_state - B/2.0) ** 2).sum(dim=1) +\n",
    "    eps * (z_in ** 2).sum(dim=1) +\n",
    "    eps * (z_out ** 2).sum(dim=1)\n",
    ") \n",
    "\n",
    "h_full = box_dual + reg_term \n",
    "Average_Risk = h_full.mean().item()\n",
    "\n",
    "loss_revenue = (Y_test_cpu * c_vec).sum(dim=1)       \n",
    "loss_full    = loss_revenue + reg_term  \n",
    "\n",
    "Average_Loss = loss_full.mean().item()\n",
    "robustness = (loss_full <= h_full).float().mean().item() * 100.0\n",
    "\n",
    "print(f\"Marginal Coverage:{pct_in:.2f}%\")\n",
    "print(f\"Average_risk: {Average_Risk:.4f}\")\n",
    "print(f\"Average_loss: {Average_Loss:.4f}\")\n",
    "print(f\"Robustness: {robustness:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8a91ac-516d-4e03-8645-391855da1fd7",
   "metadata": {
    "papermill": {
     "duration": 0.171488,
     "end_time": "2025-09-21T12:40:02.034032",
     "exception": false,
     "start_time": "2025-09-21T12:40:01.862544",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_CRC'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_CRC.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {pct_in:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1dbc539-58c9-4d00-93ae-fd45da653c0e",
   "metadata": {
    "papermill": {
     "duration": 0.046235,
     "end_time": "2025-09-21T12:40:02.203069",
     "exception": false,
     "start_time": "2025-09-21T12:40:02.156834",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### CRO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbf9392-0f04-4e8b-8d9c-485ae46d2a9a",
   "metadata": {
    "papermill": {
     "duration": 0.257314,
     "end_time": "2025-09-21T12:40:02.667729",
     "exception": false,
     "start_time": "2025-09-21T12:40:02.410415",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "\n",
    "model = BoxQuantileModel(\n",
    "    input_dim=101, hidden_dims=[20,20], n=24, dropout_rate=0.5,\n",
    "    enable_implicit=False\n",
    ").to(device)         \n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b72b22-0d43-4160-a425-df5da0ebc302",
   "metadata": {
    "papermill": {
     "duration": 0.059773,
     "end_time": "2025-09-21T12:40:02.924135",
     "exception": false,
     "start_time": "2025-09-21T12:40:02.864362",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "alpha = 0.1\n",
    "    \n",
    "X_cal_cpu  = X_cal.cpu()\n",
    "Y_cal_cpu  = Y_cal.cpu()\n",
    "X_test_cpu = X_test.cpu()\n",
    "Y_test_cpu = Y_test.cpu()\n",
    "\n",
    "q_hat = compute_q_hat_box_full(model, X_cal_cpu, Y_cal_cpu, alpha, device)\n",
    "y_lo, y_hi, _ = model(X_test_cpu, q_hat = None)\n",
    "y_lo_test, y_hi_test, _ = model(X_test_cpu, q_hat = q_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9102613-eb5c-4362-a43d-fc9eab27932b",
   "metadata": {
    "papermill": {
     "duration": 0.134387,
     "end_time": "2025-09-21T12:40:03.159678",
     "exception": false,
     "start_time": "2025-09-21T12:40:03.025291",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "s_left  = (y_lo - Y_test_cpu).amax(dim=1)   \n",
    "s_right = (Y_test_cpu - y_hi).amax(dim=1)   \n",
    "scores  = torch.maximum(s_left, s_right)   \n",
    "count_1 = 0\n",
    "for i in range(Y_test.shape[0]):\n",
    "    if scores[i] <= q_hat:\n",
    "        count_1+=1\n",
    "print('Marginal Coverage:', count_1/Y_test.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f86275d-a5ea-4e5a-8d9a-697635120461",
   "metadata": {},
   "outputs": [],
   "source": [
    "B       = 1.0\n",
    "gamma   = 0.9\n",
    "c_in    = 0.5\n",
    "c_out   = 0.2\n",
    "lam     = 0.1\n",
    "eps_reg = 0.05\n",
    "\n",
    "z_in_value    = []\n",
    "z_out_value   = []\n",
    "z_state_value = []\n",
    "obj_value     = []\n",
    "r_value = []\n",
    "\n",
    "y_lo_np = y_lo_test.detach().cpu().numpy().astype(float)\n",
    "y_hi_np = y_hi_test.detach().cpu().numpy().astype(float)\n",
    "n_test  = y_lo_np.shape[0]\n",
    "\n",
    "for i in range(n_test):\n",
    "    lb = y_lo_np[i].reshape(-1)   # [T_i]\n",
    "    ub = y_hi_np[i].reshape(-1)   # [T_i]\n",
    "    T_i = lb.shape[0]\n",
    "    e   = np.ones(T_i)\n",
    "    I_T = np.eye(T_i)\n",
    "\n",
    "    model = ro.Model()\n",
    "\n",
    "    y = model.rvar(T_i)                \n",
    "    uset = (lb <= y, y <= ub)        \n",
    "\n",
    "    z_in    = model.dvar(T_i)\n",
    "    z_out   = model.dvar(T_i)\n",
    "    z_state = model.dvar(T_i)\n",
    "\n",
    "    s1 = model.dvar()   \n",
    "    s2 = model.dvar()   \n",
    "    s3 = model.dvar()   \n",
    "    r  = model.dvar()   \n",
    "\n",
    "    model.st(z_in  >= 0);  model.st(z_in  <= c_in)\n",
    "    model.st(z_out >= 0);  model.st(z_out <= c_out)\n",
    "    model.st(z_state >= 0); model.st(z_state <= B)\n",
    "\n",
    "    model.st(z_state[0] == B/2)\n",
    "    for t in range(1, T_i):\n",
    "        model.st(z_state[t] == z_state[t-1] - z_out[t] + gamma * z_in[t])\n",
    "\n",
    "    model.st(s1 >= rso.quad(z_state - (B/2)*e, I_T))\n",
    "    model.st(s2 >= rso.quad(z_in,  I_T))\n",
    "    model.st(s3 >= rso.quad(z_out, I_T))\n",
    "    model.st(r  >= lam * s1 + eps_reg * (s2 + s3))\n",
    "\n",
    "    c_vec = z_in - z_out\n",
    "    model.minmax(y @ c_vec + r, uset )\n",
    "\n",
    "    model.solve(SOLVER, display=False)\n",
    "\n",
    "    z_in_value.append(z_in.get())\n",
    "    z_out_value.append(z_out.get())\n",
    "    z_state_value.append(z_state.get())\n",
    "    obj_value.append(model.get())\n",
    "\n",
    "    print(f\"Test_sample: {i}, Completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c1c10d1-8007-426a-bd25-91e62c5189f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "risk_value = []\n",
    "n_test = y_lo_np.shape[0]   \n",
    "\n",
    "for i in range(n_test):\n",
    "    lb = y_lo_np[i].ravel()          \n",
    "    ub = y_hi_np[i].ravel()          \n",
    "    T_i = lb.shape[0]\n",
    "\n",
    "    z_in_i    = z_in_value[i].reshape(-1)\n",
    "    z_out_i   = z_out_value[i].reshape(-1)\n",
    "    z_state_i = z_state_value[i].reshape(-1)\n",
    "    c_vec = z_in_i - z_out_i          \n",
    "\n",
    "    model = ro.Model()\n",
    "    y = model.dvar(T_i)               \n",
    "    model.max(y @ c_vec)            \n",
    "    model.st(lb <= y); model.st(y <= ub)\n",
    "    model.solve(SOLVER, display=False)\n",
    "\n",
    "    worst_lin = model.get()           \n",
    "    reg_val = (\n",
    "        lam * ((z_state_i - B/2.0)**2).sum()\n",
    "        + eps_reg * ( (z_in_i**2).sum() + (z_out_i**2).sum() )\n",
    "    )\n",
    "\n",
    "    risk = float(worst_lin + reg_val)\n",
    "    risk_value.append(risk)\n",
    "    print(f\"Test_sample: {i}, Completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a4ae64-5aaa-4088-aa14-2d751298eb19",
   "metadata": {
    "papermill": {
     "duration": 0.179491,
     "end_time": "2025-09-21T12:41:05.039151",
     "exception": false,
     "start_time": "2025-09-21T12:41:04.859660",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "lam     = float(lam)       \n",
    "eps_reg = float(eps_reg)   \n",
    "B_val   = float(B)         \n",
    "\n",
    "Y_t   = torch.as_tensor(Y_test_cpu, device=device, dtype=torch.float32)      \n",
    "ylo_t = torch.as_tensor(y_lo_test, device=device, dtype=torch.float32)       \n",
    "yhi_t = torch.as_tensor(y_hi_test, device=device, dtype=torch.float32)       \n",
    "\n",
    "N, T = Y_t.shape\n",
    "\n",
    "inside = (Y_t >= ylo_t) & (Y_t <= yhi_t)          # [N,T]\n",
    "pct_joint = inside.all(dim=1).float().mean().item() * 100.0\n",
    "\n",
    "z_in_np    = np.stack(z_in_value[:N],  axis=0).astype(np.float32)   \n",
    "z_out_np   = np.stack(z_out_value[:N], axis=0).astype(np.float32)   \n",
    "z_state_np = np.stack(z_state_value[:N], axis=0).astype(np.float32) \n",
    "\n",
    "z_in_t    = torch.as_tensor(z_in_np,  device=device, dtype=Y_t.dtype)\n",
    "z_out_t   = torch.as_tensor(z_out_np, device=device, dtype=Y_t.dtype)\n",
    "z_state_t = torch.as_tensor(z_state_np, device=device, dtype=Y_t.dtype)\n",
    "\n",
    "if z_state_t.shape[1] == T + 1:\n",
    "    z_state_use = z_state_t[:, 1:]    \n",
    "elif z_state_t.shape[1] == T:\n",
    "    z_state_use = z_state_t         \n",
    "else:\n",
    "    raise ValueError(f\"Unexpected z_state shape: {z_state_t.shape}, expect [N,T] or [N,T+1].\")\n",
    "\n",
    "c_t = z_in_t - z_out_t \n",
    "\n",
    "risk_arr = np.asarray(risk_value[:N], dtype=np.float64)  \n",
    "Average_Risk_1 = float(risk_arr.mean())\n",
    "\n",
    "decision_loss_1 = torch.einsum('ij,ij->i', Y_t, c_t)  \n",
    "\n",
    "reg_term = (\n",
    "    lam     * ((z_state_use - B_val/2.0)**2).sum(dim=1) +\n",
    "    eps_reg * ((z_in_t**2).sum(dim=1) + (z_out_t**2).sum(dim=1)))  \n",
    "\n",
    "full_cost = decision_loss_1 + reg_term  \n",
    "Average_Loss_1  = float(full_cost.mean().item())\n",
    "robustness_1 = float((full_cost <= torch.as_tensor(risk_arr, device=device)).float().mean().item() * 100.0)\n",
    "\n",
    "print(f\"Marginal Coverage:    {pct_joint:.2f}%\")\n",
    "print(f\"Average_Risk:      {Average_Risk_1:.4f}\")\n",
    "print(f\"Average_Loss:      {Average_Loss_1:.4f}\")\n",
    "print(f\"Robustness:        {robustness_1:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0049060d-384d-4472-a637-c40eb70de705",
   "metadata": {
    "papermill": {
     "duration": 0.077482,
     "end_time": "2025-09-21T12:41:05.280568",
     "exception": false,
     "start_time": "2025-09-21T12:41:05.203086",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_Two_step'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, 'metrics_Two_step.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {pct_joint:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk_1:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss_1:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_1:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914a16d2-451b-4d85-b514-253a8337405e",
   "metadata": {
    "papermill": {
     "duration": 0.136094,
     "end_time": "2025-09-21T12:41:05.500749",
     "exception": false,
     "start_time": "2025-09-21T12:41:05.364655",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### End-to-end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f6877a-f96b-4b3c-b651-bfde3755dd9a",
   "metadata": {
    "papermill": {
     "duration": 0.083947,
     "end_time": "2025-09-21T12:41:05.714965",
     "exception": false,
     "start_time": "2025-09-21T12:41:05.631018",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, X, Y):\n",
    "        self.X = torch.tensor(X, dtype=torch.float32)  \n",
    "        self.Y = torch.tensor(Y, dtype=torch.float32)  \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.Y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415765b2-3ad4-4f2f-87a6-bed381ddf9df",
   "metadata": {
    "papermill": {
     "duration": 0.097644,
     "end_time": "2025-09-21T12:41:05.913877",
     "exception": false,
     "start_time": "2025-09-21T12:41:05.816233",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_dataset = CustomDataset(X_train, Y_train)\n",
    "cal_dataset = CustomDataset(X_cal, Y_cal)\n",
    "test_dataset = CustomDataset(X_test, Y_test)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "loaders = {'train': train_loader, 'test': test_loader, 'calib': cal_loader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e95b21c-642c-4a76-8ef7-ef46f101efe7",
   "metadata": {
    "papermill": {
     "duration": 0.162245,
     "end_time": "2025-09-21T12:41:06.181565",
     "exception": false,
     "start_time": "2025-09-21T12:41:06.019320",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_Y = np.concatenate([Y_train, Y_cal, Y_test], axis=0)\n",
    "y_mean = all_Y.mean(axis=0)  \n",
    "y_std = all_Y.std(axis=0)   \n",
    "y_mean = np.array(y_mean)\n",
    "y_std = np.array(y_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17d18a3-b2dd-4653-8b02-f017508eb1a5",
   "metadata": {
    "papermill": {
     "duration": 0.223315,
     "end_time": "2025-09-21T12:41:06.554784",
     "exception": false,
     "start_time": "2025-09-21T12:41:06.331469",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from collections.abc import Sequence\n",
    "from typing import NamedTuple\n",
    "\n",
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import Tensor\n",
    "\n",
    "from collections.abc import Sequence\n",
    "from typing import Protocol\n",
    "\n",
    "class BaseProblemProtocol(Protocol):\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    vars: dict[str, cp.Variable]\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vars = self.primal_vars | self.dual_vars\n",
    "        self.constraints.extend(self.dual_constraints)\n",
    "        obj = self.f_tilde + self.dual_obj\n",
    "\n",
    "        prob = cp.Problem(cp.Minimize(obj), self.constraints)\n",
    "        assert prob.is_dpp()\n",
    "        self.prob = prob\n",
    "\n",
    "    def task_loss_np(self, y: np.ndarray, is_standardized: bool) -> float:\n",
    "        ...\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, is_standardized: bool, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        ...\n",
    "\n",
    "    def get_cvxpylayer(self) -> CvxpyLayer:\n",
    "        return CvxpyLayer(\n",
    "            self.prob,\n",
    "            parameters=list(self.params.values()),\n",
    "            variables=list(self.vars.values())\n",
    "        )\n",
    "\n",
    "class BoxProblemProtocol(BaseProblemProtocol, Protocol):\n",
    "    def solve(self, pred_lo: np.ndarray, pred_hi: np.ndarray) -> cp.Problem:\n",
    "        ...\n",
    "\n",
    "class BoxProblemV2:\n",
    "    # instance variables\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    # subclass must provide these instance variables\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(\n",
    "        self, y_dim: int, y_mean: np.ndarray, y_std: np.ndarray, Fz: cp.Expression\n",
    "    ):\n",
    "        self.y_mean = y_mean\n",
    "        self.y_std = y_std\n",
    "\n",
    "        nu = cp.Variable(y_dim, nonneg=True)\n",
    "        self.dual_vars = {'nu': nu}\n",
    "\n",
    "        # parameters\n",
    "        y_min = cp.Parameter(y_dim, name='y_min')\n",
    "        y_max = cp.Parameter(y_dim, name='y_max')\n",
    "        self.params = {\n",
    "            'y_min': y_min,\n",
    "            'y_max': y_max,\n",
    "        }\n",
    "\n",
    "        c = Fz\n",
    "        self.dual_obj = (\n",
    "            (y_max - y_min) @ nu\n",
    "            + y_min @ c \n",
    "        )\n",
    "\n",
    "        self.dual_constraints = [nu - c >= 0]\n",
    "\n",
    "    def solve(self, pred_lo: np.ndarray, pred_hi: np.ndarray) -> cp.Problem:\n",
    "     \n",
    "        self.params['y_min'].value = pred_lo\n",
    "        self.params['y_max'].value = pred_hi\n",
    "        prob = self.prob\n",
    "        prob.solve()\n",
    "        if prob.status != 'optimal':\n",
    "            print('Problem status:', prob.status)\n",
    "        v_val = np.asarray(self.dual_vars['nu'].value, dtype=float).ravel() \n",
    "        return prob, v_val\n",
    "\n",
    "class Constants(NamedTuple):\n",
    "    lam: float\n",
    "    eps: float\n",
    "    eff: float\n",
    "    c_in: float\n",
    "    c_out: float\n",
    "    B: float\n",
    "\n",
    "\n",
    "DEFAULT_CONSTANTS = Constants(\n",
    "    lam=0.1,\n",
    "    eps=0.05,\n",
    "    eff=0.9,\n",
    "    c_in=0.5,\n",
    "    c_out=0.2,\n",
    "    B=1\n",
    ")\n",
    "\n",
    "\n",
    "class StorageProblemBase:\n",
    "   \n",
    "    const: Constants\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    # to be implemented in subclass\n",
    "    prob: cp.Problem\n",
    "    y_mean: np.ndarray\n",
    "    y_std: np.ndarray\n",
    "\n",
    "    def __init__(self, T: int, const: Constants = DEFAULT_CONSTANTS):\n",
    "        self.const = const\n",
    "        lam, eps, eff, c_in, c_out, B = self.const\n",
    "\n",
    "        z_state = cp.Variable(T + 1, name='z_state', nonneg=True)\n",
    "        z_in = cp.Variable(T, name='z_in', nonneg=True)\n",
    "        z_out = cp.Variable(T, name='z_out', nonneg=True)\n",
    "        self.primal_vars = {\n",
    "            'z_in': z_in,\n",
    "            'z_out': z_out,\n",
    "            'z_state': z_state,\n",
    "        }\n",
    "\n",
    "        self.constraints = [\n",
    "            # SOC constraints\n",
    "            z_state[0] == B/2,\n",
    "            z_state[1:] == z_state[:-1] + eff * z_in - z_out,\n",
    "            z_state <= B,\n",
    "\n",
    "            # ramp constraints\n",
    "            z_in <= c_in,\n",
    "            z_out <= c_out,\n",
    "        ]\n",
    "\n",
    "        self.Fz = z_in - z_out\n",
    "\n",
    "        self.f_tilde = (\n",
    "            lam * cp.norm2(z_state - B/2)**2\n",
    "            + eps * cp.norm2(z_in)**2\n",
    "            + eps * cp.norm2(z_out)**2\n",
    "        )\n",
    "\n",
    "    def task_loss_np(self, v: np.ndarray, y_min: np.ndarray, y_max: np.ndarray) -> float:\n",
    "    \n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        lam = self.const.lam\n",
    "        eps = self.const.eps\n",
    "        B = self.const.B\n",
    "\n",
    "        z_in = self.primal_vars['z_in'].value\n",
    "        z_out = self.primal_vars['z_out'].value\n",
    "        z_state = self.primal_vars['z_state'].value\n",
    "\n",
    "        task_loss = (\n",
    "            (y_max - y_min) @ v + y_min @ (z_in - z_out)\n",
    "            + lam * np.linalg.norm(z_state - B/2)**2\n",
    "            + eps * np.linalg.norm(z_in)**2\n",
    "            + eps * np.linalg.norm(z_out)**2\n",
    "        )\n",
    "        return task_loss\n",
    "\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, v: torch.Tensor, y_min: torch.Tensor, y_max: torch.Tensor, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        \n",
    "        z_in, z_out, z_state = solution[:3]\n",
    "\n",
    "        lam = self.const.lam\n",
    "        eps = self.const.eps\n",
    "        B = self.const.B\n",
    "\n",
    "        task_loss = (\n",
    "            ((y_max - y_min) * v).sum(dim=-1) + (y_min * (z_in - z_out)).sum(dim=-1)\n",
    "            + lam * torch.norm(z_state - B/2, dim=-1)**2\n",
    "            + eps * torch.norm(z_in, dim=-1)**2\n",
    "            + eps * torch.norm(z_out, dim=-1)**2\n",
    "        )\n",
    "        assert task_loss.shape == y_min.shape[:-1]\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_torch_1(\n",
    "        self, y: torch.Tensor, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        \n",
    "        z_in, z_out, z_state = solution[:3]\n",
    "\n",
    "        lam = self.const.lam\n",
    "        eps = self.const.eps\n",
    "        B = self.const.B\n",
    "\n",
    "        task_loss = (\n",
    "            (y * (z_in - z_out)).sum(dim=-1)\n",
    "            + lam * torch.norm(z_state - B/2, dim=-1)**2\n",
    "            + eps * torch.norm(z_in, dim=-1)**2\n",
    "            + eps * torch.norm(z_out, dim=-1)**2\n",
    "        )\n",
    "        assert task_loss.shape == y.shape[:-1]\n",
    "        return task_loss\n",
    "\n",
    "class StorageProblemBoxV2(StorageProblemBase, BoxProblemV2, BoxProblemProtocol):\n",
    "    def __init__(self, T: int, y_mean: np.ndarray, y_std: np.ndarray):\n",
    "        StorageProblemBase.__init__(self, T=T)\n",
    "        BoxProblemV2.__init__(self, y_dim=T, y_mean=y_mean, y_std=y_std, Fz=self.Fz)\n",
    "        BoxProblemProtocol.__init__(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c932ba85-85b0-4d7c-92e9-e50d023aa287",
   "metadata": {
    "papermill": {
     "duration": 0.104658,
     "end_time": "2025-09-21T12:41:06.866628",
     "exception": false,
     "start_time": "2025-09-21T12:41:06.761970",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def multidim_quantile_score(pred: torch.Tensor | tuple, y: torch.Tensor, y_dim: int) -> torch.Tensor:\n",
    "  \n",
    "    if isinstance(pred, (tuple, list)):\n",
    "        pred_lo, pred_hi = pred[:2]        \n",
    "    else:\n",
    "        pred_lo = pred[..., :y_dim]\n",
    "        pred_hi = pred[..., y_dim:]\n",
    "\n",
    "    lo_gap = (pred_lo - y).amax(dim=-1)      \n",
    "    hi_gap = (y - pred_hi).amax(dim=-1)     \n",
    "    return torch.maximum(lo_gap, hi_gap)\n",
    "\n",
    "def calc_q(scores: Tensor, alpha: float) -> Tensor:\n",
    "   \n",
    "    n = len(scores)\n",
    "    j = int(np.ceil((n+1) * (1-alpha)))\n",
    "    if j > n:\n",
    "        return torch.tensor(torch.inf)\n",
    "    sorted_inds = torch.argsort(scores)\n",
    "    q = scores[sorted_inds[j-1]]\n",
    "    return q\n",
    "\n",
    "@torch.no_grad()\n",
    "def conformal_q(model, loader, y_dim: int, alpha: float, device: str = \"cpu\"):\n",
    "    model.eval()\n",
    "    model.to(device)\n",
    "\n",
    "    all_scores = []\n",
    "    for x, y in loader:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "\n",
    "        out = model(x)                                \n",
    "        scores = multidim_quantile_score(out, y, y_dim)\n",
    "        all_scores.append(scores)\n",
    "\n",
    "    scores = torch.cat(all_scores, dim=0)\n",
    "    q = calc_q(scores, alpha)\n",
    "    return q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43924fa-4568-449d-b0ae-1210f3972f2c",
   "metadata": {
    "papermill": {
     "duration": 0.122604,
     "end_time": "2025-09-21T12:41:07.112719",
     "exception": false,
     "start_time": "2025-09-21T12:41:06.990115",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_e2e_epoch(\n",
    "    model: BoxQuantileModel,          \n",
    "    prob: BoxProblemProtocol,        \n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    y_dim: int,\n",
    "    alpha: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    quantile_loss_frac: float,\n",
    "    rng: np.random.Generator,\n",
    "    optimizer: torch.optim.Optimizer,\n",
    "    show_pbar: bool = False\n",
    ") -> tuple[float, float, float]:\n",
    "    assert 0 < alpha < 1\n",
    "    model.train()\n",
    "\n",
    "    cvxpylayer = prob.get_cvxpylayer()   \n",
    "    lo_quantile_loss_fn = PinballLoss(alpha/2)\n",
    "    hi_quantile_loss_fn = PinballLoss(1-alpha/2)\n",
    "\n",
    "    total_loss = total_pinball_loss = total_task_loss = 0.0\n",
    "    total_num_tasks = 0\n",
    "    device = next(model.parameters()).device\n",
    "\n",
    "    def make_storage_box_layer(y_info):\n",
    "        if y_info == 'log':\n",
    "            def layer(lo, hi):\n",
    "                return cvxpylayer(torch.exp(lo).cpu(), torch.exp(hi).cpu()) \n",
    "        else:\n",
    "            def layer(lo, hi):\n",
    "                return cvxpylayer(lo.cpu(), hi.cpu())\n",
    "        return layer\n",
    "\n",
    "    for x, y in tqdm(loader) if show_pbar else loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        Bsz = x.size(0)\n",
    "        loss = torch.zeros((), device=device)\n",
    "\n",
    "        y_lo, y_hi, _ = model(x, box_layer=None, q_hat=None)\n",
    "        if quantile_loss_frac > 0:\n",
    "            lo_q = lo_quantile_loss_fn(y_lo, y)\n",
    "            hi_q = hi_quantile_loss_fn(y_hi, y)\n",
    "            pinball_loss = lo_q + hi_q        \n",
    "            loss = loss + quantile_loss_frac * pinball_loss\n",
    "            total_pinball_loss += pinball_loss.item() * Bsz\n",
    "        else:\n",
    "            pinball_loss = torch.tensor(0.0)\n",
    "\n",
    "        if quantile_loss_frac < 1:\n",
    "            perm = rng.permutation(Bsz)\n",
    "            cal_inds  = torch.as_tensor(perm[: Bsz//2], device=device)\n",
    "            task_inds = torch.as_tensor(perm[ Bsz//2:], device=device)\n",
    "\n",
    "            pred_full = torch.cat([y_lo, y_hi], dim=-1)\n",
    "            scores_cal = multidim_quantile_score(pred_full[cal_inds], y[cal_inds], y_dim)\n",
    "            q = calc_q(scores_cal, alpha)\n",
    "            q = torch.as_tensor(q, dtype=y_lo.dtype, device=device)\n",
    "            if not torch.isfinite(q):\n",
    "                if show_pbar: tqdm.write('Batch is too small, skipping')\n",
    "                continue\n",
    "\n",
    "            storage_box_layer = make_storage_box_layer(y_info)\n",
    "            y_lo_adj, y_hi_adj, sol = model(x[task_inds], box_layer=storage_box_layer, q_hat=q)\n",
    "            assert sol is not None and len(sol) == 4, \"Need to return (z_in,z_out,z_state,v)\"\n",
    "            z_in, z_out, z_state, v_star = sol\n",
    "\n",
    "            if y_info == 'log':\n",
    "                y_min = torch.exp(y_lo_adj)\n",
    "                y_max = torch.exp(y_hi_adj) \n",
    "            else:\n",
    "                y_min, y_max = y_lo_adj, y_hi_adj\n",
    "               \n",
    "            task_loss_vec = prob.task_loss_torch(\n",
    "                v=v_star,\n",
    "                y_min=y_min,\n",
    "                y_max=y_max,\n",
    "                solution=(z_in, z_out, z_state)   ) \n",
    "            task_loss = task_loss_vec.mean()\n",
    "\n",
    "            loss = loss + (1 - quantile_loss_frac) * task_loss\n",
    "            total_task_loss += task_loss.item() * task_inds.numel()\n",
    "            total_num_tasks += task_inds.numel()\n",
    "\n",
    "            if show_pbar:\n",
    "                tqdm.write(f'pinball_loss: {pinball_loss.item():.4f}, task_loss: {task_loss.item():.4f}')\n",
    "\n",
    "        total_loss += loss.item() * Bsz\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    N = len(loader.dataset)\n",
    "    avg_loss = total_loss / N\n",
    "    avg_pinball_loss = total_pinball_loss / N if quantile_loss_frac > 0 else 0.0\n",
    "    avg_task_loss = (total_task_loss / max(total_num_tasks, 1)) if quantile_loss_frac < 1 else 0.0\n",
    "    return avg_loss, avg_pinball_loss, avg_task_loss\n",
    "\n",
    "\n",
    "def train_e2e(\n",
    "    y_dim: int,\n",
    "    alpha: float,\n",
    "    max_epochs: int,\n",
    "    lr: float,\n",
    "    l2reg: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    prob: BoxProblemProtocol,\n",
    "    rng: np.random.Generator,\n",
    "    quantile_loss_frac: float | Sequence[float],\n",
    "    saved_model_path: str = ''\n",
    ") -> tuple[BoxQuantileModel, dict[str, Any]]:\n",
    "  \n",
    "    model = BoxQuantileModel(input_dim=101, hidden_dims=[20,20], n=24,\n",
    "                             dropout_rate=0.5, enable_implicit=True).cpu()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2reg)\n",
    "    if saved_model_path:  \n",
    "        model.load_state_dict(torch.load(saved_model_path, weights_only=True, map_location=\"cpu\"))\n",
    "\n",
    "    result: dict[str, Any] = {\n",
    "        'train_e2e_losses': [],\n",
    "        'train_pinball_losses': [],\n",
    "        'train_task_losses': [],\n",
    "        'val_task_losses': [],\n",
    "        'best_epoch': 0,\n",
    "        'val_task_loss': np.inf,\n",
    "    }\n",
    "    steps_since_decrease = 0\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    pbar = tqdm(range(max_epochs))\n",
    "    for epoch in pbar:\n",
    "        weight = (quantile_loss_frac[min(epoch, len(quantile_loss_frac)-1)]\n",
    "                  if isinstance(quantile_loss_frac, Sequence) else quantile_loss_frac)\n",
    "\n",
    "        train_e2e_loss, train_pinball_loss, train_task_loss = train_e2e_epoch(\n",
    "            model=model, prob=prob, loader=train_loader, y_dim=y_dim, alpha=alpha,\n",
    "            y_info=y_info, quantile_loss_frac=weight, rng=rng, optimizer=optimizer, show_pbar=False\n",
    "        )\n",
    "        result['train_e2e_losses'].append(train_e2e_loss)\n",
    "        result['train_pinball_losses'].append(train_pinball_loss)\n",
    "        result['train_task_losses'].append(train_task_loss)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            q = conformal_q(model, cal_loader, y_dim=y_dim, alpha=alpha)\n",
    "            val_task_loss = optimize(model, prob=prob, loader=cal_loader,\n",
    "                                         y_dim=y_dim, q=q, y_info=y_info)\n",
    "            result['val_task_losses'].append(val_task_loss)\n",
    "\n",
    "        msg = (f'Epoch {epoch}, train_task_loss {train_task_loss:.3f}, '\n",
    "               f'val_task_loss {val_task_loss:.3f}, q {q:.3f}')\n",
    "        pbar.set_description(msg)\n",
    "\n",
    "        steps_since_decrease += 1\n",
    "        if val_task_loss < result['val_task_loss']:\n",
    "            result['best_epoch'] = epoch\n",
    "            result['val_task_loss'] = val_task_loss\n",
    "            steps_since_decrease = 0\n",
    "            buffer.seek(0)\n",
    "            torch.save(model.state_dict(), buffer)\n",
    "\n",
    "        if steps_since_decrease > 10:\n",
    "            break\n",
    "\n",
    "    buffer.seek(0)\n",
    "    model.load_state_dict(torch.load(buffer, weights_only=True, map_location=\"cpu\"))\n",
    "    return model.cpu(), result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6b2779-1707-450b-915c-02d854be30bf",
   "metadata": {
    "papermill": {
     "duration": 0.093106,
     "end_time": "2025-09-21T12:41:07.344245",
     "exception": false,
     "start_time": "2025-09-21T12:41:07.251139",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def optimize(\n",
    "    model: BoxQuantileModel,\n",
    "    prob: BoxProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    y_dim: int,\n",
    "    q: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "  \n",
    "    model.eval().to(device)\n",
    "    cvxpylayer = prob.get_cvxpylayer()\n",
    "\n",
    "    def storage_box_layer(lo: torch.Tensor, hi: torch.Tensor):\n",
    "        if y_info == 'log':\n",
    "            return cvxpylayer(torch.exp(lo).cpu(), torch.exp(hi).cpu())  \n",
    "        else:\n",
    "            return cvxpylayer(lo.cpu(), hi.cpu())\n",
    "\n",
    "    q_hat = torch.as_tensor(q, device=device)\n",
    "\n",
    "    losses = []\n",
    "    it = tqdm(loader, desc=\"optimize\") if show_pbar else loader\n",
    "    for x_batch, _ in it:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "\n",
    "        y_lo_adj, y_hi_adj, sol = model(\n",
    "            x_batch, box_layer=storage_box_layer, q_hat=q_hat)\n",
    "\n",
    "        if sol is None:\n",
    "            z_in, z_out, z_state, v_star = storage_box_layer(y_lo_adj, y_hi_adj)\n",
    "    \n",
    "            if z_in.dim() == 3 and z_in.size(-1) == 1:\n",
    "                z_in, z_out, z_state, v_star = (\n",
    "                    t.squeeze(-1) for t in (z_in, z_out, z_state, v_star))\n",
    "           \n",
    "            z_in  = z_in.to(device, dtype=y_lo_adj.dtype)\n",
    "            z_out = z_out.to(device, dtype=y_lo_adj.dtype)\n",
    "            z_state = z_state.to(device, dtype=y_lo_adj.dtype)\n",
    "            v_star  = v_star.to(device, dtype=y_lo_adj.dtype)\n",
    "        else:\n",
    "            assert isinstance(sol, (tuple, list)) and len(sol) == 4, \"Need to return (z_in, z_out, z_state, v)\"\n",
    "            z_in, z_out, z_state, v_star = sol \n",
    "\n",
    "        if y_info == 'log':\n",
    "            y_min = torch.exp(y_lo_adj)\n",
    "            y_max = torch.exp(y_hi_adj)\n",
    "        else:\n",
    "            y_min, y_max = y_lo_adj, y_hi_adj\n",
    "\n",
    "        loss_vec = prob.task_loss_torch(\n",
    "            v=v_star,\n",
    "            y_min=y_min,\n",
    "            y_max=y_max,\n",
    "            solution=(z_in, z_out, z_state)\n",
    "        )  \n",
    "        losses.append(loss_vec)\n",
    "\n",
    "    return torch.cat(losses, dim=0).mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d57ea1-e1ce-439a-a0ad-6fe321d49b9d",
   "metadata": {
    "papermill": {
     "duration": 728.676367,
     "end_time": "2025-09-21T12:53:16.138609",
     "exception": false,
     "start_time": "2025-09-21T12:41:07.462242",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, io, copy, math, itertools, torch, numpy as np\n",
    "\n",
    "alpha = 0.1\n",
    "max_epochs = 100\n",
    "nll_loss_frac = 0.0      \n",
    "lr_list    = [1e-3]\n",
    "l2_list    = [1e-2,1e-3,1e-4]\n",
    "\n",
    "best = {\n",
    "    \"val_task_loss\": float(\"inf\"),\n",
    "    \"lr\": None,\n",
    "    \"l2reg\": None,\n",
    "    \"epoch\": None,\n",
    "    \"state_dict\": None,\n",
    "}\n",
    "all_results = []\n",
    "\n",
    "for lr in lr_list:\n",
    "    for l2 in l2_list:\n",
    "        print(f\"lr: {lr} | l2={l2}\")\n",
    "\n",
    "        prob = StorageProblemBoxV2(T=24, y_mean=y_mean, y_std=y_std)\n",
    "\n",
    "        model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "        saved_model_path = model_path if os.path.exists(model_path) else ''\n",
    "\n",
    "        rng = np.random.default_rng(42)\n",
    "\n",
    "        model_i, res = train_e2e(\n",
    "            y_dim=24,\n",
    "            alpha=alpha,\n",
    "            max_epochs=max_epochs,\n",
    "            lr=lr,\n",
    "            l2reg=l2,\n",
    "            y_info=y_info,                   \n",
    "            prob=prob,\n",
    "            rng=rng,\n",
    "            quantile_loss_frac=nll_loss_frac,  \n",
    "            saved_model_path=saved_model_path\n",
    "        )\n",
    "\n",
    "        record = {\n",
    "            \"lr\": lr,\n",
    "            \"l2reg\": l2,\n",
    "            \"best_epoch\": res[\"best_epoch\"],\n",
    "            \"val_task_loss\": res[\"val_task_loss\"],\n",
    "        }\n",
    "        all_results.append(record)\n",
    "\n",
    "        if res[\"val_task_loss\"] < best[\"val_task_loss\"]:\n",
    "            best.update(record)\n",
    "            best[\"state_dict\"] = copy.deepcopy(model_i.state_dict())\n",
    "\n",
    "print(\"\\n=== Grid Search Summary ===\")\n",
    "all_results = sorted(all_results, key=lambda r: r[\"val_task_loss\"])\n",
    "for r in all_results:\n",
    "    print(f\"lr={r['lr']:g}, l2={r['l2reg']:g}, \"\n",
    "          f\"best_epoch={r['best_epoch']}, val_task_loss={r['val_task_loss']:.6f}\")\n",
    "\n",
    "print(\"\\n>>> Best config:\")\n",
    "print(f\"lr={best['lr']:g}, l2={best['l2reg']:g}, \"\n",
    "      f\"best_epoch={best['best_epoch']}, val_task_loss={best['val_task_loss']:.6f}\")\n",
    "\n",
    "best_model = BoxQuantileModel(input_dim=101, hidden_dims=[20,20], n=24, dropout_rate = 0.5,enable_implicit=True).cpu()\n",
    "best_model.load_state_dict(best[\"state_dict\"])\n",
    "\n",
    "os.makedirs('outputs/model/E2E', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'E2E', f'E2E_model.pth')\n",
    "torch.save(best_model.state_dict(), save_path)\n",
    "print(f\"Saved best model to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1cc17dc-9129-408d-b422-f74fa4c53d4c",
   "metadata": {
    "papermill": {
     "duration": 0.080686,
     "end_time": "2025-09-21T12:53:16.283923",
     "exception": false,
     "start_time": "2025-09-21T12:53:16.203237",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def optimize_test_box(\n",
    "    model: BoxQuantileModel,\n",
    "    prob: BoxProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    y_dim: int,\n",
    "    q: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    "):\n",
    "   \n",
    "    model.eval().to(device)\n",
    "    cvxpylayer = prob.get_cvxpylayer()\n",
    "\n",
    "    losses_task, losses_dec = [], []\n",
    "    robust_flags = []\n",
    "    z_list: List[np.ndarray] = []\n",
    "\n",
    "    iterator = loader if not show_pbar else __import__('tqdm').tqdm(loader, desc=\"optimize_test_box\")\n",
    "\n",
    "    total_in = 0.0\n",
    "    total_cnt = 0\n",
    "\n",
    "    for x, y in iterator:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "\n",
    "        out = model(x)\n",
    "        if isinstance(out, (tuple, list)):\n",
    "            y_lo, y_hi = out[:2]        \n",
    "        else:\n",
    "            y_lo, y_hi = out[..., :y_dim], out[..., y_dim:]\n",
    "\n",
    "        q_t = torch.as_tensor(q, dtype=y_lo.dtype, device=y_lo.device)\n",
    "        y_min = y_lo - q_t\n",
    "        y_max = y_hi + q_t\n",
    "        y_max = torch.maximum(y_max, y_min + 1e-8)\n",
    "\n",
    "        if y_info == 'log':\n",
    "            y_min_solver = torch.exp(y_min)\n",
    "            y_max_solver = torch.exp(y_max)\n",
    "            y_for_decision = torch.exp(y)\n",
    "        else:\n",
    "            y_min_solver, y_max_solver = y_min, y_max\n",
    "            y_for_decision = y\n",
    "\n",
    "        z_in, z_out, z_state, v_star = cvxpylayer(y_min_solver.cpu(), y_max_solver.cpu())\n",
    "        if z_in.dim() == 3 and z_in.size(-1) == 1:\n",
    "            z_in, z_out, z_state, v_star = (\n",
    "                t.squeeze(-1) for t in (z_in, z_out, z_state, v_star)\n",
    "            )\n",
    "  \n",
    "        z_in    = z_in.to(y_min_solver.device, dtype=y_min_solver.dtype)\n",
    "        z_out   = z_out.to(y_min_solver.device, dtype=y_min_solver.dtype)\n",
    "        z_state = z_state.to(y_min_solver.device, dtype=y_min_solver.dtype)\n",
    "        v_star  = v_star.to(y_min_solver.device, dtype=y_min_solver.dtype)\n",
    "\n",
    "        task_loss_vec = prob.task_loss_torch(\n",
    "            v=v_star,\n",
    "            y_min=y_min_solver,\n",
    "            y_max=y_max_solver,\n",
    "            solution=(z_in, z_out, z_state))  \n",
    "        losses_task.append(task_loss_vec.cpu())\n",
    "\n",
    "        c_vec = z_in - z_out\n",
    "        decision_vec = prob.task_loss_torch_1(\n",
    "            y = y_for_decision,\n",
    "            solution=(z_in, z_out, z_state)\n",
    "        )  \n",
    "        losses_dec.append(decision_vec.cpu())\n",
    "\n",
    "        inside = ((y_for_decision >= y_min_solver) & (y_for_decision <= y_max_solver)).float()\n",
    "        total_in  += inside.sum().item()\n",
    "        total_cnt += inside.numel()\n",
    "\n",
    "        robust_flags.append((decision_vec <= task_loss_vec).float().cpu())\n",
    "        z_list.extend([c.cpu().numpy().ravel() for c in c_vec])\n",
    "\n",
    "    task_all = torch.cat(losses_task) if losses_task else torch.tensor([0.0])\n",
    "    dec_all  = torch.cat(losses_dec)  if losses_dec  else torch.tensor([0.0])\n",
    "    rob_all  = torch.cat(robust_flags) if robust_flags else torch.tensor([0.0])\n",
    "\n",
    "    avg_task     = task_all.mean().item()\n",
    "    avg_decision = dec_all.mean().item()\n",
    "    coverage     = (total_in / max(total_cnt, 1)) * 100.0\n",
    "    robustness   = rob_all.mean().item() * 100.0\n",
    "\n",
    "    return avg_task, avg_decision, coverage, robustness, z_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac47837-086d-49f7-b370-cc4fb586c70b",
   "metadata": {
    "papermill": {
     "duration": 0.940009,
     "end_time": "2025-09-21T12:53:17.288411",
     "exception": false,
     "start_time": "2025-09-21T12:53:16.348402",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_path = os.path.join('outputs', 'model', 'E2E', 'E2E_model.pth')\n",
    "\n",
    "model = BoxQuantileModel(input_dim=101, hidden_dims=[20,20], n=24, dropout_rate = 0.5, enable_implicit=False).cpu()\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    q = conformal_q(model, loaders['calib'], y_dim=24, alpha=alpha, device='cpu').item()\n",
    "\n",
    "avg_risk, avg_loss, coverage, robustness_2, z_value = optimize_test_box(\n",
    "    model=model,\n",
    "    prob=prob,               \n",
    "    loader=test_loader,\n",
    "    y_dim=24,\n",
    "    q=q,\n",
    "    y_info=y_info,          \n",
    "    device='cpu',\n",
    "    show_pbar=False\n",
    ")\n",
    "\n",
    "print(f\"Coverage: {coverage:.2f}%\")\n",
    "print(f\"Avg risk: {avg_risk:.4f}\")\n",
    "print(f\"Avg decision: {avg_loss:.4f}\")\n",
    "print(f\"Robustness: {robustness_2:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dae6393-43b6-4df2-81fc-f185f1b65ae4",
   "metadata": {
    "papermill": {
     "duration": 0.070385,
     "end_time": "2025-09-21T12:53:17.424977",
     "exception": false,
     "start_time": "2025-09-21T12:53:17.354592",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_E2E'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_E2E.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {coverage:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {avg_risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {avg_loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_2:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6452e5-6bb3-44f5-9a0e-a1b23ff8a1b6",
   "metadata": {
    "papermill": {
     "duration": 0.063571,
     "end_time": "2025-09-21T12:53:18.903990",
     "exception": false,
     "start_time": "2025-09-21T12:53:18.840419",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 4755.498947,
   "end_time": "2025-09-21T12:53:20.492351",
   "environment_variables": {},
   "exception": null,
   "input_path": "CRC_batter_box.ipynb",
   "output_path": "outputs/file/CRC_batter_box_0.ipynb",
   "parameters": {
    "run_id": 0
   },
   "start_time": "2025-09-21T11:34:04.993404",
   "version": "2.6.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "06b4951450514b3ab9a38009630206a0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "0d19941d3746443bbecae1c69d99db69": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_bea0128c381d47a2a159872f3a9f1afa",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_c6539c2e62d842fc90696d4d5f9a35cf",
       "tabbable": null,
       "tooltip": null,
       "value": 61
      }
     },
     "0d9b1568a9354b85ba7ccfc29f136e9c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_3af1d68ba42340e4b01508e7680b2dd0",
       "placeholder": "​",
       "style": "IPY_MODEL_4a835f4b27b640d2844e6b616e4358f2",
       "tabbable": null,
       "tooltip": null,
       "value": " 61/100 [06:20&lt;02:36,  4.00s/it]"
      }
     },
     "18133ad8e9dc4d1693238f7b2bfba9f7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "1cbfb82ac3da4aacbf0d8a684a93618e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_30c18ff0a3f94c81be8711a7ed30959c",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_34ab5670e1df427aac93b9a4cd5b4b12",
       "tabbable": null,
       "tooltip": null,
       "value": 22
      }
     },
     "30c18ff0a3f94c81be8711a7ed30959c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "34ab5670e1df427aac93b9a4cd5b4b12": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "3af1d68ba42340e4b01508e7680b2dd0": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "480c0b4fc5894d8da3394b7758567bd7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_7027fea03a4b45899579c9f2bf6e4c8a",
       "placeholder": "​",
       "style": "IPY_MODEL_06b4951450514b3ab9a38009630206a0",
       "tabbable": null,
       "tooltip": null,
       "value": " 26/100 [04:31&lt;12:47, 10.37s/it]"
      }
     },
     "4a835f4b27b640d2844e6b616e4358f2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "67e3bdf0a69c4ce29133d49d275a9e6f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_80f8d1d84e614b75b808ec4aec639b31",
        "IPY_MODEL_1cbfb82ac3da4aacbf0d8a684a93618e",
        "IPY_MODEL_95b1e2aefaca42a4b78f302789c32551"
       ],
       "layout": "IPY_MODEL_adc52b094b854dc38c3cd0d013eb95de",
       "tabbable": null,
       "tooltip": null
      }
     },
     "6c5880f07b2740bba6eb556b8d02582e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "6e8c9f0197b7405fa94a9c0e5a20546e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "7027fea03a4b45899579c9f2bf6e4c8a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "78e4ae57ae464981bc751dba99536801": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_a2a7327a81634f34858c205e7b8acd78",
        "IPY_MODEL_a3df48812f524ae0ba35982024e21fc1",
        "IPY_MODEL_480c0b4fc5894d8da3394b7758567bd7"
       ],
       "layout": "IPY_MODEL_de693d34db4e43a6bad0ec64eacf375a",
       "tabbable": null,
       "tooltip": null
      }
     },
     "7b59e1da1d6d4d38ba979c716f105d8f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "80f8d1d84e614b75b808ec4aec639b31": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_fd279050a2c34b02bf7440743e7e61f3",
       "placeholder": "​",
       "style": "IPY_MODEL_986559b0d75440ddb05686e933c36415",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 22, train_task_loss -9.299, val_task_loss -14.264, q 1.842:  22%"
      }
     },
     "882a7da288ae4e04b8ffe6e928e767ec": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "95b1e2aefaca42a4b78f302789c32551": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_7b59e1da1d6d4d38ba979c716f105d8f",
       "placeholder": "​",
       "style": "IPY_MODEL_d2d35bb4c63040a981a87edb442001ff",
       "tabbable": null,
       "tooltip": null,
       "value": " 22/100 [01:17&lt;03:57,  3.04s/it]"
      }
     },
     "970646cc08334607bdb2b901a6b91587": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_6e8c9f0197b7405fa94a9c0e5a20546e",
       "placeholder": "​",
       "style": "IPY_MODEL_18133ad8e9dc4d1693238f7b2bfba9f7",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 61, train_task_loss -8.813, val_task_loss -14.287, q 4.038:  61%"
      }
     },
     "9777221545f54392a7d7afa22c009fe3": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_970646cc08334607bdb2b901a6b91587",
        "IPY_MODEL_0d19941d3746443bbecae1c69d99db69",
        "IPY_MODEL_0d9b1568a9354b85ba7ccfc29f136e9c"
       ],
       "layout": "IPY_MODEL_cd15143afd594ed5a96b2d1071348909",
       "tabbable": null,
       "tooltip": null
      }
     },
     "986559b0d75440ddb05686e933c36415": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "a2a7327a81634f34858c205e7b8acd78": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_6c5880f07b2740bba6eb556b8d02582e",
       "placeholder": "​",
       "style": "IPY_MODEL_d872d703a16c4932b635ac5855de97ca",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 26, train_task_loss -8.658, val_task_loss -13.886, q 1.771:  26%"
      }
     },
     "a3df48812f524ae0ba35982024e21fc1": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_882a7da288ae4e04b8ffe6e928e767ec",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_cc32773d2dc94759bd8e78cf919d3536",
       "tabbable": null,
       "tooltip": null,
       "value": 26
      }
     },
     "adc52b094b854dc38c3cd0d013eb95de": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "bea0128c381d47a2a159872f3a9f1afa": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "c6539c2e62d842fc90696d4d5f9a35cf": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "cc32773d2dc94759bd8e78cf919d3536": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "cd15143afd594ed5a96b2d1071348909": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "d2d35bb4c63040a981a87edb442001ff": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "d872d703a16c4932b635ac5855de97ca": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "de693d34db4e43a6bad0ec64eacf375a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "fd279050a2c34b02bf7440743e7e61f3": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
