{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4bcfc9fd-c1e1-4d83-9fd2-0949976723b7",
   "metadata": {
    "papermill": {
     "duration": 0.007914,
     "end_time": "2025-09-22T07:30:10.704635",
     "exception": false,
     "start_time": "2025-09-22T07:30:10.696721",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Conformal Robustness Control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9070f7a-1dc8-459e-8177-54c7a006115d",
   "metadata": {},
   "source": [
    "### CRC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027da74f-f82a-4655-bfbf-94fcbaa6fb97",
   "metadata": {
    "papermill": {
     "duration": 2.636018,
     "end_time": "2025-09-22T07:30:13.369523",
     "exception": false,
     "start_time": "2025-09-22T07:30:10.733505",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import io\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pytz\n",
    "import math\n",
    "import cvxpy as cp\n",
    "import rsome as rso\n",
    "\n",
    "from rsome import ro\n",
    "from rsome import msk_solver as SOLVER\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.distributions import MultivariateNormal\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "from typing import Protocol\n",
    "from typing import NamedTuple\n",
    "from torch import distributions as tdist, nn, Tensor\n",
    "from collections.abc import Mapping\n",
    "from torch import Tensor\n",
    "from collections.abc import Sequence\n",
    "from torch.utils.data import TensorDataset\n",
    "from collections.abc import Mapping, Sequence\n",
    "from pandas.tseries.holiday import USFederalHolidayCalendar\n",
    "from datetime import datetime, timedelta\n",
    "from tqdm.auto import tqdm\n",
    "from typing import Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f8ca6a-17af-49fc-b3ea-7d20e26605dc",
   "metadata": {
    "papermill": {
     "duration": 0.013446,
     "end_time": "2025-09-22T07:30:13.393975",
     "exception": false,
     "start_time": "2025-09-22T07:30:13.380529",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "seed = random.randint(0, 99) \n",
    "print(\"Device:\", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eefd3914-9881-4079-8e71-0dcfe7b19ade",
   "metadata": {
    "papermill": {
     "duration": 0.021308,
     "end_time": "2025-09-22T07:30:13.461545",
     "exception": false,
     "start_time": "2025-09-22T07:30:13.440237",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_test_scaler(train, val, test):\n",
    "    scaler = StandardScaler()\n",
    "    normalized_x_train = pd.DataFrame(scaler.fit_transform(train))\n",
    "    normalized_x_val = pd.DataFrame(scaler.transform(val))\n",
    "    normalized_x_test = pd.DataFrame(scaler.transform(test))\n",
    "    return normalized_x_train, normalized_x_val, normalized_x_test\n",
    "\n",
    "\n",
    "def reshuffle(train, val, test, perm):\n",
    "    a, b, c = len(train), len(val), len(test)\n",
    "    assert len(perm) == a + b + c\n",
    "    combined = pd.concat([train, val, test], axis='index')\n",
    "    train = combined.iloc[perm[:a]]\n",
    "    val = combined.iloc[perm[a:a+b]]\n",
    "    test = combined.iloc[perm[a+b:a+b+c]]\n",
    "    return train, val, test\n",
    "\n",
    "def portfolio_data_gen(year, seed, shuffled):\n",
    "    year_directory = f'yfinance/2012_samples'\n",
    "    returns_file = f'2012_returns_0.txt'\n",
    "    side_file = f'2012_data_side_0.txt'\n",
    "\n",
    "    returns_file_path = os.path.join(year_directory, returns_file)\n",
    "    side_file_path = os.path.join(year_directory, side_file)\n",
    "\n",
    "    with open(returns_file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "        returns_lists = [list( line.strip().split(',')) for line in lines]\n",
    "\n",
    "    with open(side_file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "        side_info_lists = [list( line.strip().split(',')) for line in lines]\n",
    "\n",
    "    returns_cols = returns_lists[0]\n",
    "    print('returns_cols:',returns_cols)\n",
    "    side_info_cols = side_info_lists[0]\n",
    "    print('side_info_cols:',side_info_cols)\n",
    "\n",
    "    returns = pd.read_csv('yfinance/expected_return.csv')\n",
    "    data_side = pd.read_csv('yfinance/side_info.csv')\n",
    "\n",
    "    def create_train_val_test(df, year, col_list):\n",
    "\n",
    "        df_sub = df[col_list].copy()\n",
    "        df_sub['DATE'] = pd.to_datetime(df_sub['DATE'])\n",
    "        df_sub['year'] = df_sub['DATE'].dt.year\n",
    "        df_sub['year'] = df_sub['year'].astype(int)\n",
    "\n",
    "        start_year = int(year)\n",
    "        train_end_year = start_year + 1\n",
    "        val_start_year = train_end_year + 1\n",
    "        val_end_year = val_start_year + 4\n",
    "        test_start_year= val_end_year + 1\n",
    "\n",
    "        df_train = df_sub[(df_sub.year >= start_year) & (df_sub.year <= train_end_year)].copy()\n",
    "        df_val = df_sub[(df_sub.year >= val_start_year) & (df_sub.year <= val_end_year)].copy()\n",
    "        df_test = df_sub[df_sub.year >= test_start_year].copy()\n",
    "\n",
    "        df_train = df_train.drop(['DATE', 'year'], axis=1)\n",
    "        df_val = df_val.drop(['DATE', 'year'], axis=1)\n",
    "        df_test = df_test.drop(['DATE', 'year'], axis=1)\n",
    "        return df_train, df_val, df_test\n",
    "\n",
    "\n",
    "    returns_cols += ['DATE']\n",
    "    side_info_cols += ['DATE']\n",
    "    returns_train, returns_val, returns_test = create_train_val_test(returns, year, returns_cols)\n",
    "    side_train, side_val, side_test = create_train_val_test(data_side, year, side_info_cols)\n",
    "\n",
    "    if shuffled:\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        perm = rng.permutation(len(returns_train) + len(returns_val) + len(returns_test))\n",
    "        returns_train, returns_val, returns_test = reshuffle(returns_train, returns_val, returns_test, perm)\n",
    "        side_train, side_val, side_test = reshuffle(side_train, side_val, side_test, perm)\n",
    "\n",
    "    returns_means = returns_train.mean()\n",
    "    returns_train.fillna(returns_means, inplace=True)\n",
    "    returns_val.fillna(returns_means, inplace=True)\n",
    "    returns_test.fillna(returns_means, inplace=True)\n",
    "    side_means = side_train.mean()\n",
    "    side_train.fillna(side_means, inplace=True)\n",
    "    side_val.fillna(side_means, inplace=True)\n",
    "    side_test.fillna(side_means, inplace=True)\n",
    "\n",
    "    returns_train_s = torch.from_numpy(100*returns_train.values)\n",
    "    returns_val_s   = torch.from_numpy(100*returns_val.values)\n",
    "    returns_test_s  = torch.from_numpy(100*returns_test.values)\n",
    "\n",
    "    side_train_s, side_val_s, side_test_s = train_test_scaler(side_train, side_val, side_test)\n",
    "    side_train_s, side_val_s, side_test_s = torch.from_numpy(side_train_s.values), torch.from_numpy(side_val_s.values), torch.from_numpy(side_test_s.values)\n",
    "\n",
    "    return side_train_s, side_val_s, side_test_s, returns_train_s, returns_val_s, returns_test_s\n",
    "\n",
    "def get_data_and_loaders(\n",
    "    batch_size: int, year: int, seed: int, shuffled: bool\n",
    ") -> tuple[\n",
    "    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,\n",
    "    dict[str, DataLoader], tuple[np.ndarray, np.ndarray]\n",
    "]:\n",
    "   \n",
    "    X_train, X_cal, X_test, y_train, y_cal, y_test = portfolio_data_gen(year, seed, shuffled)\n",
    "\n",
    "    X_train = X_train.to(torch.float32)\n",
    "    X_cal   = X_cal.to(torch.float32)\n",
    "    X_test  = X_test.to(torch.float32)\n",
    "    y_train = y_train.to(torch.float32)\n",
    "    y_cal   = y_cal.to(torch.float32)\n",
    "    y_test  = y_test.to(torch.float32)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        all_y = torch.cat([y_train, y_cal, y_test], dim=0)\n",
    "        y_mean = all_y.mean(dim=0)\n",
    "        y_std  = all_y.std(dim=0)\n",
    "        y_info = (y_mean.numpy(), y_std.numpy())\n",
    "\n",
    "        y_train_1 = (y_train - y_mean) / y_std\n",
    "        y_cal_1   = (y_cal   - y_mean) / y_std\n",
    "        y_test_1  = (y_test  - y_mean) / y_std\n",
    "\n",
    "    train_loader = DataLoader(TensorDataset(X_train, y_train), shuffle=True,  batch_size=batch_size)\n",
    "    calib_loader = DataLoader(TensorDataset(X_cal,   y_cal),   shuffle=False, batch_size=batch_size)\n",
    "    test_loader  = DataLoader(TensorDataset(X_test,  y_test),  shuffle=False, batch_size=batch_size)\n",
    "    loaders_dict = {'train': train_loader, 'test': test_loader, 'calib': calib_loader}\n",
    "\n",
    "    return X_train, y_train, X_cal, y_cal, X_test, y_test, loaders_dict, y_info\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5c54e29-5e02-4c3a-94be-1735ac7c5f32",
   "metadata": {
    "papermill": {
     "duration": 0.105329,
     "end_time": "2025-09-22T07:30:13.575519",
     "exception": false,
     "start_time": "2025-09-22T07:30:13.470190",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "year = 2012        \n",
    "seed = seed          \n",
    "shuffled = True   \n",
    "\n",
    "X_train, y_train, X_cal, y_cal, X_test, y_test, loaders_dict, y_info = get_data_and_loaders(batch_size, year, seed, shuffled)\n",
    "\n",
    "print(X_train.shape, y_train.shape)  \n",
    "print(X_cal.shape,   y_cal.shape)    \n",
    "print(X_test.shape,  y_test.shape)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f4f3b8-2996-4058-80b6-da0727156fc8",
   "metadata": {
    "papermill": {
     "duration": 0.014019,
     "end_time": "2025-09-22T07:30:13.599936",
     "exception": false,
     "start_time": "2025-09-22T07:30:13.585917",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "import torch\n",
    "\n",
    "def create_box_layer_bounds(n: int,\n",
    "                            p: int = None,\n",
    "                            F_const: torch.Tensor | None = None,\n",
    "                            simplex: bool = True,\n",
    "                            ) -> CvxpyLayer:\n",
    "    if F_const is None:\n",
    "        p = n if p is None else p\n",
    "    else:\n",
    "        F_const = F_const.detach().cpu()\n",
    "        p = F_const.shape[1]\n",
    "\n",
    "    z = cp.Variable(p)                \n",
    "    v = cp.Variable(n)             \n",
    "\n",
    "    y_lo = cp.Parameter(n, name='y_lo')\n",
    "    y_hi = cp.Parameter(n, name='y_hi')\n",
    "\n",
    "    if F_const is None:\n",
    "        F = cp.Constant(np.eye(n, p))  \n",
    "    else:\n",
    "        F = cp.Constant(F_const.numpy())  \n",
    "\n",
    "    cons = [v >= 0]\n",
    "    cons += [v + F @ z >= 0]              \n",
    "\n",
    "    if simplex:\n",
    "        cons += [cp.sum(z) == 1, z >= 0]\n",
    "\n",
    "    c = y_hi - y_lo\n",
    "    obj = c @ v - y_lo @ (F @ z)\n",
    "\n",
    "    prob = cp.Problem(cp.Minimize(obj), cons)\n",
    "    assert prob.is_dpp('dcp')\n",
    "\n",
    "    return CvxpyLayer(prob, parameters=[y_lo, y_hi], variables=[z, v])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2405853-3b8c-42a2-a342-c70bd02d0663",
   "metadata": {
    "papermill": {
     "duration": 0.015286,
     "end_time": "2025-09-22T07:30:13.623770",
     "exception": false,
     "start_time": "2025-09-22T07:30:13.608484",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BoxQuantileModel(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, n=15, dropout_rate=0.5, enable_implicit=True):\n",
    "        super().__init__()\n",
    "        self.n = n\n",
    "        self.enable_implicit = enable_implicit\n",
    "\n",
    "        if isinstance(hidden_dims, int):\n",
    "            hidden_dims = [hidden_dims]*3\n",
    "        layers, in_dim = [], input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(dropout_rate)]\n",
    "            in_dim = h\n",
    "        self.backbone   = nn.Sequential(*layers)\n",
    "        self.head_lo    = nn.Linear(in_dim, n)\n",
    "        self.head_delta = nn.Linear(in_dim, n)\n",
    "\n",
    "    def forward(self, x, box_layer=None, q_hat=None):\n",
    "        B = x.size(0)\n",
    "        h = self.backbone(x)\n",
    "\n",
    "        y_lo  = self.head_lo(h)                        \n",
    "        delta = F.softplus(self.head_delta(h), beta=50) \n",
    "        y_hi  = y_lo + delta\n",
    "\n",
    "        if q_hat is not None:\n",
    "            if not torch.is_tensor(q_hat):\n",
    "                q_hat = torch.tensor(q_hat, dtype=y_lo.dtype, device=x.device)\n",
    "            assert q_hat.dim() == 0, \"q Error\"\n",
    "            y_lo_adj = y_lo - q_hat\n",
    "            y_hi_adj = y_hi + q_hat\n",
    "        else:\n",
    "            y_lo_adj, y_hi_adj = y_lo, y_hi\n",
    "\n",
    "        y_hi_adj = torch.maximum(y_hi_adj, y_lo_adj + 1e-8)\n",
    "\n",
    "        if (not self.enable_implicit) or (box_layer is None):\n",
    "            return y_lo_adj, y_hi_adj, None\n",
    "\n",
    "        z_star, v_star = box_layer(y_lo_adj.cpu(), y_hi_adj.cpu())\n",
    "        return y_lo_adj, y_hi_adj, (z_star.to(x.device), v_star.to(x.device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aedf628-1d5b-46de-8583-fbab96c88e5c",
   "metadata": {
    "papermill": {
     "duration": 0.011939,
     "end_time": "2025-09-22T07:30:13.644152",
     "exception": false,
     "start_time": "2025-09-22T07:30:13.632213",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class PinballLoss(nn.Module):\n",
    "    def __init__(self, q: float):\n",
    "        super().__init__()\n",
    "        self.q = q\n",
    "    def forward(self, pred, target):\n",
    "        e = target - pred                \n",
    "        loss = torch.maximum(self.q*e, (self.q-1)*e)  \n",
    "        return loss.sum(dim=-1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22fc1cca-71e6-498d-8557-c0405c9ae106",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BoxQuantileModel(input_dim=21, hidden_dims=[20,20], n=15, dropout_rate=0.5, enable_implicit=False)\n",
    "device = torch.device(\"cpu\")\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "alpha  = 0.1\n",
    "\n",
    "q_lo   = alpha/2.0\n",
    "q_hi   = 1.0 - alpha/2.0\n",
    "loss_lo = PinballLoss(q_lo)\n",
    "loss_hi = PinballLoss(q_hi)\n",
    "\n",
    "for epoch in range(1, 300+1):\n",
    "    model.train()\n",
    "    total = 0.0\n",
    "    for x_cpu, y_cpu in loaders_dict['train']:\n",
    "        x = x_cpu.to(device, non_blocking=True)\n",
    "        y = y_cpu.to(device, non_blocking=True)\n",
    "\n",
    "        y_lo_pred, y_hi_pred, _ = model(x, box_layer=None, q_hat=None)  \n",
    "        y_hi_pred = torch.maximum(y_hi_pred, y_lo_pred + 1e-8)\n",
    "\n",
    "        loss = loss_lo(y_lo_pred, y) + loss_hi(y_hi_pred, y)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total += loss.item() * x.size(0)\n",
    "\n",
    "    avg = total / len(loaders_dict['train'].dataset)\n",
    "    print(f\"Epoch {epoch:>3d} | Quantile Loss: {avg:.6f}\")\n",
    "\n",
    "os.makedirs('outputs/model/baseline', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "torch.save(model.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e65366-7e58-42bf-b769-0ccd3c5258bf",
   "metadata": {
    "papermill": {
     "duration": 0.021044,
     "end_time": "2025-09-22T07:30:21.285622",
     "exception": false,
     "start_time": "2025-09-22T07:30:21.264578",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def update_theta(\n",
    "    model,\n",
    "    cvxpylayer,\n",
    "    opt_h,\n",
    "    x, y,\n",
    "    lambda_param,\n",
    "    alpha,\n",
    "    sigma,\n",
    "    q_hat=None,         \n",
    "    reg_eps: float = 1e-6\n",
    "):\n",
    "    \n",
    "    y_lo, y_hi, zv = model(x, box_layer=cvxpylayer, q_hat=None)\n",
    "    if not (isinstance(zv, (tuple, list)) and len(zv) == 2):\n",
    "        raise RuntimeError(\"box_layer need to return (z*, v*)\")\n",
    "    z_star, v_star = zv    \n",
    "    \n",
    "    z_eff = z_star\n",
    "    c = y_hi - y_lo                     \n",
    "\n",
    "    r = (c * v_star).sum(dim=1) - (y_lo * z_eff).sum(dim=1)         \n",
    "\n",
    "    s = (- y * z_star).sum(dim=1)         \n",
    "    approx_ind = 0.5 * (1.0 + torch.erf((r - s) / (sigma * math.sqrt(2.0))))\n",
    "\n",
    "    lam = lambda_param.detach()\n",
    "    g = lam * ((1.0 - alpha) - approx_ind.mean())\n",
    "\n",
    "    loss = r.mean() + g\n",
    "    opt_h.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_h.step()\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "def update_lambda(\n",
    "    model, cvxpylayer, opt_l,\n",
    "    x, y, lambda_param, alpha,\n",
    "    q_hat=None\n",
    "):\n",
    "   \n",
    "    device = x.device\n",
    "\n",
    "    with torch.no_grad():\n",
    "        y_lo_det, y_hi_det, _ = model(x, box_layer=None, q_hat=None)\n",
    "        y_hi_det = torch.maximum(y_hi_det, y_lo_det + 1e-8)\n",
    "\n",
    "    out = cvxpylayer(y_lo_det.cpu(), y_hi_det.cpu())\n",
    "    \n",
    "    if isinstance(out, (tuple, list)) and len(out) == 2:\n",
    "        z_star, v_star = out\n",
    "    else:\n",
    "        z_star, v_star = out, None\n",
    "    z_star = z_star.to(device)\n",
    "    v_star = v_star.to(device) \n",
    "\n",
    "    z_eff = z_star\n",
    "    c     = y_hi_det - y_lo_det\n",
    "    r  = (c * v_star).sum(dim=1) - (y_lo_det * z_eff).sum(dim=1) \n",
    "    s     = (- y * z_star).sum(dim=1)               \n",
    "    indicator = (s <= r).float()\n",
    "\n",
    "    obj = r.mean() + lambda_param * ((1.0 - alpha) - indicator.mean())\n",
    "    loss = -obj\n",
    "    opt_l.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_l.step()\n",
    "    with torch.no_grad():\n",
    "        lambda_param.clamp_(min=0.0)\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "def total_loss(\n",
    "    model, cvxpylayer,\n",
    "    x, y, lambda_param, alpha,\n",
    "    q_hat=None, use_hard_indicator=True,\n",
    "    sigma: float = 0.1, reg_eps: float = 1e-6\n",
    "):\n",
    "\n",
    "    with torch.no_grad():\n",
    "        y_lo, y_hi, zv = model(x, box_layer=cvxpylayer, q_hat=None)\n",
    "        if isinstance(zv, (tuple, list)) and len(zv) == 2:\n",
    "            z_star, v_star = zv\n",
    "        else:\n",
    "            z_star, v_star = zv, None\n",
    "        if v_star is None:\n",
    "            v_star = -z_star\n",
    "\n",
    "        z_eff = z_star\n",
    "        c     = y_hi - y_lo\n",
    "        r  = (c * v_star).sum(dim=1) - (y_lo * z_eff).sum(dim=1)\n",
    "        s     = (- y * z_star).sum(dim=1)\n",
    "        \n",
    "        if use_hard_indicator:\n",
    "            indicator = (s <= r).float()\n",
    "        else:\n",
    "            indicator = 0.5 * (1.0 + torch.erf((r - s) / (sigma * math.sqrt(2.0))))\n",
    "\n",
    "        f = r.mean()\n",
    "        g = lambda_param * ((1.0 - alpha) - indicator.mean())\n",
    "        return (f + g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61316057-e795-423d-80aa-77441fd09bc4",
   "metadata": {
    "papermill": {
     "duration": 0.020886,
     "end_time": "2025-09-22T07:30:21.316963",
     "exception": false,
     "start_time": "2025-09-22T07:30:21.296077",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def compute_q_hat_box_full(\n",
    "    model,                \n",
    "    X_cal: torch.Tensor,   \n",
    "    Y_cal: torch.Tensor,   \n",
    "    alpha: float,\n",
    "    device: torch.device\n",
    ") -> torch.Tensor:\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    x = X_cal.to(device, non_blocking=True)\n",
    "    y = Y_cal.to(device, non_blocking=True)\n",
    "\n",
    "    y_lo, y_hi, _ = model(x, box_layer=None, q_hat=None) \n",
    "    y_hi = torch.maximum(y_hi, y_lo + 1e-8)\n",
    "\n",
    "    s_left  = (y_lo - y).amax(dim=1)   \n",
    "    s_right = (y - y_hi).amax(dim=1)   \n",
    "    scores  = torch.maximum(s_left, s_right)   \n",
    "\n",
    "    q_hat = torch.quantile(scores, 1.0 - alpha)  \n",
    "    return q_hat.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ad36ba9-cb1b-4fb4-8158-644c5faed229",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n = 15\n",
    "\n",
    "cvxpylayer = create_box_layer_bounds(n)   \n",
    "model = BoxQuantileModel(\n",
    "    input_dim=21, hidden_dims=[20,20], n=n, dropout_rate=0.3,\n",
    "    enable_implicit=True\n",
    ").to(device)\n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "\n",
    "opt_h = torch.optim.Adam(model.parameters(), lr=1e-4)  \n",
    "\n",
    "lambda_param = nn.Parameter(torch.tensor(1.0, device=device), requires_grad=True)\n",
    "opt_l = torch.optim.Adam([lambda_param], lr=1e-3)\n",
    "\n",
    "\n",
    "N = X_cal.shape[0]\n",
    "n = int(N/2)\n",
    "X_cal_1 = X_cal[:n];    Y_cal_1 = y_cal[:n]\n",
    "X_cal_2 = X_cal[n:];    Y_cal_2 = y_cal[n:]\n",
    "\n",
    "alpha  = 0.1\n",
    "sigma  = 0.1\n",
    "epochs = 100\n",
    "\n",
    "total_hist = []\n",
    "prev_loss = None\n",
    "consecutive_count = 0\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    model.train()\n",
    "    l_t = update_theta(               \n",
    "        model, cvxpylayer, opt_h,\n",
    "        X_cal_1.to(device), Y_cal_1.to(device),\n",
    "        lambda_param, alpha, sigma,\n",
    "        q_hat=None\n",
    "    )\n",
    "\n",
    "    l_l = update_lambda(\n",
    "        model, cvxpylayer, opt_l,\n",
    "        X_cal_2.to(device), Y_cal_2.to(device),\n",
    "        lambda_param, alpha,\n",
    "        q_hat=None\n",
    "    )\n",
    "\n",
    "    tot = total_loss(\n",
    "        model, cvxpylayer,\n",
    "        X_cal.to(device), y_cal.to(device),\n",
    "        lambda_param, alpha,\n",
    "        q_hat=None, use_hard_indicator=True, sigma=sigma\n",
    "    )\n",
    "    total_hist.append(tot.item())\n",
    "    print(f\"Epoch {epoch:02d} | loss_θ={l_t:.6f} | loss_λ={l_l:.6f} | total={tot:.6f} | λ={lambda_param.item():.4f}\")\n",
    "\n",
    "   \n",
    "    if prev_loss is not None and abs(prev_loss - tot.item()) < 5e-4:\n",
    "        consecutive_count += 1\n",
    "    else:\n",
    "        consecutive_count = 0\n",
    "    if consecutive_count >= 10:\n",
    "        print(f\"Early stopping at epoch {epoch} due to small loss change for 10 consecutive epochs.\")\n",
    "        break\n",
    "    prev_loss = tot.item()\n",
    "\n",
    "os.makedirs('outputs/model/final', exist_ok=True)\n",
    "torch.save(model.state_dict(),\n",
    "           os.path.join('outputs', 'model', 'final', 'final_model.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656c0046-9a75-46ad-bb36-473203897413",
   "metadata": {
    "papermill": {
     "duration": 0.557454,
     "end_time": "2025-09-22T07:35:27.070123",
     "exception": false,
     "start_time": "2025-09-22T07:35:26.512669",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "model.enable_implicit = True\n",
    "\n",
    "X_test_cpu = X_test.to(device, non_blocking=True)\n",
    "Y_test_cpu = y_test.to(device, non_blocking=True)\n",
    "\n",
    "with torch.no_grad():\n",
    "    y_lo, y_hi, zv = model(X_test_cpu, box_layer=cvxpylayer, q_hat=None)\n",
    "\n",
    "y_hi = torch.maximum(y_hi, y_lo)\n",
    "\n",
    "if isinstance(zv, (tuple, list)) and len(zv) == 2:\n",
    "    z_test, _v_test = zv\n",
    "else:\n",
    "    z_test = zv\n",
    "\n",
    "inside = ((Y_test_cpu >= y_lo) & (Y_test_cpu <= y_hi)).all(dim=1)\n",
    "marginal_inside = ((Y_test_cpu >= y_lo) & (Y_test_cpu <= y_hi)).float().mean().item() * 100.0\n",
    "\n",
    "s_pos = F.relu(-z_test)      \n",
    "s_neg = -F.relu(z_test)      \n",
    "h = (_v_test * (y_hi - y_lo) - y_lo * z_test).sum(dim=1)  \n",
    "Average_Risk = h.mean().item()\n",
    "decision_loss = (- Y_test_cpu * z_test).sum(dim=1)     \n",
    "Average_Loss  = decision_loss.mean().item()\n",
    "robustness = (decision_loss <= h).float().mean().item() * 100.0\n",
    "\n",
    "print(f\"Marginal Coverage: {marginal_inside:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk:.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss:.4f}\")\n",
    "print(f\"Robustness: {robustness:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8a91ac-516d-4e03-8645-391855da1fd7",
   "metadata": {
    "papermill": {
     "duration": 0.032232,
     "end_time": "2025-09-22T07:35:27.118637",
     "exception": false,
     "start_time": "2025-09-22T07:35:27.086405",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_CRC'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_CRC.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {marginal_inside:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness:.2f}%\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1dbc539-58c9-4d00-93ae-fd45da653c0e",
   "metadata": {
    "papermill": {
     "duration": 0.016728,
     "end_time": "2025-09-22T07:35:27.149946",
     "exception": false,
     "start_time": "2025-09-22T07:35:27.133218",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### CRO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbf9392-0f04-4e8b-8d9c-485ae46d2a9a",
   "metadata": {
    "papermill": {
     "duration": 0.044727,
     "end_time": "2025-09-22T07:35:27.208290",
     "exception": false,
     "start_time": "2025-09-22T07:35:27.163563",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "model = BoxQuantileModel(\n",
    "    input_dim=21, hidden_dims=[20, 20], n=n, dropout_rate=0.5,\n",
    "    enable_implicit=False\n",
    ").to(device)         \n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b72b22-0d43-4160-a425-df5da0ebc302",
   "metadata": {
    "papermill": {
     "duration": 0.020426,
     "end_time": "2025-09-22T07:35:27.255471",
     "exception": false,
     "start_time": "2025-09-22T07:35:27.235045",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "alpha = 0.1\n",
    "\n",
    "X_cal_cpu  = X_cal.cpu()\n",
    "Y_cal_cpu  = y_cal.cpu()\n",
    "X_test_cpu = X_test.cpu()\n",
    "Y_test_cpu = y_test.cpu()\n",
    "\n",
    "q_hat = compute_q_hat_box_full(model, X_cal_cpu, Y_cal_cpu, alpha, device)\n",
    "y_lo, y_hi, _ = model(X_test_cpu, q_hat = None)\n",
    "y_lo_test, y_hi_test, _ = model(X_test_cpu, q_hat = q_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9102613-eb5c-4362-a43d-fc9eab27932b",
   "metadata": {
    "papermill": {
     "duration": 0.054473,
     "end_time": "2025-09-22T07:35:27.328471",
     "exception": false,
     "start_time": "2025-09-22T07:35:27.273998",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "s_left  = (y_lo - Y_test_cpu).amax(dim=1)  \n",
    "s_right = (Y_test_cpu - y_hi).amax(dim=1)   \n",
    "scores  = torch.maximum(s_left, s_right)   \n",
    "count_1 = 0\n",
    "num = Y_test.shape[0]\n",
    "for i in range(num):\n",
    "    if scores[i] <= q_hat:\n",
    "        count_1+=1\n",
    "print('Marginal Covergae:', count_1/num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "942a4219-1a5d-4940-94be-f2c2faa8d940",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta = 1e-2 \n",
    "tol = 1e-4\n",
    "n = 15\n",
    "\n",
    "z_value = []\n",
    "loss_value = []\n",
    "n_test = Y_test_cpu.shape[0]\n",
    "\n",
    "with torch.no_grad():\n",
    "    y_lo_np = y_lo_test.detach().cpu().numpy().astype(float) \n",
    "    y_hi_np = y_hi_test.detach().cpu().numpy().astype(float)  \n",
    "\n",
    "for i in range(n_test):\n",
    "    lb = y_lo_np[i].ravel()\n",
    "    ub = y_hi_np[i].ravel()\n",
    "\n",
    "    model = ro.Model()\n",
    "    y     = model.rvar(n)   \n",
    "    z_d   = model.dvar(n)   \n",
    "    \n",
    "    uset  = (lb <= y, y <= ub)\n",
    "    model.minmax(- y @ z_d, uset)     \n",
    "    model.st(z_d <= 1)\n",
    "    model.st(z_d >= 0)\n",
    "    model.st(z_d.sum() == 1)\n",
    "\n",
    "    model.solve(SOLVER,display=False)\n",
    "\n",
    "    z_opt = z_d.get()               \n",
    "    z_value.append(z_opt)\n",
    "\n",
    "    print(f\"Test_sample: {i}, Completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df9c98bd-1cfe-46de-91bd-fcbf671d6234",
   "metadata": {},
   "outputs": [],
   "source": [
    "risk_value = []\n",
    "for i in range(n_test):\n",
    "    lb = y_lo_np[i].ravel()\n",
    "    ub = y_hi_np[i].ravel()\n",
    "\n",
    "    model = ro.Model()\n",
    "    y     = model.dvar(n)      \n",
    "    model.max(- y @ z_value[i]) \n",
    "    model.st(lb <= y)\n",
    "    model.st(y <= ub)\n",
    "    model.solve(SOLVER,display=False)\n",
    "\n",
    "    risk = model.get()             \n",
    "    risk_value.append(risk)\n",
    "    print(f\"Test_sample: {i}, Completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a4ae64-5aaa-4088-aa14-2d751298eb19",
   "metadata": {
    "papermill": {
     "duration": 0.08382,
     "end_time": "2025-09-22T07:36:22.283147",
     "exception": false,
     "start_time": "2025-09-22T07:36:22.199327",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def to_torch(x, like):\n",
    "    return torch.as_tensor(x, device=like.device, dtype=like.dtype)\n",
    "\n",
    "Y_t = torch.as_tensor(Y_test_gpu, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.float32)\n",
    "ylo_t = to_torch(y_lo_test, Y_t)\n",
    "yhi_t = to_torch(y_hi_test, Y_t)\n",
    "\n",
    "risk_loss_1    = np.array(risk_value)                \n",
    "Average_Risk_1 = risk_loss_1.mean()  \n",
    "\n",
    "if torch.is_tensor(Y_test_gpu):\n",
    "    Y_t = Y_test_gpu\n",
    "else:\n",
    "    Y_t = torch.as_tensor(\n",
    "        Y_test_gpu,\n",
    "        device='cuda' if torch.cuda.is_available() else 'cpu',\n",
    "        dtype=torch.float32\n",
    "    )\n",
    "\n",
    "N = Y_t.shape[0]\n",
    "z_np = np.stack(z_value[:N], axis=0).astype(np.float32)   \n",
    "z_t  = torch.as_tensor(z_np, device=Y_t.device, dtype=Y_t.dtype)\n",
    "\n",
    "decision_loss_1 = -torch.einsum('ij,ij->i', Y_t, z_t)    \n",
    "Average_Loss_1  = decision_loss_1.mean().item()\n",
    "\n",
    "r_count1 = 0\n",
    "num = y_test.shape[0]\n",
    "for i in range(num):\n",
    "    if decision_loss_1[i] <= risk_loss_1[i]:\n",
    "        r_count1 += 1\n",
    "\n",
    "print(f\"Marginal_Coverage: {(count_1/N)*100:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk_1.item():.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss_1:.4f}\")\n",
    "print(f\"Robustness: {(r_count1/num)*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0049060d-384d-4472-a637-c40eb70de705",
   "metadata": {
    "papermill": {
     "duration": 0.11094,
     "end_time": "2025-09-22T07:36:22.556621",
     "exception": false,
     "start_time": "2025-09-22T07:36:22.445681",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_Two_step'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_Two_step_{run_id}.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {(count_1/N)*100:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk_1:.4f}\\n\")\n",
    "    f.write(f\"VaR              : {VaR_1:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss_1:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {(r_count1/num)*100:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914a16d2-451b-4d85-b514-253a8337405e",
   "metadata": {
    "papermill": {
     "duration": 0.122877,
     "end_time": "2025-09-22T07:36:22.876145",
     "exception": false,
     "start_time": "2025-09-22T07:36:22.753268",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### End-to-end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f6877a-f96b-4b3c-b651-bfde3755dd9a",
   "metadata": {
    "papermill": {
     "duration": 0.076985,
     "end_time": "2025-09-22T07:36:23.207445",
     "exception": false,
     "start_time": "2025-09-22T07:36:23.130460",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, X, Y):\n",
    "        self.X = torch.tensor(X, dtype=torch.float32) \n",
    "        self.Y = torch.tensor(Y, dtype=torch.float32)  \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.Y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415765b2-3ad4-4f2f-87a6-bed381ddf9df",
   "metadata": {
    "papermill": {
     "duration": 0.05012,
     "end_time": "2025-09-22T07:36:23.359634",
     "exception": false,
     "start_time": "2025-09-22T07:36:23.309514",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_dataset = CustomDataset(X_train, y_train)\n",
    "cal_dataset = CustomDataset(X_cal, y_cal)\n",
    "test_dataset = CustomDataset(X_test, y_test)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "loaders = {'train': train_loader, 'test': test_loader, 'calib': cal_loader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e95b21c-642c-4a76-8ef7-ef46f101efe7",
   "metadata": {
    "papermill": {
     "duration": 0.032691,
     "end_time": "2025-09-22T07:36:23.435654",
     "exception": false,
     "start_time": "2025-09-22T07:36:23.402963",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_Y = np.concatenate([y_train, y_cal, y_test], axis=0)\n",
    "y_mean = all_Y.mean(axis=0)  \n",
    "y_std = all_Y.std(axis=0)   \n",
    "y_mean = np.array(y_mean)\n",
    "y_std = np.array(y_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17d18a3-b2dd-4653-8b02-f017508eb1a5",
   "metadata": {
    "papermill": {
     "duration": 0.070138,
     "end_time": "2025-09-22T07:36:23.559346",
     "exception": false,
     "start_time": "2025-09-22T07:36:23.489208",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from collections.abc import Sequence\n",
    "from typing import NamedTuple\n",
    "\n",
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import Tensor\n",
    "\n",
    "from collections.abc import Sequence\n",
    "from typing import Protocol\n",
    "\n",
    "class BaseProblemProtocol(Protocol):\n",
    "\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    vars: dict[str, cp.Variable]\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vars = self.primal_vars | self.dual_vars\n",
    "        self.constraints.extend(self.dual_constraints)\n",
    "        obj = self.f_tilde + self.dual_obj\n",
    "\n",
    "        prob = cp.Problem(cp.Minimize(obj), self.constraints)\n",
    "        assert prob.is_dpp()\n",
    "        self.prob = prob\n",
    "\n",
    "    def task_loss_np(self, y: np.ndarray, is_standardized: bool) -> float:\n",
    "        ...\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, is_standardized: bool, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        ...\n",
    "\n",
    "    def get_cvxpylayer(self) -> CvxpyLayer:\n",
    "        return CvxpyLayer(\n",
    "            self.prob,\n",
    "            parameters=list(self.params.values()),\n",
    "            variables=list(self.vars.values())\n",
    "        )\n",
    "\n",
    "class BoxProblemProtocol(BaseProblemProtocol, Protocol):\n",
    "    def solve(self, pred_lo: np.ndarray, pred_hi: np.ndarray) -> cp.Problem:\n",
    "        ...\n",
    "\n",
    "class BoxProblemV2:\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(\n",
    "        self, y_dim: int, y_mean: np.ndarray, y_std: np.ndarray, Fz: cp.Expression\n",
    "    ):\n",
    "        self.y_mean = y_mean\n",
    "        self.y_std = y_std\n",
    "\n",
    "        nu = cp.Variable(y_dim, nonneg=True)\n",
    "        self.dual_vars = {'nu': nu}\n",
    "\n",
    "        # parameters\n",
    "        y_min = cp.Parameter(y_dim, name='y_min')\n",
    "        y_max = cp.Parameter(y_dim, name='y_max')\n",
    "        self.params = {\n",
    "            'y_min': y_min,\n",
    "            'y_max': y_max,\n",
    "        }\n",
    "\n",
    "        c = Fz\n",
    "\n",
    "        self.dual_obj = (\n",
    "            (y_max - y_min) @ nu\n",
    "            - y_min @ c\n",
    "        )\n",
    "\n",
    "        self.dual_constraints = [nu + c >= 0]\n",
    "\n",
    "    def solve(self, pred_lo: np.ndarray, pred_hi: np.ndarray) -> cp.Problem:\n",
    "    \n",
    "        self.params['y_min'].value = pred_lo\n",
    "        self.params['y_max'].value = pred_hi\n",
    "        prob = self.prob\n",
    "        prob.solve()\n",
    "        if prob.status != 'optimal':\n",
    "            print('Problem status:', prob.status)\n",
    "        v_val = np.asarray(self.dual_vars['nu'].value, dtype=float).ravel() \n",
    "        return prob, v_val\n",
    "\n",
    "class PortfolioProblemBase:\n",
    "  \n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    # to be implemented in subclass\n",
    "    prob: cp.Problem\n",
    "    y_mean: np.ndarray\n",
    "    y_std: np.ndarray\n",
    "\n",
    "    def __init__(self, N: int):\n",
    "        \n",
    "        self.N = N\n",
    "\n",
    "        z = cp.Variable(N, name='z', nonneg=True)\n",
    "        self.primal_vars = {'z': z}\n",
    "\n",
    "        self.constraints = [\n",
    "            cp.sum(z) == 1,\n",
    "        ]\n",
    "\n",
    "        self.Fz = z\n",
    "        self.f_tilde = 0.\n",
    "\n",
    "    def task_loss_np(self,  v: np.ndarray, y_min: np.ndarray, y_max: np.ndarray) -> float:\n",
    "       \n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "            \n",
    "        z = self.primal_vars['z'].value\n",
    "        task_loss = (y_max - y_min) @ v - y_min @ z\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self,  v: torch.Tensor, y_min: torch.Tensor, y_max: torch.Tensor, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "      \n",
    "        z = solution[0]\n",
    "        assert y_min.shape == z.shape\n",
    "        task_loss = ((y_max - y_min) * v).sum(dim=-1) - (y_min * z).sum(dim=-1)\n",
    "        return task_loss\n",
    "\n",
    "class PortfolioProblemBox(PortfolioProblemBase, BoxProblemV2, BoxProblemProtocol):\n",
    "    def __init__(self, N: int, y_mean: np.ndarray, y_std: np.ndarray):\n",
    "        PortfolioProblemBase.__init__(self, N=N)\n",
    "        BoxProblemV2.__init__(self, y_dim=N, y_mean=y_mean, y_std=y_std, Fz=self.Fz)\n",
    "        BoxProblemProtocol.__init__(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c932ba85-85b0-4d7c-92e9-e50d023aa287",
   "metadata": {
    "papermill": {
     "duration": 0.067219,
     "end_time": "2025-09-22T07:36:23.699069",
     "exception": false,
     "start_time": "2025-09-22T07:36:23.631850",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def multidim_quantile_score(pred: torch.Tensor | tuple, y: torch.Tensor, y_dim: int) -> torch.Tensor:\n",
    "   \n",
    "    if isinstance(pred, (tuple, list)):\n",
    "        pred_lo, pred_hi = pred[:2]          \n",
    "    else:\n",
    "        pred_lo = pred[..., :y_dim]\n",
    "        pred_hi = pred[..., y_dim:]\n",
    "\n",
    "  \n",
    "    lo_gap = (pred_lo - y).amax(dim=-1)      \n",
    "    hi_gap = (y - pred_hi).amax(dim=-1)      \n",
    "    return torch.maximum(lo_gap, hi_gap)\n",
    "\n",
    "def calc_q(scores: Tensor, alpha: float) -> Tensor:\n",
    "  \n",
    "    n = len(scores)\n",
    "    j = int(np.ceil((n+1) * (1-alpha)))\n",
    "    if j > n:\n",
    "        return torch.tensor(torch.inf)\n",
    "    sorted_inds = torch.argsort(scores)\n",
    "    q = scores[sorted_inds[j-1]]\n",
    "    return q\n",
    "\n",
    "@torch.no_grad()\n",
    "def conformal_q(model, loader, y_dim: int, alpha: float, device: str = \"cpu\"):\n",
    "    model.eval()\n",
    "    model.to(device)\n",
    "\n",
    "    all_scores = []\n",
    "    for x, y in loader:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "\n",
    "        out = model(x)                              \n",
    "        scores = multidim_quantile_score(out, y, y_dim)\n",
    "        all_scores.append(scores)\n",
    "\n",
    "    scores = torch.cat(all_scores, dim=0)\n",
    "    q = calc_q(scores, alpha)\n",
    "    return q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43924fa-4568-449d-b0ae-1210f3972f2c",
   "metadata": {
    "papermill": {
     "duration": 0.092395,
     "end_time": "2025-09-22T07:36:24.047214",
     "exception": false,
     "start_time": "2025-09-22T07:36:23.954819",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_e2e_epoch(\n",
    "    model: BoxQuantileModel,          \n",
    "    prob: BoxProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    y_dim: int,\n",
    "    alpha: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    quantile_loss_frac: float,\n",
    "    rng: np.random.Generator,\n",
    "    optimizer: torch.optim.Optimizer,\n",
    "    show_pbar: bool = False\n",
    ") -> tuple[float, float, float]:\n",
    "    assert 0 < alpha < 1\n",
    "    model.train()\n",
    "\n",
    "    cvxpylayer = prob.get_cvxpylayer()   \n",
    "    lo_quantile_loss_fn = PinballLoss(alpha/2)\n",
    "    hi_quantile_loss_fn = PinballLoss(1-alpha/2)\n",
    "\n",
    "    total_loss = total_pinball_loss = total_task_loss = 0.0\n",
    "    total_num_tasks = 0\n",
    "\n",
    "    device = next(model.parameters()).device\n",
    "\n",
    "    def make_box_layer(y_info):\n",
    "        if y_info == 'log':\n",
    "            def layer(lo, hi):\n",
    "                return cvxpylayer(torch.exp(lo).cpu(), torch.exp(hi).cpu())  \n",
    "        else:\n",
    "            def layer(lo, hi):\n",
    "                return cvxpylayer(lo.cpu(), hi.cpu())  \n",
    "        return layer\n",
    "\n",
    "    for x, y in tqdm(loader) if show_pbar else loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        B = x.size(0)\n",
    "\n",
    "        y_lo, y_hi, _ = model(x, box_layer=None, q_hat=None)  \n",
    "        loss = torch.zeros((), device=device)\n",
    "\n",
    "        if quantile_loss_frac > 0:\n",
    "            lo_q = lo_quantile_loss_fn(y_lo, y)\n",
    "            hi_q = hi_quantile_loss_fn(y_hi, y)\n",
    "            pinball_loss = (lo_q + hi_q).mean()  \n",
    "            loss = loss + quantile_loss_frac * pinball_loss\n",
    "            total_pinball_loss += pinball_loss.item() * B\n",
    "        else:\n",
    "            pinball_loss = torch.tensor(0.0)\n",
    "\n",
    "        if quantile_loss_frac < 1:\n",
    "            perm = rng.permutation(B)\n",
    "            cal_inds  = torch.as_tensor(perm[: B//2], device=device)\n",
    "            task_inds = torch.as_tensor(perm[ B//2:], device=device)\n",
    "\n",
    "            pred_full = torch.cat([y_lo, y_hi], dim=-1)\n",
    "            scores_cal = multidim_quantile_score(pred_full[cal_inds], y[cal_inds], y_dim)\n",
    "            q = calc_q(scores_cal, alpha)\n",
    "            q = torch.as_tensor(q, dtype=y_lo.dtype, device=device)\n",
    "            if not torch.isfinite(q):\n",
    "                if show_pbar: tqdm.write('Batch is too small, skipping')\n",
    "                continue\n",
    "\n",
    "            \n",
    "            box_layer = make_box_layer(y_info)  \n",
    "            y_lo_adj_task, y_hi_adj_task, zv = model(x[task_inds], box_layer=box_layer, q_hat=q)\n",
    "            assert zv is not None, \"enable_implicit=True need to return (z_star, v_star)\"\n",
    "            z_star, v_star = zv  \n",
    "\n",
    "            if y_info == 'log':\n",
    "                y_min = torch.exp(y_lo_adj_task)\n",
    "                y_max = torch.exp(y_hi_adj_task)\n",
    "                is_standardized = False\n",
    "            else:\n",
    "                y_min, y_max = y_lo_adj_task, y_hi_adj_task\n",
    "                is_standardized = True\n",
    "\n",
    "            task_loss_vec = prob.task_loss_torch(\n",
    "                v=v_star, y_min=y_min, y_max=y_max, is_standardized=False, solution=(z_star,)\n",
    "            )  \n",
    "            task_loss = task_loss_vec.mean()\n",
    "\n",
    "            loss = loss + (1 - quantile_loss_frac) * task_loss\n",
    "            total_task_loss += task_loss.item() * task_inds.numel()\n",
    "            total_num_tasks += task_inds.numel()\n",
    "\n",
    "            if show_pbar:\n",
    "                tqdm.write(f'pinball_loss: {pinball_loss.item():.4f}, task_loss: {task_loss.item():.4f}')\n",
    "\n",
    "        total_loss += loss.item() * B\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    N = len(loader.dataset)\n",
    "    avg_loss = total_loss / N\n",
    "    avg_pinball_loss = total_pinball_loss / N if quantile_loss_frac > 0 else 0.0\n",
    "    avg_task_loss = (total_task_loss / max(total_num_tasks, 1)) if quantile_loss_frac < 1 else 0.0\n",
    "    return avg_loss, avg_pinball_loss, avg_task_loss\n",
    "\n",
    "\n",
    "def train_e2e(\n",
    "    y_dim: int,\n",
    "    alpha: float,\n",
    "    max_epochs: int,\n",
    "    lr: float,\n",
    "    l2reg: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    prob: BoxProblemProtocol,\n",
    "    rng: np.random.Generator,\n",
    "    quantile_loss_frac: float | Sequence[float],\n",
    "    saved_model_path: str = ''\n",
    ") -> tuple[BoxQuantileModel, dict[str, Any]]:\n",
    "   \n",
    "    model = BoxQuantileModel(input_dim=21, hidden_dims=[20,20], n=15,\n",
    "                             dropout_rate=0.3, enable_implicit=True).cpu()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2reg)\n",
    "    if saved_model_path:  \n",
    "        model.load_state_dict(torch.load(saved_model_path, weights_only=True, map_location=\"cpu\"))\n",
    "\n",
    "    result: dict[str, Any] = {\n",
    "        'train_e2e_losses': [],\n",
    "        'train_pinball_losses': [],\n",
    "        'train_task_losses': [],\n",
    "        'val_task_losses': [],\n",
    "        'best_epoch': 0,\n",
    "        'val_task_loss': np.inf,\n",
    "    }\n",
    "    steps_since_decrease = 0\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    pbar = tqdm(range(max_epochs))\n",
    "    for epoch in pbar:\n",
    "        weight = (quantile_loss_frac[min(epoch, len(quantile_loss_frac)-1)]\n",
    "                  if isinstance(quantile_loss_frac, Sequence) else quantile_loss_frac)\n",
    "\n",
    "        train_e2e_loss, train_pinball_loss, train_task_loss = train_e2e_epoch(\n",
    "            model=model, prob=prob, loader=train_loader, y_dim=y_dim, alpha=alpha,\n",
    "            y_info=y_info, quantile_loss_frac=weight, rng=rng, optimizer=optimizer, show_pbar=False\n",
    "        )\n",
    "        result['train_e2e_losses'].append(train_e2e_loss)\n",
    "        result['train_pinball_losses'].append(train_pinball_loss)\n",
    "        result['train_task_losses'].append(train_task_loss)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            q = conformal_q(model, cal_loader, y_dim=y_dim, alpha=alpha)\n",
    "            val_task_loss = optimize(model, prob=prob, loader=cal_loader,\n",
    "                                         y_dim=y_dim, q=q, y_info=y_info)\n",
    "            result['val_task_losses'].append(val_task_loss)\n",
    "\n",
    "        msg = (f'Epoch {epoch}, train_task_loss {train_task_loss:.3f}, '\n",
    "               f'val_task_loss {val_task_loss:.3f}, q {q:.3f}')\n",
    "        pbar.set_description(msg)\n",
    "\n",
    "        steps_since_decrease += 1\n",
    "        if val_task_loss < result['val_task_loss']:\n",
    "            result['best_epoch'] = epoch\n",
    "            result['val_task_loss'] = val_task_loss\n",
    "            steps_since_decrease = 0\n",
    "            buffer.seek(0)\n",
    "            torch.save(model.state_dict(), buffer)\n",
    "\n",
    "        if steps_since_decrease > 10:\n",
    "            break\n",
    "\n",
    "    buffer.seek(0)\n",
    "    model.load_state_dict(torch.load(buffer, weights_only=True, map_location=\"cpu\"))\n",
    "    return model.cpu(), result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6b2779-1707-450b-915c-02d854be30bf",
   "metadata": {
    "papermill": {
     "duration": 0.077565,
     "end_time": "2025-09-22T07:36:24.251695",
     "exception": false,
     "start_time": "2025-09-22T07:36:24.174130",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def optimize(\n",
    "    model: BoxQuantileModel,\n",
    "    prob: BoxProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    y_dim: int,\n",
    "    q: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "   \n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    cvxpylayer = prob.get_cvxpylayer()  \n",
    "\n",
    "    def box_layer(lo: torch.Tensor, hi: torch.Tensor):\n",
    "        if y_info == 'log':\n",
    "            return cvxpylayer(torch.exp(lo).cpu(), torch.exp(hi).cpu())  # (z, v)\n",
    "        else:\n",
    "            return cvxpylayer(lo.cpu(), hi.cpu())\n",
    "\n",
    "    q_hat = torch.tensor(q, device=device)\n",
    "\n",
    "    losses = []\n",
    "    iter_loader = tqdm(loader, desc=\"optimize\") if show_pbar else loader\n",
    "    for x_batch, _ in iter_loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "\n",
    "        y_lo_adj, y_hi_adj, zv = model(x_batch, box_layer=box_layer, q_hat=q_hat)\n",
    "        assert zv is not None, \"enable_implicit=True need to return (z*, v*)\"\n",
    "        z_star, v_star = zv  # [B,N]\n",
    "\n",
    "        if y_info == 'log':\n",
    "            y_min = torch.exp(y_lo_adj)\n",
    "            y_max = torch.exp(y_hi_adj)\n",
    "            is_standardized = False\n",
    "        else:\n",
    "            y_min, y_max = y_lo_adj, y_hi_adj\n",
    "            is_standardized = True\n",
    "\n",
    "        task_loss_vec = prob.task_loss_torch(\n",
    "            v=v_star, y_min=y_min, y_max=y_max,\n",
    "            is_standardized=False, solution=(z_star,)\n",
    "        )  \n",
    "        losses.append(task_loss_vec)\n",
    "\n",
    "    return torch.cat(losses, dim=0).mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d57ea1-e1ce-439a-a0ad-6fe321d49b9d",
   "metadata": {
    "papermill": {
     "duration": 577.317714,
     "end_time": "2025-09-22T07:46:01.785479",
     "exception": false,
     "start_time": "2025-09-22T07:36:24.467765",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, io, copy, math, itertools, torch, numpy as np\n",
    "\n",
    "alpha = 0.1\n",
    "max_epochs = 100\n",
    "nll_loss_frac = 0.0      \n",
    "lr_list    = [1e-3]\n",
    "l2_list    = [1e-2, 1e-3, 1e-4]\n",
    "\n",
    "best = {\n",
    "    \"val_task_loss\": float(\"inf\"),\n",
    "    \"lr\": None,\n",
    "    \"l2reg\": None,\n",
    "    \"epoch\": None,\n",
    "    \"state_dict\": None,\n",
    "}\n",
    "all_results = []\n",
    "\n",
    "for lr in lr_list:\n",
    "    for l2 in l2_list:\n",
    "\n",
    "        prob = PortfolioProblemBox(N=15, y_mean=y_mean, y_std=y_std)\n",
    "\n",
    "        model_path = os.path.join('outputs', 'model', 'baseline', f'baseline_model_{run_id}.pth')\n",
    "        saved_model_path = model_path if os.path.exists(model_path) else ''\n",
    "\n",
    "        rng = np.random.default_rng(42)\n",
    "\n",
    "        model_i, res = train_e2e(\n",
    "            y_dim=15,\n",
    "            alpha=alpha,\n",
    "            max_epochs=max_epochs,\n",
    "            lr=lr,\n",
    "            l2reg=l2,\n",
    "            y_info=y_info,                 \n",
    "            prob=prob,\n",
    "            rng=rng,\n",
    "            quantile_loss_frac=nll_loss_frac,  \n",
    "            saved_model_path=saved_model_path\n",
    "        )\n",
    "\n",
    "        record = {\n",
    "            \"lr\": lr,\n",
    "            \"l2reg\": l2,\n",
    "            \"best_epoch\": res[\"best_epoch\"],\n",
    "            \"val_task_loss\": res[\"val_task_loss\"],\n",
    "        }\n",
    "        all_results.append(record)\n",
    "\n",
    "        if res[\"val_task_loss\"] < best[\"val_task_loss\"]:\n",
    "            best.update(record)\n",
    "            best[\"state_dict\"] = copy.deepcopy(model_i.state_dict())\n",
    "\n",
    "print(\"\\n=== Grid Search Summary ===\")\n",
    "all_results = sorted(all_results, key=lambda r: r[\"val_task_loss\"])\n",
    "for r in all_results:\n",
    "    print(f\"lr={r['lr']:g}, l2={r['l2reg']:g}, \"\n",
    "          f\"best_epoch={r['best_epoch']}, val_task_loss={r['val_task_loss']:.6f}\")\n",
    "\n",
    "print(\"\\n>>> Best config:\")\n",
    "print(f\"lr={best['lr']:g}, l2={best['l2reg']:g}, \"\n",
    "      f\"best_epoch={best['best_epoch']}, val_task_loss={best['val_task_loss']:.6f}\")\n",
    "\n",
    "best_model = BoxQuantileModel(input_dim=21, hidden_dims=[20, 20], n=15, dropout_rate = 0.5, enable_implicit=True).cpu()\n",
    "best_model.load_state_dict(best[\"state_dict\"])\n",
    "\n",
    "os.makedirs('outputs/model/E2E', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'E2E', f'E2E_model.pth')\n",
    "torch.save(best_model.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1cc17dc-9129-408d-b422-f74fa4c53d4c",
   "metadata": {
    "papermill": {
     "duration": 0.14156,
     "end_time": "2025-09-22T07:46:01.951945",
     "exception": false,
     "start_time": "2025-09-22T07:46:01.810385",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from typing import Tuple, List\n",
    "\n",
    "@torch.no_grad()\n",
    "def optimize_test_box(\n",
    "    model: BoxQuantileModel,\n",
    "    prob: BoxProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    y_dim: int,\n",
    "    q: float,\n",
    "    y_info: str | tuple[np.ndarray, np.ndarray],\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> Tuple[float, float, float, float, List[np.ndarray]]:\n",
    "   \n",
    "    count_3 = 0\n",
    "    total = 0\n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "    cvxpylayer = prob.get_cvxpylayer()\n",
    "\n",
    "    task_losses, decision_losses = [], []\n",
    "    covered_flags, robust_flags = [], []\n",
    "    z_list: List[np.ndarray] = []\n",
    "\n",
    "    it = loader if not show_pbar else __import__('tqdm').tqdm(loader, desc=\"optimize_test_box\")\n",
    "\n",
    "    for x, y in it:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "\n",
    "        out = model(x)\n",
    "        if isinstance(out, (tuple, list)):\n",
    "            y_lo, y_hi = out[:2]        # [B,N]\n",
    "        else:\n",
    "            y_lo, y_hi = out[..., :y_dim], out[..., y_dim:]\n",
    "\n",
    "        q_t = torch.as_tensor(q, dtype=y_lo.dtype, device=y_lo.device)\n",
    "        y_min = y_lo - q_t\n",
    "        y_max = y_hi + q_t\n",
    "\n",
    "        if y_info == 'log':\n",
    "            y_min_solver = torch.exp(y_min); y_max_solver = torch.exp(y_max)\n",
    "            y_for_decision = y\n",
    "            is_standardized = False\n",
    "        else:\n",
    "            y_min_solver, y_max_solver = y_min, y_max\n",
    "            y_for_decision = y\n",
    "            is_standardized = False\n",
    "\n",
    "        z_star, v_star = cvxpylayer(y_min_solver.cpu(), y_max_solver.cpu())\n",
    "        if z_star.dim() == 3 and z_star.size(-1) == 1: z_star = z_star.squeeze(-1)\n",
    "        if v_star.dim() == 3 and v_star.size(-1) == 1: v_star = v_star.squeeze(-1)\n",
    "        z_star = z_star.to(y_min_solver.device, dtype=y_min_solver.dtype)\n",
    "        v_star = v_star.to(y_min_solver.device, dtype=y_min_solver.dtype)\n",
    "\n",
    "        task_loss_vec = prob.task_loss_torch(\n",
    "            v=v_star, y_min=y_min_solver, y_max=y_max_solver,\n",
    "            is_standardized=False, solution=(z_star,)\n",
    "        ) \n",
    "        task_losses.append(task_loss_vec.cpu())\n",
    "\n",
    "        decision_vec = (- y * z_star).sum(dim=-1)  \n",
    "        decision_losses.append(decision_vec.cpu())\n",
    "\n",
    "        s_left  = (y_lo - y).amax(dim=1)   \n",
    "        s_right = (y - y_hi).amax(dim=1)   \n",
    "        scores  = torch.maximum(s_left, s_right)   \n",
    "        count_3   += (scores <= q).sum().item()          \n",
    "        total +=  y.size(0) \n",
    "        \n",
    "        robust = (decision_vec <= task_loss_vec).float()  # [B]\n",
    "        robust_flags.append(robust.cpu())\n",
    "\n",
    "        \n",
    "        z_list.extend([z.cpu().numpy().ravel() for z in z_star])\n",
    "\n",
    "    task_all = torch.cat(task_losses) if task_losses else torch.tensor([0.0])\n",
    "    decision_all = torch.cat(decision_losses) if decision_losses else torch.tensor([0.0])\n",
    "    covered_all = count_3 / total\n",
    "    robust_all = torch.cat(robust_flags) if robust_flags else torch.tensor([0.0])\n",
    "\n",
    "    avg_task = task_all.mean().item()\n",
    "    avg_decision = decision_all.mean().item()\n",
    "    coverage =  covered_all * 100.0\n",
    "    robustness = robust_all.mean().item() * 100.0\n",
    "\n",
    "    return avg_task, avg_decision, coverage, robustness, z_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac47837-086d-49f7-b370-cc4fb586c70b",
   "metadata": {
    "papermill": {
     "duration": 0.588232,
     "end_time": "2025-09-22T07:46:02.566435",
     "exception": false,
     "start_time": "2025-09-22T07:46:01.978203",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_path = os.path.join('outputs', 'model', 'E2E', f'E2E_model.pth')\n",
    "\n",
    "model = BoxQuantileModel(input_dim=21, hidden_dims=[20, 20], n=15, enable_implicit=False).cpu()\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    q = conformal_q(model, loaders['calib'], y_dim=15, alpha=alpha, device='cpu').item()\n",
    "\n",
    "avg_risk, avg_loss, coverage, robustness_1, z_value = optimize_test_box(\n",
    "    model=model,\n",
    "    prob=prob,            \n",
    "    loader=test_loader,\n",
    "    y_dim=15,\n",
    "    q=q,\n",
    "    y_info=y_info,          \n",
    "    device='cpu',\n",
    "    show_pbar=False\n",
    ")\n",
    "\n",
    "print(f\"Coverage: {coverage:.2f}%\")\n",
    "print(f\"Avg risk: {avg_risk:.4f}\")\n",
    "print(f\"Avg decision: {avg_loss:.4f}\")\n",
    "print(f\"Robustness: {robustness_1:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dae6393-43b6-4df2-81fc-f185f1b65ae4",
   "metadata": {
    "papermill": {
     "duration": 0.028747,
     "end_time": "2025-09-22T07:46:02.620652",
     "exception": false,
     "start_time": "2025-09-22T07:46:02.591905",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_E2E'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_E2E.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {coverage:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {avg_risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {avg_loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_1:.2f}%\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6452e5-6bb3-44f5-9a0e-a1b23ff8a1b6",
   "metadata": {
    "papermill": {
     "duration": 0.022777,
     "end_time": "2025-09-22T07:46:03.229988",
     "exception": false,
     "start_time": "2025-09-22T07:46:03.207211",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 954.563047,
   "end_time": "2025-09-22T07:46:04.572971",
   "environment_variables": {},
   "exception": null,
   "input_path": "CRC_stock_box.ipynb",
   "output_path": "outputs/file/CRC_stock_box_0.ipynb",
   "parameters": {
    "run_id": 0
   },
   "start_time": "2025-09-22T07:30:10.009924",
   "version": "2.6.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "11884bc8004448a48d184b8b5e3bc8a7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "15ed5a68bb56416db27a5ccde7c3e80d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_9d34d50806ce4238925823a545bd1c14",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_aa5d02745bea4655b7c96388441c91cd",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "1dec5b4011ca470ab9b41a3c2ee23f00": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "27e9ed28863f46dab32a964308c94c72": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "35dbf8631341472d8ad5688dbc10f6dd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_f1d938bd80cd4a169c8bbc581814068b",
       "placeholder": "​",
       "style": "IPY_MODEL_11884bc8004448a48d184b8b5e3bc8a7",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 99, train_task_loss 2.061, val_task_loss 2.349, q 3.066: 100%"
      }
     },
     "39cce1a8d4084e6bb5ed422276b1b621": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "39ed563ef3464cc18a76aff4e0745515": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_57d0a1b242fa4ee6a49143f824f88eec",
       "placeholder": "​",
       "style": "IPY_MODEL_7f2ea9fd507a437b870fffdc82249e32",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [03:08&lt;00:00,  1.84s/it]"
      }
     },
     "482a416a0bc34982b50365cb896eb718": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_65deec09646849248b2ddb38a312e76b",
       "placeholder": "​",
       "style": "IPY_MODEL_a34a3686dd9e4f21bad5948ff63c2117",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 99, train_task_loss 1.794, val_task_loss 2.257, q 3.173: 100%"
      }
     },
     "4fb30c3d2bea4b11a9a0670aebe6082d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_1dec5b4011ca470ab9b41a3c2ee23f00",
       "placeholder": "​",
       "style": "IPY_MODEL_87c97ce2e02141cc93f00244a2814239",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 99, train_task_loss 1.982, val_task_loss 2.199, q 3.129: 100%"
      }
     },
     "569eb8c4e8124170a0fb065dfbc4430d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_90c5131f55d14f35aaf2866b42e698b3",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_9c7e40bcb78c443ebd2a467be281f19e",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "574e07c7991b4eee994e5584ce148345": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_d901899b08df49f693bfd850359e201e",
       "placeholder": "​",
       "style": "IPY_MODEL_39cce1a8d4084e6bb5ed422276b1b621",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [03:12&lt;00:00,  1.68s/it]"
      }
     },
     "57d0a1b242fa4ee6a49143f824f88eec": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "65deec09646849248b2ddb38a312e76b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "70169d8d16c04fd58bf9e3bc5d7cdca4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_d090e0a7ebb74642aa9b15572fb58bc5",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_f03305a6cf0a4313a915a535ee83bfc9",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "72b79678ac174640bc0bdc7f55f1f99c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "7f2ea9fd507a437b870fffdc82249e32": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "8591acacd11f41b492f738c53cbb955e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_4fb30c3d2bea4b11a9a0670aebe6082d",
        "IPY_MODEL_569eb8c4e8124170a0fb065dfbc4430d",
        "IPY_MODEL_f8921d5e602148a28cdb942cc2d04090"
       ],
       "layout": "IPY_MODEL_72b79678ac174640bc0bdc7f55f1f99c",
       "tabbable": null,
       "tooltip": null
      }
     },
     "87c97ce2e02141cc93f00244a2814239": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "90c5131f55d14f35aaf2866b42e698b3": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "9c7e40bcb78c443ebd2a467be281f19e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "9d34d50806ce4238925823a545bd1c14": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "a34a3686dd9e4f21bad5948ff63c2117": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "aa5d02745bea4655b7c96388441c91cd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "af744cb55cba46838a0f34ca182844f3": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "b0ac5044cc4a48399f5c966d64623649": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_35dbf8631341472d8ad5688dbc10f6dd",
        "IPY_MODEL_70169d8d16c04fd58bf9e3bc5d7cdca4",
        "IPY_MODEL_574e07c7991b4eee994e5584ce148345"
       ],
       "layout": "IPY_MODEL_d9f23b0be471432ab18b616f1bb13fd9",
       "tabbable": null,
       "tooltip": null
      }
     },
     "d090e0a7ebb74642aa9b15572fb58bc5": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "d901899b08df49f693bfd850359e201e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "d9f23b0be471432ab18b616f1bb13fd9": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "f03305a6cf0a4313a915a535ee83bfc9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "f1d938bd80cd4a169c8bbc581814068b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "f34efffb72164c049a3f5f39c65bcbe6": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_482a416a0bc34982b50365cb896eb718",
        "IPY_MODEL_15ed5a68bb56416db27a5ccde7c3e80d",
        "IPY_MODEL_39ed563ef3464cc18a76aff4e0745515"
       ],
       "layout": "IPY_MODEL_f7385b25d56b46c698c78c252427cf35",
       "tabbable": null,
       "tooltip": null
      }
     },
     "f7385b25d56b46c698c78c252427cf35": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "f8921d5e602148a28cdb942cc2d04090": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_af744cb55cba46838a0f34ca182844f3",
       "placeholder": "​",
       "style": "IPY_MODEL_27e9ed28863f46dab32a964308c94c72",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [03:16&lt;00:00,  1.82s/it]"
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
