{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beaaebbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch \n",
    "import resnet\n",
    "import random\n",
    "import traceback\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as data\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from typing import List, Tuple\n",
    "from sklearn.metrics import f1_score\n",
    "from imagec_utils import load_cifar10c, load_cifar10_like_c\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from confseq.betting import betting_lower_cs, betting_cs\n",
    "from confseq.conjmix_bounded import conjmix_empbern_lower_cs\n",
    "from confseq.predmix import predmix_empbern_lower_cs, predmix_empbern_twosided_cs\n",
    "from confseq.boundaries import normal_mixture_bound, gamma_exponential_mixture_bound\n",
    "from confseq.conjmix_bounded import conjmix_empbern_lower_cs, conjmix_empbern_twosided_cs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "446fc859",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "def set_seed(seed: int):\n",
    "    import os\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4bee705",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import os\n",
    "\n",
    "def compute_v_opt(x, t_opt):\n",
    "    x = np.array(x)\n",
    "    t = np.arange(1, len(x) + 1)\n",
    "    S_t = np.cumsum(x)\n",
    "    mu_hat_t = S_t / t\n",
    "    mu_hat_tminus1 = np.append(1 / 2, mu_hat_t[0 : (len(mu_hat_t) - 1)])\n",
    "    V_t = np.cumsum(np.power(x - mu_hat_tminus1, 2))\n",
    "    v_opt = V_t[t_opt] * t_opt\n",
    "    return v_opt\n",
    "    \n",
    "def running_average_cumulative(x): \n",
    "    return np.cumsum(x) / (np.arange(len(x)) + 1)\n",
    "\n",
    "def shuffle_by_severity_multi(trajs, severities):\n",
    "    shuffled_results = [[] for _ in trajs]\n",
    "    start = 0\n",
    "    n_sev = len(severities)\n",
    "    length_per_sev = len(trajs[0]) // n_sev \n",
    "    for sev in severities:\n",
    "        idx = np.arange(length_per_sev)\n",
    "        np.random.shuffle(idx)\n",
    "        for i, traj in enumerate(trajs):\n",
    "            part = traj[start:start + length_per_sev]\n",
    "            shuffled_results[i].append(part[idx])  \n",
    "        start += length_per_sev\n",
    "\n",
    "    return [np.concatenate(parts) for parts in shuffled_results]\n",
    "\n",
    "para = 1\n",
    "def conjmix_empbern_cs_flexible(x, v_opt, alpha=0.05, c=1,  running_intersection=False):\n",
    "    x = np.array(x)\n",
    "    t = np.arange(1, len(x) + 1)\n",
    "    S_t = np.cumsum(x)\n",
    "    mu_hat_t = S_t / t\n",
    "    mu_hat_tminus1 = np.append(1/2., mu_hat_t[0:(len(mu_hat_t) - 1)])\n",
    "    V_t = np.cumsum(np.power(x - mu_hat_tminus1, 2))\n",
    "    bdry = (gamma_exponential_mixture_bound(\n",
    "        V_t, alpha=alpha / 2, v_opt=v_opt, c=c, alpha_opt=alpha / 2) / t)\n",
    "    l, u = mu_hat_t - bdry, mu_hat_t + bdry\n",
    "    l = np.maximum(l, -1 * para)\n",
    "    u = np.minimum(u, 2 * para)\n",
    "    if running_intersection:\n",
    "        l = np.maximum.accumulate(l)\n",
    "        u = np.minimum.accumulate(u)\n",
    "    return l, u\n",
    "\n",
    "    \n",
    "def evaluate_and_save(model, severities, data_dir, dataset_name, batch_size=128, n_examples=None, corruption_type=None, batch_order=None, shuffle=False, collate_fn=None, save_dir=\"./results\"):\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "    os.makedirs(os.path.join(save_dir, dataset_name), exist_ok=True)\n",
    "    model.eval()\n",
    "    all_results = {}  \n",
    "    with torch.no_grad():\n",
    "        for severity in severities:\n",
    "            print(f\"Processing severity {severity}...\")\n",
    "            if severity == 0:\n",
    "                x_test, y_test = load_cifar10_like_c(data_dir=data_dir, n_examples=n_examples)\n",
    "            else:\n",
    "                x_test, y_test = load_cifar10c(n_examples, severity, data_dir, [corruption_type], batch_order)\n",
    "\n",
    "            dataset = TensorDataset(x_test, y_test)\n",
    "            testloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)\n",
    "\n",
    "            logits_list, labels_list = [], []\n",
    "\n",
    "            accuracies = {}\n",
    "            for imgs, labels in testloader:\n",
    "                imgs, labels = imgs.to(device), labels.to(device)\n",
    "                logits = model(imgs)\n",
    "                logits_list.append(logits.cpu())\n",
    "                labels_list.append(labels.cpu())\n",
    "\n",
    "            logits_all = torch.cat(logits_list, dim=0)\n",
    "            labels_all = torch.cat(labels_list, dim=0)\n",
    "            preds = logits_all.argmax(dim=1)\n",
    "            acc = (preds == labels_all).float().mean().item()\n",
    "            accuracies[severity] = acc\n",
    "\n",
    "            print(f\"Severity {severity} Accuracy: {acc*100:.2f}%\")\n",
    "\n",
    "            all_results[severity] = {\n",
    "                \"logits\": logits_all,\n",
    "                \"labels\": labels_all\n",
    "            }\n",
    "            torch.save(\n",
    "                {\"logits\": logits_all, \"labels\": labels_all},\n",
    "                os.path.join(save_dir, f\"{dataset_name}/sev{severity}.pt\")\n",
    "            )\n",
    "            print(f\"Saved severity {severity} -> {save_dir}/{dataset_name}/sev{severity}.pt\")\n",
    "\n",
    "    return all_results\n",
    "\n",
    "def load_severity_data( dataset_name, save_dir=\"./results\", severities=[0,1,2,3,4,5]):\n",
    "    data_dict = {}\n",
    "    for s in severities:\n",
    "        path = os.path.join(save_dir, f\"{dataset_name}/sev{s}.pt\")\n",
    "        if os.path.exists(path):\n",
    "            data_dict[s] = torch.load(path)\n",
    "        else:\n",
    "            print(f\"Warning: missing {path}\")\n",
    "    return data_dict\n",
    "\n",
    "def split_into_batches(data, batch_size):\n",
    "    logits, labels = data[\"logits\"], data[\"labels\"]\n",
    "    n = logits.size(0)\n",
    "    n_batches = (n + batch_size - 1) // batch_size  \n",
    "\n",
    "    batches = []\n",
    "    for i in range(n_batches):\n",
    "        start, end = i * batch_size, min((i+1) * batch_size, n)\n",
    "        batches.append({\n",
    "            \"logits\": logits[start:end],\n",
    "            \"labels\": labels[start:end]\n",
    "        })\n",
    "    return batches\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d79e1862",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "arch = \"resnet110\"\n",
    "dataset_name=\"cifar10\"\n",
    "severities= [0,1,2,3,4,5]\n",
    "data_dir = \"data\"\n",
    "batch_order = \"uniform\"\n",
    "shuffle = True if batch_order == \"uniform\" else False\n",
    "corruption_type = \"gaussian_noise\"\n",
    "n_examples = 10000\n",
    "batch_size = 128\n",
    "collate_fn = None\n",
    "results_dir = \"./results\"\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "checkpoint_path = \"pretrained_models/resnet110-1d1ed7c2.th\"\n",
    "checkpoint = torch.load(checkpoint_path)\n",
    "model = torch.nn.DataParallel(resnet.__dict__[arch]())\n",
    "model.load_state_dict(checkpoint['state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772cee21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def temperature_scaling(logits, T):\n",
    "    return logits / T\n",
    "\n",
    "def calibrate_temperature(model, calibloader, device='cuda', max_iter=200):\n",
    "    model.eval()\n",
    "    model.to(device)\n",
    "    \n",
    "    logits_list, labels_list = [], []\n",
    "    with torch.no_grad():\n",
    "        for x, y in calibloader:\n",
    "            x, y = x.to(device), y.to(device)\n",
    "            logits = model(x)\n",
    "            logits_list.append(logits)\n",
    "            labels_list.append(y)\n",
    "    \n",
    "    logits_val = torch.cat(logits_list, dim=0)\n",
    "    labels_val = torch.cat(labels_list, dim=0)\n",
    "\n",
    "    class _TempScaler(nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.theta = nn.Parameter(torch.zeros(1))  \n",
    "        def forward(self, logits):\n",
    "            return logits / torch.exp(self.theta)\n",
    "        def temperature(self):\n",
    "            return torch.exp(self.theta).item()\n",
    "\n",
    "    scaler = _TempScaler().to(device)\n",
    "    nll = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.LBFGS(scaler.parameters(), lr=0.1, max_iter=max_iter)\n",
    "\n",
    "    def closure():\n",
    "        optimizer.zero_grad()\n",
    "        scaled_logits = scaler(logits_val)\n",
    "        loss = nll(scaled_logits, labels_val)\n",
    "        loss.backward()\n",
    "        return loss\n",
    "\n",
    "    optimizer.step(closure)\n",
    "    T = scaler.temperature()\n",
    "    print(f\"[Temperature Calibration] Optimal T = {T:.4f}\")\n",
    "    return T\n",
    "\n",
    "x_all, y_all = load_cifar10c(n_examples, 1, data_dir, [corruption_type], batch_order)\n",
    "\n",
    "\n",
    "calib_ratio = 0.2  \n",
    "n_calib = int(len(x_all) * calib_ratio)\n",
    "indices = torch.randperm(len(x_all))\n",
    "calib_idx = indices[:n_calib]\n",
    "test_idx = indices[n_calib:]\n",
    "x_calib, y_calib = x_all[calib_idx], y_all[calib_idx]\n",
    "calibset = TensorDataset(x_calib, y_calib)\n",
    "calibloader = DataLoader(calibset, batch_size=batch_size, shuffle=False)\n",
    "T = calibrate_temperature(model, calibloader, device=device)\n",
    "\n",
    "print(f\"Calibrated Temperature: {T:.3f}\")\n",
    "\n",
    "idx = 2\n",
    "logits_test = model(x_all[idx:idx+1].to(device)).detach()\n",
    "probs_test1 = F.softmax(logits_test, dim=1).cpu()\n",
    "logits_test = temperature_scaling(logits_test, T)\n",
    "probs_test2 = F.softmax(logits_test, dim=1).cpu()\n",
    "\n",
    "print(\"Before temperature scaling:\")\n",
    "print(probs_test1)\n",
    "print(\"After temperature scaling:\")\n",
    "print(probs_test2)\n",
    "plt.plot(probs_test1.numpy().flatten(), label='Before T scaling')\n",
    "plt.plot(probs_test2.numpy().flatten(), label='After T scaling')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91421209",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def compute_probs_error_max(logits):\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    probs_error = 1 - torch.max(probs, dim=-1).values\n",
    "    return probs_error\n",
    "\n",
    "def compute_brier_score(\n",
    "    logits: torch.Tensor, labels: torch.Tensor, mean_red: bool = False, c: int = 0.5\n",
    ") -> torch.Tensor:\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    one_hot = F.one_hot(labels.to(torch.int64), num_classes=logits.shape[1]).to(dtype=probs.dtype)\n",
    "    brier_scores = c * ((probs - one_hot) ** 2).sum(dim=1)\n",
    "    return brier_scores\n",
    "\n",
    "def compute_acc(\n",
    "    logits: torch.Tensor, labels: torch.Tensor, mean_red: bool = False, c: float = 0.5\n",
    ") -> torch.Tensor:\n",
    "    preds = torch.argmax(logits, dim=-1)\n",
    "    acc = (preds == labels).to(dtype=logits.dtype)\n",
    "    return 1 - acc\n",
    "\n",
    "def calibrate_max_f1_q(E,q,E_hat,q_hat,):\n",
    "    max_f1 = -1\n",
    "    best_q = None\n",
    "    best_q_hat = None\n",
    "    # Iterate over all possible pairs of thresholds\n",
    "    for q_i in q:\n",
    "        E_binary = E > q_i  # Convert continuous E to binary using threshold q_i\n",
    "        for q_hat_i in q_hat:\n",
    "            E_hat_binary = E_hat > q_hat_i  # Apply threshold to E_hat\n",
    "            # Calculate F1 score for this pair of thresholds      \n",
    "            current_f1 = f1_score(E_binary, E_hat_binary)\n",
    "            if current_f1 > max_f1:\n",
    "                max_f1 = current_f1\n",
    "                best_q = q_i\n",
    "                best_q_hat = q_hat_i\n",
    "    return best_q, best_q_hat, max_f1\n",
    "\n",
    "def calibrate_max_f1(E, E_hat, q_hat):\n",
    "    f1_scores = [f1_score(E, E_hat > _q_hat) for _q_hat in q_hat]\n",
    "    best_q_hat = q_hat[np.argmax(f1_scores)]\n",
    "    return best_q_hat\n",
    "\n",
    "def compute_wn(n_calib, alpha_calib):\n",
    "    return (np.log(2) - np.log(alpha_calib)) / (2 * n_calib)\n",
    "\n",
    "def compute_false_alarm(E, E_hat, q, q_hat):\n",
    "    mask = (E_hat > q_hat) & (E <= q)\n",
    "    return mask.float()\n",
    "save_dir = results_dir\n",
    "source_sev = 0\n",
    "path = os.path.join(save_dir, f\"{dataset_name}/sev{source_sev}.pt\")\n",
    "data = torch.load(path)\n",
    "Q_MIN = 0.001\n",
    "Q_MAX = 0.99 \n",
    "Q_STEP = 0.01 \n",
    "Q_HAT_MIN = 0.001\n",
    "Q_HAT_MAX = 0.99 \n",
    "Q_HAT_STEP = 0.01\n",
    "CALIB_SIZE = 1000\n",
    "ALPHA_PROD_2 = 0.2\n",
    "\n",
    "logits_source, labels_source = data[\"logits\"][:CALIB_SIZE], data[\"labels\"][:CALIB_SIZE]\n",
    "\n",
    "print(f\"Loaded logits shape: {logits_source.shape}, labels shape: {labels_source.shape}\")\n",
    "E_hat = compute_probs_error_max(logits_source)\n",
    "E = compute_brier_score(logits=logits_source, labels=labels_source, mean_red=True)\n",
    "q = torch.arange(Q_MIN, Q_MAX, Q_STEP)\n",
    "q_hat = torch.arange(Q_HAT_MIN, Q_HAT_MAX, Q_HAT_STEP)\n",
    "best_q, best_q_hat, _ = calibrate_max_f1_q(E=E, E_hat=E_hat, q=q, q_hat=q_hat)\n",
    "false_alarm = compute_false_alarm(E=E, E_hat=E_hat, q=best_q, q_hat=best_q_hat)\n",
    "L1 = false_alarm.mean() + compute_wn(\n",
    "    n_calib=CALIB_SIZE, alpha_calib=ALPHA_PROD_2)\n",
    "print(f\"Best q: {best_q}, Best q_hat: {best_q_hat}\")\n",
    "print(f\"False Alarm: {false_alarm.mean():.4f}, L1={L1:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b489a05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax_brier_loss(logits, targets, temperature=1.0):\n",
    "    probs = F.softmax(logits/temperature, dim=1)\n",
    "    if targets.ndim == 1:\n",
    "        targets = F.one_hot(targets.to(torch.int64), num_classes=probs.shape[1]).to(dtype=probs.dtype)\n",
    "    return torch.mean(torch.sum((probs - targets) ** 2, dim=1)) * 0.5\n",
    "\n",
    "def split_into_batches(logits, labels, batch_size):\n",
    "    n = logits.size(0)\n",
    "    indices = torch.randperm(n)\n",
    "    logits, labels = logits[indices], labels[indices]\n",
    "    return [\n",
    "        (logits[i:i+batch_size], labels[i:i+batch_size])\n",
    "        for i in range(0, n, batch_size)\n",
    "    ]\n",
    "def softmax_brier_loss_vector(logits, targets, temperature=1.0):\n",
    "    probs = F.softmax(logits / temperature, dim=1)\n",
    "\n",
    "    if targets.ndim == 1:\n",
    "        targets = F.one_hot(\n",
    "            targets.to(torch.int64),\n",
    "            num_classes=probs.shape[1]\n",
    "        ).to(dtype=probs.dtype)\n",
    "\n",
    "    loss_vec = 0.5 * torch.sum((probs - targets) ** 2, dim=1)\n",
    "    return loss_vec.to(torch.float32)\n",
    "def compute_eta_t(labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist, eta_max,  eps=1e-8):\n",
    "    u = np.concatenate(labeled_losses_hist, axis=0)\n",
    "    u_tilde_l = np.concatenate(agent_labeled_losses_hist, axis=0)\n",
    "    u_tilde_u = np.concatenate(agent_unlabeled_losses_hist, axis=0)\n",
    "    n_l = len(u)\n",
    "    n_u = len(u_tilde_u)\n",
    "\n",
    "    if n_l == 0 or n_u == 0:\n",
    "        return 0.0\n",
    "    u_bar = u.mean()\n",
    "    u_tilde_l_bar = u_tilde_l.mean()\n",
    "    u_tilde_u_bar = u_tilde_u.mean()\n",
    "    cov = np.mean((u - u_bar) * (u_tilde_l - u_tilde_l_bar))\n",
    "    var = np.mean((u_tilde_u - u_tilde_u_bar) ** 2)\n",
    "    if var < eps:\n",
    "        return 0.0\n",
    "    eta = cov / ((1.0 + n_l / n_u) * var)\n",
    "    eta = np.clip(eta, 0.0, eta_max)\n",
    "    return float(eta)\n",
    "\n",
    "def temperature_scaling(logits, T):\n",
    "    return logits / T\n",
    "\n",
    "def compute_brier_trajectories(save_dir=\"./results\", severities=[0,1,2,3,4,5], \n",
    "                               batch_size=256, temperature=1.0, max_batch_num=200, n_label=1, threshold=0.2, WINDOW_SIZE=100, ETA_MAX=1.0):\n",
    "    trajectories = {\n",
    "        \"labeled_true\": [],\n",
    "        \"labeled_pred\": [],\n",
    "        \"unlabeled_pred\": [],\n",
    "        \"unlabeled_true\": [],\n",
    "        \"unlabeled_unsup\": []\n",
    "    }\n",
    "    eta_seq = []\n",
    "    labeled_losses_hist = []\n",
    "    agent_labeled_losses_hist = []\n",
    "    agent_unlabeled_losses_hist = []\n",
    "    for sev in severities:\n",
    "        path = os.path.join(save_dir, f\"{dataset_name}/sev{sev}.pt\")\n",
    "        if not os.path.exists(path):\n",
    "            print(f\"Warning: missing severity {sev}\")\n",
    "            continue\n",
    "\n",
    "        data = torch.load(path)\n",
    "        logits, labels = data[\"logits\"], data[\"labels\"]\n",
    "        batches = split_into_batches(logits, labels, batch_size=batch_size)\n",
    "\n",
    "        # 保存每个 batch 的 Brier loss\n",
    "        l_true_list, l_pred_list = [], []\n",
    "        u_pred_list, u_true_list = [], []\n",
    "        u_unsup_list = []\n",
    "        for b_logits, b_labels in batches:\n",
    "            if len(l_true_list) >= max_batch_num:\n",
    "                break\n",
    "            if b_logits.size(0) <= n_label:\n",
    "                l_logits = b_logits\n",
    "                l_labels = b_labels\n",
    "                u_logits = torch.empty(0, b_logits.size(1))\n",
    "                u_labels = torch.empty(0, dtype=torch.long)\n",
    "            else:\n",
    "                l_logits, l_labels = b_logits[:n_label], b_labels[:n_label]\n",
    "                u_logits, u_labels = b_logits[n_label:], b_labels[n_label:]\n",
    "            l_pred_labels = l_logits.argmax(1)\n",
    "            u_pred_labels = u_logits.argmax(1) if u_logits.size(0) > 0 else torch.empty(0, dtype=torch.long)\n",
    "            l_true_list.append(softmax_brier_loss(l_logits, l_labels, temperature).item())\n",
    "            l_pred_list.append(softmax_brier_loss(l_logits, l_pred_labels, temperature).item())\n",
    "            if u_logits.size(0) > 0:\n",
    "                u_pred_list.append(softmax_brier_loss(u_logits, u_pred_labels, temperature).item())\n",
    "                u_true_list.append(softmax_brier_loss(u_logits, u_labels , temperature).item())\n",
    "            probs_data = F.softmax(temperature_scaling(b_logits, temperature), dim=1)\n",
    "            probs_error = 1. - torch.max(probs_data, dim=1)[0]\n",
    "            probs_error_temp = torch.mean((probs_error > threshold).float()) \n",
    "            u_unsup_list.append(probs_error_temp)\n",
    "            u_t = softmax_brier_loss_vector(l_logits, l_labels, temperature)\n",
    "            u_tilde_l_t = softmax_brier_loss_vector(l_logits, l_pred_labels, temperature)\n",
    "            u_tilde_u_t = softmax_brier_loss_vector(u_logits, u_pred_labels, temperature)\n",
    "            # ---- compute eta_t using ONLY history ----\n",
    "            if len(labeled_losses_hist) > 5:\n",
    "                eta_t = compute_eta_t(labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist, ETA_MAX)\n",
    "            else:\n",
    "                eta_t = ETA_MAX / 2   \n",
    "            eta_seq.append(eta_t)\n",
    "            # ---- update history ----\n",
    "            labeled_losses_hist.append(u_t)\n",
    "            agent_labeled_losses_hist.append(u_tilde_l_t)\n",
    "            agent_unlabeled_losses_hist.append(u_tilde_u_t)\n",
    "            # ---- sliding window ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE:\n",
    "                labeled_losses_hist.pop(0)\n",
    "                agent_labeled_losses_hist.pop(0)\n",
    "                agent_unlabeled_losses_hist.pop(0)\n",
    "        trajectories[\"labeled_true\"].append(torch.tensor(l_true_list))\n",
    "        trajectories[\"labeled_pred\"].append(torch.tensor(l_pred_list))\n",
    "        trajectories[\"unlabeled_pred\"].append(torch.tensor(u_pred_list))\n",
    "        trajectories[\"unlabeled_true\"].append(torch.tensor(u_true_list))\n",
    "        trajectories[\"unlabeled_unsup\"].append(torch.tensor(u_unsup_list))\n",
    "    for key in trajectories.keys():\n",
    "        trajectories[key] = torch.cat(trajectories[key], dim=0)\n",
    "    return trajectories, eta_seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dee55b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha, t_opt_ratio):\n",
    "    # --- prepare proxy L1 (as np array) ---\n",
    "    L1 = L1_proxy.numpy() if hasattr(L1_proxy, \"numpy\") else np.array(L1_proxy)\n",
    "    # ====================== conjmix ======================\n",
    "    T_OPT = int(len(traj_sup_true) * t_opt_ratio)\n",
    "    v_opt_sup = compute_v_opt(traj_sup_true, T_OPT)\n",
    "    v_opt_unsup = compute_v_opt(traj_unsup_proxy, T_OPT)\n",
    "    L_sup_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_sup_true, alpha=alpha * 2, v_opt=v_opt_sup)\n",
    "    # unsupervised\n",
    "    L_unsup_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_unsup_proxy, alpha=alpha * 2, v_opt=v_opt_unsup)\n",
    "    L_unsup_conjmix = np.maximum(L_unsup_conjmix - L1, np.zeros_like(L1))\n",
    "    supervised_bounds = {\"CM-EB\": L_sup_conjmix}\n",
    "    unsupervised_bounds = {\"CM-EB\": L_unsup_conjmix}\n",
    "    return supervised_bounds, unsupervised_bounds\n",
    "\n",
    "def compute_bounds_pmeb_betting(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha):\n",
    "    # --- prepare proxy L1 (as np array) ---\n",
    "    L1 = L1_proxy.numpy() if hasattr(L1_proxy, \"numpy\") else np.array(L1_proxy)\n",
    "    # ====================== predmix ======================\n",
    "    # supervised\n",
    "    L_sup_predmix, _ = predmix_empbern_twosided_cs(x=traj_sup_true, alpha=alpha)\n",
    "    # unsupervised\n",
    "    L_unsup_predmix, _ = predmix_empbern_twosided_cs(x=traj_unsup_proxy, alpha=alpha)\n",
    "    L_unsup_predmix = np.maximum(L_unsup_predmix - L1, np.zeros_like(L1))\n",
    "\n",
    "    # ====================== betting ======================\n",
    "    L_sup_betting, _ = betting_cs(x=traj_sup_true, alpha=alpha)\n",
    "\n",
    "    L_unsup_betting, _ = betting_cs(x=traj_unsup_proxy, alpha=alpha)\n",
    "    L_unsup_betting = np.maximum(L_unsup_betting - L1, np.zeros_like(L_unsup_betting))\n",
    "\n",
    "    supervised_bounds = {\"PM-EB\": L_sup_predmix,\"Betting\": L_sup_betting,}\n",
    "    unsupervised_bounds = {\"PM-EB\": L_unsup_predmix,\"Betting\": L_unsup_betting,}\n",
    "    return supervised_bounds, unsupervised_bounds\n",
    "\n",
    "\n",
    "def compute_bounds_cmeb_ppi(traj_ppi_any, alpha, t_opt_ratio, eta_max):\n",
    "    a = - eta_max\n",
    "    b = 1 + eta_max\n",
    "    traj_ppi_any = (traj_ppi_any  - a) / (b - a)\n",
    "    T_OPT = int(len(traj_ppi_any) * t_opt_ratio)\n",
    "    v_opt_ppi = compute_v_opt(traj_ppi_any, T_OPT)\n",
    "    # v_opt_ppi = T_OPT * 0.5\n",
    "    L_ppi_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_ppi_any, alpha=alpha * 2, v_opt=v_opt_ppi)\n",
    "    L_ppi_conjmix = L_ppi_conjmix * (b - a) + a\n",
    "    ppi_bounds = {\"CM-EB\": L_ppi_conjmix}\n",
    "    return ppi_bounds\n",
    "\n",
    "def compute_bounds_pmeb_betting_ppi(traj_ppi_any, alpha, eta_max):\n",
    "    a = - eta_max\n",
    "    b = 1 + eta_max\n",
    "    traj_ppi_any = (traj_ppi_any  - a) / (b - a)\n",
    "    L_ppi_predmix, _ = predmix_empbern_twosided_cs(x=traj_ppi_any, alpha=alpha)\n",
    "    L_ppi_predmix = L_ppi_predmix * (b - a) + a\n",
    "    L_ppi_betting, _ = betting_cs(x=traj_ppi_any, alpha=alpha)\n",
    "    L_ppi_betting = L_ppi_betting * (b - a) + a\n",
    "    ppi_bounds = {\"PM-EB\": L_ppi_predmix, \"Betting\": L_ppi_betting}\n",
    "    return ppi_bounds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce50bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXPERIMENTS = 10                 \n",
    "ALPHA = 0.2                          \n",
    "THRESH = 0.15     \n",
    "T_OPT_RATIO = 0.25\n",
    "VERBOSE = True\n",
    "ETA_MAX =  1.0\n",
    "WINDOW_SIZE = 60\n",
    "eta_fixed = 1.0\n",
    "# np.random.seed(1)\n",
    "severities = [0, 1, 2]  \n",
    "temperature = 5.\n",
    "max_batch_num = 300\n",
    "n_label = 1\n",
    "best_q_hat = torch.tensor(best_q_hat)\n",
    "batch_size = 16\n",
    "set_seed(2)\n",
    "methods = [\"CM-EB\"]\n",
    "results = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"adaptive_ppi\": [], \"ideal_ppi\": []} for m in methods}\n",
    "results_box = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"adaptive_ppi\": [], \"ideal_ppi\": []} for m in methods}\n",
    "all_traj_dicts = []\n",
    "for exp_id in range(N_EXPERIMENTS):\n",
    "    traj_dict, eta_seq = compute_brier_trajectories(\n",
    "        save_dir=results_dir,\n",
    "        severities=severities,\n",
    "        batch_size=batch_size,\n",
    "        temperature=temperature,\n",
    "        max_batch_num = max_batch_num,\n",
    "        n_label = n_label,\n",
    "        threshold=best_q_hat,\n",
    "        WINDOW_SIZE=WINDOW_SIZE,\n",
    "        ETA_MAX=ETA_MAX)\n",
    "    all_traj_dicts.append(traj_dict)\n",
    "    risk_traj_supervised_true   =  traj_dict[\"labeled_true\"].numpy()\n",
    "    risk_traj_supervised_pred   =  traj_dict[\"labeled_pred\"].numpy()\n",
    "    risk_traj_unsupervised_pred =  traj_dict[\"unlabeled_pred\"].numpy()\n",
    "    risk_traj_ppi_ideal =  traj_dict[\"unlabeled_true\"].numpy()\n",
    "    risk_traj_unsupervised_proxy =  traj_dict[\"unlabeled_unsup\"].numpy()\n",
    "\n",
    "    traj_sup_true = risk_traj_supervised_true\n",
    "    traj_ppi_ideal = risk_traj_ppi_ideal\n",
    "    traj_unsup_proxy = risk_traj_unsupervised_proxy \n",
    "    traj_ppi_pred = eta_fixed * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_fixed * risk_traj_supervised_pred\n",
    "    traj_ppi_pred_adaptive = eta_seq * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_seq * risk_traj_supervised_pred\n",
    "    L1_proxy = L1 \n",
    "    if VERBOSE:\n",
    "        print(f\"=== Experiment {exp_id+1}/{N_EXPERIMENTS} ===\")\n",
    "    try:\n",
    "        sup_bounds, unsup_bounds = compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO)\n",
    "        ppi_bounds =               compute_bounds_cmeb_ppi(traj_ppi_pred, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "        adaptive_ppi_bounds =      compute_bounds_cmeb_ppi(traj_ppi_pred_adaptive,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=ETA_MAX)\n",
    "        ideal_ppi_bounds =         compute_bounds_cmeb_ppi(traj_ppi_ideal,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Skipping exp {exp_id}: compute_bounds() failed — {e}\")\n",
    "        continue\n",
    "    for m in methods:\n",
    "        results[m][\"supervised\"].append(sup_bounds[m])\n",
    "        results[m][\"unsupervised\"].append(unsup_bounds[m])\n",
    "        results[m][\"ppi\"].append(ppi_bounds[m])\n",
    "        results[m][\"ideal_ppi\"].append(ideal_ppi_bounds[m])\n",
    "        results[m][\"adaptive_ppi\"].append(adaptive_ppi_bounds[m])\n",
    "    for m in methods:\n",
    "        sup_arr = np.asarray(sup_bounds.get(m, []))\n",
    "        unsup_arr = np.asarray(unsup_bounds.get(m, []))\n",
    "        adaptive_ppi_arr = np.asarray(adaptive_ppi_bounds.get(m, []))\n",
    "        ppi_arr = np.asarray(ppi_bounds.get(m, []))\n",
    "        ideal_ppi_arr = np.asarray(ideal_ppi_bounds.get(m, []))\n",
    "\n",
    "        assert len(sup_arr) == len(ppi_arr) == len(ideal_ppi_arr) == len(unsup_arr) == len(traj_sup_true)\n",
    "        # find threshold crossings\n",
    "        t_sup = int(np.argmax(sup_arr > THRESH) + 1) if np.any(sup_arr > THRESH) else np.nan\n",
    "        t_unsup = int(np.argmax(unsup_arr > THRESH) + 1) if np.any(unsup_arr > THRESH) else np.nan\n",
    "        t_ppi = int(np.argmax(ppi_arr > THRESH) + 1) if np.any(ppi_arr > THRESH) else np.nan\n",
    "        t_adaptive = int(np.argmax(adaptive_ppi_arr > THRESH) + 1) if np.any(adaptive_ppi_arr > THRESH) else np.nan\n",
    "        t_ideal = int(np.argmax(ideal_ppi_arr > THRESH) + 1) if np.any(ideal_ppi_arr > THRESH) else np.nan\n",
    "\n",
    "        results_box[m][\"supervised\"].append(t_sup)\n",
    "        results_box[m][\"unsupervised\"].append(t_unsup)\n",
    "        results_box[m][\"ppi\"].append(t_ppi)\n",
    "        results_box[m][\"adaptive_ppi\"].append(t_adaptive)\n",
    "        results_box[m][\"ideal_ppi\"].append(t_ideal)\n",
    "colors = {\"supervised\": \"#1f77b4\", \"unsupervised\": \"#7f7f7f\",  \"ppi\": \"#ff7f0e\",  \"adaptive_ppi\": \"#2ca02c\", \"ideal_ppi\": \"#9467bd\" }\n",
    "labels = {\"supervised\": \"SRM\", \"unsupervised\": \"URM\", \"adaptive_ppi\": r\"PPRM\", \"ppi\": \"PPRM\", \"ideal_ppi\": \"Ideal PPRM\"}\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "m = methods[0]\n",
    "if True:\n",
    "    for key in [\"supervised\", \"adaptive_ppi\", \"ideal_ppi\"]:\n",
    "        arrs = np.array(results[m][key])\n",
    "        if len(arrs) == 0:\n",
    "            continue\n",
    "        mean_curve = np.nanmean(arrs, axis=0)\n",
    "        std_curve = np.nanstd(arrs, axis=0)\n",
    "        steps = np.arange(len(mean_curve))\n",
    "        ax.plot(steps, mean_curve, label=labels[key], color=colors[key])\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            mean_curve - std_curve,\n",
    "            mean_curve + std_curve,\n",
    "            color=colors[key],\n",
    "            alpha=0.2)\n",
    "    try:\n",
    "        traj_ppi_ideal_all = np.array([\n",
    "            running_average_cumulative(traj[\"unlabeled_true\"].numpy())\n",
    "            for traj in all_traj_dicts  ])\n",
    "        traj_ppi_ideal_mean = np.nanmean(traj_ppi_ideal_all, axis=0)\n",
    "        traj_ppi_ideal_std = np.nanstd(traj_ppi_ideal_all, axis=0)\n",
    "        steps = np.arange(len(traj_ppi_ideal_mean))\n",
    "        ax.plot(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean,\n",
    "            linestyle=\"--\",\n",
    "            color=\"#9467bd\",\n",
    "            linewidth=2,\n",
    "            label=\"Running Risk\")\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean - traj_ppi_ideal_std,\n",
    "            traj_ppi_ideal_mean + traj_ppi_ideal_std,\n",
    "            color=\"#9467bd\",\n",
    "            alpha=0.15)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Failed to plot averaged PPI_Ideal risk trajectory: {e}\")\n",
    "    font_size = 22\n",
    "    ax.axhline(y=THRESH, color=\"red\", linestyle=\"--\", linewidth=2, label=f\"Risk Threshold\")\n",
    "    ax.set_xlabel(r\"Time Step $t$\", fontsize=font_size)\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "    ax.set_ylabel(\"Running risk lower bound\", fontsize=font_size+1)\n",
    "    ax.legend(fontsize=font_size, loc='upper left', ncol=1)\n",
    "    ax.tick_params(axis='both', labelsize=font_size+1)  \n",
    "    plt.ylim(0.06, 0.205)\n",
    "    plt.tight_layout()\n",
    "    save_dir = \"Simulations/Results_Image\"\n",
    "    os.makedirs(save_dir, exist_ok=True)  #\n",
    "\n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_imgc_increase_sev{severities[0:]}_lowerbound_Brier_self.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "font_size = 22.5\n",
    "if True:\n",
    "    plot_params = {\n",
    "        \"title_fontsize\": font_size,\n",
    "        \"xlabel_fontsize\": font_size,\n",
    "        \"ylabel_fontsize\": font_size,\n",
    "        \"xtick_fontsize\": font_size,\n",
    "        \"ytick_fontsize\": font_size,\n",
    "        \"legend_fontsize\": font_size,\n",
    "        \"suptitle_fontsize\": font_size,\n",
    "        \"title_fontweight\": \"bold\",\n",
    "        \"label_fontweight\": \"normal\"}\n",
    "\n",
    "    colors = [ \"#1f77b4\",\"#2ca02c\",  \"#9467bd\" ]\n",
    "    title_map = {\"srm\": \"SRM\", \"pprm\": \"PPRM\", \"ideal pprm\": \"Ideal PPRM\", \"urm\": \"URM\"}\n",
    "    fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "    m = methods[0]\n",
    "    data_for_box = []\n",
    "    labels_for_box = []\n",
    "\n",
    "    # supervised\n",
    "    sup_list = np.array(results_box[m][\"supervised\"], dtype=float)\n",
    "    sup_list = sup_list[~np.isnan(sup_list)]\n",
    "    if sup_list.size == 0:\n",
    "        sup_list = np.array([np.nan])\n",
    "    data_for_box.append(sup_list)\n",
    "    labels_for_box.append(\"SRM\")\n",
    "\n",
    "    # ppi\n",
    "    ppi_list = np.array(results_box[m][\"adaptive_ppi\"], dtype=float)\n",
    "    ppi_list = ppi_list[~np.isnan(ppi_list)]\n",
    "    if ppi_list.size == 0:\n",
    "        ppi_list = np.array([np.nan])\n",
    "    data_for_box.append(ppi_list)\n",
    "    labels_for_box.append(\"PPRM\")\n",
    "\n",
    "    # ideal ppi\n",
    "    ideal_list = np.array(results_box[m][\"ideal_ppi\"], dtype=float)\n",
    "    ideal_list = ideal_list[~np.isnan(ideal_list)]\n",
    "    if ideal_list.size == 0:\n",
    "        ideal_list = np.array([np.nan])\n",
    "    data_for_box.append(ideal_list)\n",
    "    labels_for_box.append(\"Ideal PPRM\")\n",
    "    bp = ax.boxplot(\n",
    "        data_for_box,\n",
    "        labels=labels_for_box,\n",
    "        showmeans=True,\n",
    "        patch_artist=True,\n",
    "        boxprops=dict(linewidth=2),\n",
    "        whiskerprops=dict(linewidth=2),\n",
    "        capprops=dict(linewidth=2),\n",
    "        medianprops=dict(linewidth=2),\n",
    "        showfliers=False)\n",
    "    for patch, color in zip(bp['boxes'], colors):\n",
    "        patch.set_facecolor(color)\n",
    "        patch.set_alpha(0.6)\n",
    "    for mean in bp['means']:\n",
    "        mean.set_markerfacecolor(\"red\")\n",
    "        mean.set_markeredgecolor(\"black\")\n",
    "    ax.tick_params(axis=\"x\", labelsize=plot_params[\"xtick_fontsize\"])\n",
    "    ax.tick_params(axis=\"y\", labelsize=plot_params[\"ytick_fontsize\"])\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "\n",
    "    ax.set_ylabel(\"Average time to alarm\",\n",
    "                    fontsize=plot_params[\"ylabel_fontsize\"],\n",
    "                    fontweight=plot_params[\"label_fontweight\"])\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_dir = \"Simulations/Results_Image\"\n",
    "    os.makedirs(save_dir, exist_ok=True)  \n",
    "\n",
    "\n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_imgc_increase_sev{severities[0:]}_alarm_time_brier_self.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
