{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4bcfc9fd-c1e1-4d83-9fd2-0949976723b7",
   "metadata": {
    "papermill": {
     "duration": 0.00372,
     "end_time": "2025-08-20T04:32:23.940684",
     "exception": false,
     "start_time": "2025-08-20T04:32:23.936964",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Conformal Robustness Control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1c560fb-9697-4c12-9186-9da41b02b5c4",
   "metadata": {},
   "source": [
    "### CRC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027da74f-f82a-4655-bfbf-94fcbaa6fb97",
   "metadata": {
    "papermill": {
     "duration": 2.240649,
     "end_time": "2025-08-20T04:32:26.184976",
     "exception": false,
     "start_time": "2025-08-20T04:32:23.944327",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import io\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pytz\n",
    "import math\n",
    "import cvxpy as cp\n",
    "import rsome as rso\n",
    "\n",
    "from rsome import ro\n",
    "from rsome import msk_solver as SOLVER\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.distributions import MultivariateNormal\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "from typing import Protocol\n",
    "from typing import NamedTuple\n",
    "from torch import distributions as tdist, nn, Tensor\n",
    "from collections.abc import Mapping\n",
    "from torch import Tensor\n",
    "from collections.abc import Sequence\n",
    "from torch.utils.data import TensorDataset\n",
    "from collections.abc import Mapping, Sequence\n",
    "from pandas.tseries.holiday import USFederalHolidayCalendar\n",
    "from datetime import datetime, timedelta\n",
    "from tqdm.auto import tqdm\n",
    "from typing import Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f8ca6a-17af-49fc-b3ea-7d20e26605dc",
   "metadata": {
    "papermill": {
     "duration": 0.012479,
     "end_time": "2025-08-20T04:32:26.228912",
     "exception": false,
     "start_time": "2025-08-20T04:32:26.216433",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Device:\", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bdede33-96e2-4338-afd3-354fa3a0cacc",
   "metadata": {
    "papermill": {
     "duration": 0.009271,
     "end_time": "2025-08-20T04:32:26.242136",
     "exception": false,
     "start_time": "2025-08-20T04:32:26.232865",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class SyntheticRegressionDataset(Dataset):\n",
    "    def __init__(self, n_samples, mu=None, cov_scale=2.25, noise_std=1.0):\n",
    "        if mu is None:\n",
    "            mu = np.array([1.0, 1.0])\n",
    "        cov = cov_scale * np.eye(2)\n",
    "        features = np.random.multivariate_normal(mu, cov, size=n_samples)\n",
    "        eps      = np.random.normal(0, noise_std, size=(n_samples, 2))\n",
    "\n",
    "        X1 = features[:, 0]\n",
    "        X2 = features[:, 1]\n",
    "        Y1 = 5 * X1 + 2 * X2**2 - eps[:, 0]\n",
    "        Y2 = 3 * X1**2 +  X2    - eps[:, 1]\n",
    "\n",
    "        self.features = torch.tensor(features, dtype=torch.float32)\n",
    "        self.labels   = torch.tensor(np.stack([Y1, Y2], axis=1),\n",
    "                                     dtype=torch.float32)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.features.size(0)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.features[idx], self.labels[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f4f3b8-2996-4058-80b6-da0727156fc8",
   "metadata": {
    "papermill": {
     "duration": 0.008875,
     "end_time": "2025-08-20T04:32:26.254886",
     "exception": false,
     "start_time": "2025-08-20T04:32:26.246011",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_simplex_layer(n: int) -> CvxpyLayer:\n",
    "    z = cp.Variable(n)\n",
    "    Pp = cp.Parameter((n, n), PSD=True)\n",
    "    q  = cp.Parameter(n)\n",
    "    constraints = [cp.sum(z) == 1, z >= 0, z <= 1]\n",
    "    obj = cp.Minimize(cp.norm(Pp @ z, 2) + q.T @ z)\n",
    "    problem = cp.Problem(obj, constraints)\n",
    "    assert problem.is_dpp('dcp')\n",
    "    return CvxpyLayer(problem, parameters=[Pp, q], variables=[z])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2405853-3b8c-42a2-a342-c70bd02d0663",
   "metadata": {
    "papermill": {
     "duration": 0.024276,
     "end_time": "2025-08-20T04:32:26.283034",
     "exception": false,
     "start_time": "2025-08-20T04:32:26.258758",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EllipsoidalUncertaintyModel(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, n=2, enable_implicit=False):\n",
    "        super().__init__()\n",
    "        self.n = n\n",
    "        self.enable_implicit = enable_implicit\n",
    "\n",
    "        if isinstance(hidden_dims, int):\n",
    "            hidden_dims = [hidden_dims] * 3\n",
    "        layers, in_dim = [], input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.ReLU()]\n",
    "            in_dim = h\n",
    "        self.backbone = nn.Sequential(*layers)\n",
    "\n",
    "        self.fc_mu = nn.Linear(in_dim, n)\n",
    "        self.fc_L  = nn.Linear(in_dim, n*(n+1)//2)\n",
    "\n",
    "    def forward(self, x, cvxpylayer=None):\n",
    "        B = x.size(0)\n",
    "        h = self.backbone(x)\n",
    "        mu = self.fc_mu(h)\n",
    "\n",
    "        l_flat = self.fc_L(h)\n",
    "        L = x.new_zeros((B, self.n, self.n))\n",
    "        idx = 0\n",
    "        for i in range(self.n):\n",
    "            for j in range(i+1):\n",
    "                v = l_flat[:, idx]\n",
    "                L[:, i, j] = F.softplus(v) + (1e-6 if i==j else 0)\n",
    "                idx += 1\n",
    "        Sigma = L @ L.transpose(1, 2)\n",
    "\n",
    "        if not self.enable_implicit:\n",
    "            return mu, L, Sigma\n",
    "\n",
    "        if cvxpylayer is None:\n",
    "            raise ValueError(\"enable_implicit=False\")\n",
    "        Pp = torch.linalg.cholesky(Sigma + 1e-6 * torch.eye(self.n, device=x.device))\n",
    "        z_opt, = cvxpylayer(Pp.cpu(), mu.cpu())\n",
    "        return mu, Sigma, z_opt.to(x.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab325b8b-dfb7-4f09-8a7d-a0cdf024561f",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train = 1500\n",
    "n_cal = 1500\n",
    "n_test = 1500\n",
    "mu = [1.0, 1.0]\n",
    "cov_scale = 2.25\n",
    "noise_std = 1.0\n",
    "\n",
    "epoch = 250\n",
    "\n",
    "train_ds = SyntheticRegressionDataset(n_train, mu, cov_scale, noise_std)\n",
    "cal_ds   = SyntheticRegressionDataset(n_cal,   mu, cov_scale, noise_std)\n",
    "test_ds  = SyntheticRegressionDataset(n_test,  mu, cov_scale, noise_std)\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
    "cal_loader = DataLoader(cal_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)\n",
    "test_loader= DataLoader(test_ds,batch_size=32, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "model = EllipsoidalUncertaintyModel(input_dim=2, hidden_dims=[20,10], n=2, enable_implicit=False).to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)\n",
    "\n",
    "for epoch in range(1, epoch + 1):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    for x_cpu, y_cpu in train_loader:\n",
    "        x = x_cpu.to(device, non_blocking=True)\n",
    "        y = y_cpu.to(device, non_blocking=True)\n",
    "\n",
    "        mu, _, Sigma = model(x)\n",
    "        Sigma = Sigma + 1e-6 * torch.eye(2, device=device)\n",
    "        dist = MultivariateNormal(mu, Sigma)\n",
    "        loss = -dist.log_prob(y).mean()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item() * x.size(0)\n",
    "\n",
    "    avg = total_loss / len(train_loader.dataset)\n",
    "    print(f\"Epoch {epoch:>2d} | NLL Loss: {avg:.4f}\")\n",
    "\n",
    "os.makedirs('outputs/model/baseline', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "torch.save(model.state_dict(), save_path)\n",
    "print(f\"Saved model to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e65366-7e58-42bf-b769-0ccd3c5258bf",
   "metadata": {
    "papermill": {
     "duration": 0.028602,
     "end_time": "2025-08-20T04:33:46.744068",
     "exception": false,
     "start_time": "2025-08-20T04:33:46.715466",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def update_theta(model, cvxpylayer, opt_h, x, y, lambda_param, alpha, sigma, eps = 1e-6):\n",
    "    mu, Sigma, z = model(x, cvxpylayer)\n",
    "    Sigma_j = Sigma + eps*torch.eye(model.n, device=device)\n",
    "\n",
    "    L = torch.linalg.cholesky(Sigma_j)\n",
    "    term1 = torch.norm((L @ z.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu*z).sum(dim=1)\n",
    "    f = (term1 + term2).mean()\n",
    "\n",
    "    term = term1 + term2\n",
    "    approx_ind = 0.5*(1 + torch.erf((term - (y*z).sum(dim=1)) / (sigma*(2**0.5))))\n",
    "    g = lambda_param * ((1-alpha) - approx_ind.mean())\n",
    "\n",
    "    loss = f + g\n",
    "    opt_h.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_h.step()\n",
    "    return loss.item()\n",
    "\n",
    "def update_lambda(model, cvxpylayer, opt_l, x, y, lambda_param, alpha, sigma, eps = 1e-6):\n",
    "    mu, Sigma, _ = model(x, cvxpylayer)\n",
    "    mu_det     = mu.detach()\n",
    "    Sigma_det  = (Sigma + eps*torch.eye(model.n, device=device)).detach()\n",
    "\n",
    "    L_j = torch.linalg.cholesky(Sigma_det)\n",
    "    z2, = cvxpylayer(L_j.cpu(), mu_det.cpu())\n",
    "    z2 = z2.to(device)\n",
    "\n",
    "    term1 = torch.norm((L_j.to(device) @ z2.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu_det * z2).sum(dim=1)\n",
    "    f2 = (term1 + term2).mean()\n",
    "\n",
    "    term = term1 + term2\n",
    "    indicator = ((term - (y*z2).sum(dim=1)) >= 0).float()\n",
    "    g2 = lambda_param * ((1-alpha) - indicator.mean())\n",
    "\n",
    "    loss = -(f2 + g2)\n",
    "    opt_l.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_l.step()\n",
    "    with torch.no_grad():\n",
    "        lambda_param.clamp_(min=0.0)\n",
    "    return loss.item()\n",
    "\n",
    "def total_loss(model, cvxpylayer, x, y, lambda_param, alpha, sigma=0.1, eps = 1e-6):\n",
    "    mu, Sigma, z = model(x, cvxpylayer)\n",
    "    Sigma_j = Sigma + eps*torch.eye(model.n, device=device)\n",
    "    L = torch.linalg.cholesky(Sigma_j)\n",
    "    term1 = torch.norm((L @ z.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu*z).sum(dim=1)\n",
    "    f = (term1 + term2).mean()\n",
    "\n",
    "    indicator = ((term1+term2) - (y*z).sum(dim=1) >= 0).float()\n",
    "    g = lambda_param * ((1-alpha) - indicator.mean())\n",
    "    return f + g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "790ea979-204e-4e02-ad2b-1ae63ede6e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "total = []\n",
    "prev_loss = None\n",
    "consecutive_count = 0\n",
    "cvxpylayer = create_simplex_layer(n=2)\n",
    "\n",
    "model = EllipsoidalUncertaintyModel(\n",
    "    input_dim=2, hidden_dims=[20,10], n=2, enable_implicit=True\n",
    ").to(device)\n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "\n",
    "opt_h = torch.optim.Adam([\n",
    "    {'params': model.fc_mu.parameters(),    'lr':1e-3},\n",
    "    {'params': model.fc_L.parameters(),     'lr':1e-4},\n",
    "    {'params': model.backbone.parameters(), 'lr':1e-4},\n",
    "])\n",
    "\n",
    "lambda_param = nn.Parameter(torch.tensor(7.0, device=device), requires_grad=True)\n",
    "opt_l = torch.optim.Adam([lambda_param], lr=1e-3)\n",
    "\n",
    "X_cal = torch.stack([x for x,_ in cal_ds], dim=0).to(device)\n",
    "Y_cal = torch.stack([y for _,y in cal_ds], dim=0).to(device)\n",
    "\n",
    "N = X_cal.shape[0]\n",
    "n = int(N/2)\n",
    "X_cal_1 = X_cal[:n]\n",
    "Y_cal_1 = Y_cal[:n]\n",
    "X_cal_2 = X_cal[n:]\n",
    "Y_cal_2 = Y_cal[n:]\n",
    "\n",
    "alpha = 0.1\n",
    "sigma = 0.1\n",
    "epochs = 500\n",
    "\n",
    "for epoch in range(1, epochs+1):\n",
    "    l_t = update_theta(model, cvxpylayer, opt_h, X_cal_1, Y_cal_1, lambda_param, alpha, sigma)\n",
    "    l_l = update_lambda(model, cvxpylayer, opt_l, X_cal_2, Y_cal_2, lambda_param, alpha, sigma)\n",
    "    tot = total_loss(model, cvxpylayer, X_cal, Y_cal, lambda_param, alpha, sigma)\n",
    "    total.append(tot.item())\n",
    "    print(f\"Epoch {epoch:02d} | loss_θ={l_t:.6f} | loss_λ={l_l:.6f} | total={tot:.6f} | λ={lambda_param.item():.4f}\")\n",
    "    \n",
    "    if epoch > 350:\n",
    "        if prev_loss is not None and abs(prev_loss - tot.item()) < 3e-4:\n",
    "            consecutive_count += 1\n",
    "        else:\n",
    "            consecutive_count = 0  \n",
    "        if consecutive_count >= 10:\n",
    "            print(f\"Early stopping at epoch {epoch} due to small loss change for 10 consecutive epochs.\")\n",
    "            break\n",
    "    prev_loss = tot.item()\n",
    "\n",
    "os.makedirs('outputs/model/final', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'final', 'final_model.pth')\n",
    "torch.save(model.state_dict(), save_path)\n",
    "print(f\"Saved final model to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656c0046-9a75-46ad-bb36-473203897413",
   "metadata": {
    "papermill": {
     "duration": 1.334016,
     "end_time": "2025-08-20T05:50:37.574643",
     "exception": false,
     "start_time": "2025-08-20T05:50:36.240627",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "model.enable_implicit = True\n",
    "\n",
    "X_test  = torch.stack([x for x, y in test_ds],  dim=0)  \n",
    "Y_test  = torch.stack([y for x, y in test_ds],  dim=0)  \n",
    "\n",
    "X_test_cpu = X_test.to(device, non_blocking=True)\n",
    "Y_test_cpu = Y_test.to(device, non_blocking=True)\n",
    "\n",
    "with torch.no_grad():\n",
    "    mu_test, Sigma_test, z_test = model(X_test_cpu, cvxpylayer)\n",
    "\n",
    "eps = 1e-6\n",
    "I = torch.eye(model.n, device=device)\n",
    "\n",
    "Sigma_j   = Sigma_test + eps * I.unsqueeze(0)   \n",
    "Sigma_inv = torch.linalg.inv(Sigma_j)     \n",
    "Sigma_half = torch.linalg.cholesky(Sigma_test + eps * I.unsqueeze(0)) \n",
    "\n",
    "delta = Y_test_cpu - mu_test                     \n",
    "d2    = torch.einsum('bi,bij,bj->b', delta, Sigma_inv, delta)\n",
    "in_set = d2 <= 1.0\n",
    "pct_in = in_set.float().mean().item() * 100\n",
    "\n",
    "z_col       = z_test.unsqueeze(-1)                      \n",
    "Sigma_z     = torch.matmul(Sigma_half, z_col)            \n",
    "norm_Sigma_z = torch.norm(Sigma_z, p=2, dim=(1,2))       \n",
    "\n",
    "mu_z = (mu_test * z_test).sum(dim=1)                     \n",
    "\n",
    "risk_loss    = norm_Sigma_z + mu_z                       \n",
    "Average_Risk = risk_loss.mean().item()\n",
    "\n",
    "decision_loss = (Y_test_cpu * z_test).sum(dim=1)         \n",
    "Average_Loss  = decision_loss.mean().item()\n",
    "robustness = (decision_loss <= risk_loss).float().mean().item() * 100\n",
    "\n",
    "print(f\"Marginal Coverage: {pct_in:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk:.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss:.4f}\")\n",
    "print(f\"Robustness: {robustness:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8a91ac-516d-4e03-8645-391855da1fd7",
   "metadata": {
    "papermill": {
     "duration": 0.036961,
     "end_time": "2025-08-20T05:50:37.684455",
     "exception": false,
     "start_time": "2025-08-20T05:50:37.647494",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_CRC'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, 'metrics_CRC.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {pct_in:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9483c74a-9245-4170-9957-f5d73e63780a",
   "metadata": {},
   "source": [
    "### Cal-CRC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1123f25c-1e20-4935-9c66-44425edd528f",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cal = torch.stack([x for x,_ in cal_ds], dim=0).to(device)\n",
    "Y_cal = torch.stack([y for _,y in cal_ds], dim=0).to(device)\n",
    "X_test = torch.stack([x for x,_ in test_ds], dim=0).to(device)\n",
    "Y_test = torch.stack([y for _,y in test_ds], dim=0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3271e2fe-0363-4f7e-b894-286e07c1acb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(X): \n",
    "    model = EllipsoidalUncertaintyModel(\n",
    "        input_dim=2,      \n",
    "        hidden_dims=[20, 10],   \n",
    "        n=2,\n",
    "        enable_implicit=False    \n",
    "    ).to(device)\n",
    "    \n",
    "    model_path = os.path.join('outputs', 'model', 'final', 'final_model.pth')\n",
    "    model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "    model.eval()  \n",
    "    with torch.no_grad():\n",
    "        mu, _, Sigma= model(X)\n",
    "    return mu, Sigma\n",
    "\n",
    "def score(X, Y, eps = 1e-6):\n",
    "    mu, Sigma = predict(X)\n",
    "    Sigma_inv = torch.linalg.inv(Sigma + eps * torch.eye(Sigma.size(-1)))\n",
    "    diff = (Y - mu)\n",
    "    score = torch.einsum('bi,bij,bj->b', diff, Sigma_inv, diff)\n",
    "    return score\n",
    "\n",
    "def interval(X,Q):\n",
    "    mu, Sigma = predict(X)\n",
    "    return mu, Sigma, Q\n",
    "\n",
    "def to_numpy(x):\n",
    "    import numpy as np\n",
    "    try:\n",
    "        import torch\n",
    "        if isinstance(x, torch.Tensor):\n",
    "            return x.detach().cpu().numpy()\n",
    "    except Exception:\n",
    "        pass\n",
    "    return np.asarray(x)\n",
    "\n",
    "def ensure_np_1sample(mean, covariance):\n",
    "  \n",
    "    mean = to_numpy(mean)\n",
    "    covariance = to_numpy(covariance)\n",
    "\n",
    "    if mean.ndim == 2 and mean.shape[0] == 1:   \n",
    "        mean = mean[0]\n",
    "    if covariance.ndim == 3 and covariance.shape[0] == 1:  \n",
    "        covariance = covariance[0]\n",
    "\n",
    "    mean = np.ascontiguousarray(mean)\n",
    "    covariance = np.ascontiguousarray(covariance)\n",
    "    return mean, covariance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "576474af-7c0e-4607-8461-9d4cc59a3426",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from scipy.linalg import sqrtm\n",
    "\n",
    "def BoxProblem(w, y_down, y_up):\n",
    "    \n",
    "    loss = np.sum(y_up.ravel() * w)\n",
    "    \n",
    "    return loss\n",
    "\n",
    "def projection(w):\n",
    "    \n",
    "    h = 0.5 + 0.5 * (w[0] - w[1])\n",
    "    \n",
    "    if h > 1:\n",
    "        w_n = np.array([1,0])\n",
    "    elif h < 0:\n",
    "        w_n = np.array([0,1])\n",
    "    else:\n",
    "        w_n = np.array([h,1-h])\n",
    "    \n",
    "    return w_n\n",
    "\n",
    "def optimize(y_matrix, y_index, r, w_0, T, eta):\n",
    "    \n",
    "    N = y_index.shape[0]\n",
    "    \n",
    "    for t in range(T):\n",
    "        \n",
    "        Loss_matrix =  - np.ones((N,N)) * float('inf')\n",
    "        \n",
    "        for i in range(N):\n",
    "            for j in range(N):\n",
    "                if y_index[i,j] == 1:\n",
    "                    y_down = y_matrix[i][j] - r\n",
    "                    y_up = y_matrix[i][j] + r\n",
    "                    Loss_matrix[i,j] = BoxProblem(w_0, y_down, y_up)\n",
    "        \n",
    "        i, j = np.unravel_index(np.argmax(Loss_matrix, axis=None), Loss_matrix.shape)\n",
    "        \n",
    "        w_0 = w_0 - eta * (y_matrix[i][j] + r).ravel()\n",
    "        \n",
    "        w_0 = projection( w_0 )\n",
    "        \n",
    "        eta = eta * 0.9\n",
    "    \n",
    "    return w_0\n",
    "\n",
    "\n",
    "def rrisk(y_matrix, y_index, r, w_0):\n",
    "    N = y_index.shape[0]\n",
    "    \n",
    "    Loss_matrix =  - np.ones((N,N)) * float('inf')\n",
    "        \n",
    "    for i in range(N):\n",
    "        for j in range(N):\n",
    "            if y_index[i,j] == 1:\n",
    "                y_down = y_matrix[i][j] - r\n",
    "                y_up = y_matrix[i][j] + r\n",
    "                Loss_matrix[i,j] = BoxProblem(w_0, y_down, y_up)\n",
    "        \n",
    "    Risk = np.max(Loss_matrix)\n",
    "    \n",
    "    return Risk\n",
    "\n",
    "\n",
    "def evaluate(X_test, Y_test, Interval_test, Decision_test, Risk_test_dic):\n",
    "\n",
    "    Y_test = to_numpy(Y_test)\n",
    "\n",
    "    n_test = Y_test.shape[0]\n",
    "    Loss_test = {}\n",
    "    Risk_test = {}\n",
    "    Coverage_test = {}\n",
    "    Robustness_test = {}\n",
    "\n",
    "    for te in range(n_test):\n",
    "       \n",
    "        Y_index = to_numpy(Interval_test[te])\n",
    "\n",
    "       \n",
    "        Coverage_test[te] = 1 if (Y_index.sum() > 0) else 0\n",
    "\n",
    "        dte = to_numpy(Decision_test[te]).reshape(-1) \n",
    "        yte = Y_test[te, :].reshape(-1)                \n",
    "\n",
    "        Loss_test[te] = float(np.sum(yte * dte))\n",
    "\n",
    "        Risk_test[te] = float(Risk_test_dic[te])\n",
    "\n",
    "        Robustness_test[te] = 1 if (Loss_test[te] <= Risk_test[te]) else 0\n",
    "\n",
    "    return Loss_test, Risk_test, Coverage_test, Robustness_test\n",
    "\n",
    "\n",
    "def linkfun(Y_pre):\n",
    " \n",
    "    try:\n",
    "        import torch\n",
    "        if isinstance(Y_pre, torch.Tensor):\n",
    "            Y_pre_np = Y_pre.detach().cpu().numpy()\n",
    "        else:\n",
    "            Y_pre_np = np.asarray(Y_pre)\n",
    "    except ImportError:\n",
    "        Y_pre_np = np.asarray(Y_pre)\n",
    "\n",
    "    N = 20\n",
    "    x = np.linspace(-10.00, 45.00, N)\n",
    "    y = np.linspace(-2.00, 46.00, N)\n",
    "\n",
    "    X, Y = np.meshgrid(x, y, indexing='xy')\n",
    "    Y_grid = np.stack([X, Y], axis=-1)  \n",
    "\n",
    "    n_pre = len(Y_pre_np)\n",
    "    Y_pre_new = np.zeros((n_pre, 2), dtype=float)\n",
    "\n",
    "    for pre in range(n_pre):\n",
    "        Yp = np.asarray(Y_pre_np[pre])  \n",
    "        j = int(np.argmin(np.abs(Yp[0] - x))) \n",
    "        i = int(np.argmin(np.abs(Yp[1] - y)))\n",
    "      \n",
    "        i = max(0, min(N-1, i))\n",
    "        j = max(0, min(N-1, j))\n",
    "        Y_pre_new[pre, :] = Y_grid[i, j, :]   \n",
    "\n",
    "    return Y_pre_new\n",
    "    \n",
    "    \n",
    "def crccal(X_cal, Y_cal, X_test, alpha):\n",
    "    \n",
    "    n_cal = X_cal.shape[0]\n",
    "    n_test = X_test.shape[0]\n",
    "    \n",
    "    q = Y_cal.shape[1]\n",
    "    \n",
    "    score_quantile = 1\n",
    "    \n",
    "    delta = 0.05\n",
    "    Delta = 250\n",
    "    \n",
    "    \n",
    "    Y_cal = linkfun(Y_cal)\n",
    "    n_test = X_test.shape[0]\n",
    "    \n",
    "    N = 20\n",
    "    Y_matrix = dict()\n",
    "    x = np.linspace(-10.00, 45.00, N)\n",
    "    y = np.linspace(-2.00, 46.00, N)\n",
    "    X, Y = np.meshgrid(x, y, indexing = 'xy')\n",
    "    for i in range(N):\n",
    "        Y_matrix[i] = dict()\n",
    "        for j in range(N):\n",
    "            Y_matrix[i][j] = np.array([[X[i,j],Y[i,j]]])\n",
    "    x_space = (x[1] - x[0])/2\n",
    "    y_space = (y[1] - y[0])/2\n",
    "    r = np.array([[x_space,y_space]])\n",
    "    w_0 = np.array([0.4, 0.6])\n",
    "    T = 100\n",
    "    eta = 0.2\n",
    "    \n",
    "    Rate_down = dict()\n",
    "    \n",
    "    for l in range(Delta):\n",
    "        Rate_down[l] = -1\n",
    "    \n",
    "    Interval_test = dict()\n",
    "    Decision_test = dict()\n",
    "    Risk_test = dict()\n",
    "    \n",
    "    for te in range(n_test):                    \n",
    "        \n",
    "        Y_index = np.zeros((N,N))\n",
    "        \n",
    "        for i in range(N):\n",
    "            for j in range(N):\n",
    "                \n",
    "                for l in range(Delta):\n",
    "                    if Rate_down[l] > -1:\n",
    "                        Rate = Rate_down[l]\n",
    "                    else:\n",
    "                        Rate = 0\n",
    "                        Q = score_quantile + l * delta\n",
    "                        for cal in range(n_cal):\n",
    "                            \n",
    "                            mean, covariance, Q = interval(X_cal[cal:cal+1,:], Q)\n",
    "                            mean = to_numpy(mean)[0]          \n",
    "                            covariance = to_numpy(covariance)[0]\n",
    "                            \n",
    "                            z = cp.Variable((1,q),nonneg=True)\n",
    "                            covariance_sqrt = sqrtm(covariance)\n",
    "                            Q_sqrt = np.sqrt(Q)\n",
    "                            \n",
    "                            objective = cp.Minimize( Q_sqrt * cp.norm(covariance_sqrt @ z.T) + mean @ z.T)\n",
    "                            constraints = [cp.sum(z) == 1]\n",
    "                            problem = cp.Problem(objective, constraints)\n",
    "\n",
    "                            problem.solve()\n",
    "                            \n",
    "                            decision = z.value[0]\n",
    "                            middle = decision.reshape(1,2) @ covariance @ decision.reshape(1,2).T\n",
    "                            risk = np.sqrt(Q)*np.sqrt(middle[0,0])+np.sum(mean * decision)\n",
    "                            loss = np.sum(Y_cal[cal,:] * decision)\n",
    "                            if loss <= risk:\n",
    "                                Rate = Rate + 1\n",
    "                        Rate_down[l] = Rate\n",
    "                    \n",
    "                    if Rate < (1-alpha)*(n_cal + 1) - 1:\n",
    "                        continue\n",
    "                    \n",
    "                    point = 0\n",
    "                    Q = score_quantile + l * delta\n",
    "                    \n",
    "                    mean, covariance, Q = interval(X_test[te:te+1,:], Q)\n",
    "                    mean = to_numpy(mean)[0]         \n",
    "                    covariance = to_numpy(covariance)[0]\n",
    "                    \n",
    "                    z = cp.Variable((1,q),nonneg=True)\n",
    "                    covariance_sqrt = sqrtm(covariance)\n",
    "                    Q_sqrt = np.sqrt(Q)\n",
    "                    \n",
    "                    objective = cp.Minimize( Q_sqrt * cp.norm(covariance_sqrt @ z.T) + mean @ z.T)\n",
    "                    constraints = [cp.sum(z) == 1]\n",
    "                    problem = cp.Problem(objective, constraints)\n",
    "\n",
    "                    problem.solve()\n",
    "                    \n",
    "                    decision = z.value[0]\n",
    "                    middle = decision.reshape(1,2) @ covariance @ decision.reshape(1,2).T\n",
    "                    risk = np.sqrt(Q)*np.sqrt(middle[0,0])+np.sum(mean * decision)\n",
    "                    loss = np.sum(Y_matrix[i][j] * decision)\n",
    "                    if loss <= risk:\n",
    "                        Rate = Rate + 1\n",
    "                        point = 1\n",
    "                        break\n",
    "                    if Rate >= (1-alpha)*(n_cal + 1):\n",
    "                        break\n",
    "                \n",
    "                if point == 1:\n",
    "                    Y_index[i,j] = 1\n",
    "                    \n",
    "        if np.sum(Y_index) == 0:\n",
    "            Q = score_quantile\n",
    "            \n",
    "            mean, covariance, Q = interval(X_test[te:te+1,:], Q)\n",
    "            mean = to_numpy(mean)[0]          \n",
    "            covariance = to_numpy(covariance)[0]  \n",
    "            \n",
    "            z = cp.Variable((1,q),nonneg=True)\n",
    "            covariance_sqrt = sqrtm(covariance)\n",
    "            Q_sqrt = np.sqrt(Q)\n",
    "            \n",
    "            objective = cp.Minimize( Q_sqrt * cp.norm(covariance_sqrt @ z.T) + mean @ z.T)\n",
    "            constraints = [cp.sum(z) == 1]\n",
    "            problem = cp.Problem(objective, constraints)\n",
    "\n",
    "            problem.solve()\n",
    "            \n",
    "            Decision_test[te] = z.value[0]\n",
    "            middle = Decision_test[te].reshape(1,2) @ covariance @ Decision_test[te].reshape(1,2).T\n",
    "            Risk_test[te] = np.sqrt(Q)*np.sqrt(middle[0,0])+np.sum(mean * Decision_test[te])\n",
    "            Interval_test[te] = Y_index\n",
    "        else:\n",
    "            Decision_test[te] = optimize(Y_matrix, Y_index, r, w_0, T, eta)\n",
    "            Interval_test[te] = Y_index\n",
    "            Risk_test[te] = rrisk(Y_matrix, Y_index, r, Decision_test[te])\n",
    "        print(te)\n",
    "    \n",
    "    return Interval_test, Decision_test, Risk_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b33a619-5443-4ada-bf50-999e653bf823",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "start_time = time.time()\n",
    "alpha = 0.1\n",
    "Interval_crccal, Decision_crccal, Risk_crccal = crccal(X_cal, Y_cal, X_test, alpha)\n",
    "Loss_crccal, Risk_crccal, Coverage_crccal, Robustness_crccal = evaluate(X_test, Y_test, Interval_crccal, Decision_crccal, Risk_crccal)\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"CRCCal:{elapsed_time}s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66233933-e0b9-4e0a-a9eb-3e35125da981",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_from_dict(d):\n",
    "    if isinstance(d, dict):\n",
    "        vals = list(d.values())\n",
    "    else:\n",
    "        vals = d  \n",
    "    if len(vals) == 0:\n",
    "        return float('nan')\n",
    "    return float(np.mean(vals))\n",
    "\n",
    "print(\n",
    "    \"CRCCal: Loss {:.6f}, Risk {:.6f}, VaR {:.6f}, Coverage {:.4f}, Robustness {:.4f}\".format(\n",
    "        mean_from_dict(Loss_crccal),\n",
    "        mean_from_dict(Risk_crccal),\n",
    "        mean_from_dict(Coverage_crccal),\n",
    "        mean_from_dict(Robustness_crccal),\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f8217f3-f8b9-4f90-8e3a-a7d43fe631ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = mean_from_dict(Loss_crccal)\n",
    "b = mean_from_dict(Risk_crccal)\n",
    "d = mean_from_dict(Coverage_crccal)\n",
    "e = mean_from_dict(Robustness_crccal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3bf4019-dba0-499e-9252-af057dbfa76c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_cal_crc'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_cal_crc_{run_id}.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {d*100:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {b:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {a:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {e*100:.2f}%\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1dbc539-58c9-4d00-93ae-fd45da653c0e",
   "metadata": {
    "papermill": {
     "duration": 0.030927,
     "end_time": "2025-08-20T05:50:37.770614",
     "exception": false,
     "start_time": "2025-08-20T05:50:37.739687",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### CRO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbf9392-0f04-4e8b-8d9c-485ae46d2a9a",
   "metadata": {
    "papermill": {
     "duration": 0.040439,
     "end_time": "2025-08-20T05:50:37.842399",
     "exception": false,
     "start_time": "2025-08-20T05:50:37.801960",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "model = EllipsoidalUncertaintyModel(input_dim=2, hidden_dims=[20,10], n=2, enable_implicit=False).to(device)           \n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b72b22-0d43-4160-a425-df5da0ebc302",
   "metadata": {
    "papermill": {
     "duration": 0.046145,
     "end_time": "2025-08-20T05:50:37.939741",
     "exception": false,
     "start_time": "2025-08-20T05:50:37.893596",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "X_cal_cpu  = X_cal.cpu()\n",
    "Y_cal_cpu  = Y_cal.cpu()\n",
    "X_test_cpu = X_test.cpu()\n",
    "Y_test_cpu = Y_test.cpu()\n",
    "\n",
    "eps = 1e-6\n",
    "Y_cal_pred, _, cal_Sigma = model(X_cal_cpu)\n",
    "Sigma_inv_cal = torch.linalg.inv(cal_Sigma + eps * torch.eye(cal_Sigma.size(-1)))\n",
    "cal_diff = (Y_cal_cpu - Y_cal_pred)\n",
    "cal_score = torch.einsum('bi,bij,bj->b', cal_diff, Sigma_inv_cal, cal_diff)\n",
    "S_cal = torch.sqrt(cal_score.clamp(min=1e-12))\n",
    "\n",
    "n0 = S_cal.shape[0]\n",
    "Y_test_pred, _, test_Sigma = model(X_test_cpu)\n",
    "Sigma_inv_test = torch.linalg.inv(test_Sigma + eps * torch.eye(test_Sigma.size(-1)))\n",
    "test_diff = (Y_test_cpu - Y_test_pred)\n",
    "test_score = torch.einsum('bi,bij,bj->b', test_diff, Sigma_inv_test, test_diff)\n",
    "S_test = torch.sqrt(test_score.clamp(min=1e-12))\n",
    "marginal_quantile = torch.quantile(S_cal, 0.9 * (1+1/n0)).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "037ead9e-e89a-48a3-a3f3-d87e5118ff6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta = 1e-2 \n",
    "tol = 1e-4\n",
    "n = 2\n",
    "T = 200\n",
    "\n",
    "z_value = []\n",
    "loss_value = []\n",
    "n_test = Y_test_pred.shape[0]\n",
    "\n",
    "for i in range(n_test):\n",
    "    Sigma_inv = torch.inverse(test_Sigma)                 \n",
    "    L_torch   = torch.linalg.cholesky(Sigma_inv)               \n",
    "    L = L_torch.detach().cpu().numpy()   \n",
    "    conformal_quantile = float(marginal_quantile)\n",
    "\n",
    "    model = ro.Model()\n",
    "    y     = model.rvar(n)   \n",
    "    z_d   = model.dvar(n)   \n",
    "\n",
    "    y_hat = Y_test_pred[i].detach().cpu().numpy()   \n",
    "    uset   = rso.norm(L[i] @ (y - y_hat), 2) <= conformal_quantile\n",
    "    model.minmax(y @ z_d, uset)     \n",
    "    model.st(z_d <= 1)\n",
    "    model.st(z_d >= 0)\n",
    "    model.st(z_d.sum() == 1)\n",
    "\n",
    "    model.solve(SOLVER, display=False)\n",
    "\n",
    "    z_opt = z_d.get()               \n",
    "    z_value.append(z_opt)\n",
    "\n",
    "    print(f\"Test_sample: {i}, Completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a4ae64-5aaa-4088-aa14-2d751298eb19",
   "metadata": {
    "papermill": {
     "duration": 0.181828,
     "end_time": "2025-08-20T05:55:55.121793",
     "exception": false,
     "start_time": "2025-08-20T05:55:54.939965",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "c_count1 = 0\n",
    "for i in range(Y_test.shape[0]):\n",
    "    if S_test[i] <= marginal_quantile:\n",
    "        c_count1 += 1\n",
    "        \n",
    "Sigma_half = torch.linalg.cholesky(test_Sigma)  \n",
    "z_value = torch.tensor(z_value, dtype=torch.float32)  \n",
    "z_col = z_value.unsqueeze(-1)               \n",
    "Sigma_z = torch.matmul(Sigma_half, z_col)   \n",
    "\n",
    "norm_Sigma_z = torch.norm(Sigma_z, p=2, dim=(1,2)) \n",
    "mu_z = (Y_test_pred * z_value).sum(dim=1)        \n",
    "\n",
    "risk_loss_1    = marginal_quantile * norm_Sigma_z + mu_z                \n",
    "Average_Risk_1 = risk_loss_1.mean()                \n",
    "\n",
    "decision_loss_1 = (Y_test * z_value).sum(dim=1)    \n",
    "Average_Loss_1  = decision_loss_1.mean()   \n",
    "\n",
    "r_count1 = 0\n",
    "num = Y_test.shape[0]\n",
    "for i in range(num):\n",
    "    if decision_loss_1[i] <= risk_loss_1[i]:\n",
    "        r_count1 += 1\n",
    "\n",
    "print(f\"Marginal_Coverage: {(c_count1/Y_test.shape[0])*100:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk_1.item():.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss_1.item():.4f}\")\n",
    "print(f\"Robustness: {(r_count1/num)*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0049060d-384d-4472-a637-c40eb70de705",
   "metadata": {
    "papermill": {
     "duration": 0.191656,
     "end_time": "2025-08-20T05:55:55.491566",
     "exception": false,
     "start_time": "2025-08-20T05:55:55.299910",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_Two_step'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_Two_step.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {(c_count1/Y_test.shape[0])*100:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk_1:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss_1:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {(r_count1/num)*100:.2f}%\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914a16d2-451b-4d85-b514-253a8337405e",
   "metadata": {
    "papermill": {
     "duration": 0.15556,
     "end_time": "2025-08-20T05:55:55.811113",
     "exception": false,
     "start_time": "2025-08-20T05:55:55.655553",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### End-to-end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9698adb0-7373-463c-b586-db323dabea45",
   "metadata": {
    "papermill": {
     "duration": 0.173288,
     "end_time": "2025-08-20T05:55:56.142193",
     "exception": false,
     "start_time": "2025-08-20T05:55:55.968905",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "X_train = torch.stack([x for x,_ in train_ds], dim=0).to(device)\n",
    "Y_train = torch.stack([y for _,y in train_ds], dim=0).to(device)\n",
    "\n",
    "train_loader = DataLoader(train_ds, shuffle=True, batch_size=batch_size)\n",
    "val_loader = DataLoader(cal_ds, shuffle=False, batch_size=batch_size)\n",
    "test_loader = DataLoader(test_ds, shuffle=False, batch_size=batch_size)\n",
    "\n",
    "all_Y = np.concatenate([Y_train, Y_cal, Y_test], axis=0)\n",
    "y_mean = all_Y.mean(axis=0)  \n",
    "y_std = all_Y.std(axis=0)   \n",
    "y_mean = np.array(y_mean)\n",
    "y_std = np.array(y_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17d18a3-b2dd-4653-8b02-f017508eb1a5",
   "metadata": {
    "papermill": {
     "duration": 0.171912,
     "end_time": "2025-08-20T05:55:56.495686",
     "exception": false,
     "start_time": "2025-08-20T05:55:56.323774",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BaseProblemProtocol(Protocol):\n",
    "\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    vars: dict[str, cp.Variable]\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vars = self.primal_vars | self.dual_vars\n",
    "        self.constraints.extend(self.dual_constraints)\n",
    "        obj = self.dual_obj\n",
    "\n",
    "        prob = cp.Problem(cp.Minimize(obj), self.constraints)\n",
    "        assert prob.is_dpp()\n",
    "        self.prob = prob\n",
    "\n",
    "    def task_loss_np(self, y: np.ndarray, is_standardized: bool) -> float:\n",
    "        ...\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, is_standardized: bool, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        ...\n",
    "\n",
    "    def get_cvxpylayer(self) -> CvxpyLayer:\n",
    "        return CvxpyLayer(\n",
    "            self.prob,\n",
    "            parameters=list(self.params.values()),\n",
    "            variables=list(self.vars.values())\n",
    "        )\n",
    "\n",
    "class EllipsoidProblemProtocol(BaseProblemProtocol, Protocol):\n",
    "    def solve(self, loc: np.ndarray, scale_tril: np.ndarray) -> cp.Problem:\n",
    "        ...\n",
    "\n",
    "class EllipsoidProblem:\n",
    "   \n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(\n",
    "        self, y_dim: int, y_mean: np.ndarray, y_std: np.ndarray, Fz: cp.Expression\n",
    "    ):\n",
    "        self.y_mean = y_mean\n",
    "        self.y_std = y_std\n",
    "\n",
    "        # parameters\n",
    "        loc = cp.Parameter(y_dim, name='loc')\n",
    "        scale_tril = cp.Parameter((y_dim, y_dim), name='scale_tril')\n",
    "        self.params = {\n",
    "            'loc': loc,\n",
    "            'scale_tril': scale_tril,\n",
    "        }\n",
    "\n",
    "        Fz_ystd = Fz\n",
    "\n",
    "        # objective\n",
    "        self.dual_obj = (\n",
    "            cp.norm(scale_tril @ Fz_ystd)\n",
    "            + loc @ Fz_ystd)\n",
    "\n",
    "        self.dual_vars = {}\n",
    "        self.dual_constraints = []\n",
    "\n",
    "    def solve(self, loc: np.ndarray, scale_tril: np.ndarray) -> cp.Problem:\n",
    "        \n",
    "        self.params['loc'].value = loc\n",
    "        self.params['scale_tril'].value = scale_tril\n",
    "        prob = self.prob\n",
    "        prob.solve()\n",
    "        if prob.status != 'optimal':\n",
    "            print('Problem status:', prob.status)\n",
    "        return prob\n",
    "\n",
    "class PortfolioProblemBase:\n",
    "   \n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    prob: cp.Problem\n",
    "    y_mean: np.ndarray\n",
    "    y_std: np.ndarray\n",
    "\n",
    "    def __init__(self, N: int):\n",
    "      \n",
    "        self.N = N\n",
    "\n",
    "        z = cp.Variable(N, name='z', nonneg=True)\n",
    "        self.primal_vars = {'z': z}\n",
    "\n",
    "        self.constraints = [\n",
    "            cp.sum(z) == 1,\n",
    "        ]\n",
    "\n",
    "        self.Fz = z\n",
    "        self.f_tilde = 0.\n",
    "\n",
    "    def task_loss_np(self, scale_tril: np.ndarray, mu: np.ndarray) -> float:\n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        z = self.primal_vars['z'].value\n",
    "        #task_loss = y @ z\n",
    "        task_loss =  np.linalg.norm(scale_tril @ z, 2) + np.dot(mu, z)\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_np_1(self, y: np.ndarray) -> float:\n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        z = self.primal_vars['z'].value\n",
    "        task_loss = y @ z\n",
    "        #task_loss = task_loss = float((y * z).sum())\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, scale_tril: np.ndarray, mu: np.ndarray, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        z = solution[0]\n",
    "        assert y.shape == z.shape\n",
    "        task_loss = torch.norm((scale_tril @ z.unsqueeze(-1)).squeeze(-1),dim=1)  + (mu * z).sum(dim=-1)\n",
    "        return task_loss\n",
    "\n",
    "class PortfolioProblemEllipsoid(PortfolioProblemBase, EllipsoidProblem, EllipsoidProblemProtocol):\n",
    "    def __init__(self, N: int, y_mean: np.ndarray, y_std: np.ndarray):\n",
    "        PortfolioProblemBase.__init__(self, N=N)\n",
    "        EllipsoidProblem.__init__(self, y_dim=N, y_mean=y_mean, y_std=y_std, Fz=self.Fz)\n",
    "        EllipsoidProblemProtocol.__init__(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c932ba85-85b0-4d7c-92e9-e50d023aa287",
   "metadata": {
    "papermill": {
     "duration": 0.161303,
     "end_time": "2025-08-20T05:55:56.845873",
     "exception": false,
     "start_time": "2025-08-20T05:55:56.684570",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def mahalanobis_dist2(loc: Tensor, scale_tril: Tensor, y: Tensor) -> Tensor:\n",
    "    return tdist.multivariate_normal._batch_mahalanobis(bL=scale_tril, bx=y - loc)\n",
    "\n",
    "def calc_q(scores: Tensor, alpha: float) -> Tensor:\n",
    "    n = len(scores)\n",
    "    j = int(np.ceil((n+1) * (1-alpha)))\n",
    "    if j > n:\n",
    "        return torch.tensor(torch.inf)\n",
    "    sorted_inds = torch.argsort(scores)\n",
    "    q = scores[sorted_inds[j-1]]\n",
    "    return q\n",
    "\n",
    "def conformal_q(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    alpha: float,\n",
    "    device: str = 'cpu'\n",
    ") -> Tensor:\n",
    "   \n",
    "    model.to(device)\n",
    "    all_scores = []\n",
    "    for x, y in loader:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "        loc, scale_tril, _ = model(x)\n",
    "        scores = mahalanobis_dist2(loc, scale_tril, y)\n",
    "        all_scores.append(scores)\n",
    "\n",
    "    scores = torch.cat(all_scores)\n",
    "    q = calc_q(scores, alpha)\n",
    "    assert q != torch.inf, 'Size of calibration set is too small'\n",
    "    return q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43924fa-4568-449d-b0ae-1210f3972f2c",
   "metadata": {
    "papermill": {
     "duration": 0.164581,
     "end_time": "2025-08-20T05:55:57.162745",
     "exception": false,
     "start_time": "2025-08-20T05:55:56.998164",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_e2e_epoch(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    alpha: float,\n",
    "    nll_loss_frac: float,\n",
    "    rng: np.random.Generator,\n",
    "    optimizer: torch.optim.Optimizer,\n",
    "    show_pbar: bool = False\n",
    ") -> tuple[float, float, float]:\n",
    "   \n",
    "    assert 0 < alpha < 1\n",
    "    model.train()\n",
    "\n",
    "    cvxpylayer = prob.get_cvxpylayer()\n",
    "\n",
    "    total_loss = 0.\n",
    "    total_nll_loss = 0.\n",
    "    total_task_loss = 0.\n",
    "    total_num_tasks = 0\n",
    "    for x, y in tqdm(loader) if show_pbar else loader:\n",
    "        batch_size = x.shape[0]\n",
    "        loc, scale_tril, _ = model(x)\n",
    "\n",
    "        loss = torch.tensor(0.)\n",
    "        msgs = []\n",
    "\n",
    "        # compute nll loss if desired\n",
    "        if nll_loss_frac > 0:\n",
    "            nll_loss = -tdist.MultivariateNormal(loc, scale_tril=scale_tril).log_prob(y).mean()\n",
    "            loss += nll_loss_frac * nll_loss\n",
    "            total_nll_loss += nll_loss.item() * batch_size\n",
    "            msgs.append(f'nll_loss: {nll_loss.item()}')\n",
    "\n",
    "        # compute task loss if desired\n",
    "        if nll_loss_frac < 1:\n",
    "            perm = rng.permutation(batch_size)\n",
    "\n",
    "            cal_inds = perm[:batch_size//2]\n",
    "            scores_cal = mahalanobis_dist2(loc[cal_inds], scale_tril[cal_inds], y[cal_inds])\n",
    "            q = calc_q(scores_cal, alpha)\n",
    "            if q == torch.inf:\n",
    "                tqdm.write('Batch is too small, skipping')\n",
    "                continue\n",
    "\n",
    "            task_inds = perm[batch_size//2:]\n",
    "            y_task = y[task_inds]\n",
    "            loc_task = loc[task_inds]\n",
    "            scale_tril_task = scale_tril[task_inds] * torch.sqrt(q)\n",
    "\n",
    "            try:\n",
    "                solution = cvxpylayer(loc_task, scale_tril_task)\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                raise e\n",
    "                # import pdb\n",
    "                # pdb.set_trace()\n",
    "\n",
    "            #task_loss = prob.task_loss_torch(scale_tril_task, loc_task, solution=solution).mean()\n",
    "            task_loss = prob.task_loss_torch(y_task, scale_tril_task, loc_task, solution=solution).mean()\n",
    "            loss += (1 - nll_loss_frac) * task_loss\n",
    "\n",
    "            total_task_loss += task_loss.item() * task_inds.shape[0]\n",
    "            total_num_tasks += task_inds.shape[0]\n",
    "\n",
    "            msgs.append(f'task_loss: {task_loss.item()}')\n",
    "\n",
    "        if show_pbar:\n",
    "            tqdm.write(','.join(msgs))\n",
    "        total_loss += loss.item() * batch_size\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    avg_loss = total_loss / len(loader.dataset)\n",
    "    avg_nll_loss = total_nll_loss / len(loader.dataset)\n",
    "    avg_task_loss = total_task_loss / total_num_tasks\n",
    "    return avg_loss, avg_nll_loss, avg_task_loss\n",
    "\n",
    "\n",
    "def train_e2e(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    loaders: Mapping[str, torch.utils.data.DataLoader],\n",
    "    alpha: float,\n",
    "    max_epochs: int,\n",
    "    lr: float,\n",
    "    l2reg: float,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    rng: np.random.Generator,\n",
    "    nll_loss_frac: float | Sequence[float],\n",
    "    saved_model_path: str = ''\n",
    ") -> dict[str, Any]:\n",
    "   \n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2reg)\n",
    "    if saved_model_path != '':\n",
    "        tqdm.write(f'Loading saved model: {saved_model_path}')\n",
    "        model.load_state_dict(torch.load(saved_model_path, weights_only=True))\n",
    "\n",
    "    # Train model\n",
    "    result: dict[str, Any] = {\n",
    "        'train_e2e_losses': [],\n",
    "        'train_nll_losses': [],\n",
    "        'train_task_losses': [],\n",
    "        'val_task_losses': [],\n",
    "        'best_epoch': 0,\n",
    "        'val_task_loss': np.inf,  # best loss on val set\n",
    "    }\n",
    "    steps_since_decrease = 0\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    pbar = tqdm(range(max_epochs))\n",
    "    for epoch in pbar:\n",
    "        if isinstance(nll_loss_frac, Sequence):\n",
    "            weight = nll_loss_frac[min(epoch, len(nll_loss_frac) - 1)]\n",
    "        else:\n",
    "            weight = nll_loss_frac\n",
    "\n",
    "        train_e2e_loss, train_nll_loss, train_task_loss = train_e2e_epoch(\n",
    "            model, prob=prob, loader=loaders['train'], alpha=alpha,\n",
    "            nll_loss_frac=weight, rng=rng, optimizer=optimizer)\n",
    "        result['train_e2e_losses'].append(train_e2e_loss)\n",
    "        result['train_nll_losses'].append(train_nll_loss)\n",
    "        result['train_task_losses'].append(train_task_loss)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            q = conformal_q(model, loaders['calib'], alpha=alpha).item()\n",
    "            val_task_loss = optimize(model, prob=prob, loader=loaders['calib'], q=q)\n",
    "            result['val_task_losses'].append(val_task_loss)\n",
    "\n",
    "        msg = (f'Epoch {epoch}, train_task_loss {train_task_loss:.3f}, '\n",
    "               f'val_task_loss {val_task_loss:.3f}, q {q:.3f}')\n",
    "        pbar.set_description(msg)\n",
    "\n",
    "        steps_since_decrease += 1\n",
    "\n",
    "        if val_task_loss < result['val_task_loss']:\n",
    "            result['best_epoch'] = epoch\n",
    "            result['val_task_loss'] = val_task_loss\n",
    "            steps_since_decrease = 0\n",
    "            buffer.seek(0)\n",
    "            torch.save(model.state_dict(), buffer)\n",
    "\n",
    "        if steps_since_decrease > 10:\n",
    "            break\n",
    "\n",
    "    buffer.seek(0)\n",
    "    model.load_state_dict(torch.load(buffer, weights_only=True))\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6b2779-1707-450b-915c-02d854be30bf",
   "metadata": {
    "papermill": {
     "duration": 0.168955,
     "end_time": "2025-08-20T05:55:57.500731",
     "exception": false,
     "start_time": "2025-08-20T05:55:57.331776",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def optimize(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    q: float,\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "\n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    task_losses = []\n",
    "    if show_pbar:\n",
    "        pbar = tqdm(total=len(loader.dataset))\n",
    "    for x_batch, y_batch in loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "        pred = model(x_batch)\n",
    "        loc_batch = pred[0].detach().cpu().numpy()\n",
    "        scale_tril_batch = pred[1].detach().cpu().numpy()\n",
    "        y_batch = y_batch.detach().cpu().numpy()\n",
    "\n",
    "        scale_tril_batch *= np.sqrt(q)\n",
    "\n",
    "        for y, loc, scale_tril in zip(y_batch, loc_batch, scale_tril_batch):\n",
    "            prob.solve(loc, scale_tril)\n",
    "            task_loss = prob.task_loss_np(scale_tril, loc)\n",
    "            task_losses.append(task_loss)\n",
    "\n",
    "            if show_pbar:\n",
    "                pbar.update(1)\n",
    "\n",
    "    return np.mean(task_losses).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d57ea1-e1ce-439a-a0ad-6fe321d49b9d",
   "metadata": {
    "papermill": {
     "duration": 0.163649,
     "end_time": "2025-08-20T05:55:57.859541",
     "exception": false,
     "start_time": "2025-08-20T05:55:57.695892",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy, math, itertools, torch, numpy as np\n",
    "\n",
    "loaders = {'train': train_loader, 'calib': val_loader, 'test': test_loader}\n",
    "alpha = 0.1\n",
    "max_epochs = 100\n",
    "nll_loss_frac = 0.0  \n",
    "lr_list    = [1e-3, 1e-4]\n",
    "l2_list    = [1e-2, 1e-3, 1e-4]\n",
    "\n",
    "best = {\n",
    "    \"val_task_loss\": float(\"inf\"),\n",
    "    \"lr\": None,\n",
    "    \"l2reg\": None,\n",
    "    \"epoch\": None,\n",
    "    \"state_dict\": None,\n",
    "}\n",
    "all_results = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1cc17dc-9129-408d-b422-f74fa4c53d4c",
   "metadata": {
    "papermill": {
     "duration": 2398.440752,
     "end_time": "2025-08-20T06:35:56.472392",
     "exception": false,
     "start_time": "2025-08-20T05:55:58.031640",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "for lr in lr_list:\n",
    "    for l2 in l2_list:\n",
    "        model = EllipsoidalUncertaintyModel(input_dim=2, hidden_dims = [20, 10], n=2, enable_implicit=False)  \n",
    "        model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "        model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "        prob = PortfolioProblemEllipsoid(N=2, y_mean=y_mean, y_std=y_std)\n",
    "\n",
    "        rng = np.random.default_rng(42)\n",
    "\n",
    "        res = train_e2e(\n",
    "            model=model,\n",
    "            loaders=loaders,      \n",
    "            alpha=alpha,\n",
    "            max_epochs=max_epochs,\n",
    "            lr=lr,\n",
    "            l2reg=l2,\n",
    "            prob=prob,\n",
    "            rng=rng,\n",
    "            nll_loss_frac=nll_loss_frac,\n",
    "            saved_model_path=''\n",
    "        )\n",
    "\n",
    "        record = {\n",
    "            \"lr\": lr,\n",
    "            \"l2reg\": l2,\n",
    "            \"best_epoch\": res[\"best_epoch\"],\n",
    "            \"val_task_loss\": res[\"val_task_loss\"],\n",
    "        }\n",
    "        all_results.append(record)\n",
    "\n",
    "        if res[\"val_task_loss\"] > 0 and res[\"val_task_loss\"] < best[\"val_task_loss\"] :\n",
    "            best.update(record)\n",
    "            best[\"state_dict\"] = copy.deepcopy(model.state_dict())\n",
    "\n",
    "print(\"\\n=== Grid Search Summary ===\")\n",
    "all_results = sorted(all_results, key=lambda r: r[\"val_task_loss\"])\n",
    "for r in all_results:\n",
    "    print(f\"lr={r['lr']:g}, l2={r['l2reg']:g}, \"\n",
    "          f\"best_epoch={r['best_epoch']}, val_task_loss={r['val_task_loss']:.6f}\")\n",
    "\n",
    "print(\"\\n>>> Best config:\")\n",
    "print(f\"lr={best['lr']:g}, l2={best['l2reg']:g}, \"\n",
    "      f\"best_epoch={best['best_epoch']}, val_task_loss={best['val_task_loss']:.6f}\")\n",
    "\n",
    "best_model = EllipsoidalUncertaintyModel(input_dim=2, hidden_dims = [20, 10], n=2, enable_implicit=False) \n",
    "best_model.load_state_dict(best[\"state_dict\"])\n",
    "\n",
    "os.makedirs('outputs/model/E2E', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'E2E', 'E2E_model.pth')\n",
    "torch.save(best_model.state_dict(), save_path)\n",
    "print(\"Saved best model to best_model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b183982-1989-439a-b017-8e891595ce33",
   "metadata": {
    "papermill": {
     "duration": 0.161478,
     "end_time": "2025-08-20T06:35:56.799127",
     "exception": false,
     "start_time": "2025-08-20T06:35:56.637649",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def optimize_test(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    q: float,\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "    \n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    task_losses = []\n",
    "    expr_losses = []\n",
    "    z_value = []       \n",
    "    count_smaller = 0  \n",
    "    total_points = 0  \n",
    "    count_within_q = 0\n",
    "    \n",
    "    if show_pbar:\n",
    "        pbar = tqdm(total=len(loader.dataset))\n",
    "    for x_batch, y_batch in loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "        pred = model(x_batch)\n",
    "        loc_batch = pred[0].detach().cpu().numpy()\n",
    "        scale_tril_batch = pred[1].detach().cpu().numpy()\n",
    "        y_batch = y_batch.detach().cpu().numpy()\n",
    "\n",
    "        scale_tril_batch *= np.sqrt(q)\n",
    "\n",
    "        for y, loc, scale_tril in zip(y_batch, loc_batch, scale_tril_batch):\n",
    "            prob.solve(loc, scale_tril)\n",
    "            task_loss = prob.task_loss_np_1(y)\n",
    "            task_losses.append(task_loss)\n",
    "\n",
    "            Fz_val = np.asarray(prob.Fz.value).reshape(-1)\n",
    "            z_value.append(Fz_val)\n",
    "            Fz_ystd = Fz_val\n",
    "            scale_term = np.linalg.norm(scale_tril @ Fz_ystd)\n",
    "            mean_term  = loc @ Fz_ystd\n",
    "            expr_loss = scale_term + mean_term\n",
    "            expr_losses.append(expr_loss)\n",
    "\n",
    "            if task_loss < expr_loss:\n",
    "                count_smaller += 1\n",
    "\n",
    "            scale_tril_1 = scale_tril / np.sqrt(q)\n",
    "            Sigma_inv = np.linalg.inv(scale_tril_1 @ scale_tril_1.T) \n",
    "            diff = y - loc\n",
    "            mahalanobis_sq = diff.T @ Sigma_inv @ diff\n",
    "\n",
    "            if mahalanobis_sq <= q:\n",
    "                count_within_q += 1\n",
    "\n",
    "            total_points += 1\n",
    "\n",
    "\n",
    "            if show_pbar:\n",
    "                pbar.update(1)\n",
    "                \n",
    "    avg_loss = np.mean(task_losses)\n",
    "    avg_risk = np.mean(expr_losses)\n",
    "    robustness = count_smaller / total_points * 100\n",
    "    coverage = count_within_q / total_points * 100\n",
    "\n",
    "    return coverage, avg_risk, avg_loss, robustness, z_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac47837-086d-49f7-b370-cc4fb586c70b",
   "metadata": {
    "papermill": {
     "duration": 1.544949,
     "end_time": "2025-08-20T06:35:58.506044",
     "exception": false,
     "start_time": "2025-08-20T06:35:56.961095",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_path = os.path.join('outputs', 'model', 'E2E', 'E2E_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, weights_only=True))\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    q = conformal_q(model, loaders['calib'], alpha=alpha).item()\n",
    "coverage, avg_risk, avg_loss, robustness_1, z_value = optimize_test(model, prob, loaders['test'], q=q)\n",
    "\n",
    "print(f\"Coverage: {coverage:.2f}%\")\n",
    "print(f\"Ave_task_loss: {avg_risk:.4f}\")\n",
    "print(f\"Ave_loss: {avg_loss:.4f}\")\n",
    "print(f\"Robustness: {robustness_1:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dae6393-43b6-4df2-81fc-f185f1b65ae4",
   "metadata": {
    "papermill": {
     "duration": 0.158075,
     "end_time": "2025-08-20T06:35:58.851090",
     "exception": false,
     "start_time": "2025-08-20T06:35:58.693015",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_E2E'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_E2E.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {coverage:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {avg_risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {avg_loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_1:.2f}%\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6452e5-6bb3-44f5-9a0e-a1b23ff8a1b6",
   "metadata": {
    "papermill": {
     "duration": 0.151598,
     "end_time": "2025-08-20T06:36:00.528723",
     "exception": false,
     "start_time": "2025-08-20T06:36:00.377125",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 7418.752614,
   "end_time": "2025-08-20T06:36:02.011798",
   "environment_variables": {},
   "exception": null,
   "input_path": "CRC_GPU.ipynb",
   "output_path": "outputs/file/CRC_GPU_0.ipynb",
   "parameters": {
    "run_id": 0
   },
   "start_time": "2025-08-20T04:32:23.259184",
   "version": "2.6.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "026e337f7d484d6e95dcd375fd7de27c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_fc23a79426b344c890f2a26ef3c9156d",
       "placeholder": "​",
       "style": "IPY_MODEL_92112d2baeb14eb29054b40747c031f8",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 99, train_task_loss 9.586, val_task_loss 9.325, q 1.807: 100%"
      }
     },
     "036a5ff217f74069885b76af9995d86f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_4480f109a80848a48d2f508aa0bef738",
       "placeholder": "​",
       "style": "IPY_MODEL_13308ffc808445848f9d06028566ac9a",
       "tabbable": null,
       "tooltip": null,
       "value": " 45/100 [04:46&lt;05:39,  6.17s/it]"
      }
     },
     "0692ea86de604b4183af86cd0aa27fe6": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "118e573edad642b4a43d8450d044bf77": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "126dcedd70674b96b42ca18a3f374119": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "13308ffc808445848f9d06028566ac9a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "1844a65aaf354db9976d5f52fe4d70dc": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_9f5562e6959f4f57804ef27a67065409",
       "placeholder": "​",
       "style": "IPY_MODEL_f63f9ffabae842d0a16e5482ba568c44",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 39, train_task_loss 9.132, val_task_loss 8.844, q 3.481:  39%"
      }
     },
     "1d64cbd813544d939a9abb0a51456c3f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "1e449cbedb414c90bddf7c2146f4ee82": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "24a4cae76d924a4bb399bb1223c4d67c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_1844a65aaf354db9976d5f52fe4d70dc",
        "IPY_MODEL_4ac45c42b83448638fb5810eaa7d311e",
        "IPY_MODEL_65642d95774c41bcac85e02731cf6314"
       ],
       "layout": "IPY_MODEL_1d64cbd813544d939a9abb0a51456c3f",
       "tabbable": null,
       "tooltip": null
      }
     },
     "2fe816cedfea461fab8568369d61bdb8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "36878497c10d494da43ed30aa908e612": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_8bb0a53804bf4824a05761b89ccb8485",
       "placeholder": "​",
       "style": "IPY_MODEL_68750af418e340938b36b55dec057c95",
       "tabbable": null,
       "tooltip": null,
       "value": " 83/100 [08:29&lt;01:43,  6.06s/it]"
      }
     },
     "3bed96318a50421fba4feaa80ac804ea": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_8871bed2d3314a059760f13a60c6fe5c",
       "placeholder": "​",
       "style": "IPY_MODEL_6004261dcfdb45b487cd3eec83edb3eb",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [07:24&lt;00:00,  4.35s/it]"
      }
     },
     "3d72b7fdb57c42df9339901963a148da": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "40a7a6033ec2456dab6ebcd824735d12": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_ba8bae1ebc424312b5cb2eab28d3d6b8",
       "placeholder": "​",
       "style": "IPY_MODEL_c2df85849b0441c5b5e0f89fd2b45f45",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 45, train_task_loss 9.155, val_task_loss 8.598, q 3.241:  45%"
      }
     },
     "4480f109a80848a48d2f508aa0bef738": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "488146816bc342f8b36063c362f33373": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "4ac24431d46b4782b7298dcd6c6df6bb": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_994a930ace6b4743a5eb405c23af81eb",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_2fe816cedfea461fab8568369d61bdb8",
       "tabbable": null,
       "tooltip": null,
       "value": 53
      }
     },
     "4ac45c42b83448638fb5810eaa7d311e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_7832cad38e61461493974e711473a9c6",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_118e573edad642b4a43d8450d044bf77",
       "tabbable": null,
       "tooltip": null,
       "value": 39
      }
     },
     "4b33169a5a6040f6b2e4783933ce3c99": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_bde5c90770ac44beaa1923c7025f9b72",
        "IPY_MODEL_74033833488745d19e41989257824018",
        "IPY_MODEL_f06a2777ec8b460da3bfc9b0a3b60539"
       ],
       "layout": "IPY_MODEL_d6a409a6f5354c55b96bbc03d07610e2",
       "tabbable": null,
       "tooltip": null
      }
     },
     "4e83ee49eaff408cab8884464d3dc24e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_40a7a6033ec2456dab6ebcd824735d12",
        "IPY_MODEL_99bb07bb9a0a427983cc1c0b0105d13d",
        "IPY_MODEL_036a5ff217f74069885b76af9995d86f"
       ],
       "layout": "IPY_MODEL_58e0e604043241d1a5c1918afca97041",
       "tabbable": null,
       "tooltip": null
      }
     },
     "4eda984a09e04f3b8e6e781cff09d670": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "58e0e604043241d1a5c1918afca97041": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "5eee1687dd1e45ddb58e40c7ec351877": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_c8578b8d720548bc8e462c1ea5d1faae",
        "IPY_MODEL_4ac24431d46b4782b7298dcd6c6df6bb",
        "IPY_MODEL_97cfa26232484b9f96c88b8f390eb1c4"
       ],
       "layout": "IPY_MODEL_999e45cc25d94036b0aef735aa655a42",
       "tabbable": null,
       "tooltip": null
      }
     },
     "6004261dcfdb45b487cd3eec83edb3eb": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "6352c4a1b6874e06b3471580a5ad2c1a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "65642d95774c41bcac85e02731cf6314": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_126dcedd70674b96b42ca18a3f374119",
       "placeholder": "​",
       "style": "IPY_MODEL_a47a656cf1634945bac9883171886922",
       "tabbable": null,
       "tooltip": null,
       "value": " 39/100 [04:03&lt;06:17,  6.20s/it]"
      }
     },
     "67013b57888f4fda8314d34b6eff070f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "68750af418e340938b36b55dec057c95": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "6ee91f1e0d6a41f3b288ee03e702ca3d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_026e337f7d484d6e95dcd375fd7de27c",
        "IPY_MODEL_cc4d856b8f004cd1a88a6562661cd752",
        "IPY_MODEL_3bed96318a50421fba4feaa80ac804ea"
       ],
       "layout": "IPY_MODEL_bd58602ea4cd4b2c9ffa31d2df7cc60a",
       "tabbable": null,
       "tooltip": null
      }
     },
     "74033833488745d19e41989257824018": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_1e449cbedb414c90bddf7c2146f4ee82",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_6352c4a1b6874e06b3471580a5ad2c1a",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "7751fd9c34ca49a18d68b98ffa9b7489": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "7832cad38e61461493974e711473a9c6": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8871bed2d3314a059760f13a60c6fe5c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "88f0a35cd19d4b89a78066eefdcefbc6": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8bb0a53804bf4824a05761b89ccb8485": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8bca2508e4d142ceb46a6e9bdd5ec64d": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8ebfe4ed1d9046ffa2c8eb15c19a0f06": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "8f9790e104c74fbe87a1f0f86e2a819b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_cf2623197fa24f2cbef8c3f839c0d1ab",
        "IPY_MODEL_ca33c90d06964d1cbdb97194bdebf3c7",
        "IPY_MODEL_36878497c10d494da43ed30aa908e612"
       ],
       "layout": "IPY_MODEL_f0f02c851bf64590a9fca9db6fd3a019",
       "tabbable": null,
       "tooltip": null
      }
     },
     "92112d2baeb14eb29054b40747c031f8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "959b93bd26014e099b3b23f625dd37d5": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "97cfa26232484b9f96c88b8f390eb1c4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_88f0a35cd19d4b89a78066eefdcefbc6",
       "placeholder": "​",
       "style": "IPY_MODEL_0692ea86de604b4183af86cd0aa27fe6",
       "tabbable": null,
       "tooltip": null,
       "value": " 53/100 [05:31&lt;04:41,  5.99s/it]"
      }
     },
     "994a930ace6b4743a5eb405c23af81eb": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "999e45cc25d94036b0aef735aa655a42": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "99bb07bb9a0a427983cc1c0b0105d13d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_3d72b7fdb57c42df9339901963a148da",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_fce8f8580a134fc9b24a60ef1a302f80",
       "tabbable": null,
       "tooltip": null,
       "value": 45
      }
     },
     "9f5562e6959f4f57804ef27a67065409": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "a47a656cf1634945bac9883171886922": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "ba8bae1ebc424312b5cb2eab28d3d6b8": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "bd58602ea4cd4b2c9ffa31d2df7cc60a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "bde5c90770ac44beaa1923c7025f9b72": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_8bca2508e4d142ceb46a6e9bdd5ec64d",
       "placeholder": "​",
       "style": "IPY_MODEL_d1f9b860f3564f86b26e7fcc833feab8",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 99, train_task_loss 9.487, val_task_loss 9.300, q 1.776: 100%"
      }
     },
     "c2df85849b0441c5b5e0f89fd2b45f45": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "c8578b8d720548bc8e462c1ea5d1faae": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_67013b57888f4fda8314d34b6eff070f",
       "placeholder": "​",
       "style": "IPY_MODEL_e83424edce8142aeb374d9b9335f249b",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 53, train_task_loss 8.746, val_task_loss 8.611, q 3.074:  53%"
      }
     },
     "ca33c90d06964d1cbdb97194bdebf3c7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "danger",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_caf508369e784251ae7f558b388fedf8",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_ccfaf27355b14fd08450b5faced556dc",
       "tabbable": null,
       "tooltip": null,
       "value": 83
      }
     },
     "caf508369e784251ae7f558b388fedf8": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "cc4d856b8f004cd1a88a6562661cd752": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_959b93bd26014e099b3b23f625dd37d5",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_8ebfe4ed1d9046ffa2c8eb15c19a0f06",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "ccfaf27355b14fd08450b5faced556dc": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "cf2623197fa24f2cbef8c3f839c0d1ab": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_7751fd9c34ca49a18d68b98ffa9b7489",
       "placeholder": "​",
       "style": "IPY_MODEL_e5a483d37aa642c093f2c0f824266c1a",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 83, train_task_loss 9.796, val_task_loss 9.393, q 1.773:  83%"
      }
     },
     "d1f9b860f3564f86b26e7fcc833feab8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "d6a409a6f5354c55b96bbc03d07610e2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "e5a483d37aa642c093f2c0f824266c1a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "e83424edce8142aeb374d9b9335f249b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "f06a2777ec8b460da3bfc9b0a3b60539": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_488146816bc342f8b36063c362f33373",
       "placeholder": "​",
       "style": "IPY_MODEL_4eda984a09e04f3b8e6e781cff09d670",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [09:42&lt;00:00,  5.36s/it]"
      }
     },
     "f0f02c851bf64590a9fca9db6fd3a019": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "f63f9ffabae842d0a16e5482ba568c44": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "fc23a79426b344c890f2a26ef3c9156d": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "fce8f8580a134fc9b24a60ef1a302f80": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
