{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4bcfc9fd-c1e1-4d83-9fd2-0949976723b7",
   "metadata": {
    "papermill": {
     "duration": 0.015241,
     "end_time": "2025-09-22T14:13:13.524339",
     "exception": false,
     "start_time": "2025-09-22T14:13:13.509098",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Conformal Robustness Control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55ab34a0-4b2e-41e9-b17d-04bd2eaef62e",
   "metadata": {},
   "source": [
    "### CRC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027da74f-f82a-4655-bfbf-94fcbaa6fb97",
   "metadata": {
    "papermill": {
     "duration": 2.383004,
     "end_time": "2025-09-22T14:13:15.939956",
     "exception": false,
     "start_time": "2025-09-22T14:13:13.556952",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import io\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pytz\n",
    "import cvxpy as cp\n",
    "import rsome as rso\n",
    "\n",
    "from rsome import ro\n",
    "from rsome import msk_solver as SOLVER\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.distributions import MultivariateNormal\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "from typing import Protocol\n",
    "from typing import NamedTuple\n",
    "from torch import distributions as tdist, nn, Tensor\n",
    "from collections.abc import Mapping\n",
    "from torch import Tensor\n",
    "from collections.abc import Sequence\n",
    "from torch.utils.data import TensorDataset\n",
    "from collections.abc import Mapping, Sequence\n",
    "from pandas.tseries.holiday import USFederalHolidayCalendar\n",
    "from datetime import datetime, timedelta\n",
    "from tqdm.auto import tqdm\n",
    "from typing import Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f8ca6a-17af-49fc-b3ea-7d20e26605dc",
   "metadata": {
    "papermill": {
     "duration": 0.012685,
     "end_time": "2025-09-22T14:13:16.058227",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.045542",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "seed = random.randint(0, 99)\n",
    "print(\"Device:\", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b10486b-e69b-4957-833f-5a17c4028540",
   "metadata": {
    "papermill": {
     "duration": 0.076459,
     "end_time": "2025-09-22T14:13:16.144139",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.067680",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tz = pytz.timezone('America/New_York')\n",
    "df = pd.read_csv('data.csv', parse_dates=[0])\n",
    "display(df)\n",
    "df['log_da_price'] = np.log(df['da_price'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "349193f9-a9f6-4ea2-aca8-5092abe8eb11",
   "metadata": {
    "papermill": {
     "duration": 0.102672,
     "end_time": "2025-09-22T14:13:16.275389",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.172717",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def validate_data(df: pd.DataFrame) -> None:\n",
    "    expected_dts = pd.date_range(datetime(2011, 1, 3, 0, 0, 0), datetime(2016, 12, 31, 23, 0, 0), freq=timedelta(hours=1))\n",
    "    actual_dts = set(df['datetime'])\n",
    "    assert len(expected_dts.difference(actual_dts)) == 0\n",
    "    for col in ('datetime', 'da_price', 'load_forecast'):\n",
    "        assert df[col].isna().sum() == 0\n",
    "\n",
    "validate_data(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54ffbadd-99c7-435c-9fea-9445b52a54e3",
   "metadata": {
    "papermill": {
     "duration": 0.033228,
     "end_time": "2025-09-22T14:13:16.367336",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.334108",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "df['date'] = df['datetime'].dt.date\n",
    "df['hour'] = df['datetime'].dt.hour\n",
    "df_prices = df.pivot(index='date', columns='hour', values='da_price')\n",
    "df_logprices = df.pivot(index='date', columns='hour', values='log_da_price')\n",
    "df_load = df.pivot(index='date', columns='hour', values='load_forecast')\n",
    "df_temp = df.pivot(index='date', columns='hour', values='temp_dca')\n",
    "df_temp = df_temp.transpose().bfill().ffill().transpose()\n",
    "assert df_logprices.index.equals(df_load.index)\n",
    "assert df_logprices.index.equals(df_temp.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33d9a61c-77b2-4f5a-a720-544925dffca7",
   "metadata": {
    "papermill": {
     "duration": 0.057557,
     "end_time": "2025-09-22T14:13:16.459368",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.401811",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_dates = df_logprices.index.to_series()\n",
    "\n",
    "holidays = USFederalHolidayCalendar().holidays(\n",
    "    start='2011-01-01', end='2017-01-01').date\n",
    "\n",
    "df_feat = pd.DataFrame({\n",
    "    'weekend': df_dates.map(lambda x: x.isoweekday() >= 6),\n",
    "    'holiday': df_dates.isin(holidays),\n",
    "    'dst': df_dates.map(\n",
    "        lambda x: tz.localize(datetime.combine(x, datetime.min.time())).dst().seconds > 0\n",
    "    ),\n",
    "    \"cos_doy\": df_dates.map(lambda x: np.cos(x.timetuple().tm_yday/365*2*np.pi)),\n",
    "    \"sin_doy\": df_dates.map(lambda x: np.sin(x.timetuple().tm_yday/365*2*np.pi))\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2368ef08-4f60-4f99-9d03-9510f8560e30",
   "metadata": {
    "papermill": {
     "duration": 0.135045,
     "end_time": "2025-09-22T14:13:16.604991",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.469946",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "X = np.hstack([df_logprices.iloc[:-1].values, \n",
    "               df_load.iloc[1:].values,        \n",
    "               df_temp.iloc[:-1].values,       \n",
    "               df_temp.iloc[1:].values,        \n",
    "               df_feat.iloc[1:].values]).astype(np.float64)\n",
    "X_mean = np.mean(X, axis=0)\n",
    "X_std = np.std(X, axis=0)\n",
    "X = (X - X_mean) / X_std\n",
    "Y = df_prices.iloc[1:].values\n",
    "\n",
    "dates = np.array(df_prices.iloc[1:].index).astype('datetime64[D]')\n",
    "np.savez_compressed(f'data.npz', X=X, Y=Y, X_mean=X_mean, X_std=X_std, dates=dates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5f4136e-5388-4b06-9b72-5f5ab32cb92f",
   "metadata": {
    "papermill": {
     "duration": 0.018729,
     "end_time": "2025-09-22T14:13:16.631059",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.612330",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "TRAIN_CALIB_FRAC = 0.8 \n",
    "TRAIN_FRAC = 0.3        \n",
    "\n",
    "def load_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n",
    "    with np.load(f'data.npz') as npz:\n",
    "        X = npz['X']\n",
    "        Y = npz['Y']\n",
    "        dates = npz['dates']\n",
    "    assert X.shape == (2189, 101)\n",
    "    assert Y.shape == (2189, 24)\n",
    "    return X.astype(np.float32), Y.astype(np.float32), dates.astype('datetime64[D]')\n",
    "\n",
    "def get_traincalib_test_split(\n",
    "    X: np.ndarray, Y: np.ndarray, dates: np.ndarray, shuffle: bool = False\n",
    ") -> dict[str, Tensor | np.ndarray]:\n",
    "   \n",
    "    if shuffle:\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        inds = rng.permutation(len(X))\n",
    "        X = X[inds]\n",
    "        Y = Y[inds]\n",
    "        dates = dates[inds]\n",
    "\n",
    "    n = int(X.shape[0] * TRAIN_CALIB_FRAC)\n",
    "    tensors: dict[str, Tensor | np.ndarray] = {\n",
    "        'X_traincalib': torch.from_numpy(X[:n]),\n",
    "        'Y_traincalib': torch.from_numpy(Y[:n]),\n",
    "        'X_test': torch.from_numpy(X[n:]),\n",
    "        'Y_test': torch.from_numpy(Y[n:]),\n",
    "        'date_traincalib': dates[:n],\n",
    "        'date_test': dates[n:]\n",
    "    }\n",
    "    return tensors\n",
    "\n",
    "\n",
    "def get_train_calib_split(\n",
    "    tensors: Mapping[str, Tensor | np.ndarray], seed: int\n",
    ") -> tuple[dict[str, Tensor], dict[str, np.ndarray]]:\n",
    "  \n",
    "    X_traincalib: Tensor = tensors['X_traincalib']  \n",
    "    Y_traincalib: Tensor = tensors['Y_traincalib']  \n",
    "    dates_traincalib: np.ndarray = tensors['date_traincalib']  \n",
    "\n",
    "    rng = np.random.default_rng(seed=seed)\n",
    "    inds = rng.permutation(len(X_traincalib))\n",
    "\n",
    "    n_train = int(len(X_traincalib) * TRAIN_FRAC)\n",
    "    train_inds = inds[:n_train]\n",
    "    calib_inds = inds[n_train:]\n",
    "\n",
    "    new_tensors: dict[str, Tensor] = {\n",
    "        'X_train': X_traincalib[train_inds],\n",
    "        'Y_train': Y_traincalib[train_inds],\n",
    "        'X_calib': X_traincalib[calib_inds],\n",
    "        'Y_calib': Y_traincalib[calib_inds],\n",
    "        'X_test': tensors['X_test'],  \n",
    "        'Y_test': tensors['Y_test'],  \n",
    "    }\n",
    "    new_dates: dict[str, np.ndarray] = {\n",
    "        'train': dates_traincalib[train_inds],\n",
    "        'calib': dates_traincalib[calib_inds],\n",
    "        'test': tensors['date_test'],  \n",
    "    }\n",
    "    return new_tensors, new_dates\n",
    "\n",
    "def get_loaders(tensors: Mapping[str, Tensor], batch_size: int) -> dict[str, DataLoader]:\n",
    "\n",
    "    shuffle = True\n",
    "    if batch_size == -1:\n",
    "        batch_size = len(tensors['X_train'])\n",
    "        shuffle = False\n",
    "    train_loader = DataLoader(\n",
    "        TensorDataset(tensors['X_train'], tensors['Y_train']),\n",
    "        shuffle=shuffle, batch_size=batch_size, pin_memory=True)\n",
    "\n",
    "    num_test = len(tensors['X_test'])\n",
    "    test_loader = DataLoader(\n",
    "        TensorDataset(tensors['X_test'], tensors['Y_test']),\n",
    "        shuffle=False, batch_size=num_test, pin_memory=True)\n",
    "\n",
    "    num_calib = len(tensors['X_calib'])\n",
    "    calib_loader = DataLoader(\n",
    "        TensorDataset(tensors['X_calib'], tensors['Y_calib']),\n",
    "        shuffle=False, batch_size=num_calib, pin_memory=True)\n",
    "    return {'train': train_loader, 'test': test_loader, 'calib': calib_loader}\n",
    "\n",
    "def get_tensors(\n",
    "    shuffle: bool, log_prices: bool\n",
    ") -> tuple[dict[str, Tensor | np.ndarray], str | tuple[np.ndarray, np.ndarray]]:\n",
    "   \n",
    "    X, Y, dates = load_data()\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray]\n",
    "    if log_prices:\n",
    "        Y = np.log(Y)\n",
    "        y_info = 'log'\n",
    "    else:\n",
    "        Y_mean = Y.mean(axis=0)\n",
    "        Y_std = Y.std(axis=0)\n",
    "        Y_1 = (Y - Y_mean) / Y_std\n",
    "        y_info = (Y_mean, Y_std)\n",
    "    tensors = get_traincalib_test_split(X, Y, dates, shuffle=shuffle)\n",
    "    return tensors, y_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfbd26a1-0eaf-48fd-af1b-8ee574a50cde",
   "metadata": {
    "papermill": {
     "duration": 0.041949,
     "end_time": "2025-09-22T14:13:16.682601",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.640652",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tensors_tc, y_info = get_tensors(shuffle=True, log_prices=False)\n",
    "tensors, dates = get_train_calib_split(tensors_tc, seed=seed)\n",
    "\n",
    "X_train, Y_train = tensors['X_train'], tensors['Y_train']\n",
    "X_cal,   Y_cal   = tensors['X_calib'], tensors['Y_calib']   \n",
    "X_test,  Y_test  = tensors['X_test'],  tensors['Y_test']\n",
    "loaders = get_loaders(tensors, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e4648ff-c62e-4d94-a4ad-22f778c821a9",
   "metadata": {
    "papermill": {
     "duration": 0.015137,
     "end_time": "2025-09-22T14:13:16.736107",
     "exception": false,
     "start_time": "2025-09-22T14:13:16.720970",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EllipsoidalUncertaintyModel(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, n=2, enable_implicit=False, dropout_rate=0.5):\n",
    "        super().__init__()\n",
    "        self.n = n\n",
    "        self.enable_implicit = enable_implicit\n",
    "\n",
    "        if isinstance(hidden_dims, int):\n",
    "            hidden_dims = [hidden_dims] * 3\n",
    "        layers, in_dim = [], input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(dropout_rate)]  \n",
    "            in_dim = h\n",
    "        self.backbone = nn.Sequential(*layers)\n",
    "\n",
    "        # outputs\n",
    "        self.fc_mu = nn.Linear(in_dim, n)\n",
    "        self.fc_L  = nn.Linear(in_dim, n*(n+1)//2)\n",
    "\n",
    "    def forward(self, x, cvxpylayer=None):\n",
    "        B = x.size(0)\n",
    "        h = self.backbone(x)\n",
    "        mu = self.fc_mu(h)\n",
    "\n",
    "        l_flat = self.fc_L(h)\n",
    "        L = x.new_zeros((B, self.n, self.n))\n",
    "        idx = 0\n",
    "        for i in range(self.n):\n",
    "            for j in range(i+1):\n",
    "                v = l_flat[:, idx]\n",
    "                L[:, i, j] = F.softplus(v) + (1e-3 if i==j else 0)\n",
    "                idx += 1\n",
    "        Sigma = L @ L.transpose(1, 2)\n",
    "\n",
    "        if not self.enable_implicit:\n",
    "            return mu, L, Sigma\n",
    "\n",
    "        if cvxpylayer is None:\n",
    "            raise ValueError(\"enable_implicit=False\")\n",
    "        \n",
    "        eye = torch.eye(self.n, device=x.device, dtype=Sigma.dtype)\n",
    "        Pp  = torch.linalg.cholesky(Sigma + 1e-3 * eye)     \n",
    "        q   = mu                                             \n",
    "        Pp_in = Pp.cpu()\n",
    "        q_in  = q.cpu()\n",
    "\n",
    "        z_in_opt, z_out_opt, z_state_opt = cvxpylayer(Pp_in, q_in)\n",
    "        z_in_opt   = z_in_opt.to(x.device).float()\n",
    "        z_out_opt  = z_out_opt.to(x.device).float()\n",
    "        z_state_opt= z_state_opt.to(x.device).float()\n",
    "\n",
    "        return mu, Sigma, (z_in_opt, z_out_opt, z_state_opt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c799f1d0-9941-4f43-bfed-6c297e42c4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributions import MultivariateNormal\n",
    "\n",
    "model = EllipsoidalUncertaintyModel(\n",
    "    input_dim=101, hidden_dims=[20, 20], n=24,\n",
    "    enable_implicit=False, dropout_rate = 0.0\n",
    ")\n",
    "model.train()  \n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-3)\n",
    "\n",
    "for epoch in range(1,600 + 1):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    for x_cpu, y_cpu in loaders['train']:\n",
    "        x = x_cpu.to(device, non_blocking=True)\n",
    "        y = y_cpu.to(device, non_blocking=True)\n",
    "\n",
    "        mu, _, Sigma = model(x)\n",
    "        Sigma = Sigma + 1e-3 * torch.eye(24, device=device)\n",
    "        dist = MultivariateNormal(mu, Sigma)\n",
    "        loss = -dist.log_prob(y).mean()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item() * x.size(0)\n",
    "\n",
    "    avg = total_loss / len(loaders['train'].dataset)\n",
    "    print(f\"Epoch {epoch:>2d} | NLL Loss: {avg:.4f}\")\n",
    "\n",
    "os.makedirs('outputs/model/baseline', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "torch.save(model.state_dict(), save_path)\n",
    "print(f\"Saved model to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f4cf718-6aa4-4a24-8b96-a26de416f846",
   "metadata": {
    "papermill": {
     "duration": 0.035475,
     "end_time": "2025-09-22T14:16:07.828130",
     "exception": false,
     "start_time": "2025-09-22T14:16:07.792655",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_storage_layer(T: int, const) -> CvxpyLayer:\n",
    "    lam, eps, eff, c_in, c_out, B = const\n",
    "    z_state = cp.Variable(T + 1, name='z_state', nonneg=True)\n",
    "    z_in    = cp.Variable(T,     name='z_in',    nonneg=True)\n",
    "    z_out   = cp.Variable(T,     name='z_out',   nonneg=True)\n",
    "    \n",
    "    Pp = cp.Parameter((T, T), PSD=True)\n",
    "    q  = cp.Parameter(T)\n",
    "\n",
    "    constraints = [\n",
    "        z_state[0] == B/2,\n",
    "        z_state[1:] == z_state[:-1] + eff * z_in - z_out,   \n",
    "        z_state <= B,                                       \n",
    "        z_in  <= c_in,\n",
    "        z_out <= c_out,\n",
    "    ]\n",
    "    lin_term = cp.norm(Pp @ (z_in - z_out), 2) + q.T @ (z_in - z_out)\n",
    "    quad_term = (\n",
    "        lam * cp.sum_squares(z_state - B/2) +\n",
    "        eps * cp.sum_squares(z_in) +\n",
    "        eps * cp.sum_squares(z_out))\n",
    "    \n",
    "    obj = cp.Minimize(lin_term + quad_term)\n",
    "    problem = cp.Problem(obj, constraints)\n",
    "    assert problem.is_dpp('dcp')   \n",
    "    return CvxpyLayer(problem, parameters=[Pp, q], variables=[z_in, z_out, z_state])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e65366-7e58-42bf-b769-0ccd3c5258bf",
   "metadata": {
    "papermill": {
     "duration": 0.040462,
     "end_time": "2025-09-22T14:16:08.782733",
     "exception": false,
     "start_time": "2025-09-22T14:16:08.742271",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def update_theta(model, cvxpylayer, opt_h, x, y, lambda_param, const, alpha, sigma, eps = 1e-3):\n",
    "    lam, eps, eff, c_in, c_out, B = const\n",
    "    mu, Sigma, z = model(x, cvxpylayer)\n",
    "    Sigma_j = Sigma + eps*torch.eye(model.n, device=device)\n",
    "\n",
    "    L = torch.linalg.cholesky(Sigma_j)\n",
    "    z_0 = z[0] - z[1]\n",
    "    term1 = torch.norm((L @ z_0.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu*z_0).sum(dim=1)\n",
    "    term3 = (\n",
    "    lam * torch.norm(z[2] - B / 2, p=2, dim=1) ** 2 +  \n",
    "    eps * torch.norm(z[0], p=2, dim=1) ** 2 + \n",
    "    eps * torch.norm(z[1], p=2, dim=1) ** 2 )\n",
    "    f = (term1 + term2 + term3).mean()\n",
    "\n",
    "    term = term1 + term2\n",
    "    approx_ind = 0.5*(1 + torch.erf((term - (y*z_0).sum(dim=1)) / (sigma*(2**0.5))))\n",
    "    g = lambda_param * ((1 - alpha) - approx_ind.mean())\n",
    "\n",
    "    loss = f + g\n",
    "    opt_h.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_h.step()\n",
    "    return loss.item()\n",
    "\n",
    "def update_lambda(model, cvxpylayer, opt_l, x, y, lambda_param, const, alpha, sigma, eps = 1e-3):\n",
    "    lam, eps, eff, c_in, c_out, B = const\n",
    "    mu, Sigma, _ = model(x, cvxpylayer)\n",
    "    mu_det     = mu.detach()\n",
    "    Sigma_det  = (Sigma + eps * torch.eye(model.n, device=device)).detach()\n",
    "\n",
    "    L_j = torch.linalg.cholesky(Sigma_det)\n",
    "    z_in, z_out, z_state = cvxpylayer(L_j.cpu(), mu_det.cpu())\n",
    "    z_out = z_out.to(device)\n",
    "    z_state = z_state.to(device)\n",
    "    z_0 = z_in - z_out\n",
    "\n",
    "    term1 = torch.norm((L_j.to(device) @ z_0.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu_det * z_0).sum(dim=1)\n",
    "    term3 = (\n",
    "    lam * torch.norm(z_state - B / 2, p=2, dim=1) ** 2 +  \n",
    "    eps * torch.norm(z_in, p=2, dim=1) ** 2 + \n",
    "    eps * torch.norm(z_out, p=2, dim=1) ** 2 )\n",
    "    f2 = (term1 + term2 + term3).mean()\n",
    "\n",
    "    term = term1 + term2\n",
    "    indicator = ((term - (y*z_0).sum(dim=1)) >= 0).float()\n",
    "    g2 = lambda_param * ((1 - alpha) - indicator.mean())\n",
    "\n",
    "    loss = -(f2 + g2)\n",
    "    opt_l.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_l.step()\n",
    "    with torch.no_grad():\n",
    "        lambda_param.clamp_(min=0.0)\n",
    "    return loss.item()\n",
    "\n",
    "def total_loss(model, cvxpylayer, x, y, lambda_param, const, alpha, sigma=0.1, eps = 1e-3):\n",
    "    lam, eps, eff, c_in, c_out, B = const\n",
    "    mu, Sigma, z = model(x, cvxpylayer)\n",
    "    Sigma_j = Sigma + eps * torch.eye(model.n, device=device)\n",
    "    L = torch.linalg.cholesky(Sigma_j)\n",
    "    z_0 = z[0] - z[1]\n",
    "    term1 = torch.norm((L @ z_0.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu*z_0).sum(dim=1)\n",
    "    term3 = (\n",
    "    lam * torch.norm(z[2] - B / 2, p=2, dim=1) ** 2 +  \n",
    "    eps * torch.norm(z[0], p=2, dim=1) ** 2 + \n",
    "    eps * torch.norm(z[1], p=2, dim=1) ** 2 )\n",
    "    f = (term1 + term2 + term3).mean()\n",
    "\n",
    "    indicator = ((term1 + term2) - (y*z_0).sum(dim=1) >= 0).float()\n",
    "    g = lambda_param * ((1 - alpha) - indicator.mean())\n",
    "    return f + g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c89e753-24ec-4c4a-8b88-9cf2c9471b9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "T   = 24\n",
    "B   = 1.0\n",
    "eff = 0.9       \n",
    "c_in  = 0.5\n",
    "c_out = 0.2\n",
    "lam = 0.1\n",
    "eps = 0.05\n",
    "const = (lam, eps, eff, c_in, c_out, B) \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "total = []\n",
    "prev_loss = None\n",
    "consecutive_count = 0\n",
    "n = 24\n",
    "cvxpylayer = create_storage_layer(T=n, const = const)\n",
    "\n",
    "model = EllipsoidalUncertaintyModel(\n",
    "    input_dim=101,      \n",
    "    hidden_dims=[20, 20],   \n",
    "    n=n,\n",
    "    enable_implicit=True    \n",
    ").to(device)\n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "\n",
    "opt_h = torch.optim.Adam(model.parameters(), lr=1e-3) \n",
    "lambda_param = nn.Parameter(torch.tensor(1.0, device=device), requires_grad=True)\n",
    "opt_l = torch.optim.Adam([lambda_param], lr=1e-3)\n",
    "\n",
    "N = X_cal.shape[0]\n",
    "n = int(N/2)\n",
    "X_cal_1 = X_cal[:n]\n",
    "Y_cal_1 = Y_cal[:n]\n",
    "X_cal_2 = X_cal[n:]\n",
    "Y_cal_2 = Y_cal[n:]\n",
    "\n",
    "alpha = 0.1\n",
    "sigma = 0.1\n",
    "epochs = 60 \n",
    "\n",
    "for epoch in range(1, epochs+1):\n",
    "    l_t = update_theta(model, cvxpylayer, opt_h, X_cal_1, Y_cal_1, lambda_param, const, alpha, sigma)\n",
    "    l_l = update_lambda(model, cvxpylayer, opt_l, X_cal_2, Y_cal_2, lambda_param, const, alpha, sigma)\n",
    "    tot = total_loss(model, cvxpylayer, X_cal, Y_cal, lambda_param, const, alpha, sigma)\n",
    "    total.append(tot.item())\n",
    "    print(f\"Epoch {epoch:02d} | loss_θ={l_t:.6f} | loss_λ={l_l:.6f} | total={tot:.6f} | λ={lambda_param.item():.4f}\")\n",
    "    \n",
    "    if prev_loss is not None and abs(prev_loss - tot.item()) < 3e-4:\n",
    "        consecutive_count += 1\n",
    "    else:\n",
    "        consecutive_count = 0  \n",
    "    if consecutive_count >= 10:\n",
    "        print(f\"Early stopping at epoch {epoch} due to small loss change for 10 consecutive epochs.\")\n",
    "        break\n",
    "    prev_loss = tot.item()\n",
    "\n",
    "os.makedirs('outputs/model/final', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'final', 'final_model.pth')\n",
    "torch.save(model.state_dict(), save_path)\n",
    "print(f\"Saved final model to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9da1cae-7417-46d6-9429-386f0123fa5b",
   "metadata": {
    "papermill": {
     "duration": 0.034186,
     "end_time": "2025-09-22T14:16:07.990635",
     "exception": false,
     "start_time": "2025-09-22T14:16:07.956449",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def compute_loss_components(z_in, z_out, z_state, B, lam, eps):\n",
    "    \n",
    "    state_loss = lam * torch.sum((z_state - B / 2) ** 2, dim=1)  \n",
    "    in_loss = eps * torch.sum(z_in ** 2, dim=1)  \n",
    "    out_loss = eps * torch.sum(z_out ** 2, dim=1)  \n",
    "\n",
    "    return state_loss, in_loss, out_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656c0046-9a75-46ad-bb36-473203897413",
   "metadata": {
    "papermill": {
     "duration": 0.691653,
     "end_time": "2025-09-22T14:25:23.108917",
     "exception": false,
     "start_time": "2025-09-22T14:25:22.417264",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "model.enable_implicit = True\n",
    "const = (0.1, 0.05, 0.9, 0.5, 0.2, 1.0) \n",
    "cvxpylayer = create_storage_layer(T=T, const=const)  \n",
    "\n",
    "X_test_cpu = X_test.to(device, non_blocking=True)\n",
    "Y_test_cpu = Y_test.to(device, non_blocking=True)\n",
    "\n",
    "with torch.no_grad():\n",
    "    mu_test, Sigma_test, z_test = model(X_test_cpu, cvxpylayer)\n",
    "\n",
    "eps = 1e-3\n",
    "I = torch.eye(model.n, device=device)\n",
    "loss = compute_loss_components(z_test[0], z_test[1], z_test[2], B, lam, eps)\n",
    "loss_0 = loss[0] + loss [1] + loss[2]\n",
    "\n",
    "Sigma_j   = Sigma_test + eps * I.unsqueeze(0)   \n",
    "Sigma_inv = torch.linalg.inv(Sigma_j)           \n",
    "\n",
    "delta = Y_test_cpu - mu_test                     \n",
    "d2    = torch.einsum('bi,bij,bj->b', delta, Sigma_inv, delta)\n",
    "in_set = d2 <= 1.0\n",
    "\n",
    "pct_in_0 = in_set.float().mean().item() * 100\n",
    "Sigma_half = torch.linalg.cholesky(Sigma_test + eps * I.unsqueeze(0)) \n",
    "\n",
    "z_col_in  = z_test[0].unsqueeze(-1)\n",
    "z_col_out  = z_test[1].unsqueeze(-1)\n",
    "z_col = z_col_in - z_col_out\n",
    "\n",
    "Sigma_z     = torch.matmul(Sigma_half, z_col)            \n",
    "norm_Sigma_z = torch.norm(Sigma_z, p=2, dim=(1,2))       \n",
    "mu_z = (mu_test * (z_test[0]-z_test[1]) ).sum(dim=1)                     \n",
    "\n",
    "risk_loss    = norm_Sigma_z + mu_z  + loss_0              \n",
    "Average_Risk = risk_loss.mean().item()\n",
    "\n",
    "decision_loss = (Y_test_cpu * (z_test[0]-z_test[1])).sum(dim=1)  + loss_0 \n",
    "Average_Loss  = decision_loss.mean().item()\n",
    "robustness = (decision_loss <= risk_loss).float().mean().item() * 100\n",
    "\n",
    "print(f\"Average_Risk: {Average_Risk:.4f}\")\n",
    "print(f\"Marginal Coverage: {pct_in_0:.2f}%\")\n",
    "print(f\"Average_Loss: {Average_Loss:.4f}\")\n",
    "print(f\"Robustness: {robustness:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8a91ac-516d-4e03-8645-391855da1fd7",
   "metadata": {
    "papermill": {
     "duration": 0.037525,
     "end_time": "2025-09-22T14:25:23.182845",
     "exception": false,
     "start_time": "2025-09-22T14:25:23.145320",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_CRC'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_CRC.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {pct_in_0:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1dbc539-58c9-4d00-93ae-fd45da653c0e",
   "metadata": {
    "papermill": {
     "duration": 0.031383,
     "end_time": "2025-09-22T14:25:23.245217",
     "exception": false,
     "start_time": "2025-09-22T14:25:23.213834",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### CRO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbf9392-0f04-4e8b-8d9c-485ae46d2a9a",
   "metadata": {
    "papermill": {
     "duration": 0.047534,
     "end_time": "2025-09-22T14:25:23.328314",
     "exception": false,
     "start_time": "2025-09-22T14:25:23.280780",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "device      = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "input_dim   = 101      \n",
    "n_outputs   = 24       \n",
    "hidden_dim  = [20, 20]\n",
    "eps = 1e-6\n",
    "\n",
    "model = EllipsoidalUncertaintyModel(\n",
    "        input_dim=input_dim,\n",
    "        hidden_dims=hidden_dim,\n",
    "        n=n_outputs,\n",
    "        enable_implicit=False).to(device)\n",
    "model.load_state_dict(torch.load(\"outputs/model/baseline/baseline_model.pth\", map_location=device))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b72b22-0d43-4160-a425-df5da0ebc302",
   "metadata": {
    "papermill": {
     "duration": 0.092522,
     "end_time": "2025-09-22T14:25:23.452242",
     "exception": false,
     "start_time": "2025-09-22T14:25:23.359720",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "Y_cal_pred, _, cal_Sigma = model(X_cal)      \n",
    "Sigma_inv = torch.linalg.inv(cal_Sigma + eps * torch.eye(cal_Sigma.size(-1), device=device))\n",
    "cal_diff = (Y_cal - Y_cal_pred)             \n",
    "cal_score = torch.einsum('bi,bij,bj->b', cal_diff, Sigma_inv, cal_diff)\n",
    "S_cal = torch.sqrt(cal_score.clamp(min=1e-12))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae5f9ebd-e113-40bc-a52a-74307af0493a",
   "metadata": {
    "papermill": {
     "duration": 0.064511,
     "end_time": "2025-09-22T14:25:23.553197",
     "exception": false,
     "start_time": "2025-09-22T14:25:23.488686",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "n0 = X_cal.shape[0]\n",
    "Y_test_pred, _, test_Sigma = model(X_test)      \n",
    "Sigma_inv = torch.linalg.inv(test_Sigma + eps * torch.eye(test_Sigma.size(-1), device=device))\n",
    "test_diff = (Y_test - Y_test_pred)             \n",
    "test_score = torch.einsum('bi,bij,bj->b', test_diff, Sigma_inv, test_diff)\n",
    "S_test = torch.sqrt(test_score.clamp(min=1e-12))\n",
    "\n",
    "marginal_quantile = torch.quantile(S_cal , 0.9 * (1 + 1 / n0)).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cb491f8-0092-43f1-9bbf-a84f6d51ff2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta = 1e-2\n",
    "tol = 1e-4\n",
    "B = 1.0\n",
    "gamma = 0.9\n",
    "c_in  = 0.5\n",
    "c_out = 0.2\n",
    "lam   = 0.1\n",
    "eps_reg = 0.05      \n",
    "\n",
    "z_in_value    = []\n",
    "z_out_value   = []\n",
    "z_state_value = []\n",
    "obj_value     = []\n",
    "\n",
    "\n",
    "n_test = Y_test_pred.shape[0]\n",
    "jitter = 1e-2 \n",
    "\n",
    "if test_Sigma.ndim == 2:\n",
    "    T = test_Sigma.shape[0]\n",
    "    I = torch.eye(T, device=test_Sigma.device, dtype=test_Sigma.dtype)\n",
    "    Sigma_inv = torch.inverse(test_Sigma + jitter * I)\n",
    "elif test_Sigma.ndim == 3:\n",
    "    n, T, _ = test_Sigma.shape\n",
    "    I = torch.eye(T, device=test_Sigma.device, dtype=test_Sigma.dtype)\n",
    "    Sigma_inv = torch.inverse(test_Sigma + jitter * I.unsqueeze(0))\n",
    "else:\n",
    "    raise ValueError(\"test_Sigma must be [T,T] or [n,T,T].\")\n",
    "\n",
    "L_torch = torch.linalg.cholesky(Sigma_inv)\n",
    "L = L_torch.detach().cpu().numpy()\n",
    "\n",
    "for i in range(n_test):\n",
    "    y_hat = Y_test_pred[i].detach().cpu().numpy().reshape(-1)   \n",
    "    T_i   = y_hat.shape[0]\n",
    "\n",
    "    if L.ndim == 2:\n",
    "        Li = L\n",
    "    elif L.ndim == 3:\n",
    "        Li = L[i]\n",
    "    else:\n",
    "        raise ValueError(\"Unexpected L shape\")\n",
    "\n",
    "    q = float(marginal_quantile)   \n",
    "\n",
    "    model = ro.Model()\n",
    "\n",
    "    y       = model.rvar(T_i)\n",
    "    z_in    = model.dvar(T_i)\n",
    "    z_out   = model.dvar(T_i)\n",
    "    z_state = model.dvar(T_i)\n",
    "\n",
    "    t_state = model.dvar()\n",
    "    t_in    = model.dvar()\n",
    "    t_out   = model.dvar()\n",
    "\n",
    "    uset = (rso.norm(Li @ (y - y_hat), 2) <= q)\n",
    "\n",
    "    model.st(z_in  >= 0);   model.st(z_in  <= c_in)\n",
    "    model.st(z_out >= 0);   model.st(z_out <= c_out)\n",
    "    model.st(z_state >= 0); model.st(z_state <= B)\n",
    "\n",
    "    model.st(z_state[0] == B/2)\n",
    "    for t in range(1, T_i):\n",
    "        model.st(z_state[t] == z_state[t-1] - z_out[t] + gamma * z_in[t])\n",
    "\n",
    "    e = np.ones(T_i)\n",
    "    model.st(rso.norm(z_state - (B/2) * e, 2) <= t_state)\n",
    "    model.st(rso.norm(z_in,  2) <= t_in)\n",
    "    model.st(rso.norm(z_out, 2) <= t_out)\n",
    "\n",
    "    worst_lin = y @ (z_in - z_out)\n",
    "    \n",
    "    I_T = np.eye(T_i)\n",
    "    r  = model.dvar()  \n",
    "    s1 = model.dvar()  \n",
    "    s2 = model.dvar()  \n",
    "    s3 = model.dvar()  \n",
    "\n",
    "    model.st(s1 >= rso.quad(z_state - (B/2)*e, I_T))\n",
    "    model.st(s2 >= rso.quad(z_in,  I_T))\n",
    "    model.st(s3 >= rso.quad(z_out, I_T))\n",
    "    model.st(r >= lam * s1 + eps_reg * (s2 + s3))\n",
    "\n",
    "    model.minmax(worst_lin + r, uset)\n",
    "    model.solve(SOLVER, display=False)\n",
    "\n",
    "    z_in_value.append(z_in.get())\n",
    "    z_out_value.append(z_out.get())\n",
    "    z_state_value.append(z_state.get())\n",
    "    obj_value.append(model.get())\n",
    "\n",
    "    print(f\"Test_sample: {i}, Completed!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a4ae64-5aaa-4088-aa14-2d751298eb19",
   "metadata": {
    "papermill": {
     "duration": 0.39141,
     "end_time": "2025-09-22T14:26:09.685364",
     "exception": false,
     "start_time": "2025-09-22T14:26:09.293954",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def stack_to_tensor(xs, device, dtype=torch.float32):\n",
    "    arr = np.vstack(xs)  \n",
    "    return torch.as_tensor(arr, dtype=dtype, device=device)\n",
    "    \n",
    "z_in_t    = stack_to_tensor(z_in_value,    device)  \n",
    "z_out_t   = stack_to_tensor(z_out_value,   device)   \n",
    "z_state_t = stack_to_tensor(z_state_value, device)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b951756b-7558-4667-9d3d-20e5e35ce819",
   "metadata": {
    "papermill": {
     "duration": 0.248948,
     "end_time": "2025-09-22T14:26:10.497764",
     "exception": false,
     "start_time": "2025-09-22T14:26:10.248816",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "const = (0.1, 0.05, 0.9, 0.5, 0.2, 1.0)    \n",
    "X_test_gpu = X_test.to(device, non_blocking=True)\n",
    "Y_test_gpu = Y_test.to(device, non_blocking=True)\n",
    "\n",
    "eps = 1e-3\n",
    "I = torch.eye(24, device=device)\n",
    "loss = compute_loss_components(z_in_t, z_out_t, z_state_t, B, lam, eps)\n",
    "loss_0 = loss[0] + loss [1] + loss[2]\n",
    "\n",
    "Sigma_j   = test_Sigma + eps * I.unsqueeze(0)  \n",
    "Sigma_inv = torch.linalg.inv(Sigma_j)  \n",
    "Sigma_half = torch.linalg.cholesky(Sigma_j + eps * I.unsqueeze(0)) \n",
    "\n",
    "delta = Y_test_gpu - Y_test_pred                    \n",
    "d2    = torch.einsum('bi,bij,bj->b', delta, Sigma_inv, delta)\n",
    "in_set = d2 <= marginal_quantile**2\n",
    "pct_in = in_set.float().mean().item() * 100\n",
    "\n",
    "z_col_in  = z_in_t.unsqueeze(-1)\n",
    "z_col_out  = z_out_t.unsqueeze(-1)\n",
    "z_col = z_col_in - z_col_out\n",
    "\n",
    "Sigma_z     = torch.matmul(Sigma_half, z_col)            \n",
    "norm_Sigma_z = torch.norm(Sigma_z, p=2, dim=(1,2))       \n",
    "mu_z = (Y_test_pred * (z_in_t - z_out_t) ).sum(dim=1)                     \n",
    "\n",
    "risk_loss_1    = marginal_quantile * norm_Sigma_z + mu_z + loss_0              \n",
    "Average_Risk_1 = risk_loss_1.mean().item()\n",
    "\n",
    "decision_loss_1 = (Y_test_gpu * (z_in_t - z_out_t)).sum(dim=1) + loss_0  \n",
    "Average_Loss_1  = decision_loss_1.mean().item()\n",
    "robustness_1 = (decision_loss_1 <= risk_loss_1).float().mean().item() * 100\n",
    "\n",
    "print(f\"Marginal Coverage: {pct_in:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk_1:.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss_1:.4f}\")\n",
    "print(f\"Robustness: {robustness_1:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0049060d-384d-4472-a637-c40eb70de705",
   "metadata": {
    "papermill": {
     "duration": 0.045718,
     "end_time": "2025-09-22T14:26:11.138410",
     "exception": false,
     "start_time": "2025-09-22T14:26:11.092692",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_Two_step'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_Two_step.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {pct_in:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk_1:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss_1:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_1:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914a16d2-451b-4d85-b514-253a8337405e",
   "metadata": {
    "papermill": {
     "duration": 0.200639,
     "end_time": "2025-09-22T14:26:11.610747",
     "exception": false,
     "start_time": "2025-09-22T14:26:11.410108",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### End-to-end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f6877a-f96b-4b3c-b651-bfde3755dd9a",
   "metadata": {
    "papermill": {
     "duration": 0.055188,
     "end_time": "2025-09-22T14:26:12.013456",
     "exception": false,
     "start_time": "2025-09-22T14:26:11.958268",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class TransformDataset(Dataset):\n",
    "    def __init__(self, X, Y):\n",
    "        self.X = torch.tensor(X, dtype=torch.float32)  \n",
    "        self.Y = torch.tensor(Y, dtype=torch.float32) \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.Y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415765b2-3ad4-4f2f-87a6-bed381ddf9df",
   "metadata": {
    "papermill": {
     "duration": 0.318988,
     "end_time": "2025-09-22T14:26:12.597806",
     "exception": false,
     "start_time": "2025-09-22T14:26:12.278818",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_dataset = TransformDataset(X_train, Y_train)\n",
    "cal_dataset = TransformDataset(X_cal, Y_cal)\n",
    "test_dataset = TransformDataset(X_test, Y_test)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "cal_loader = DataLoader(cal_dataset, batch_size=256, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)\n",
    "\n",
    "all_Y = np.concatenate([Y_train, Y_cal, Y_test], axis=0)\n",
    "y_mean = all_Y.mean(axis=0)  \n",
    "y_std = all_Y.std(axis=0)   \n",
    "y_mean = np.array(y_mean)\n",
    "y_std = np.array(y_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17d18a3-b2dd-4653-8b02-f017508eb1a5",
   "metadata": {
    "papermill": {
     "duration": 0.484142,
     "end_time": "2025-09-22T14:26:13.664896",
     "exception": false,
     "start_time": "2025-09-22T14:26:13.180754",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BaseProblemProtocol(Protocol):\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    vars: dict[str, cp.Variable]\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vars = self.primal_vars | self.dual_vars\n",
    "        self.constraints.extend(self.dual_constraints)\n",
    "        obj = self.dual_obj + self.f_tilde\n",
    "\n",
    "        prob = cp.Problem(cp.Minimize(obj), self.constraints)\n",
    "        assert prob.is_dpp()\n",
    "        self.prob = prob\n",
    "\n",
    "    def task_loss_np(self, y: np.ndarray, is_standardized: bool) -> float:\n",
    "        ...\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, is_standardized: bool, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        ...\n",
    "\n",
    "    def get_cvxpylayer(self) -> CvxpyLayer:\n",
    "        return CvxpyLayer(\n",
    "            self.prob,\n",
    "            parameters=list(self.params.values()),\n",
    "            variables=list(self.vars.values())\n",
    "        )\n",
    "\n",
    "class EllipsoidProblemProtocol(BaseProblemProtocol, Protocol):\n",
    "    def solve(self, loc: np.ndarray, scale_tril: np.ndarray) -> cp.Problem:\n",
    "        ...\n",
    "\n",
    "class EllipsoidProblem:\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(\n",
    "        self, y_dim: int, y_mean: np.ndarray, y_std: np.ndarray, Fz: cp.Expression\n",
    "    ):\n",
    "        self.y_mean = y_mean\n",
    "        self.y_std = y_std\n",
    "\n",
    "        loc = cp.Parameter(y_dim, name='loc')\n",
    "        scale_tril = cp.Parameter((y_dim, y_dim), name='scale_tril')\n",
    "        self.params = {\n",
    "            'loc': loc,\n",
    "            'scale_tril': scale_tril,\n",
    "        }\n",
    "\n",
    "        Fz_ystd = Fz\n",
    "\n",
    "        self.dual_obj = (\n",
    "            cp.norm(scale_tril @ Fz_ystd)\n",
    "            + loc @ Fz_ystd)\n",
    "\n",
    "        self.dual_vars = {}\n",
    "        self.dual_constraints = []\n",
    "\n",
    "    def solve(self, loc: np.ndarray, scale_tril: np.ndarray) -> cp.Problem:\n",
    "        self.params['loc'].value = loc\n",
    "        self.params['scale_tril'].value = scale_tril\n",
    "        prob = self.prob\n",
    "        prob.solve()\n",
    "        if prob.status != 'optimal':\n",
    "            print('Problem status:', prob.status)\n",
    "        return prob\n",
    "\n",
    "class Constants(NamedTuple):\n",
    "    lam: float\n",
    "    eps: float\n",
    "    eff: float\n",
    "    c_in: float\n",
    "    c_out: float\n",
    "    B: float\n",
    "\n",
    "\n",
    "DEFAULT_CONSTANTS = Constants(\n",
    "    lam=0.1,\n",
    "    eps=0.05,\n",
    "    eff=0.9,\n",
    "    c_in=0.5,\n",
    "    c_out=0.2,\n",
    "    B=1\n",
    ")\n",
    "\n",
    "\n",
    "class StorageProblemBase:\n",
    "   \n",
    "    const: Constants\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    prob: cp.Problem\n",
    "    y_mean: np.ndarray\n",
    "    y_std: np.ndarray\n",
    "\n",
    "    def __init__(self, T: int, const: Constants = DEFAULT_CONSTANTS):\n",
    "        self.const = const\n",
    "        lam, eps, eff, c_in, c_out, B = self.const\n",
    "\n",
    "        z_state = cp.Variable(T + 1, name='z_state', nonneg=True)\n",
    "        z_in = cp.Variable(T, name='z_in', nonneg=True)\n",
    "        z_out = cp.Variable(T, name='z_out', nonneg=True)\n",
    "        self.primal_vars = {\n",
    "            'z_in': z_in,\n",
    "            'z_out': z_out,\n",
    "            'z_state': z_state,\n",
    "        }\n",
    "\n",
    "        self.constraints = [\n",
    "            z_state[0] == B/2,\n",
    "            z_state[1:] == z_state[:-1] + eff * z_in - z_out,\n",
    "            z_state <= B,\n",
    "\n",
    "            z_in <= c_in,\n",
    "            z_out <= c_out,\n",
    "        ]\n",
    "\n",
    "        self.Fz = z_in - z_out\n",
    "\n",
    "        self.f_tilde = (\n",
    "            lam * cp.norm2(z_state - B/2)**2\n",
    "            + eps * cp.norm2(z_in)**2\n",
    "            + eps * cp.norm2(z_out)**2\n",
    "        )\n",
    "\n",
    "    def task_loss_np(self, scale_tril: np.ndarray, mu: np.ndarray) -> float:\n",
    "        \n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        lam = self.const.lam\n",
    "        eps = self.const.eps\n",
    "        B = self.const.B\n",
    "\n",
    "        z_in = self.primal_vars['z_in'].value\n",
    "        z_out = self.primal_vars['z_out'].value\n",
    "        z_state = self.primal_vars['z_state'].value\n",
    "\n",
    "        task_loss = (\n",
    "            np.linalg.norm(scale_tril @ (z_in - z_out), 2) + np.dot(mu, (z_in - z_out))\n",
    "            + lam * np.linalg.norm(z_state - B/2)**2\n",
    "            + eps * np.linalg.norm(z_in)**2\n",
    "            + eps * np.linalg.norm(z_out)**2\n",
    "        )\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_np_1(self, y: np.ndarray) -> float:\n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        lam = self.const.lam\n",
    "        eps = self.const.eps\n",
    "        B = self.const.B\n",
    "\n",
    "        z_in = self.primal_vars['z_in'].value\n",
    "        z_out = self.primal_vars['z_out'].value\n",
    "        z_state = self.primal_vars['z_state'].value\n",
    "        \n",
    "        task_loss = (\n",
    "            y @ (z_in - z_out)\n",
    "            + lam * np.linalg.norm(z_state - B/2)**2\n",
    "            + eps * np.linalg.norm(z_in)**2\n",
    "            + eps * np.linalg.norm(z_out)**2\n",
    "        )\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, scale_tril: np.ndarray, mu: np.ndarray, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        \n",
    "        z_in, z_out, z_state = solution[:3]\n",
    "\n",
    "        lam = self.const.lam\n",
    "        eps = self.const.eps\n",
    "        B = self.const.B\n",
    "\n",
    "        task_loss = (\n",
    "            torch.norm((scale_tril @ (z_in - z_out).unsqueeze(-1)).squeeze(-1),dim=1)  + (mu * (z_in - z_out)).sum(dim=-1)\n",
    "            + lam * torch.norm(z_state - B/2, dim=-1)**2\n",
    "            + eps * torch.norm(z_in, dim=-1)**2\n",
    "            + eps * torch.norm(z_out, dim=-1)**2\n",
    "        )\n",
    "        assert task_loss.shape == mu.shape[:-1]\n",
    "        return task_loss\n",
    "\n",
    "class StorageProblemEllipsoid(StorageProblemBase, EllipsoidProblem, EllipsoidProblemProtocol):\n",
    "    def __init__(self, T: int, y_mean: np.ndarray, y_std: np.ndarray):\n",
    "        StorageProblemBase.__init__(self, T=T)\n",
    "        EllipsoidProblem.__init__(self, y_dim=T, y_mean=y_mean, y_std=y_std, Fz=self.Fz)\n",
    "        EllipsoidProblemProtocol.__init__(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c932ba85-85b0-4d7c-92e9-e50d023aa287",
   "metadata": {
    "papermill": {
     "duration": 0.388255,
     "end_time": "2025-09-22T14:26:14.740921",
     "exception": false,
     "start_time": "2025-09-22T14:26:14.352666",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def mahalanobis_dist2(loc: Tensor, scale_tril: Tensor, y: Tensor) -> Tensor:\n",
    "    return tdist.multivariate_normal._batch_mahalanobis(bL=scale_tril, bx=y - loc)\n",
    "\n",
    "def calc_q(scores: Tensor, alpha: float) -> Tensor:\n",
    "    n = len(scores)\n",
    "    j = int(np.ceil((n+1) * (1-alpha)))\n",
    "    if j > n:\n",
    "        return torch.tensor(torch.inf)\n",
    "    sorted_inds = torch.argsort(scores)\n",
    "    q = scores[sorted_inds[j-1]]\n",
    "    return q\n",
    "\n",
    "def conformal_q(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    alpha: float,\n",
    "    device: str = 'cpu'\n",
    ") -> Tensor:\n",
    "   \n",
    "    model.to(device)\n",
    "    all_scores = []\n",
    "    for x, y in loader:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "        loc, scale_tril, _ = model(x)\n",
    "        \n",
    "        scores = mahalanobis_dist2(loc, scale_tril, y)\n",
    "        all_scores.append(scores)\n",
    "\n",
    "    scores = torch.cat(all_scores)\n",
    "    q = calc_q(scores, alpha)\n",
    "    assert q != torch.inf, 'Size of calibration set is too small'\n",
    "    return q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43924fa-4568-449d-b0ae-1210f3972f2c",
   "metadata": {
    "papermill": {
     "duration": 0.228579,
     "end_time": "2025-09-22T14:26:15.421324",
     "exception": false,
     "start_time": "2025-09-22T14:26:15.192745",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_e2e_epoch(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    alpha: float,\n",
    "    nll_loss_frac: float,\n",
    "    rng: np.random.Generator,\n",
    "    optimizer: torch.optim.Optimizer,\n",
    "    show_pbar: bool = False\n",
    ") -> tuple[float, float, float]:\n",
    "   \n",
    "    assert 0 < alpha < 1\n",
    "    model.train()\n",
    "\n",
    "    cvxpylayer = prob.get_cvxpylayer()\n",
    "\n",
    "    total_loss = 0.\n",
    "    total_nll_loss = 0.\n",
    "    total_task_loss = 0.\n",
    "    total_num_tasks = 0\n",
    "    for x, y in tqdm(loader) if show_pbar else loader:\n",
    "        batch_size = x.shape[0]\n",
    "        loc, scale_tril, _ = model(x)\n",
    "\n",
    "        loss = torch.tensor(0.)\n",
    "        msgs = []\n",
    "\n",
    "        if nll_loss_frac > 0:\n",
    "            nll_loss = -tdist.MultivariateNormal(loc, scale_tril=scale_tril).log_prob(y).mean()\n",
    "            loss += nll_loss_frac * nll_loss\n",
    "            total_nll_loss += nll_loss.item() * batch_size\n",
    "            msgs.append(f'nll_loss: {nll_loss.item()}')\n",
    "\n",
    "        if nll_loss_frac < 1:\n",
    "            perm = rng.permutation(batch_size)\n",
    "\n",
    "            cal_inds = perm[:batch_size//2]\n",
    "            scores_cal = mahalanobis_dist2(loc[cal_inds], scale_tril[cal_inds], y[cal_inds])\n",
    "            q = calc_q(scores_cal, alpha)\n",
    "            if q == torch.inf:\n",
    "                tqdm.write('Batch is too small, skipping')\n",
    "                continue\n",
    "\n",
    "            task_inds = perm[batch_size//2:]\n",
    "            y_task = y[task_inds]\n",
    "            loc_task = loc[task_inds]\n",
    "            scale_tril_task = scale_tril[task_inds] * torch.sqrt(q)\n",
    "\n",
    "            try:\n",
    "                solution = cvxpylayer(loc_task, scale_tril_task)\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                raise e\n",
    "                # import pdb\n",
    "                # pdb.set_trace()\n",
    "\n",
    "            task_loss = prob.task_loss_torch(scale_tril_task, loc_task, solution=solution).mean()\n",
    "            loss += (1 - nll_loss_frac) * task_loss\n",
    "\n",
    "            total_task_loss += task_loss.item() * task_inds.shape[0]\n",
    "            total_num_tasks += task_inds.shape[0]\n",
    "\n",
    "            msgs.append(f'task_loss: {task_loss.item()}')\n",
    "\n",
    "        if show_pbar:\n",
    "            tqdm.write(','.join(msgs))\n",
    "        total_loss += loss.item() * batch_size\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    avg_loss = total_loss / len(loader.dataset)\n",
    "    avg_nll_loss = total_nll_loss / len(loader.dataset)\n",
    "    avg_task_loss = total_task_loss / total_num_tasks\n",
    "    return avg_loss, avg_nll_loss, avg_task_loss\n",
    "\n",
    "\n",
    "def train_e2e(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    alpha: float,\n",
    "    max_epochs: int,\n",
    "    lr: float,\n",
    "    l2reg: float,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    rng: np.random.Generator,\n",
    "    nll_loss_frac: float | Sequence[float],\n",
    "    saved_model_path: str = ''\n",
    ") -> dict[str, Any]:\n",
    "  \n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2reg)\n",
    "    if saved_model_path != '':\n",
    "        tqdm.write(f'Loading saved model: {saved_model_path}')\n",
    "        model.load_state_dict(torch.load(saved_model_path, weights_only=True))\n",
    "\n",
    "    result: dict[str, Any] = {\n",
    "        'train_e2e_losses': [],\n",
    "        'train_nll_losses': [],\n",
    "        'train_task_losses': [],\n",
    "        'val_task_losses': [],\n",
    "        'best_epoch': 0,\n",
    "        'val_task_loss': np.inf,  \n",
    "    }\n",
    "    steps_since_decrease = 0\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    pbar = tqdm(range(max_epochs))\n",
    "    for epoch in pbar:\n",
    "        if isinstance(nll_loss_frac, Sequence):\n",
    "            weight = nll_loss_frac[min(epoch, len(nll_loss_frac) - 1)]\n",
    "        else:\n",
    "            weight = nll_loss_frac\n",
    "\n",
    "        train_e2e_loss, train_nll_loss, train_task_loss = train_e2e_epoch(\n",
    "            model, prob=prob, loader=train_loader, alpha=alpha,\n",
    "            nll_loss_frac=weight, rng=rng, optimizer=optimizer)\n",
    "        result['train_e2e_losses'].append(train_e2e_loss)\n",
    "        result['train_nll_losses'].append(train_nll_loss)\n",
    "        result['train_task_losses'].append(train_task_loss)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            q = conformal_q(model, cal_loader, alpha=alpha).item()\n",
    "            val_task_loss = optimize(model, prob=prob, loader=cal_loader, q=q)\n",
    "            result['val_task_losses'].append(val_task_loss)\n",
    "\n",
    "        msg = (f'Epoch {epoch}, train_task_loss {train_task_loss:.3f}, '\n",
    "               f'val_task_loss {val_task_loss:.3f}, q {q:.3f}')\n",
    "        pbar.set_description(msg)\n",
    "\n",
    "        steps_since_decrease += 1\n",
    "\n",
    "        if val_task_loss < result['val_task_loss']:\n",
    "            result['best_epoch'] = epoch\n",
    "            result['val_task_loss'] = val_task_loss\n",
    "            steps_since_decrease = 0\n",
    "            buffer.seek(0)\n",
    "            torch.save(model.state_dict(), buffer)\n",
    "\n",
    "        if steps_since_decrease > 10:\n",
    "            break\n",
    "\n",
    "    buffer.seek(0)\n",
    "    model.load_state_dict(torch.load(buffer, weights_only=True))\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6b2779-1707-450b-915c-02d854be30bf",
   "metadata": {
    "papermill": {
     "duration": 0.044407,
     "end_time": "2025-09-22T14:26:16.070383",
     "exception": false,
     "start_time": "2025-09-22T14:26:16.025976",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def optimize(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    q: float,\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "\n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    task_losses = []\n",
    "    if show_pbar:\n",
    "        pbar = tqdm(total=len(loader.dataset))\n",
    "    for x_batch, y_batch in loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "        pred = model(x_batch)\n",
    "        loc_batch = pred[0].detach().cpu().numpy()\n",
    "        scale_tril_batch = pred[1].detach().cpu().numpy()\n",
    "        y_batch = y_batch.detach().cpu().numpy()\n",
    "\n",
    "        scale_tril_batch *= np.sqrt(q)\n",
    "\n",
    "        for y, loc, scale_tril in zip(y_batch, loc_batch, scale_tril_batch):\n",
    "            prob.solve(loc, scale_tril)\n",
    "            task_loss = prob.task_loss_np(scale_tril, loc)\n",
    "            task_losses.append(task_loss)\n",
    "            if show_pbar:\n",
    "                pbar.update(1)\n",
    "\n",
    "    return np.mean(task_losses).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d57ea1-e1ce-439a-a0ad-6fe321d49b9d",
   "metadata": {
    "papermill": {
     "duration": 0.099702,
     "end_time": "2025-09-22T14:26:16.405200",
     "exception": false,
     "start_time": "2025-09-22T14:26:16.305498",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy, math, itertools, torch, numpy as np\n",
    "\n",
    "alpha = 0.1\n",
    "max_epochs = 50\n",
    "nll_loss_frac = 0.0\n",
    "lr_list    = [1e-4]\n",
    "l2_list    = [1e-3, 1e-4]\n",
    "\n",
    "best = {\n",
    "    \"val_task_loss\": float(\"inf\"),\n",
    "    \"lr\": None,\n",
    "    \"l2reg\": None,\n",
    "    \"epoch\": None,\n",
    "    \"state_dict\": None,\n",
    "}\n",
    "all_results = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d83d21b-ba57-42d6-926e-87776d1b99f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "for lr in lr_list:\n",
    "    for l2 in l2_list:\n",
    "        model = EllipsoidalUncertaintyModel(input_dim=101, hidden_dims = [20, 20], n=24, enable_implicit=False)  \n",
    "        model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "        model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "        prob = StorageProblemEllipsoid(T=24, y_mean=y_mean, y_std=y_std)\n",
    "\n",
    "        rng = np.random.default_rng(42)\n",
    "        print(f\"lr: {lr } | l2={l2}\")\n",
    "        \n",
    "        res = train_e2e(\n",
    "            model=model,\n",
    "            alpha=alpha,\n",
    "            max_epochs=max_epochs,\n",
    "            lr=lr,\n",
    "            l2reg=l2,\n",
    "            prob=prob,\n",
    "            rng=rng,\n",
    "            nll_loss_frac=nll_loss_frac,\n",
    "            saved_model_path=''\n",
    "        )\n",
    "\n",
    "        record = {\n",
    "            \"lr\": lr,\n",
    "            \"l2reg\": l2,\n",
    "            \"best_epoch\": res[\"best_epoch\"],\n",
    "            \"val_task_loss\": res[\"val_task_loss\"],\n",
    "        }\n",
    "        all_results.append(record)\n",
    "\n",
    "        if res[\"val_task_loss\"] < best[\"val_task_loss\"] :\n",
    "            best.update(record)\n",
    "            best[\"state_dict\"] = copy.deepcopy(model.state_dict())\n",
    "\n",
    "print(\"\\n=== Grid Search Summary ===\")\n",
    "all_results = sorted(all_results, key=lambda r: r[\"val_task_loss\"])\n",
    "for r in all_results:\n",
    "    print(f\"lr={r['lr']:g}, l2={r['l2reg']:g}, \"\n",
    "          f\"best_epoch={r['best_epoch']}, val_task_loss={r['val_task_loss']:.6f}\")\n",
    "\n",
    "print(\"\\n>>> Best config:\")\n",
    "print(f\"lr={best['lr']:g}, l2={best['l2reg']:g}, \"\n",
    "      f\"best_epoch={best['best_epoch']}, val_task_loss={best['val_task_loss']:.6f}\")\n",
    "\n",
    "best_model = EllipsoidalUncertaintyModel(input_dim=101, hidden_dims = [20, 20], n=24, enable_implicit=False) \n",
    "best_model.load_state_dict(best[\"state_dict\"])\n",
    "\n",
    "os.makedirs('outputs/model/E2E', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'E2E', 'E2E_model.pth')\n",
    "torch.save(best_model.state_dict(), save_path)\n",
    "print(\"Saved best model to best_model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b183982-1989-439a-b017-8e891595ce33",
   "metadata": {
    "papermill": {
     "duration": 0.044902,
     "end_time": "2025-09-22T15:08:10.734226",
     "exception": false,
     "start_time": "2025-09-22T15:08:10.689324",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def optimize_test(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    q: float,\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "    \n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    task_losses = []\n",
    "    expr_losses = []\n",
    "    z_value = []       \n",
    "    count_smaller = 0  \n",
    "    total_points = 0  \n",
    "    count_within_q = 0\n",
    "    \n",
    "    if show_pbar:\n",
    "        pbar = tqdm(total=len(loader.dataset))\n",
    "    for x_batch, y_batch in loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "        pred = model(x_batch)\n",
    "        loc_batch = pred[0].detach().cpu().numpy()\n",
    "        scale_tril_batch = pred[1].detach().cpu().numpy()\n",
    "        y_batch = y_batch.detach().cpu().numpy()\n",
    "\n",
    "        scale_tril_batch = np.sqrt(q) * scale_tril_batch\n",
    "\n",
    "        for y, loc, scale_tril in zip(y_batch, loc_batch, scale_tril_batch):\n",
    "            prob.solve(loc, scale_tril)\n",
    "            task_loss = prob.task_loss_np_1(y)\n",
    "            task_losses.append(task_loss)\n",
    "            \n",
    "            Fz_val = np.asarray(prob.Fz.value).reshape(-1)\n",
    "            z_value.append(Fz_val)\n",
    "            Fz_ystd = Fz_val\n",
    "            scale_term = np.linalg.norm(scale_tril @ Fz_ystd)\n",
    "            mean_term  = loc @ Fz_ystd\n",
    "           \n",
    "            expr_loss = prob.task_loss_np(scale_tril, loc)\n",
    "            expr_losses.append(expr_loss)\n",
    "        \n",
    "            if task_loss <= expr_loss:\n",
    "                count_smaller += 1\n",
    "\n",
    "            scale_tril_1 = scale_tril / np.sqrt(q)\n",
    "            Sigma_inv = np.linalg.inv(scale_tril_1 @ scale_tril_1.T) \n",
    "            diff = y - loc\n",
    "            mahalanobis_sq = diff.T @ Sigma_inv @ diff\n",
    "\n",
    "            if mahalanobis_sq <= q:\n",
    "                count_within_q += 1\n",
    "\n",
    "            total_points += 1\n",
    "\n",
    "            if show_pbar:\n",
    "                pbar.update(1)\n",
    "                \n",
    "    avg_loss = np.mean(task_losses)\n",
    "    avg_risk = np.mean(expr_losses)\n",
    "    robustness = count_smaller / total_points * 100\n",
    "    coverage = count_within_q / total_points * 100\n",
    "\n",
    "    return coverage, avg_risk, avg_loss, robustness, z_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac47837-086d-49f7-b370-cc4fb586c70b",
   "metadata": {
    "papermill": {
     "duration": 4.32185,
     "end_time": "2025-09-22T15:08:15.094065",
     "exception": false,
     "start_time": "2025-09-22T15:08:10.772215",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_path = os.path.join('outputs', 'model', 'E2E', 'E2E_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, weights_only=True))\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    q = conformal_q(model, cal_loader, alpha=alpha).item()\n",
    "prob = StorageProblemEllipsoid(T=24, y_mean=y_mean, y_std=y_std)\n",
    "coverage, avg_risk, avg_loss, robustness_3, z_value = optimize_test(model, prob, test_loader, q=q)\n",
    "\n",
    "print(f\"Coverage: {coverage:.2f}%\")\n",
    "print(f\"Ave_task_loss: {avg_risk:.4f}\")\n",
    "print(f\"Ave_loss: {avg_loss:.4f}\")\n",
    "print(f\"Robustness: {robustness_3:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dae6393-43b6-4df2-81fc-f185f1b65ae4",
   "metadata": {
    "papermill": {
     "duration": 0.042207,
     "end_time": "2025-09-22T15:08:15.176761",
     "exception": false,
     "start_time": "2025-09-22T15:08:15.134554",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_E2E'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_E2E.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {coverage:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {avg_risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {avg_loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_3:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6452e5-6bb3-44f5-9a0e-a1b23ff8a1b6",
   "metadata": {
    "papermill": {
     "duration": 0.038462,
     "end_time": "2025-09-22T15:08:15.983920",
     "exception": false,
     "start_time": "2025-09-22T15:08:15.945458",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 3306.165671,
   "end_time": "2025-09-22T15:08:18.908576",
   "environment_variables": {},
   "exception": null,
   "input_path": "CRC_GPU_batter.ipynb",
   "output_path": "outputs/file/CRC_GPU_stock0.ipynb",
   "parameters": {
    "run_id": 0
   },
   "start_time": "2025-09-22T14:13:12.742905",
   "version": "2.6.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "0281c6f3298646a0b4411e58c34d3287": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "192a8109d4744b3d9a34cfa9af671661": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "3eee15a71207484783203893ce3ddbb9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "46c1b3fb272b49aca09ad850e745a8ee": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_0281c6f3298646a0b4411e58c34d3287",
       "placeholder": "​",
       "style": "IPY_MODEL_e3a440e628714b81a890b289bb00327e",
       "tabbable": null,
       "tooltip": null,
       "value": " 11/50 [19:52&lt;1:01:50, 95.15s/it]"
      }
     },
     "48c7a411e8ff40afb4a656bba2e6a907": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "50e048dc6a7f4e1883764cfcc0f47e1e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "514db8568c0849dc8728a60bb9fa45af": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "566c70f680ac41de8be14fe0195279da": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "6270fe14dc11489b80db8dd92fc59345": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_50e048dc6a7f4e1883764cfcc0f47e1e",
       "placeholder": "​",
       "style": "IPY_MODEL_566c70f680ac41de8be14fe0195279da",
       "tabbable": null,
       "tooltip": null,
       "value": " 11/50 [22:01&lt;1:12:32, 111.59s/it]"
      }
     },
     "733d28e25e734d2197527b0fcbc09348": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_9fa3933b7e5b42c793018b0bb094cc3e",
        "IPY_MODEL_c404e641b15b4a6dad2c0ce6c1cd9fe5",
        "IPY_MODEL_46c1b3fb272b49aca09ad850e745a8ee"
       ],
       "layout": "IPY_MODEL_f7d343660f5c4942abbff21ff962f292",
       "tabbable": null,
       "tooltip": null
      }
     },
     "805cb91299bd4107af0181b94cf7fd96": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_48c7a411e8ff40afb4a656bba2e6a907",
       "max": 50,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_c3217893cc7b4383b544bdc4d4ff0528",
       "tabbable": null,
       "tooltip": null,
       "value": 11
      }
     },
     "886afaac09804ff99813c899ab0d770b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8990062cefab42b48d14091eb25644af": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_9817eeceb9e3484583eb93bb9addf36d",
        "IPY_MODEL_805cb91299bd4107af0181b94cf7fd96",
        "IPY_MODEL_6270fe14dc11489b80db8dd92fc59345"
       ],
       "layout": "IPY_MODEL_b0e37f5f903e47b7a6efcdf77671ac80",
       "tabbable": null,
       "tooltip": null
      }
     },
     "9817eeceb9e3484583eb93bb9addf36d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_192a8109d4744b3d9a34cfa9af671661",
       "placeholder": "​",
       "style": "IPY_MODEL_e175207d970145a8a4387a13d1c5e61b",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 11, train_task_loss -0.170, val_task_loss -0.727, q 72.258:  22%"
      }
     },
     "9fa3933b7e5b42c793018b0bb094cc3e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_514db8568c0849dc8728a60bb9fa45af",
       "placeholder": "​",
       "style": "IPY_MODEL_3eee15a71207484783203893ce3ddbb9",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 11, train_task_loss -0.124, val_task_loss -0.743, q 70.125:  22%"
      }
     },
     "b0e37f5f903e47b7a6efcdf77671ac80": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "c3217893cc7b4383b544bdc4d4ff0528": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "c404e641b15b4a6dad2c0ce6c1cd9fe5": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_886afaac09804ff99813c899ab0d770b",
       "max": 50,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_dce8653f9bb94858b5e1a725b153dcbd",
       "tabbable": null,
       "tooltip": null,
       "value": 11
      }
     },
     "dce8653f9bb94858b5e1a725b153dcbd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "e175207d970145a8a4387a13d1c5e61b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "e3a440e628714b81a890b289bb00327e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "f7d343660f5c4942abbff21ff962f292": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
