{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4bcfc9fd-c1e1-4d83-9fd2-0949976723b7",
   "metadata": {
    "papermill": {
     "duration": 0.006828,
     "end_time": "2025-09-21T08:32:15.654709",
     "exception": false,
     "start_time": "2025-09-21T08:32:15.647881",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Conformal Robustness Control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef4ebb71-9251-4a54-b6eb-9b1c8de14b77",
   "metadata": {},
   "source": [
    "### CRC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027da74f-f82a-4655-bfbf-94fcbaa6fb97",
   "metadata": {
    "papermill": {
     "duration": 2.276159,
     "end_time": "2025-09-21T08:32:18.042628",
     "exception": false,
     "start_time": "2025-09-21T08:32:15.766469",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import io\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pytz\n",
    "import math\n",
    "import cvxpy as cp\n",
    "import rsome as rso\n",
    "\n",
    "from rsome import ro\n",
    "from rsome import msk_solver as SOLVER\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.distributions import MultivariateNormal\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "from typing import Protocol\n",
    "from typing import NamedTuple\n",
    "from torch import distributions as tdist, nn, Tensor\n",
    "from collections.abc import Mapping\n",
    "from torch import Tensor\n",
    "from collections.abc import Sequence\n",
    "from torch.utils.data import TensorDataset\n",
    "from collections.abc import Mapping, Sequence\n",
    "from pandas.tseries.holiday import USFederalHolidayCalendar\n",
    "from datetime import datetime, timedelta\n",
    "from tqdm.auto import tqdm\n",
    "from typing import Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f8ca6a-17af-49fc-b3ea-7d20e26605dc",
   "metadata": {
    "papermill": {
     "duration": 0.011301,
     "end_time": "2025-09-21T08:32:18.066855",
     "exception": false,
     "start_time": "2025-09-21T08:32:18.055554",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "seed = random.randint(0, 99)\n",
    "print(\"Device:\", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eefd3914-9881-4079-8e71-0dcfe7b19ade",
   "metadata": {
    "papermill": {
     "duration": 0.025035,
     "end_time": "2025-09-21T08:32:18.106734",
     "exception": false,
     "start_time": "2025-09-21T08:32:18.081699",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_test_scaler(train, val, test):\n",
    "    scaler = StandardScaler()\n",
    "    normalized_x_train = pd.DataFrame(scaler.fit_transform(train))\n",
    "    normalized_x_val = pd.DataFrame(scaler.transform(val))\n",
    "    normalized_x_test = pd.DataFrame(scaler.transform(test))\n",
    "    return normalized_x_train, normalized_x_val, normalized_x_test\n",
    "\n",
    "\n",
    "def reshuffle(train, val, test, perm):\n",
    "    a, b, c = len(train), len(val), len(test)\n",
    "    assert len(perm) == a + b + c\n",
    "    combined = pd.concat([train, val, test], axis='index')\n",
    "    train = combined.iloc[perm[:a]]\n",
    "    val = combined.iloc[perm[a:a+b]]\n",
    "    test = combined.iloc[perm[a+b:a+b+c]]\n",
    "    return train, val, test\n",
    "\n",
    "def portfolio_data_gen(year, seed, shuffled):\n",
    "    year_directory = f'yfinance/2012_samples'\n",
    "    returns_file = f'2012_returns_0.txt'\n",
    "    side_file = f'2012_data_side_0.txt'\n",
    "\n",
    "    returns_file_path = os.path.join(year_directory, returns_file)\n",
    "    side_file_path = os.path.join(year_directory, side_file)\n",
    "\n",
    "    with open(returns_file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "        returns_lists = [list( line.strip().split(',')) for line in lines]\n",
    "\n",
    "    with open(side_file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "        side_info_lists = [list( line.strip().split(',')) for line in lines]\n",
    "\n",
    "    returns_cols = returns_lists[0]\n",
    "    print('returns_cols:',returns_cols)\n",
    "    side_info_cols = side_info_lists[0]\n",
    "    print('side_info_cols:',side_info_cols)\n",
    "\n",
    "    returns = pd.read_csv('yfinance/expected_return.csv')\n",
    "    data_side = pd.read_csv('yfinance/side_info.csv')\n",
    "\n",
    "    def create_train_val_test(df, year, col_list):\n",
    "\n",
    "        df_sub = df[col_list].copy()\n",
    "        df_sub['DATE'] = pd.to_datetime(df_sub['DATE'])\n",
    "        df_sub['year'] = df_sub['DATE'].dt.year\n",
    "        df_sub['year'] = df_sub['year'].astype(int)\n",
    "\n",
    "        # Parse start year to integer\n",
    "        start_year = int(year)\n",
    "        train_end_year = start_year + 1\n",
    "        val_start_year = train_end_year + 1\n",
    "        val_end_year = val_start_year + 4\n",
    "        test_start_year= val_end_year + 1\n",
    "\n",
    "        df_train = df_sub[(df_sub.year >= start_year) & (df_sub.year <= train_end_year)].copy()\n",
    "        df_val = df_sub[(df_sub.year >= val_start_year) & (df_sub.year <= val_end_year)].copy()\n",
    "        df_test = df_sub[df_sub.year >= test_start_year].copy()\n",
    "\n",
    "        df_train = df_train.drop(['DATE', 'year'], axis=1)\n",
    "        df_val = df_val.drop(['DATE', 'year'], axis=1)\n",
    "        df_test = df_test.drop(['DATE', 'year'], axis=1)\n",
    "        return df_train, df_val, df_test\n",
    "\n",
    "\n",
    "    returns_cols += ['DATE']\n",
    "    side_info_cols += ['DATE']\n",
    "    returns_train, returns_val, returns_test = create_train_val_test(returns, year, returns_cols)\n",
    "    side_train, side_val, side_test = create_train_val_test(data_side, year, side_info_cols)\n",
    "\n",
    "    if shuffled:\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        perm = rng.permutation(len(returns_train) + len(returns_val) + len(returns_test))\n",
    "        returns_train, returns_val, returns_test = reshuffle(returns_train, returns_val, returns_test, perm)\n",
    "        side_train, side_val, side_test = reshuffle(side_train, side_val, side_test, perm)\n",
    "\n",
    "    returns_means = returns_train.mean()\n",
    "    returns_train.fillna(returns_means, inplace=True)\n",
    "    returns_val.fillna(returns_means, inplace=True)\n",
    "    returns_test.fillna(returns_means, inplace=True)\n",
    "    side_means = side_train.mean()\n",
    "    side_train.fillna(side_means, inplace=True)\n",
    "    side_val.fillna(side_means, inplace=True)\n",
    "    side_test.fillna(side_means, inplace=True)\n",
    "\n",
    "  \n",
    "    returns_train_s = torch.from_numpy(100*returns_train.values)\n",
    "    returns_val_s   = torch.from_numpy(100*returns_val.values)\n",
    "    returns_test_s  = torch.from_numpy(100*returns_test.values)\n",
    "\n",
    "    side_train_s, side_val_s, side_test_s = train_test_scaler(side_train, side_val, side_test)\n",
    "    side_train_s, side_val_s, side_test_s = torch.from_numpy(side_train_s.values), torch.from_numpy(side_val_s.values), torch.from_numpy(side_test_s.values)\n",
    "\n",
    "    return side_train_s, side_val_s, side_test_s, returns_train_s, returns_val_s, returns_test_s\n",
    "\n",
    "\n",
    "def get_loaders(\n",
    "    batch_size: int, year: int, seed: int, shuffled: bool\n",
    ") -> tuple[dict[str, DataLoader], tuple[np.ndarray, np.ndarray]]:\n",
    " \n",
    "    X_train, X_cal, X_test, y_train, y_cal, y_test = portfolio_data_gen(year, seed, shuffled)\n",
    "    X_train = X_train.to(torch.float32)\n",
    "    X_cal = X_cal.to(torch.float32)\n",
    "    X_test = X_test.to(torch.float32)\n",
    "    y_train = y_train.to(torch.float32)\n",
    "    y_cal = y_cal.to(torch.float32)\n",
    "    y_test = y_test.to(torch.float32)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        all_y = torch.cat([y_train, y_cal, y_test], dim=0)\n",
    "        y_mean = all_y.mean(dim=0)\n",
    "        y_std = all_y.std(dim=0)\n",
    "        y_info = (y_mean.numpy(), y_std.numpy())\n",
    "        \n",
    "        y_train = (y_train - y_mean) / y_std\n",
    "        y_cal = (y_cal - y_mean) / y_std\n",
    "        y_test = (y_test - y_mean) / y_std\n",
    "\n",
    "    train_loader = DataLoader(TensorDataset(X_train, y_train), shuffle=True, batch_size=batch_size)\n",
    "    calib_loader = DataLoader(TensorDataset(X_cal, y_cal), shuffle=False, batch_size=batch_size)\n",
    "    test_loader = DataLoader(TensorDataset(X_test, y_test), shuffle=False, batch_size=batch_size)\n",
    "    loaders_dict = {'train': train_loader, 'test': test_loader, 'calib': calib_loader}\n",
    "\n",
    "    return loaders_dict, y_info\n",
    "\n",
    "def get_data_and_loaders(\n",
    "    batch_size: int, year: int, seed: int, shuffled: bool\n",
    ") -> tuple[\n",
    "    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,\n",
    "    dict[str, DataLoader], tuple[np.ndarray, np.ndarray]\n",
    "]:\n",
    "   \n",
    "    X_train, X_cal, X_test, y_train, y_cal, y_test = portfolio_data_gen(year, seed, shuffled)\n",
    "\n",
    "    X_train = X_train.to(torch.float32)\n",
    "    X_cal   = X_cal.to(torch.float32)\n",
    "    X_test  = X_test.to(torch.float32)\n",
    "    y_train = y_train.to(torch.float32)\n",
    "    y_cal   = y_cal.to(torch.float32)\n",
    "    y_test  = y_test.to(torch.float32)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        all_y = torch.cat([y_train, y_cal, y_test], dim=0)\n",
    "        y_mean = all_y.mean(dim=0)\n",
    "        y_std  = all_y.std(dim=0)\n",
    "        y_info = (y_mean.numpy(), y_std.numpy())\n",
    "\n",
    "        y_train_1 = (y_train - y_mean) / y_std\n",
    "        y_cal_1  = (y_cal   - y_mean) / y_std\n",
    "        y_test_1  = (y_test  - y_mean) / y_std\n",
    "\n",
    "    # DataLoaders\n",
    "    train_loader = DataLoader(TensorDataset(X_train, y_train), shuffle=True,  batch_size=batch_size)\n",
    "    calib_loader = DataLoader(TensorDataset(X_cal,   y_cal),   shuffle=False, batch_size=batch_size)\n",
    "    test_loader  = DataLoader(TensorDataset(X_test,  y_test),  shuffle=False, batch_size=batch_size)\n",
    "    loaders_dict = {'train': train_loader, 'test': test_loader, 'calib': calib_loader}\n",
    "\n",
    "    return X_train, y_train, X_cal, y_cal, X_test, y_test, loaders_dict, y_info\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5c54e29-5e02-4c3a-94be-1735ac7c5f32",
   "metadata": {
    "papermill": {
     "duration": 0.09129,
     "end_time": "2025-09-21T08:32:18.205279",
     "exception": false,
     "start_time": "2025-09-21T08:32:18.113989",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "year = 2012        \n",
    "seed = seed          \n",
    "shuffled = True   \n",
    "\n",
    "X_train, y_train, X_cal, y_cal, X_test, y_test, loaders_dict, y_info = get_data_and_loaders(batch_size, year, seed, shuffled)\n",
    "\n",
    "print(X_train.shape, y_train.shape)  \n",
    "print(X_cal.shape,   y_cal.shape)    \n",
    "print(X_test.shape,  y_test.shape)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f4f3b8-2996-4058-80b6-da0727156fc8",
   "metadata": {
    "papermill": {
     "duration": 0.011831,
     "end_time": "2025-09-21T08:32:18.234541",
     "exception": false,
     "start_time": "2025-09-21T08:32:18.222710",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_simplex_layer(n: int) -> CvxpyLayer:\n",
    "    z = cp.Variable(n)\n",
    "    Pp = cp.Parameter((n, n), PSD=True)\n",
    "    q  = cp.Parameter(n)\n",
    "    constraints = [cp.sum(z) == 1, z >= 0, z <= 1]\n",
    "    obj = cp.Minimize(cp.norm(Pp @ z, 2) - q.T @ z)\n",
    "    problem = cp.Problem(obj, constraints)\n",
    "    assert problem.is_dpp('dcp')\n",
    "    return CvxpyLayer(problem, parameters=[Pp, q], variables=[z])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2405853-3b8c-42a2-a342-c70bd02d0663",
   "metadata": {
    "papermill": {
     "duration": 0.018702,
     "end_time": "2025-09-21T08:32:18.260662",
     "exception": false,
     "start_time": "2025-09-21T08:32:18.241960",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EllipsoidalUncertaintyModel(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, n=2, enable_implicit=False, dropout_rate = 0.5):\n",
    "        super().__init__()\n",
    "        self.n = n\n",
    "        self.enable_implicit = enable_implicit\n",
    "\n",
    "        if isinstance(hidden_dims, int):\n",
    "            hidden_dims = [hidden_dims] * 3\n",
    "        layers, in_dim = [], input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.ReLU(),nn.Dropout(dropout_rate)]\n",
    "            in_dim = h\n",
    "        self.backbone = nn.Sequential(*layers)\n",
    "\n",
    "        # outputs\n",
    "        self.fc_mu = nn.Linear(in_dim, n)\n",
    "        self.fc_L  = nn.Linear(in_dim, n*(n+1)//2)\n",
    "\n",
    "    def forward(self, x, cvxpylayer=None):\n",
    "        B = x.size(0)\n",
    "        h = self.backbone(x)\n",
    "        mu = self.fc_mu(h)\n",
    "\n",
    "        l_flat = self.fc_L(h)\n",
    "        L = x.new_zeros((B, self.n, self.n))\n",
    "        idx = 0\n",
    "        for i in range(self.n):\n",
    "            for j in range(i+1):\n",
    "                v = l_flat[:, idx]\n",
    "                L[:, i, j] = F.softplus(v) + (1e-3 if i==j else 0)\n",
    "                idx += 1\n",
    "        Sigma = L @ L.transpose(1, 2)\n",
    "\n",
    "        if not self.enable_implicit:\n",
    "            return mu, L, Sigma\n",
    "\n",
    "        if cvxpylayer is None:\n",
    "            raise ValueError(\"enable_implicit=False\")\n",
    "        Pp = torch.linalg.cholesky(Sigma + 1e-3 * torch.eye(self.n, device=x.device))\n",
    "        z_opt, = cvxpylayer(Pp.cpu(), mu.cpu())\n",
    "        return mu, Sigma, z_opt.to(x.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38c1277a-448c-46c4-a6bb-130c9a689bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EllipsoidalUncertaintyModel(\n",
    "    input_dim=21, hidden_dims=[20,20], n=15,\n",
    "    enable_implicit=False,\n",
    "    dropout_rate = 0.2\n",
    ")\n",
    "model.train()  \n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "for epoch in range(1,300 + 1):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    for x_cpu, y_cpu in loaders_dict['train']:\n",
    "        x = x_cpu.to(device, non_blocking=True)\n",
    "        y = y_cpu.to(device, non_blocking=True)\n",
    "\n",
    "        mu, _, Sigma = model(x)\n",
    "        Sigma = Sigma + 1e-3 * torch.eye(15, device=device)\n",
    "        dist = MultivariateNormal(mu, Sigma)\n",
    "        loss = -dist.log_prob(y).mean()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item() * x.size(0)\n",
    "\n",
    "    avg = total_loss / len(loaders_dict['train'].dataset)\n",
    "    print(f\"Epoch {epoch:>2d} | NLL Loss: {avg:.4f}\")\n",
    "\n",
    "os.makedirs('outputs/model/baseline', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'baseline', f'baseline_model.pth')\n",
    "torch.save(model.state_dict(), save_path)\n",
    "print(f\"Saved model to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e65366-7e58-42bf-b769-0ccd3c5258bf",
   "metadata": {
    "papermill": {
     "duration": 0.020973,
     "end_time": "2025-09-21T08:32:54.131048",
     "exception": false,
     "start_time": "2025-09-21T08:32:54.110075",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def update_theta(model, cvxpylayer, opt_h, x, y, lambda_param, alpha, sigma, eps = 1e-3):\n",
    "    mu, Sigma, z = model(x, cvxpylayer)\n",
    "    Sigma_j = Sigma + eps*torch.eye(model.n, device=device)\n",
    "\n",
    "    L = torch.linalg.cholesky(Sigma_j)\n",
    "    term1 = torch.norm((L @ z.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu*z).sum(dim=1)\n",
    "    f = (term1 - term2).mean()\n",
    "\n",
    "    term = term1 - term2\n",
    "    approx_ind = 0.5*(1 + torch.erf((term - (- y*z).sum(dim=1)) / (sigma*(2**0.5))))\n",
    "    g = lambda_param * ((1-alpha) - approx_ind.mean())\n",
    "\n",
    "    loss = f + g\n",
    "    opt_h.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_h.step()\n",
    "    return loss.item()\n",
    "\n",
    "def update_lambda(model, cvxpylayer, opt_l, x, y, lambda_param, alpha, sigma, eps = 1e-3):\n",
    "    mu, Sigma, _ = model(x, cvxpylayer)\n",
    "    mu_det     = mu.detach()\n",
    "    Sigma_det  = (Sigma + eps*torch.eye(model.n, device=device)).detach()\n",
    "\n",
    "    L_j = torch.linalg.cholesky(Sigma_det)\n",
    "    z2, = cvxpylayer(L_j.cpu(), mu_det.cpu())\n",
    "    z2 = z2.to(device)\n",
    "\n",
    "    term1 = torch.norm((L_j.to(device) @ z2.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu_det * z2).sum(dim=1)\n",
    "    f2 = (term1 - term2).mean()\n",
    "\n",
    "    term = term1 - term2\n",
    "    indicator = ((term - (- y*z2).sum(dim=1)) >= 0).float()\n",
    "    g2 = lambda_param * ((1-alpha) - indicator.mean())\n",
    "\n",
    "    loss = -(f2 + g2)\n",
    "    opt_l.zero_grad()\n",
    "    loss.backward()\n",
    "    opt_l.step()\n",
    "    with torch.no_grad():\n",
    "        lambda_param.clamp_(min=0.0)\n",
    "    return loss.item()\n",
    "\n",
    "def total_loss(model, cvxpylayer, x, y, lambda_param, alpha, sigma=0.1, eps = 1e-3):\n",
    "    mu, Sigma, z = model(x, cvxpylayer)\n",
    "    Sigma_j = Sigma + eps*torch.eye(model.n, device=device)\n",
    "    L = torch.linalg.cholesky(Sigma_j)\n",
    "    term1 = torch.norm((L @ z.unsqueeze(-1)).squeeze(-1), dim=1)\n",
    "    term2 = (mu*z).sum(dim=1)\n",
    "    f = (term1 - term2).mean()\n",
    "\n",
    "    indicator = ((term1-term2) - (- y*z).sum(dim=1) >= 0).float()\n",
    "    g = lambda_param * ((1-alpha) - indicator.mean())\n",
    "    return f + g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d215e79-cba1-4033-88b6-a5a71bde0c43",
   "metadata": {},
   "outputs": [],
   "source": [
    "total = []\n",
    "prev_loss = None\n",
    "n = 15\n",
    "consecutive_count = 0\n",
    "cvxpylayer = create_simplex_layer(n=n)\n",
    "\n",
    "model = EllipsoidalUncertaintyModel(\n",
    "    input_dim=21,      \n",
    "    hidden_dims=[20,20],   \n",
    "    n=n,\n",
    "    enable_implicit=True,\n",
    "    dropout_rate = 0.5\n",
    ").to(device)\n",
    "model_path = os.path.join('outputs', 'model', 'baseline', 'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "\n",
    "opt_h = torch.optim.Adam([\n",
    "    {'params': model.fc_mu.parameters(),    'lr':1e-3},\n",
    "    {'params': model.fc_L.parameters(),     'lr':5e-4}, \n",
    "    {'params': model.backbone.parameters(), 'lr':1e-3}, \n",
    "])\n",
    "\n",
    "lambda_param = nn.Parameter(torch.tensor(7.0, device=device), requires_grad=True)\n",
    "opt_l = torch.optim.Adam([lambda_param], lr=1e-3)\n",
    "\n",
    "N = X_cal.shape[0]\n",
    "n = int(N/2)\n",
    "X_cal_1 = X_cal[:n]\n",
    "Y_cal_1 = y_cal[:n]\n",
    "X_cal_2 = X_cal[n:]\n",
    "Y_cal_2 = y_cal[n:]\n",
    "\n",
    "alpha = 0.1\n",
    "sigma = 0.1\n",
    "epochs = 120\n",
    "\n",
    "for epoch in range(1, epochs+1):\n",
    "    l_t = update_theta(model, cvxpylayer, opt_h, X_cal_1, Y_cal_1, lambda_param, alpha, sigma)\n",
    "    l_l = update_lambda(model, cvxpylayer, opt_l, X_cal_2, Y_cal_2, lambda_param, alpha, sigma)\n",
    "    tot = total_loss(model, cvxpylayer, X_cal, y_cal, lambda_param, alpha, sigma)\n",
    "    total.append(tot.item())\n",
    "    print(f\"Epoch {epoch:02d} | loss_θ={l_t:.6f} | loss_λ={l_l:.6f} | total={tot:.6f} | λ={lambda_param.item():.4f}\")\n",
    " \n",
    "    if prev_loss is not None and abs(prev_loss - tot.item()) < 3e-4:\n",
    "        consecutive_count += 1\n",
    "    else:\n",
    "        consecutive_count = 0  \n",
    "    if consecutive_count >= 10:\n",
    "        print(f\"Early stopping at epoch {epoch} due to small loss change for 10 consecutive epochs.\")\n",
    "        break\n",
    "        \n",
    "    prev_loss = tot.item()\n",
    "\n",
    "os.makedirs('outputs/model/final', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'final', f'final_model.pth')\n",
    "torch.save(model.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656c0046-9a75-46ad-bb36-473203897413",
   "metadata": {
    "papermill": {
     "duration": 9.539646,
     "end_time": "2025-09-21T08:51:38.931698",
     "exception": false,
     "start_time": "2025-09-21T08:51:29.392052",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "model.enable_implicit = True\n",
    "\n",
    "X_test_gpu = X_test.to(device, non_blocking=True)\n",
    "Y_test_gpu = y_test.to(device, non_blocking=True)\n",
    "\n",
    "with torch.no_grad():\n",
    "    mu_test, Sigma_test, z_test = model(X_test_gpu, cvxpylayer)\n",
    "\n",
    "eps = 1e-3\n",
    "I = torch.eye(model.n, device=device)\n",
    "\n",
    "Sigma_j   = Sigma_test + eps * I.unsqueeze(0)   \n",
    "Sigma_inv = torch.linalg.inv(Sigma_j)  \n",
    "Sigma_half = torch.linalg.cholesky(Sigma_test + eps * I.unsqueeze(0)) \n",
    "\n",
    "delta = Y_test_gpu - mu_test                     # [B, n]\n",
    "d2    = torch.einsum('bi,bij,bj->b', delta, Sigma_inv, delta)\n",
    "in_set = d2 <= 1.0\n",
    "pct_in = in_set.float().mean().item() * 100\n",
    "\n",
    "z_col       = z_test.unsqueeze(-1)                      \n",
    "Sigma_z     = torch.matmul(Sigma_half, z_col)            \n",
    "norm_Sigma_z = torch.norm(Sigma_z, p=2, dim=(1,2))       \n",
    "\n",
    "mu_z = (mu_test * z_test).sum(dim=1)                     \n",
    "risk_loss    = norm_Sigma_z - mu_z                       \n",
    "Average_Risk = risk_loss.mean().item()\n",
    "print(f\"Average_Risk: {Average_Risk:.4f}\")\n",
    "\n",
    "decision_loss = (- Y_test_gpu * z_test).sum(dim=1)    \n",
    "Average_Loss  = decision_loss.mean().item()\n",
    "robustness = (decision_loss <= risk_loss).float().mean().item() * 100\n",
    "\n",
    "print(f\"Marginal Coverage: {pct_in:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk:.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss:.4f}\")\n",
    "print(f\"Robustness: {robustness:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8a91ac-516d-4e03-8645-391855da1fd7",
   "metadata": {
    "papermill": {
     "duration": 0.068377,
     "end_time": "2025-09-21T08:51:39.313780",
     "exception": false,
     "start_time": "2025-09-21T08:51:39.245403",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_CRC'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_CRC.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {pct_in:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1dbc539-58c9-4d00-93ae-fd45da653c0e",
   "metadata": {
    "papermill": {
     "duration": 0.067351,
     "end_time": "2025-09-21T08:51:39.587917",
     "exception": false,
     "start_time": "2025-09-21T08:51:39.520566",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### CRO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbf9392-0f04-4e8b-8d9c-485ae46d2a9a",
   "metadata": {
    "papermill": {
     "duration": 0.104721,
     "end_time": "2025-09-21T08:51:39.835477",
     "exception": false,
     "start_time": "2025-09-21T08:51:39.730756",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "model = EllipsoidalUncertaintyModel(input_dim=21, hidden_dims=[20,20], n=15, enable_implicit=False).to(device)           \n",
    "model_path = os.path.join('outputs', 'model', 'baseline', f'baseline_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b72b22-0d43-4160-a425-df5da0ebc302",
   "metadata": {
    "papermill": {
     "duration": 0.173854,
     "end_time": "2025-09-21T08:51:40.385513",
     "exception": false,
     "start_time": "2025-09-21T08:51:40.211659",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "X_cal_cpu  = X_cal.cpu()\n",
    "Y_cal_cpu  = y_cal.cpu()\n",
    "X_test_cpu = X_test.cpu()\n",
    "Y_test_cpu = y_test.cpu()\n",
    "\n",
    "eps = 1e-3\n",
    "Y_cal_pred, _, cal_Sigma = model(X_cal_cpu)\n",
    "Sigma_inv_cal = torch.linalg.inv(cal_Sigma + eps * torch.eye(cal_Sigma.size(-1)))\n",
    "cal_diff = (Y_cal_cpu - Y_cal_pred)\n",
    "cal_score = torch.einsum('bi,bij,bj->b', cal_diff, Sigma_inv_cal, cal_diff)\n",
    "S_cal = torch.sqrt(cal_score.clamp(min=1e-12))\n",
    "\n",
    "Y_test_pred, _, test_Sigma = model(X_test_cpu)\n",
    "Sigma_inv_test = torch.linalg.inv(test_Sigma + eps * torch.eye(test_Sigma.size(-1)))\n",
    "test_diff = (Y_test_cpu - Y_test_pred)\n",
    "test_score = torch.einsum('bi,bij,bj->b', test_diff, Sigma_inv_test, test_diff)\n",
    "S_test = torch.sqrt(test_score.clamp(min=1e-12))\n",
    "\n",
    "marginal_quantile = torch.quantile(S_cal, 0.9).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8b71b65-7054-4ce4-8a35-1a05e4c94374",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta = 1e-2 \n",
    "tol = 1e-4\n",
    "n = 15\n",
    "T = 200\n",
    "\n",
    "z_value = []\n",
    "loss_value = []\n",
    "n_test = Y_test_pred.shape[0]\n",
    "\n",
    "for i in range(n_test):\n",
    "    Sigma_inv = torch.inverse(test_Sigma)                 \n",
    "    L_torch   = torch.linalg.cholesky(Sigma_inv)               \n",
    "    L = L_torch.detach().cpu().numpy()   \n",
    "    conformal_quantile = float(np.sqrt(marginal_quantile))\n",
    "\n",
    "    model = ro.Model()\n",
    "    y     = model.rvar(n)   \n",
    "    z_d   = model.dvar(n)   \n",
    "\n",
    "    y_hat = Y_test_pred[i].detach().cpu().numpy()   \n",
    "    uset   = rso.norm(L[i] @ (y - y_hat), 2) <= conformal_quantile\n",
    "    model.minmax(- y @ z_d, uset)     \n",
    "    model.st(z_d <= 1)\n",
    "    model.st(z_d >= 0)\n",
    "    model.st(z_d.sum() == 1)\n",
    "\n",
    "    model.solve(SOLVER, display=False)\n",
    "\n",
    "    z_opt = z_d.get()               \n",
    "    z_value.append(z_opt)\n",
    "\n",
    "    print(f\"Test_sample: {i}, Completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c88e60-9169-444f-a4ba-df389bc3a59c",
   "metadata": {
    "papermill": {
     "duration": 0.221028,
     "end_time": "2025-09-21T08:53:01.302318",
     "exception": false,
     "start_time": "2025-09-21T08:53:01.081290",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "c_count1 = 0\n",
    "for i in range(y_test.shape[0]):\n",
    "    if S_test[i] <= marginal_quantile:\n",
    "        c_count1 += 1\n",
    "        \n",
    "Sigma_half = torch.linalg.cholesky(test_Sigma)  \n",
    "z_value = torch.tensor(z_value, dtype=torch.float32)  \n",
    "z_col = z_value.unsqueeze(-1)               \n",
    "Sigma_z = torch.matmul(Sigma_half, z_col)   \n",
    "\n",
    "norm_Sigma_z = torch.norm(Sigma_z, p=2, dim=(1,2)) \n",
    "mu_z = (Y_test_pred * z_value).sum(dim=1)        \n",
    "\n",
    "risk_loss_1    = marginal_quantile * norm_Sigma_z - mu_z                \n",
    "Average_Risk_1 = risk_loss_1.mean()                \n",
    "\n",
    "decision_loss_1 = (- y_test * z_value).sum(dim=1)    \n",
    "Average_Loss_1  = decision_loss_1.mean()   \n",
    "\n",
    "r_count1 = 0\n",
    "num = y_test.shape[0]\n",
    "for i in range(num):\n",
    "    if decision_loss_1[i] <= risk_loss_1[i]:\n",
    "        r_count1 += 1\n",
    "\n",
    "print(f\"Marginal_Coverage: {(c_count1/y_test.shape[0])*100:.2f}%\")\n",
    "print(f\"Average_Risk: {Average_Risk_1.item():.4f}\")\n",
    "print(f\"Average_Loss: {Average_Loss_1.item():.4f}\")\n",
    "print(f\"Robustness: {(r_count1/num)*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0049060d-384d-4472-a637-c40eb70de705",
   "metadata": {
    "papermill": {
     "duration": 0.236455,
     "end_time": "2025-09-21T08:53:01.776454",
     "exception": false,
     "start_time": "2025-09-21T08:53:01.539999",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_Two_step'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_Two_step.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {(c_count1/y_test.shape[0])*100:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {Average_Risk_1:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {Average_Loss_1:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {(r_count1/num)*100:.2f}%\\n\")\n",
    "print(f\"✅ Saved metrics to {txt_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914a16d2-451b-4d85-b514-253a8337405e",
   "metadata": {
    "papermill": {
     "duration": 0.233557,
     "end_time": "2025-09-21T08:53:02.213745",
     "exception": false,
     "start_time": "2025-09-21T08:53:01.980188",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### End-to-end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f6877a-f96b-4b3c-b651-bfde3755dd9a",
   "metadata": {
    "papermill": {
     "duration": 0.16578,
     "end_time": "2025-09-21T08:53:02.812725",
     "exception": false,
     "start_time": "2025-09-21T08:53:02.646945",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class TransformDataset(Dataset):\n",
    "    def __init__(self, X, Y):\n",
    "        self.X = torch.tensor(X, dtype=torch.float32)  \n",
    "        self.Y = torch.tensor(Y, dtype=torch.float32) \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.Y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415765b2-3ad4-4f2f-87a6-bed381ddf9df",
   "metadata": {
    "papermill": {
     "duration": 0.173784,
     "end_time": "2025-09-21T08:53:03.168310",
     "exception": false,
     "start_time": "2025-09-21T08:53:02.994526",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_dataset = TransformDataset(X_train, y_train)\n",
    "cal_dataset = TransformDataset(X_cal, y_cal)\n",
    "test_dataset = TransformDataset(X_test, y_test)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "cal_loader = DataLoader(cal_dataset, batch_size=256, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)\n",
    "\n",
    "all_Y = np.concatenate([y_train, y_cal, y_test], axis=0)\n",
    "y_mean = all_Y.mean(axis=0)  \n",
    "y_std = all_Y.std(axis=0)   \n",
    "y_mean = np.array(y_mean)\n",
    "y_std = np.array(y_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17d18a3-b2dd-4653-8b02-f017508eb1a5",
   "metadata": {
    "papermill": {
     "duration": 0.082188,
     "end_time": "2025-09-21T08:53:03.431820",
     "exception": false,
     "start_time": "2025-09-21T08:53:03.349632",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BaseProblemProtocol(Protocol):\n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    vars: dict[str, cp.Variable]\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vars = self.primal_vars | self.dual_vars\n",
    "        self.constraints.extend(self.dual_constraints)\n",
    "        obj = self.dual_obj\n",
    "\n",
    "        prob = cp.Problem(cp.Minimize(obj), self.constraints)\n",
    "        assert prob.is_dpp()\n",
    "        self.prob = prob\n",
    "\n",
    "    def task_loss_np(self, y: np.ndarray, is_standardized: bool) -> float:\n",
    "        ...\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, is_standardized: bool, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        ...\n",
    "\n",
    "    def get_cvxpylayer(self) -> CvxpyLayer:\n",
    "        return CvxpyLayer(\n",
    "            self.prob,\n",
    "            parameters=list(self.params.values()),\n",
    "            variables=list(self.vars.values())\n",
    "        )\n",
    "\n",
    "class EllipsoidProblemProtocol(BaseProblemProtocol, Protocol):\n",
    "    def solve(self, loc: np.ndarray, scale_tril: np.ndarray) -> cp.Problem:\n",
    "        ...\n",
    "\n",
    "class EllipsoidProblem:\n",
    "    dual_constraints: list[cp.Constraint]\n",
    "    dual_obj: cp.Expression\n",
    "    dual_vars: dict[str, cp.Variable]\n",
    "    params: dict[str, cp.Parameter]\n",
    "\n",
    "    prob: cp.Problem\n",
    "\n",
    "    def __init__(\n",
    "        self, y_dim: int, y_mean: np.ndarray, y_std: np.ndarray, Fz: cp.Expression\n",
    "    ):\n",
    "        self.y_mean = y_mean\n",
    "        self.y_std = y_std\n",
    "\n",
    "        # parameters\n",
    "        loc = cp.Parameter(y_dim, name='loc')\n",
    "        scale_tril = cp.Parameter((y_dim, y_dim), name='scale_tril')\n",
    "        self.params = {\n",
    "            'loc': loc,\n",
    "            'scale_tril': scale_tril,\n",
    "        }\n",
    "\n",
    "        Fz_ystd = Fz\n",
    "\n",
    "        # objective\n",
    "        self.dual_obj = (\n",
    "            cp.norm(scale_tril @ Fz_ystd)\n",
    "            - loc @ Fz_ystd)\n",
    "\n",
    "        self.dual_vars = {}\n",
    "        self.dual_constraints = []\n",
    "\n",
    "    def solve(self, loc: np.ndarray, scale_tril: np.ndarray) -> cp.Problem:\n",
    "       \n",
    "        self.params['loc'].value = loc\n",
    "        self.params['scale_tril'].value = scale_tril\n",
    "        prob = self.prob\n",
    "        prob.solve()\n",
    "        if prob.status != 'optimal':\n",
    "            print('Problem status:', prob.status)\n",
    "        return prob\n",
    "\n",
    "class PortfolioProblemBase:\n",
    "   \n",
    "    constraints: list[cp.Constraint]\n",
    "    f_tilde: cp.Expression | float\n",
    "    Fz: cp.Expression\n",
    "    primal_vars: dict[str, cp.Variable]\n",
    "\n",
    "    prob: cp.Problem\n",
    "    y_mean: np.ndarray\n",
    "    y_std: np.ndarray\n",
    "\n",
    "    def __init__(self, N: int):\n",
    "       \n",
    "        self.N = N\n",
    "\n",
    "        z = cp.Variable(N, name='z', nonneg=True)\n",
    "        self.primal_vars = {'z': z}\n",
    "\n",
    "        self.constraints = [\n",
    "            cp.sum(z) == 1,\n",
    "        ]\n",
    "\n",
    "        self.Fz = z\n",
    "        self.f_tilde = 0.\n",
    "\n",
    "    def task_loss_np(self, scale_tril: np.ndarray, mu: np.ndarray) -> float:\n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        z = self.primal_vars['z'].value\n",
    "        #task_loss = y @ z\n",
    "        task_loss =  np.linalg.norm(scale_tril @ z, 2) - np.dot(mu, z)\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_np_1(self, y: np.ndarray) -> float:\n",
    "        assert self.prob.value is not None, 'Problem must be solved first'\n",
    "        if is_standardized:\n",
    "            y = y * self.y_std + self.y_mean\n",
    "        z = self.primal_vars['z'].value\n",
    "        task_loss = - y @ z\n",
    "        #task_loss = task_loss = float((y * z).sum())\n",
    "        return task_loss\n",
    "\n",
    "    def task_loss_torch(\n",
    "        self, y: Tensor, scale_tril: np.ndarray, mu: np.ndarray, solution: Sequence[Tensor]\n",
    "    ) -> Tensor:\n",
    "        z = solution[0]\n",
    "        assert y.shape == z.shape\n",
    "        task_loss = torch.norm((scale_tril @ z.unsqueeze(-1)).squeeze(-1),dim=1)  - (mu * z).sum(dim=-1)\n",
    "        return task_loss\n",
    "\n",
    "class PortfolioProblemEllipsoid(PortfolioProblemBase, EllipsoidProblem, EllipsoidProblemProtocol):\n",
    "    def __init__(self, N: int, y_mean: np.ndarray, y_std: np.ndarray):\n",
    "        PortfolioProblemBase.__init__(self, N=N)\n",
    "        EllipsoidProblem.__init__(self, y_dim=N, y_mean=y_mean, y_std=y_std, Fz=self.Fz)\n",
    "        EllipsoidProblemProtocol.__init__(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c932ba85-85b0-4d7c-92e9-e50d023aa287",
   "metadata": {
    "papermill": {
     "duration": 0.047998,
     "end_time": "2025-09-21T08:53:03.650765",
     "exception": false,
     "start_time": "2025-09-21T08:53:03.602767",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def mahalanobis_dist2(loc: Tensor, scale_tril: Tensor, y: Tensor) -> Tensor:\n",
    "    return tdist.multivariate_normal._batch_mahalanobis(bL=scale_tril, bx=y - loc)\n",
    "\n",
    "def calc_q(scores: Tensor, alpha: float) -> Tensor:\n",
    "    n = len(scores)\n",
    "    j = int(np.ceil((n+1) * (1-alpha)))\n",
    "    if j > n:\n",
    "        return torch.tensor(torch.inf)\n",
    "    sorted_inds = torch.argsort(scores)\n",
    "    q = scores[sorted_inds[j-1]]\n",
    "    return q\n",
    "\n",
    "def conformal_q(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    alpha: float,\n",
    "    device: str = 'cpu'\n",
    ") -> Tensor:\n",
    "   \n",
    "    model.to(device)\n",
    "    all_scores = []\n",
    "    for x, y in loader:\n",
    "        x = x.to(device, non_blocking=True)\n",
    "        y = y.to(device, non_blocking=True)\n",
    "        loc, scale_tril, _ = model(x)\n",
    "        scores = mahalanobis_dist2(loc, scale_tril, y)\n",
    "        all_scores.append(scores)\n",
    "\n",
    "    scores = torch.cat(all_scores)\n",
    "    q = calc_q(scores, alpha)\n",
    "    assert q != torch.inf, 'Size of calibration set is too small'\n",
    "    return q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43924fa-4568-449d-b0ae-1210f3972f2c",
   "metadata": {
    "papermill": {
     "duration": 0.12181,
     "end_time": "2025-09-21T08:53:04.040068",
     "exception": false,
     "start_time": "2025-09-21T08:53:03.918258",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_e2e_epoch(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    alpha: float,\n",
    "    nll_loss_frac: float,\n",
    "    rng: np.random.Generator,\n",
    "    optimizer: torch.optim.Optimizer,\n",
    "    show_pbar: bool = False\n",
    ") -> tuple[float, float, float]:\n",
    "   \n",
    "    assert 0 < alpha < 1\n",
    "    model.train()\n",
    "\n",
    "    cvxpylayer = prob.get_cvxpylayer()\n",
    "\n",
    "    total_loss = 0.\n",
    "    total_nll_loss = 0.\n",
    "    total_task_loss = 0.\n",
    "    total_num_tasks = 0\n",
    "    for x, y in tqdm(loader) if show_pbar else loader:\n",
    "        batch_size = x.shape[0]\n",
    "        loc, scale_tril, _ = model(x)\n",
    "\n",
    "        loss = torch.tensor(0.)\n",
    "        msgs = []\n",
    "\n",
    "        if nll_loss_frac > 0:\n",
    "            nll_loss = -tdist.MultivariateNormal(loc, scale_tril=scale_tril).log_prob(y).mean()\n",
    "            loss += nll_loss_frac * nll_loss\n",
    "            total_nll_loss += nll_loss.item() * batch_size\n",
    "            msgs.append(f'nll_loss: {nll_loss.item()}')\n",
    "\n",
    "        if nll_loss_frac < 1:\n",
    "            perm = rng.permutation(batch_size)\n",
    "\n",
    "\n",
    "            cal_inds = perm[:batch_size//2]\n",
    "            scores_cal = mahalanobis_dist2(loc[cal_inds], scale_tril[cal_inds], y[cal_inds])\n",
    "            q = calc_q(scores_cal, alpha)\n",
    "            if q == torch.inf:\n",
    "                tqdm.write('Batch is too small, skipping')\n",
    "                continue\n",
    "\n",
    "            task_inds = perm[batch_size//2:]\n",
    "            y_task = y[task_inds]\n",
    "            loc_task = loc[task_inds]\n",
    "            scale_tril_task = scale_tril[task_inds] * torch.sqrt(q)\n",
    "\n",
    "            try:\n",
    "\n",
    "                solution = cvxpylayer(loc_task, scale_tril_task)\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                raise e\n",
    "                \n",
    "            task_loss = prob.task_loss_torch(y_task, scale_tril_task, loc_task, solution=solution).mean()\n",
    "            loss += (1 - nll_loss_frac) * task_loss\n",
    "\n",
    "            total_task_loss += task_loss.item() * task_inds.shape[0]\n",
    "            total_num_tasks += task_inds.shape[0]\n",
    "\n",
    "            msgs.append(f'task_loss: {task_loss.item()}')\n",
    "\n",
    "        if show_pbar:\n",
    "            tqdm.write(','.join(msgs))\n",
    "        total_loss += loss.item() * batch_size\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    avg_loss = total_loss / len(loader.dataset)\n",
    "    avg_nll_loss = total_nll_loss / len(loader.dataset)\n",
    "    avg_task_loss = total_task_loss / total_num_tasks\n",
    "    return avg_loss, avg_nll_loss, avg_task_loss\n",
    "\n",
    "\n",
    "def train_e2e(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    alpha: float,\n",
    "    max_epochs: int,\n",
    "    lr: float,\n",
    "    l2reg: float,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    rng: np.random.Generator,\n",
    "    nll_loss_frac: float | Sequence[float],\n",
    "    saved_model_path: str = ''\n",
    ") -> dict[str, Any]:\n",
    "   \n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2reg)\n",
    "    if saved_model_path != '':\n",
    "        tqdm.write(f'Loading saved model: {saved_model_path}')\n",
    "        model.load_state_dict(torch.load(saved_model_path, weights_only=True))\n",
    "\n",
    "    result: dict[str, Any] = {\n",
    "        'train_e2e_losses': [],\n",
    "        'train_nll_losses': [],\n",
    "        'train_task_losses': [],\n",
    "        'val_task_losses': [],\n",
    "        'best_epoch': 0,\n",
    "        'val_task_loss': np.inf,  # best loss on val set\n",
    "    }\n",
    "    steps_since_decrease = 0\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    pbar = tqdm(range(max_epochs))\n",
    "    for epoch in pbar:\n",
    "        if isinstance(nll_loss_frac, Sequence):\n",
    "            weight = nll_loss_frac[min(epoch, len(nll_loss_frac) - 1)]\n",
    "        else:\n",
    "            weight = nll_loss_frac\n",
    "\n",
    "        train_e2e_loss, train_nll_loss, train_task_loss = train_e2e_epoch(\n",
    "            model, prob=prob, loader=train_loader, alpha=alpha,\n",
    "            nll_loss_frac=weight, rng=rng, optimizer=optimizer)\n",
    "        result['train_e2e_losses'].append(train_e2e_loss)\n",
    "        result['train_nll_losses'].append(train_nll_loss)\n",
    "        result['train_task_losses'].append(train_task_loss)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            q = conformal_q(model, cal_loader, alpha=alpha).item()\n",
    "            val_task_loss = optimize(model, prob=prob, loader=cal_loader, q=q)\n",
    "            result['val_task_losses'].append(val_task_loss)\n",
    "\n",
    "        msg = (f'Epoch {epoch}, train_task_loss {train_task_loss:.3f}, '\n",
    "               f'val_task_loss {val_task_loss:.3f}, q {q:.3f}')\n",
    "        pbar.set_description(msg)\n",
    "\n",
    "        steps_since_decrease += 1\n",
    "\n",
    "        if val_task_loss < result['val_task_loss']:\n",
    "            result['best_epoch'] = epoch\n",
    "            result['val_task_loss'] = val_task_loss\n",
    "            steps_since_decrease = 0\n",
    "            buffer.seek(0)\n",
    "            torch.save(model.state_dict(), buffer)\n",
    "\n",
    "        if steps_since_decrease > 10:\n",
    "            break\n",
    "\n",
    "    buffer.seek(0)\n",
    "    model.load_state_dict(torch.load(buffer, weights_only=True))\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6b2779-1707-450b-915c-02d854be30bf",
   "metadata": {
    "papermill": {
     "duration": 0.078854,
     "end_time": "2025-09-21T08:53:04.544020",
     "exception": false,
     "start_time": "2025-09-21T08:53:04.465166",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def optimize(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    q: float,\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "\n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    task_losses = []\n",
    "    if show_pbar:\n",
    "        pbar = tqdm(total=len(loader.dataset))\n",
    "    for x_batch, y_batch in loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "        pred = model(x_batch)\n",
    "        loc_batch = pred[0].detach().cpu().numpy()\n",
    "        scale_tril_batch = pred[1].detach().cpu().numpy()\n",
    "        y_batch = y_batch.detach().cpu().numpy()\n",
    "\n",
    "        scale_tril_batch *= np.sqrt(q)\n",
    "\n",
    "        for y, loc, scale_tril in zip(y_batch, loc_batch, scale_tril_batch):\n",
    "            prob.solve(loc, scale_tril)\n",
    "            task_loss = prob.task_loss_np(scale_tril, loc)\n",
    "            task_losses.append(task_loss)\n",
    "\n",
    "            if show_pbar:\n",
    "                pbar.update(1)\n",
    "\n",
    "    return np.mean(task_losses).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d57ea1-e1ce-439a-a0ad-6fe321d49b9d",
   "metadata": {
    "papermill": {
     "duration": 0.093328,
     "end_time": "2025-09-21T08:53:04.748332",
     "exception": false,
     "start_time": "2025-09-21T08:53:04.655004",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy, math, itertools, torch, numpy as np\n",
    "\n",
    "alpha = 0.1\n",
    "max_epochs = 100\n",
    "nll_loss_frac = 0.0  \n",
    "lr_list    = [5e-4]\n",
    "l2_list    = [1e-4]\n",
    "\n",
    "best = {\n",
    "    \"val_task_loss\": float(\"inf\"),\n",
    "    \"lr\": None,\n",
    "    \"l2reg\": None,\n",
    "    \"epoch\": None,\n",
    "    \"state_dict\": None,\n",
    "}\n",
    "all_results = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1cc17dc-9129-408d-b422-f74fa4c53d4c",
   "metadata": {
    "papermill": {
     "duration": 3150.366946,
     "end_time": "2025-09-21T09:45:35.488024",
     "exception": false,
     "start_time": "2025-09-21T08:53:05.121078",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for lr in lr_list:\n",
    "    for l2 in l2_list:\n",
    "        model = EllipsoidalUncertaintyModel(input_dim=21, hidden_dims = [20, 20], n=15, enable_implicit=False)  \n",
    "        model_path = os.path.join('outputs', 'model', 'baseline', f'baseline_model.pth')\n",
    "        model.load_state_dict(torch.load(model_path, map_location='cpu'))\n",
    "        prob = PortfolioProblemEllipsoid(N=15, y_mean=y_mean, y_std=y_std)\n",
    "\n",
    "        rng = np.random.default_rng(42)\n",
    "        print(f\"lr: {lr } | l2={l2}\")\n",
    "        \n",
    "        res = train_e2e(\n",
    "            model=model,\n",
    "            alpha=alpha,\n",
    "            max_epochs=max_epochs,\n",
    "            lr=lr,\n",
    "            l2reg=l2,\n",
    "            prob=prob,\n",
    "            rng=rng,\n",
    "            nll_loss_frac=nll_loss_frac,\n",
    "            saved_model_path=''\n",
    "        )\n",
    "\n",
    "        record = {\n",
    "            \"lr\": lr,\n",
    "            \"l2reg\": l2,\n",
    "            \"best_epoch\": res[\"best_epoch\"],\n",
    "            \"val_task_loss\": res[\"val_task_loss\"],\n",
    "        }\n",
    "        all_results.append(record)\n",
    "\n",
    "        if res[\"val_task_loss\"] > 0 and res[\"val_task_loss\"] < best[\"val_task_loss\"] :\n",
    "            best.update(record)\n",
    "            best[\"state_dict\"] = copy.deepcopy(model.state_dict())\n",
    "\n",
    "print(\"\\n=== Grid Search Summary ===\")\n",
    "all_results = sorted(all_results, key=lambda r: r[\"val_task_loss\"])\n",
    "for r in all_results:\n",
    "    print(f\"lr={r['lr']:g}, l2={r['l2reg']:g}, \"\n",
    "          f\"best_epoch={r['best_epoch']}, val_task_loss={r['val_task_loss']:.6f}\")\n",
    "\n",
    "print(\"\\n>>> Best config:\")\n",
    "print(f\"lr={best['lr']:g}, l2={best['l2reg']:g}, \"\n",
    "      f\"best_epoch={best['best_epoch']}, val_task_loss={best['val_task_loss']:.6f}\")\n",
    "\n",
    "best_model = EllipsoidalUncertaintyModel(input_dim=21, hidden_dims = [20, 20], n=15, enable_implicit=False) \n",
    "best_model.load_state_dict(best[\"state_dict\"])\n",
    "\n",
    "os.makedirs('outputs/model/E2E', exist_ok=True)\n",
    "save_path = os.path.join('outputs', 'model', 'E2E', f'E2E_model.pth')\n",
    "torch.save(best_model.state_dict(), save_path)\n",
    "print(\"Saved best model to best_model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b183982-1989-439a-b017-8e891595ce33",
   "metadata": {
    "papermill": {
     "duration": 0.036607,
     "end_time": "2025-09-21T09:45:35.566561",
     "exception": false,
     "start_time": "2025-09-21T09:45:35.529954",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def optimize_test(\n",
    "    model: EllipsoidalUncertaintyModel,\n",
    "    prob: EllipsoidProblemProtocol,\n",
    "    loader: torch.utils.data.DataLoader,\n",
    "    q: float,\n",
    "    device: str = 'cpu',\n",
    "    show_pbar: bool = False\n",
    ") -> float:\n",
    "    \n",
    "    model.eval()\n",
    "    model = model.to(device)\n",
    "\n",
    "    task_losses = []\n",
    "    expr_losses = []\n",
    "    z_value = []       \n",
    "    count_smaller = 0  \n",
    "    total_points = 0  \n",
    "    count_within_q = 0\n",
    "    \n",
    "    if show_pbar:\n",
    "        pbar = tqdm(total=len(loader.dataset))\n",
    "    for x_batch, y_batch in loader:\n",
    "        x_batch = x_batch.to(device, non_blocking=True)\n",
    "        pred = model(x_batch)\n",
    "        loc_batch = pred[0].detach().cpu().numpy()\n",
    "        scale_tril_batch = pred[1].detach().cpu().numpy()\n",
    "        y_batch = y_batch.detach().cpu().numpy()\n",
    "\n",
    "        scale_tril_batch *= np.sqrt(q)\n",
    "\n",
    "        for y, loc, scale_tril in zip(y_batch, loc_batch, scale_tril_batch):\n",
    "            prob.solve(loc, scale_tril)\n",
    "            task_loss = prob.task_loss_np_1(y, is_standardized=False)\n",
    "            task_losses.append(task_loss)\n",
    "\n",
    "            Fz_val = np.asarray(prob.Fz.value).reshape(-1)\n",
    "            z_value.append(Fz_val)\n",
    "            Fz_ystd = Fz_val\n",
    "            scale_term = np.linalg.norm(scale_tril @ Fz_ystd)\n",
    "            mean_term  = loc @ Fz_ystd\n",
    "            expr_loss = scale_term - mean_term\n",
    "            expr_losses.append(expr_loss)\n",
    "\n",
    "            if task_loss < expr_loss:\n",
    "                count_smaller += 1\n",
    "\n",
    "            scale_tril_1 = scale_tril / np.sqrt(q)\n",
    "            Sigma_inv = np.linalg.inv(scale_tril_1 @ scale_tril_1.T) \n",
    "            diff = y - loc\n",
    "            mahalanobis_sq = diff.T @ Sigma_inv @ diff\n",
    "\n",
    "            if mahalanobis_sq <= q:\n",
    "                count_within_q += 1\n",
    "\n",
    "            total_points += 1\n",
    "\n",
    "\n",
    "            if show_pbar:\n",
    "                pbar.update(1)\n",
    "                \n",
    "    avg_loss = np.mean(task_losses)\n",
    "    avg_risk = np.mean(expr_losses)\n",
    "    robustness = count_smaller / total_points * 100\n",
    "    coverage = count_within_q / total_points * 100\n",
    "\n",
    "    return coverage, avg_risk, avg_loss, robustness, z_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac47837-086d-49f7-b370-cc4fb586c70b",
   "metadata": {
    "papermill": {
     "duration": 1.624377,
     "end_time": "2025-09-21T09:45:37.224894",
     "exception": false,
     "start_time": "2025-09-21T09:45:35.600517",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_path = os.path.join('outputs', 'model', 'E2E', f'E2E_model.pth')\n",
    "model.load_state_dict(torch.load(model_path, weights_only=True))\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    q = conformal_q(model, cal_loader, alpha=alpha).item()\n",
    "coverage, avg_risk, avg_loss, robustness_1, z_value = optimize_test(model, prob, test_loader, q=q)\n",
    "\n",
    "print(f\"Coverage: {coverage:.2f}%\")\n",
    "print(f\"Ave_task_loss: {avg_risk:.4f}\")\n",
    "print(f\"VaR: {var:.4f}\")\n",
    "print(f\"Ave_loss: {avg_loss:.4f}\")\n",
    "print(f\"Robustness: {robustness_1:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dae6393-43b6-4df2-81fc-f185f1b65ae4",
   "metadata": {
    "papermill": {
     "duration": 0.034883,
     "end_time": "2025-09-21T09:45:37.299239",
     "exception": false,
     "start_time": "2025-09-21T09:45:37.264356",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir = 'outputs/results/data_E2E'\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "txt_path = os.path.join(data_dir, f'metrics_E2E.txt')\n",
    "with open(txt_path, 'w') as f:\n",
    "    f.write(f\"Marginal Coverage: {coverage:.2f}%\\n\")\n",
    "    f.write(f\"Average Risk     : {avg_risk:.4f}\\n\")\n",
    "    f.write(f\"Average Loss     : {avg_loss:.4f}\\n\")\n",
    "    f.write(f\"Robustness       : {robustness_1:.2f}%\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6452e5-6bb3-44f5-9a0e-a1b23ff8a1b6",
   "metadata": {
    "papermill": {
     "duration": 0.02936,
     "end_time": "2025-09-21T09:45:38.291321",
     "exception": false,
     "start_time": "2025-09-21T09:45:38.261961",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 4404.732068,
   "end_time": "2025-09-21T09:45:39.646450",
   "environment_variables": {},
   "exception": null,
   "input_path": "CRC_GPU_stock.ipynb",
   "output_path": "outputs/file/CRC_GPU_stock0.ipynb",
   "parameters": {
    "run_id": 0
   },
   "start_time": "2025-09-21T08:32:14.914382",
   "version": "2.6.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "00c3331c1d88409e8eb58ae52f10af11": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "06e323c3b4a248559c4ca00d59acfd2f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "2374ffa9cce2450e8e142badc3731b54": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "4d643c0074f64a9ba6096e3bc29dbf42": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "4dcc2fc27bb74c0d971bbb1a68d6dcd1": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "58dda13bec1740dc9a8d1721fa4a893f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "a528aeaad665457796224b1c3054591c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_00c3331c1d88409e8eb58ae52f10af11",
       "placeholder": "​",
       "style": "IPY_MODEL_06e323c3b4a248559c4ca00d59acfd2f",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [52:30&lt;00:00, 22.66s/it]"
      }
     },
     "d9de77425abc45e2867e002c0807488b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_f4db7bed4a314308999229d907bbb1f1",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_2374ffa9cce2450e8e142badc3731b54",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "eace2e2c80024d1e9910806b33d53752": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_4dcc2fc27bb74c0d971bbb1a68d6dcd1",
       "placeholder": "​",
       "style": "IPY_MODEL_58dda13bec1740dc9a8d1721fa4a893f",
       "tabbable": null,
       "tooltip": null,
       "value": "Epoch 99, train_task_loss 5.963, val_task_loss 4.817, q 58.292: 100%"
      }
     },
     "f110635798584ae48e994b0688c9de64": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_eace2e2c80024d1e9910806b33d53752",
        "IPY_MODEL_d9de77425abc45e2867e002c0807488b",
        "IPY_MODEL_a528aeaad665457796224b1c3054591c"
       ],
       "layout": "IPY_MODEL_4d643c0074f64a9ba6096e3bc29dbf42",
       "tabbable": null,
       "tooltip": null
      }
     },
     "f4db7bed4a314308999229d907bbb1f1": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
