{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from scipy.stats import norm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pathlib\n",
    "import os\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "from bin_cp.helpers.storage import load_smooth_prediction\n",
    "from bin_cp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor\n",
    "from bin_cp.robust.confidence import bernstein_bound, dkw_cdf\n",
    "from bin_cp.robust.confidence import clopper_pearson_lower, clopper_pearson_upper\n",
    "from bin_cp.robust.bounds import mean_bounds_l2, CDF_bounds_l2\n",
    "\n",
    "from bin_cp.cp.core import ConformalClassifier as CP\n",
    "from bin_cp.cp.scores import APSScore, TPSScore, LogitScore\n",
    "\n",
    "from bin_cp.methods.robust_cp import RobustCP, VanillaSmoothCP\n",
    "from bin_cp.methods.cas import CAS\n",
    "from bin_cp.methods.bincp import BinCP\n",
    "from bin_cp.methods.rcp_one import RCP1, RCP1Plus\n",
    "# from qrcp.methods.binary import QRCPThresholds\n",
    "import time\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "nn = 500\n",
    "nk = 10000\n",
    "\n",
    "from statsmodels.stats.proportion import proportion_confint\n",
    "proportion_confint(\n",
    "        0.0 * nk, nk, alpha=0.001/(nn + 2), method=\"beta\")\n",
    "\n",
    "# IMPORTANT: make sure that the result is close to 0.0 otherwise there is a bug due to the version of scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#region primary configs of the experiment\n",
    "\n",
    "output_dir = \"./results/\"\n",
    "\n",
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.5\n",
    "n_classes=10\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.5\n",
    "n_samples = 10000\n",
    "n_trial_samples = 10000\n",
    "\n",
    "score_method = \"TPS\"\n",
    "calibration_budget = 0.5\n",
    "n_iterations = 100\n",
    "\n",
    "confidence = 0.999\n",
    "coverage_range = [0.85, 0.9, 0.95]\n",
    "# coverage_range = [0.9]\n",
    "r_range = [0.06, 0.12, 0.18, 0.25, 0.37, 0.5, 0.75]\n",
    "# r_range = [0.12,  0.25, 0.5, ]\n",
    "\n",
    "marginal_coverage = 0.8\n",
    "r = 0.25\n",
    "n_0 = 50\n",
    "#endregion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = pathlib.Path(output_dir)/dataset_name\n",
    "output_dir.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = pathlib.Path(\"\")\n",
    "dataset_dir = pathlib.Path(\"\")\n",
    "logits_dir = pathlib.Path(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#region loding smooth logit predictions\n",
    "if n_samples < n_trial_samples:\n",
    "    print(f\"Number of trial samples is set to {n_trial_samples} as it is smaller than the number of samples.\")\n",
    "    \n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "try:\n",
    "    smooth_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "        model_sigma=model_sigma,\n",
    "        n_datapoints=n_datapoints,\n",
    "        smoothing_sigma=smoothing_sigma,\n",
    "        n_samples=n_samples,\n",
    "        models_dir=models_dir,\n",
    "        dataset_dir=dataset_dir,\n",
    "        logits_dir=logits_dir,)\n",
    "    n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "except FileNotFoundError as e:\n",
    "    print(\"Smooth predictions not found, you can generate them using bin/smooth_logits_clean.py\")\n",
    "    print(\"Full description of the error: \", e)\n",
    "\n",
    "try:\n",
    "    smooth_prediction_pert = load_smooth_prediction(dataset_name=dataset_name,\n",
    "        model_sigma=model_sigma,\n",
    "        n_datapoints=n_datapoints,\n",
    "        smoothing_sigma=smoothing_sigma,\n",
    "        models_dir=models_dir,\n",
    "        dataset_dir=dataset_dir,\n",
    "        logits_dir=logits_dir,\n",
    "        n_samples=n_samples, r=r, attack=\"pgd_rs\")\n",
    "    \n",
    "except FileNotFoundError as e:\n",
    "    print(\"Smooth predictions for perturbation not found, you can generate them using bin/smooth_logits_pert.py\")\n",
    "    print(\"Full description of the error: \", e)\n",
    "\n",
    "#endregion\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Class for Proxy Certificate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ProxyRobustnessCertificate(object):\n",
    "    def __init__(self, r=0, smoothing_sigma=0, marginal_guarantee=0.9, proxy_certificate=None, risk_setup=\"marginal\", proxy_setup=\"vanilla\", **kwargs):\n",
    "        self.r = r\n",
    "        self.smoothing_sigma = smoothing_sigma\n",
    "        self._marginal_guarantee = marginal_guarantee\n",
    "        self.proxy_setup = proxy_setup\n",
    "\n",
    "        if self.proxy_setup == \"vanilla\":\n",
    "            self.marginal_guarantee = deepcopy(self._marginal_guarantee)\n",
    "            \n",
    "        else:\n",
    "            self.marginal_guarantee = self.adjust_robust_marginal_guarantee()\n",
    "\n",
    "        if \"certificate_setup\" in kwargs:\n",
    "            self.certificate_setup = kwargs[\"certificate_setup\"]\n",
    "            self.certificate_confidence = self.certificate_setup.get(\"confidence\", 0.999)\n",
    "            self.certificate_n_0 = self.certificate_setup.get(\"n_0\", 100)\n",
    "        \n",
    "        self.proxy_certificate = proxy_certificate  # proxy certificate receives two arguments: logits, lambd and returns two elements: robust_labels, y_top\n",
    "        self.risk_setup = risk_setup\n",
    "        # for logits samples are always along the dim=1\n",
    "\n",
    "        self._acceptable_risk = None\n",
    "        self._acceptable_lambd = None\n",
    "\n",
    "    def adjust_robust_marginal_guarantee(self):\n",
    "        if self.r == 0:\n",
    "            adjusted_guarantee = deepcopy(self._marginal_guarantee)\n",
    "        else:\n",
    "            adjusted_guarantee = norm.cdf(norm.ppf(self._marginal_guarantee, scale=self.smoothing_sigma) + self.r, scale=self.smoothing_sigma)\n",
    "        return adjusted_guarantee\n",
    "\n",
    "    def hard_certify(self, logits, y_top=None, dim=1, r_coef=1):\n",
    "        precomputed = (y_top is not None)\n",
    "        if y_top is None:\n",
    "            y_pred = torch.argmax(logits[:, :self.certificate_n_0, :], dim=dim)\n",
    "            votes, y_top = torch.nn.functional.one_hot(y_pred, num_classes=logits.shape[-1]).float().mean(1).max(-1)\n",
    "\n",
    "        votes_intensive = (torch.argmax(logits[:, self.certificate_n_0 if not precomputed else 0:, :], dim=-1) == y_top.reshape(-1, 1)).sum(-1)\n",
    "        probs = clopper_pearson_lower(votes_intensive.cpu(), logits.shape[1] - (0 if precomputed else self.certificate_n_0), alpha=1 - self.marginal_guarantee)\n",
    "        prob_lower = norm.cdf(norm.ppf(probs, scale=self.smoothing_sigma) - r_coef * self.r, scale=self.smoothing_sigma)\n",
    "        robust_labels = torch.tensor(prob_lower >= 0.5)\n",
    "        return robust_labels, y_top\n",
    "    \n",
    "    def tune_proxy_certificate(self, logits, dim=2, verbose=True):\n",
    "        certified, y_top_certified = self.hard_certify(logits, r_coef= 1 if self.proxy_setup == \"vanilla\" else 2, dim=dim)\n",
    "        certified = certified.to(device)\n",
    "        y_top_certified = y_top_certified.to(device)\n",
    "\n",
    "        lambd_min = 0.0\n",
    "        lambd_max = 1.0\n",
    "        while lambd_max - lambd_min > 1e-7:\n",
    "            lambd_mid = (lambd_min + lambd_max) / 2\n",
    "            certified_estim, y_top_estim = self.proxy_certificate(logits, lambd_mid, y_top=y_top_certified)\n",
    "\n",
    "            y_top_acc = (y_top_estim == y_top_certified).float().mean().item()\n",
    "            robust_acc = ((certified_estim == certified).float().mean().item())\n",
    "            \n",
    "            if self.risk_setup == \"marginal\":\n",
    "                # risk = ((~certified) & (certified_estim)) | (y_top_estim != y_top_certified)\n",
    "                risk = ((~certified) & (certified_estim)) # | (y_top_estim != y_top_certified) TODO: RETURN\n",
    "                risk = (risk.float().sum().item() + 1) / (logits.shape[0] + 1)\n",
    "            else:\n",
    "                risk = (y_top_estim != y_top_certified) | (certified_estim)\n",
    "                risk = risk[~certified]\n",
    "                risk = (risk.float().sum().item() + 1) / ((~certified).sum().item() + 1)\n",
    "\n",
    "            if self._acceptable_risk is None and risk <= (1 - self.marginal_guarantee):\n",
    "                self._acceptable_risk = risk\n",
    "                self._acceptable_lambd = lambd_mid\n",
    "            elif self._acceptable_risk is not None and risk >= self._acceptable_risk and risk <= (1 - self.marginal_guarantee):\n",
    "                self._acceptable_risk = risk\n",
    "                self._acceptable_lambd = lambd_mid\n",
    "\n",
    "            if risk > (1 - self.marginal_guarantee):\n",
    "                lambd_min = lambd_mid\n",
    "            else:\n",
    "                lambd_max = lambd_mid\n",
    "\n",
    "            if verbose:    \n",
    "                print(f\"lambda: {lambd_mid}, robust accuracy: {robust_acc:.4f}, y_top accuracy: {y_top_acc:.4f}, risk: {risk:.4f}\")\n",
    "        if verbose:\n",
    "            if self._acceptable_lambd is not None:\n",
    "                print(f\"Found acceptable lambda: {self._acceptable_lambd} with risk: {self._acceptable_risk}\")\n",
    "            else: \n",
    "                print(\"No acceptable lambda found.\")\n",
    "    \n",
    "    def certify_proxy(self, logits, y_top=None, dim=2):\n",
    "        if self._acceptable_lambd is None:\n",
    "            print(Warning(\"Proxy certificate not tuned yet. Please run tune_proxy_certificate first.\"))\n",
    "            result, result_label = torch.tensor([False]*logits.shape[0]), torch.tensor([-1]*logits.shape[0])\n",
    "        else:\n",
    "            result, result_label = self.proxy_certificate(logits, self._acceptable_lambd, y_top=y_top)\n",
    "        return result, result_label\n",
    "\n",
    "def single_sample_proxy_certificate(logits, lambd, y_top=None, vectorized=False):\n",
    "    if logits.ndim == 3 and not vectorized:\n",
    "        computable_logits = logits[:, 0, :] \n",
    "    else:\n",
    "        computable_logits = logits\n",
    "\n",
    "    softmaxes = F.softmax(computable_logits, dim=-1)\n",
    "    if y_top is not None:\n",
    "        if not vectorized:\n",
    "            conf = softmaxes[torch.arange(softmaxes.shape[0]), y_top]\n",
    "        else:\n",
    "            conf = torch.stack([softmaxes[torch.arange(softmaxes.shape[0]), i, y_top] for i in range(softmaxes.shape[1])]).permute(1, 0)\n",
    "    else:\n",
    "        if vectorized:\n",
    "            raise NotImplementedError(\"Vectorized proxy certificate requires y_top to be provided.\")\n",
    "        conf, y_top = softmaxes.max(dim=-1)\n",
    "\n",
    "    votes = (conf >= lambd)\n",
    "    return votes, y_top\n",
    "\n",
    "def estimate_y_top(logits, n_0, dim=2):\n",
    "    if logits.ndim == 3:\n",
    "        computable_logits = logits[:, :n_0, :] \n",
    "    else:\n",
    "        computable_logits = logits\n",
    "\n",
    "    y_pred = torch.argmax(computable_logits, dim=dim)\n",
    "    votes, y_top = torch.nn.functional.one_hot(y_pred, num_classes=logits.shape[-1]).float().mean(1).max(-1)\n",
    "    return y_top"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of robustly classified points (clean): 1313 out of 2048, ratio: 0.6411\n",
      "Number of robustly classified points (perturbed): 1211 out of 2048, ratio: 0.5913\n"
     ]
    }
   ],
   "source": [
    "proxy_certificate = ProxyRobustnessCertificate(\n",
    "    r=r,\n",
    "    smoothing_sigma=smoothing_sigma,\n",
    "    marginal_guarantee=marginal_coverage,\n",
    "    certificate_setup={\"confidence\": confidence, \"n_0\": n_0}, \n",
    "    proxy_certificate=single_sample_proxy_certificate,\n",
    "    proxy_setup=\"vanilla\", risk_setup=\"marginal\"\n",
    ")\n",
    "robust_labels, y_top = proxy_certificate.hard_certify(smooth_prediction.logits, r_coef=1, dim=2)\n",
    "robust_labels_pert, y_top_pert = proxy_certificate.hard_certify(smooth_prediction_pert.logits, r_coef=1, dim=2)\n",
    "n_robust = robust_labels.sum().item()\n",
    "n_robust_pert = robust_labels_pert.sum().item()\n",
    "print(f\"Number of robustly classified points (clean): {n_robust} out of {n_datapoints}, ratio: {n_robust/n_datapoints:.4f}\")\n",
    "print(f\"Number of robustly classified points (perturbed): {n_robust_pert} out of {n_datapoints}, ratio: {n_robust_pert/n_datapoints:.4f}\")\n",
    "certified_ratio_clean = n_robust / n_datapoints\n",
    "certified_ratio_pert = n_robust_pert / n_datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smooth_prediction.logits = smooth_prediction.logits.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      " nominal marginal guarantee: 0.6, adjusted robust marginal guarantee: 0.6\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6367 , marginal validity: 0.6367\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:01<00:11,  1.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9463, conditional validity: 0.5840 , marginal validity: 0.5840\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6475 , marginal validity: 0.6475\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:02<00:10,  1.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9434, conditional validity: 0.5986 , marginal validity: 0.5986\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6641 , marginal validity: 0.6641\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:04<00:09,  1.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9424, conditional validity: 0.6143 , marginal validity: 0.6143\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6504 , marginal validity: 0.6504\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:05<00:08,  1.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9395, conditional validity: 0.6035 , marginal validity: 0.6035\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6484 , marginal validity: 0.6484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:06<00:06,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9375, conditional validity: 0.5957 , marginal validity: 0.5957\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6602 , marginal validity: 0.6602\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:08<00:05,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9561, conditional validity: 0.6064 , marginal validity: 0.6064\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6816 , marginal validity: 0.6816\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 7/10 [00:09<00:04,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9521, conditional validity: 0.6377 , marginal validity: 0.6377\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6582 , marginal validity: 0.6582\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 8/10 [00:11<00:02,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9473, conditional validity: 0.5996 , marginal validity: 0.5996\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6396 , marginal validity: 0.6396\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 9/10 [00:12<00:01,  1.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9395, conditional validity: 0.5869 , marginal validity: 0.5869\n",
      "Certified ratio (proxy, clean): 1.0000, proxy label accuracy: 1.0000, conditional validity: 0.6514 , marginal validity: 0.6514\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:13<00:00,  1.40s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 1.0000, proxy label accuracy: 0.9385, conditional validity: 0.5859 , marginal validity: 0.5859\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.6, adjusted robust marginal guarantee: 0.7743793197188857\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.6709, proxy label accuracy: 1.0000, conditional validity: 0.8887 , marginal validity: 0.8887\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:01<00:11,  1.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6357, proxy label accuracy: 0.9385, conditional validity: 0.8701 , marginal validity: 0.8701\n",
      "Certified ratio (proxy, clean): 0.6904, proxy label accuracy: 1.0000, conditional validity: 0.8789 , marginal validity: 0.8789\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:02<00:11,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6328, proxy label accuracy: 0.9336, conditional validity: 0.8760 , marginal validity: 0.8760\n",
      "Certified ratio (proxy, clean): 0.6768, proxy label accuracy: 1.0000, conditional validity: 0.8877 , marginal validity: 0.8877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:04<00:09,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6328, proxy label accuracy: 0.9463, conditional validity: 0.8691 , marginal validity: 0.8691\n",
      "Certified ratio (proxy, clean): 0.6719, proxy label accuracy: 1.0000, conditional validity: 0.8984 , marginal validity: 0.8984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:05<00:08,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6221, proxy label accuracy: 0.9424, conditional validity: 0.8789 , marginal validity: 0.8789\n",
      "Certified ratio (proxy, clean): 0.6504, proxy label accuracy: 1.0000, conditional validity: 0.9014 , marginal validity: 0.9014\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:06<00:07,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6348, proxy label accuracy: 0.9473, conditional validity: 0.8789 , marginal validity: 0.8789\n",
      "Certified ratio (proxy, clean): 0.6416, proxy label accuracy: 1.0000, conditional validity: 0.9082 , marginal validity: 0.9082\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:08<00:05,  1.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6172, proxy label accuracy: 0.9385, conditional validity: 0.8779 , marginal validity: 0.8779\n",
      "Certified ratio (proxy, clean): 0.6396, proxy label accuracy: 1.0000, conditional validity: 0.9092 , marginal validity: 0.9092\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 7/10 [00:09<00:04,  1.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6172, proxy label accuracy: 0.9424, conditional validity: 0.8936 , marginal validity: 0.8936\n",
      "Certified ratio (proxy, clean): 0.6641, proxy label accuracy: 1.0000, conditional validity: 0.9053 , marginal validity: 0.9053\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 8/10 [00:11<00:02,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6123, proxy label accuracy: 0.9395, conditional validity: 0.8955 , marginal validity: 0.8955\n",
      "Certified ratio (proxy, clean): 0.6699, proxy label accuracy: 1.0000, conditional validity: 0.8926 , marginal validity: 0.8926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 9/10 [00:12<00:01,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6367, proxy label accuracy: 0.9355, conditional validity: 0.8662 , marginal validity: 0.8662\n",
      "Certified ratio (proxy, clean): 0.7051, proxy label accuracy: 1.0000, conditional validity: 0.8535 , marginal validity: 0.8535\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:13<00:00,  1.39s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.6465, proxy label accuracy: 0.9326, conditional validity: 0.8486 , marginal validity: 0.8486\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.7, adjusted robust marginal guarantee: 0.7\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.9326, proxy label accuracy: 1.0000, conditional validity: 0.7305 , marginal validity: 0.7305\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:01<00:12,  1.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9102, proxy label accuracy: 0.9463, conditional validity: 0.6816 , marginal validity: 0.6816\n",
      "Certified ratio (proxy, clean): 0.9395, proxy label accuracy: 1.0000, conditional validity: 0.7090 , marginal validity: 0.7090\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:02<00:11,  1.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9160, proxy label accuracy: 0.9355, conditional validity: 0.6807 , marginal validity: 0.6807\n",
      "Certified ratio (proxy, clean): 0.9629, proxy label accuracy: 1.0000, conditional validity: 0.6631 , marginal validity: 0.6631\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:04<00:09,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9551, proxy label accuracy: 0.9453, conditional validity: 0.6201 , marginal validity: 0.6201\n",
      "Certified ratio (proxy, clean): 0.9307, proxy label accuracy: 1.0000, conditional validity: 0.7031 , marginal validity: 0.7031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:05<00:08,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9121, proxy label accuracy: 0.9346, conditional validity: 0.6709 , marginal validity: 0.6709\n",
      "Certified ratio (proxy, clean): 0.9131, proxy label accuracy: 1.0000, conditional validity: 0.7354 , marginal validity: 0.7354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:07<00:07,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9004, proxy label accuracy: 0.9482, conditional validity: 0.6914 , marginal validity: 0.6914\n",
      "Certified ratio (proxy, clean): 0.9619, proxy label accuracy: 1.0000, conditional validity: 0.6689 , marginal validity: 0.6689\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:08<00:05,  1.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9502, proxy label accuracy: 0.9385, conditional validity: 0.6279 , marginal validity: 0.6279\n",
      "Certified ratio (proxy, clean): 0.9307, proxy label accuracy: 1.0000, conditional validity: 0.7178 , marginal validity: 0.7178\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 7/10 [00:09<00:04,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9150, proxy label accuracy: 0.9434, conditional validity: 0.6709 , marginal validity: 0.6709\n",
      "Certified ratio (proxy, clean): 0.9297, proxy label accuracy: 1.0000, conditional validity: 0.7158 , marginal validity: 0.7158\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 8/10 [00:11<00:02,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9131, proxy label accuracy: 0.9424, conditional validity: 0.6895 , marginal validity: 0.6895\n",
      "Certified ratio (proxy, clean): 0.9238, proxy label accuracy: 1.0000, conditional validity: 0.7178 , marginal validity: 0.7178\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 9/10 [00:12<00:01,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.9053, proxy label accuracy: 0.9443, conditional validity: 0.6875 , marginal validity: 0.6875\n",
      "Certified ratio (proxy, clean): 0.9238, proxy label accuracy: 1.0000, conditional validity: 0.7383 , marginal validity: 0.7383\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:14<00:00,  1.41s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.8955, proxy label accuracy: 0.9473, conditional validity: 0.7197 , marginal validity: 0.7197\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.7, adjusted robust marginal guarantee: 0.8471769300473455\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.5781, proxy label accuracy: 1.0000, conditional validity: 0.9297 , marginal validity: 0.9297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:01<00:12,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5420, proxy label accuracy: 0.9395, conditional validity: 0.9229 , marginal validity: 0.9229\n",
      "Certified ratio (proxy, clean): 0.5684, proxy label accuracy: 1.0000, conditional validity: 0.9375 , marginal validity: 0.9375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:02<00:11,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5420, proxy label accuracy: 0.9453, conditional validity: 0.9199 , marginal validity: 0.9199\n",
      "Certified ratio (proxy, clean): 0.5732, proxy label accuracy: 1.0000, conditional validity: 0.9316 , marginal validity: 0.9316\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:04<00:09,  1.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5645, proxy label accuracy: 0.9424, conditional validity: 0.9219 , marginal validity: 0.9219\n",
      "Certified ratio (proxy, clean): 0.5625, proxy label accuracy: 1.0000, conditional validity: 0.9385 , marginal validity: 0.9385\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:05<00:08,  1.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5332, proxy label accuracy: 0.9443, conditional validity: 0.9326 , marginal validity: 0.9326\n",
      "Certified ratio (proxy, clean): 0.5557, proxy label accuracy: 1.0000, conditional validity: 0.9326 , marginal validity: 0.9326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:06<00:06,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5146, proxy label accuracy: 0.9443, conditional validity: 0.9307 , marginal validity: 0.9307\n",
      "Certified ratio (proxy, clean): 0.5732, proxy label accuracy: 1.0000, conditional validity: 0.9336 , marginal validity: 0.9336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:08<00:05,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5371, proxy label accuracy: 0.9434, conditional validity: 0.9229 , marginal validity: 0.9229\n",
      "Certified ratio (proxy, clean): 0.5781, proxy label accuracy: 1.0000, conditional validity: 0.9180 , marginal validity: 0.9180\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 7/10 [00:09<00:04,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5459, proxy label accuracy: 0.9453, conditional validity: 0.9111 , marginal validity: 0.9111\n",
      "Certified ratio (proxy, clean): 0.5771, proxy label accuracy: 1.0000, conditional validity: 0.9346 , marginal validity: 0.9346\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 8/10 [00:11<00:02,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5449, proxy label accuracy: 0.9375, conditional validity: 0.9160 , marginal validity: 0.9160\n",
      "Certified ratio (proxy, clean): 0.5996, proxy label accuracy: 1.0000, conditional validity: 0.9170 , marginal validity: 0.9170\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 9/10 [00:12<00:01,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5713, proxy label accuracy: 0.9287, conditional validity: 0.9004 , marginal validity: 0.9004\n",
      "Certified ratio (proxy, clean): 0.5781, proxy label accuracy: 1.0000, conditional validity: 0.9219 , marginal validity: 0.9219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:13<00:00,  1.40s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.5469, proxy label accuracy: 0.9395, conditional validity: 0.9082 , marginal validity: 0.9082\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.8, adjusted robust marginal guarantee: 0.8\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.8008, proxy label accuracy: 1.0000, conditional validity: 0.8076 , marginal validity: 0.8076\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:01<00:12,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7852, proxy label accuracy: 0.9385, conditional validity: 0.7852 , marginal validity: 0.7852\n",
      "Certified ratio (proxy, clean): 0.8037, proxy label accuracy: 1.0000, conditional validity: 0.7930 , marginal validity: 0.7930\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:02<00:11,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7900, proxy label accuracy: 0.9424, conditional validity: 0.7646 , marginal validity: 0.7646\n",
      "Certified ratio (proxy, clean): 0.8223, proxy label accuracy: 1.0000, conditional validity: 0.8076 , marginal validity: 0.8076\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:03<00:09,  1.31s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7822, proxy label accuracy: 0.9482, conditional validity: 0.7871 , marginal validity: 0.7871\n",
      "Certified ratio (proxy, clean): 0.7871, proxy label accuracy: 1.0000, conditional validity: 0.8359 , marginal validity: 0.8359\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:05<00:08,  1.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7695, proxy label accuracy: 0.9375, conditional validity: 0.7969 , marginal validity: 0.7969\n",
      "Certified ratio (proxy, clean): 0.7793, proxy label accuracy: 1.0000, conditional validity: 0.8262 , marginal validity: 0.8262\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:06<00:06,  1.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7549, proxy label accuracy: 0.9365, conditional validity: 0.7852 , marginal validity: 0.7852\n",
      "Certified ratio (proxy, clean): 0.7930, proxy label accuracy: 1.0000, conditional validity: 0.8145 , marginal validity: 0.8145\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:08<00:05,  1.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7754, proxy label accuracy: 0.9414, conditional validity: 0.7578 , marginal validity: 0.7578\n",
      "Certified ratio (proxy, clean): 0.7754, proxy label accuracy: 1.0000, conditional validity: 0.8330 , marginal validity: 0.8330\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 7/10 [00:09<00:04,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7656, proxy label accuracy: 0.9492, conditional validity: 0.7920 , marginal validity: 0.7920\n",
      "Certified ratio (proxy, clean): 0.8174, proxy label accuracy: 1.0000, conditional validity: 0.7959 , marginal validity: 0.7959\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 8/10 [00:10<00:02,  1.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7793, proxy label accuracy: 0.9414, conditional validity: 0.7812 , marginal validity: 0.7812\n",
      "Certified ratio (proxy, clean): 0.8086, proxy label accuracy: 1.0000, conditional validity: 0.8066 , marginal validity: 0.8066\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 9/10 [00:12<00:01,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7695, proxy label accuracy: 0.9453, conditional validity: 0.7803 , marginal validity: 0.7803\n",
      "Certified ratio (proxy, clean): 0.8242, proxy label accuracy: 1.0000, conditional validity: 0.7910 , marginal validity: 0.7910\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:13<00:00,  1.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.7930, proxy label accuracy: 0.9443, conditional validity: 0.7656 , marginal validity: 0.7656\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.8, adjusted robust marginal guarantee: 0.9101405810766334\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.4580, proxy label accuracy: 1.0000, conditional validity: 0.9678 , marginal validity: 0.9678\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:01<00:12,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.4062, proxy label accuracy: 0.9395, conditional validity: 0.9648 , marginal validity: 0.9648\n",
      "Certified ratio (proxy, clean): 0.4697, proxy label accuracy: 1.0000, conditional validity: 0.9541 , marginal validity: 0.9541\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:02<00:11,  1.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.4092, proxy label accuracy: 0.9336, conditional validity: 0.9600 , marginal validity: 0.9600\n",
      "Certified ratio (proxy, clean): 0.4561, proxy label accuracy: 1.0000, conditional validity: 0.9609 , marginal validity: 0.9609\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:04<00:09,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.4053, proxy label accuracy: 0.9326, conditional validity: 0.9629 , marginal validity: 0.9629\n",
      "Certified ratio (proxy, clean): 0.4824, proxy label accuracy: 1.0000, conditional validity: 0.9541 , marginal validity: 0.9541\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:05<00:08,  1.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.4209, proxy label accuracy: 0.9434, conditional validity: 0.9707 , marginal validity: 0.9707\n",
      "Certified ratio (proxy, clean): 0.4561, proxy label accuracy: 1.0000, conditional validity: 0.9639 , marginal validity: 0.9639\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:06<00:06,  1.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.3984, proxy label accuracy: 0.9297, conditional validity: 0.9678 , marginal validity: 0.9678\n",
      "Certified ratio (proxy, clean): 0.4492, proxy label accuracy: 1.0000, conditional validity: 0.9619 , marginal validity: 0.9619\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:08<00:08,  1.61s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[96], line 38\u001b[0m\n\u001b[1;32m     36\u001b[0m eval_y_top_pert \u001b[38;5;241m=\u001b[39m estimate_y_top(smooth_prediction_pert\u001b[38;5;241m.\u001b[39mlogits[eval_mask], n_0\u001b[38;5;241m=\u001b[39mn_0, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m     37\u001b[0m proxy_prediction_pert, proxy_labels_pert \u001b[38;5;241m=\u001b[39m proxy_certificate\u001b[38;5;241m.\u001b[39mcertify_proxy(smooth_prediction_pert\u001b[38;5;241m.\u001b[39mlogits[eval_mask], y_top\u001b[38;5;241m=\u001b[39meval_y_top_pert, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 38\u001b[0m eval_true_labels_pert, hard_certificate_labels_pert \u001b[38;5;241m=\u001b[39m \u001b[43mproxy_certificate\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhard_certify\u001b[49m\u001b[43m(\u001b[49m\u001b[43msmooth_prediction_pert\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogits\u001b[49m\u001b[43m[\u001b[49m\u001b[43meval_mask\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_0\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     40\u001b[0m certified_ratio_pert \u001b[38;5;241m=\u001b[39m eval_true_labels_pert\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m     41\u001b[0m certified_ratio_pert_proxy \u001b[38;5;241m=\u001b[39m proxy_prediction_pert\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem()\n",
      "Cell \u001b[0;32mIn[85], line 37\u001b[0m, in \u001b[0;36mProxyRobustnessCertificate.hard_certify\u001b[0;34m(self, logits, y_top, dim, r_coef)\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m y_top \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m     36\u001b[0m     y_pred \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits[:, :\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcertificate_n_0, :], dim\u001b[38;5;241m=\u001b[39mdim)\n\u001b[0;32m---> 37\u001b[0m     votes, y_top \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mone_hot\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_classes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mmean(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     39\u001b[0m votes_intensive \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39margmax(logits[:, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcertificate_n_0 \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m precomputed \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;241m0\u001b[39m:, :], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m==\u001b[39m y_top\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m))\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     40\u001b[0m probs \u001b[38;5;241m=\u001b[39m clopper_pearson_lower(votes_intensive\u001b[38;5;241m.\u001b[39mcpu(), logits\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m-\u001b[39m (\u001b[38;5;241m0\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m precomputed \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcertificate_n_0), alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmarginal_guarantee)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "\n",
    "guarantee_range = [0.6, 0.7, 0.8, 0.9, 0.95]\n",
    "\n",
    "results = []\n",
    "\n",
    "for marginal_coverage in guarantee_range:\n",
    "    for setup in [\"vanilla\", \"robust\"]:\n",
    "        proxy_certificate = ProxyRobustnessCertificate(\n",
    "            r=r,\n",
    "            smoothing_sigma=smoothing_sigma,\n",
    "            marginal_guarantee=marginal_coverage,\n",
    "            certificate_setup={\"confidence\": confidence, \"n_0\": n_0}, \n",
    "            proxy_certificate=single_sample_proxy_certificate,\n",
    "            proxy_setup=setup, risk_setup=\"marginal\"\n",
    "        )\n",
    "\n",
    "        print(f\"\\n\\n nominal marginal guarantee: {marginal_coverage}, adjusted robust marginal guarantee: {proxy_certificate.marginal_guarantee}\")\n",
    "        print(\"\\n\\n\")\n",
    "\n",
    "        for i in tqdm(range(10)):\n",
    "            cal_mask = get_cal_mask(smooth_prediction.logits, fraction=calibration_budget)\n",
    "            eval_mask = ~cal_mask\n",
    "\n",
    "            proxy_certificate.tune_proxy_certificate(smooth_prediction.logits[cal_mask], dim=2, verbose=False)\n",
    "            eval_y_top = estimate_y_top(smooth_prediction.logits[eval_mask], n_0=n_0, dim=2)\n",
    "            proxy_prediction_clean, proxy_labels = proxy_certificate.certify_proxy(smooth_prediction.logits[eval_mask], dim=2, y_top=eval_y_top)\n",
    "            eval_true_labels, hard_certificate_labels = proxy_certificate.hard_certify(smooth_prediction.logits[eval_mask][:, n_0:, :], dim=2, y_top=eval_y_top)\n",
    "\n",
    "            certified_ratio_clean = eval_true_labels.float().mean().item()\n",
    "            certified_ratio_clean_proxy = proxy_prediction_clean.float().mean().item()\n",
    "            proxy_label_acc = (proxy_labels == hard_certificate_labels).float().mean().item()\n",
    "\n",
    "            conditional_validity = 1 - ((proxy_prediction_clean == True)[eval_true_labels == False].sum().item() / eval_mask.sum().item())\n",
    "            marginal_validity = 1 - (((proxy_prediction_clean == True) & (eval_true_labels == False)).sum().item() / eval_mask.sum().item())\n",
    "            print(f\"Certified ratio (proxy, clean): {certified_ratio_clean_proxy:.4f}, proxy label accuracy: {proxy_label_acc:.4f}, conditional validity: {conditional_validity:.4f} , marginal validity: {marginal_validity:.4f}\")\n",
    "\n",
    "            eval_y_top_pert = estimate_y_top(smooth_prediction_pert.logits[eval_mask], n_0=n_0, dim=2)\n",
    "            proxy_prediction_pert, proxy_labels_pert = proxy_certificate.certify_proxy(smooth_prediction_pert.logits[eval_mask], y_top=eval_y_top_pert, dim=2)\n",
    "            eval_true_labels_pert, hard_certificate_labels_pert = proxy_certificate.hard_certify(smooth_prediction_pert.logits[eval_mask][:, n_0:, :], dim=2)\n",
    "\n",
    "            certified_ratio_pert = eval_true_labels_pert.float().mean().item()\n",
    "            certified_ratio_pert_proxy = proxy_prediction_pert.float().mean().item()\n",
    "            proxy_label_acc_pert = (proxy_labels_pert == hard_certificate_labels_pert).float().mean().item()\n",
    "            conditional_validity_pert = 1 - ((proxy_prediction_pert == True)[eval_true_labels_pert == False].sum().item() / eval_mask.sum().item())\n",
    "            marginal_validity_pert = 1 - (((proxy_prediction_pert == True) & (eval_true_labels_pert == False)).sum().item() / eval_mask.sum().item())\n",
    "            print(f\"Certified ratio (proxy, perturbed): {certified_ratio_pert_proxy:.4f}, proxy label accuracy: {proxy_label_acc_pert:.4f}, conditional validity: {conditional_validity_pert:.4f} , marginal validity: {marginal_validity_pert:.4f}\")\n",
    "            \n",
    "            results.append({\n",
    "                \"setup\": setup,\n",
    "                \"risk_type\": \"marginal\",\n",
    "                \"trial\": i,\n",
    "                \"r\": r,\n",
    "                \"smoothing_sigma\": smoothing_sigma,\n",
    "                \"n_samples\": n_samples,\n",
    "                \n",
    "                \"nominal_coverage\": marginal_coverage,\n",
    "                \"robust_nominal_coverage\": proxy_certificate.marginal_guarantee,\n",
    "                \n",
    "                \"certified_ratio_clean_proxy\": certified_ratio_clean_proxy,\n",
    "                \"certified_ratio_pert_proxy\": certified_ratio_pert_proxy,\n",
    "\n",
    "                \"certified_ratio_clean\": certified_ratio_clean,\n",
    "                \"certified_ratio_pert\": certified_ratio_pert,\n",
    "\n",
    "                # \"proxy_label_acc\": proxy_label_acc,\n",
    "                \"conditional_validity\": conditional_validity,\n",
    "                \"marginal_validity\": marginal_validity,\n",
    "                \"proxy_label_acc_pert\": proxy_label_acc_pert,\n",
    "                \"conditional_validity_pert\": conditional_validity_pert,\n",
    "                \"marginal_validity_pert\": marginal_validity_pert,\n",
    "                \"acceptable_risk\": proxy_certificate._acceptable_risk,\n",
    "                \"acceptable_lambd\": proxy_certificate._acceptable_lambd,\n",
    "            })\n",
    "results_convex_df = pd.DataFrame(results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>trial</th>\n",
       "      <th>r</th>\n",
       "      <th>smoothing_sigma</th>\n",
       "      <th>n_samples</th>\n",
       "      <th>robust_nominal_coverage</th>\n",
       "      <th>certified_ratio_clean_proxy</th>\n",
       "      <th>certified_ratio_pert_proxy</th>\n",
       "      <th>certified_ratio_clean</th>\n",
       "      <th>certified_ratio_pert</th>\n",
       "      <th>conditional_validity</th>\n",
       "      <th>marginal_validity</th>\n",
       "      <th>proxy_label_acc_pert</th>\n",
       "      <th>conditional_validity_pert</th>\n",
       "      <th>marginal_validity_pert</th>\n",
       "      <th>acceptable_risk</th>\n",
       "      <th>acceptable_lambd</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nominal_coverage</th>\n",
       "      <th>setup</th>\n",
       "      <th>risk_type</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.60</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.774379</td>\n",
       "      <td>0.664648</td>\n",
       "      <td>0.625879</td>\n",
       "      <td>0.643848</td>\n",
       "      <td>0.595996</td>\n",
       "      <td>0.889062</td>\n",
       "      <td>0.889062</td>\n",
       "      <td>0.944824</td>\n",
       "      <td>0.875293</td>\n",
       "      <td>0.875293</td>\n",
       "      <td>0.225366</td>\n",
       "      <td>4.649901e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.647949</td>\n",
       "      <td>0.600098</td>\n",
       "      <td>0.647949</td>\n",
       "      <td>0.647949</td>\n",
       "      <td>0.942871</td>\n",
       "      <td>0.600098</td>\n",
       "      <td>0.600098</td>\n",
       "      <td>0.366341</td>\n",
       "      <td>5.960464e-08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.70</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.847177</td>\n",
       "      <td>0.566113</td>\n",
       "      <td>0.528418</td>\n",
       "      <td>0.644434</td>\n",
       "      <td>0.597266</td>\n",
       "      <td>0.934082</td>\n",
       "      <td>0.934082</td>\n",
       "      <td>0.941797</td>\n",
       "      <td>0.926074</td>\n",
       "      <td>0.926074</td>\n",
       "      <td>0.152195</td>\n",
       "      <td>5.849057e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.700000</td>\n",
       "      <td>0.936133</td>\n",
       "      <td>0.920801</td>\n",
       "      <td>0.645801</td>\n",
       "      <td>0.597070</td>\n",
       "      <td>0.702539</td>\n",
       "      <td>0.702539</td>\n",
       "      <td>0.940332</td>\n",
       "      <td>0.663965</td>\n",
       "      <td>0.663965</td>\n",
       "      <td>0.299512</td>\n",
       "      <td>1.445975e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.80</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.910141</td>\n",
       "      <td>0.453027</td>\n",
       "      <td>0.396680</td>\n",
       "      <td>0.650977</td>\n",
       "      <td>0.598926</td>\n",
       "      <td>0.963672</td>\n",
       "      <td>0.963672</td>\n",
       "      <td>0.940625</td>\n",
       "      <td>0.966895</td>\n",
       "      <td>0.966895</td>\n",
       "      <td>0.089756</td>\n",
       "      <td>7.272440e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.809375</td>\n",
       "      <td>0.780566</td>\n",
       "      <td>0.646191</td>\n",
       "      <td>0.601172</td>\n",
       "      <td>0.803223</td>\n",
       "      <td>0.803223</td>\n",
       "      <td>0.939648</td>\n",
       "      <td>0.778906</td>\n",
       "      <td>0.778906</td>\n",
       "      <td>0.199024</td>\n",
       "      <td>3.027473e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.90</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.962589</td>\n",
       "      <td>0.323047</td>\n",
       "      <td>0.281641</td>\n",
       "      <td>0.638379</td>\n",
       "      <td>0.587891</td>\n",
       "      <td>0.985938</td>\n",
       "      <td>0.985938</td>\n",
       "      <td>0.941211</td>\n",
       "      <td>0.984863</td>\n",
       "      <td>0.984863</td>\n",
       "      <td>0.037073</td>\n",
       "      <td>8.607546e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.900000</td>\n",
       "      <td>0.640234</td>\n",
       "      <td>0.603027</td>\n",
       "      <td>0.642090</td>\n",
       "      <td>0.592871</td>\n",
       "      <td>0.901074</td>\n",
       "      <td>0.901074</td>\n",
       "      <td>0.940723</td>\n",
       "      <td>0.886523</td>\n",
       "      <td>0.886523</td>\n",
       "      <td>0.099512</td>\n",
       "      <td>4.917556e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.95</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.984018</td>\n",
       "      <td>0.237109</td>\n",
       "      <td>0.199023</td>\n",
       "      <td>0.638672</td>\n",
       "      <td>0.591113</td>\n",
       "      <td>0.995801</td>\n",
       "      <td>0.995801</td>\n",
       "      <td>0.943848</td>\n",
       "      <td>0.993457</td>\n",
       "      <td>0.993457</td>\n",
       "      <td>0.015610</td>\n",
       "      <td>9.294368e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>4.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.950000</td>\n",
       "      <td>0.490723</td>\n",
       "      <td>0.441504</td>\n",
       "      <td>0.639160</td>\n",
       "      <td>0.591309</td>\n",
       "      <td>0.955566</td>\n",
       "      <td>0.955566</td>\n",
       "      <td>0.937012</td>\n",
       "      <td>0.950586</td>\n",
       "      <td>0.950586</td>\n",
       "      <td>0.049756</td>\n",
       "      <td>6.769572e-01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    trial     r  smoothing_sigma  n_samples  \\\n",
       "nominal_coverage setup   risk_type                                            \n",
       "0.60             robust  marginal     4.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     4.5  0.25              0.5    10000.0   \n",
       "0.70             robust  marginal     4.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     4.5  0.25              0.5    10000.0   \n",
       "0.80             robust  marginal     4.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     4.5  0.25              0.5    10000.0   \n",
       "0.90             robust  marginal     4.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     4.5  0.25              0.5    10000.0   \n",
       "0.95             robust  marginal     4.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     4.5  0.25              0.5    10000.0   \n",
       "\n",
       "                                    robust_nominal_coverage  \\\n",
       "nominal_coverage setup   risk_type                            \n",
       "0.60             robust  marginal                  0.774379   \n",
       "                 vanilla marginal                  0.600000   \n",
       "0.70             robust  marginal                  0.847177   \n",
       "                 vanilla marginal                  0.700000   \n",
       "0.80             robust  marginal                  0.910141   \n",
       "                 vanilla marginal                  0.800000   \n",
       "0.90             robust  marginal                  0.962589   \n",
       "                 vanilla marginal                  0.900000   \n",
       "0.95             robust  marginal                  0.984018   \n",
       "                 vanilla marginal                  0.950000   \n",
       "\n",
       "                                    certified_ratio_clean_proxy  \\\n",
       "nominal_coverage setup   risk_type                                \n",
       "0.60             robust  marginal                      0.664648   \n",
       "                 vanilla marginal                      1.000000   \n",
       "0.70             robust  marginal                      0.566113   \n",
       "                 vanilla marginal                      0.936133   \n",
       "0.80             robust  marginal                      0.453027   \n",
       "                 vanilla marginal                      0.809375   \n",
       "0.90             robust  marginal                      0.323047   \n",
       "                 vanilla marginal                      0.640234   \n",
       "0.95             robust  marginal                      0.237109   \n",
       "                 vanilla marginal                      0.490723   \n",
       "\n",
       "                                    certified_ratio_pert_proxy  \\\n",
       "nominal_coverage setup   risk_type                               \n",
       "0.60             robust  marginal                     0.625879   \n",
       "                 vanilla marginal                     1.000000   \n",
       "0.70             robust  marginal                     0.528418   \n",
       "                 vanilla marginal                     0.920801   \n",
       "0.80             robust  marginal                     0.396680   \n",
       "                 vanilla marginal                     0.780566   \n",
       "0.90             robust  marginal                     0.281641   \n",
       "                 vanilla marginal                     0.603027   \n",
       "0.95             robust  marginal                     0.199023   \n",
       "                 vanilla marginal                     0.441504   \n",
       "\n",
       "                                    certified_ratio_clean  \\\n",
       "nominal_coverage setup   risk_type                          \n",
       "0.60             robust  marginal                0.643848   \n",
       "                 vanilla marginal                0.647949   \n",
       "0.70             robust  marginal                0.644434   \n",
       "                 vanilla marginal                0.645801   \n",
       "0.80             robust  marginal                0.650977   \n",
       "                 vanilla marginal                0.646191   \n",
       "0.90             robust  marginal                0.638379   \n",
       "                 vanilla marginal                0.642090   \n",
       "0.95             robust  marginal                0.638672   \n",
       "                 vanilla marginal                0.639160   \n",
       "\n",
       "                                    certified_ratio_pert  \\\n",
       "nominal_coverage setup   risk_type                         \n",
       "0.60             robust  marginal               0.595996   \n",
       "                 vanilla marginal               0.600098   \n",
       "0.70             robust  marginal               0.597266   \n",
       "                 vanilla marginal               0.597070   \n",
       "0.80             robust  marginal               0.598926   \n",
       "                 vanilla marginal               0.601172   \n",
       "0.90             robust  marginal               0.587891   \n",
       "                 vanilla marginal               0.592871   \n",
       "0.95             robust  marginal               0.591113   \n",
       "                 vanilla marginal               0.591309   \n",
       "\n",
       "                                    conditional_validity  marginal_validity  \\\n",
       "nominal_coverage setup   risk_type                                            \n",
       "0.60             robust  marginal               0.889062           0.889062   \n",
       "                 vanilla marginal               0.647949           0.647949   \n",
       "0.70             robust  marginal               0.934082           0.934082   \n",
       "                 vanilla marginal               0.702539           0.702539   \n",
       "0.80             robust  marginal               0.963672           0.963672   \n",
       "                 vanilla marginal               0.803223           0.803223   \n",
       "0.90             robust  marginal               0.985938           0.985938   \n",
       "                 vanilla marginal               0.901074           0.901074   \n",
       "0.95             robust  marginal               0.995801           0.995801   \n",
       "                 vanilla marginal               0.955566           0.955566   \n",
       "\n",
       "                                    proxy_label_acc_pert  \\\n",
       "nominal_coverage setup   risk_type                         \n",
       "0.60             robust  marginal               0.944824   \n",
       "                 vanilla marginal               0.942871   \n",
       "0.70             robust  marginal               0.941797   \n",
       "                 vanilla marginal               0.940332   \n",
       "0.80             robust  marginal               0.940625   \n",
       "                 vanilla marginal               0.939648   \n",
       "0.90             robust  marginal               0.941211   \n",
       "                 vanilla marginal               0.940723   \n",
       "0.95             robust  marginal               0.943848   \n",
       "                 vanilla marginal               0.937012   \n",
       "\n",
       "                                    conditional_validity_pert  \\\n",
       "nominal_coverage setup   risk_type                              \n",
       "0.60             robust  marginal                    0.875293   \n",
       "                 vanilla marginal                    0.600098   \n",
       "0.70             robust  marginal                    0.926074   \n",
       "                 vanilla marginal                    0.663965   \n",
       "0.80             robust  marginal                    0.966895   \n",
       "                 vanilla marginal                    0.778906   \n",
       "0.90             robust  marginal                    0.984863   \n",
       "                 vanilla marginal                    0.886523   \n",
       "0.95             robust  marginal                    0.993457   \n",
       "                 vanilla marginal                    0.950586   \n",
       "\n",
       "                                    marginal_validity_pert  acceptable_risk  \\\n",
       "nominal_coverage setup   risk_type                                            \n",
       "0.60             robust  marginal                 0.875293         0.225366   \n",
       "                 vanilla marginal                 0.600098         0.366341   \n",
       "0.70             robust  marginal                 0.926074         0.152195   \n",
       "                 vanilla marginal                 0.663965         0.299512   \n",
       "0.80             robust  marginal                 0.966895         0.089756   \n",
       "                 vanilla marginal                 0.778906         0.199024   \n",
       "0.90             robust  marginal                 0.984863         0.037073   \n",
       "                 vanilla marginal                 0.886523         0.099512   \n",
       "0.95             robust  marginal                 0.993457         0.015610   \n",
       "                 vanilla marginal                 0.950586         0.049756   \n",
       "\n",
       "                                    acceptable_lambd  \n",
       "nominal_coverage setup   risk_type                    \n",
       "0.60             robust  marginal       4.649901e-01  \n",
       "                 vanilla marginal       5.960464e-08  \n",
       "0.70             robust  marginal       5.849057e-01  \n",
       "                 vanilla marginal       1.445975e-01  \n",
       "0.80             robust  marginal       7.272440e-01  \n",
       "                 vanilla marginal       3.027473e-01  \n",
       "0.90             robust  marginal       8.607546e-01  \n",
       "                 vanilla marginal       4.917556e-01  \n",
       "0.95             robust  marginal       9.294368e-01  \n",
       "                 vanilla marginal       6.769572e-01  "
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_convex_df.groupby([\"nominal_coverage\", \"setup\", \"risk_type\"]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_convex_df.to_csv(output_dir/f\"convex_marginal_binary_certificate_proxy_r{r}_sigma{smoothing_sigma}_n{n_samples}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Beta Sampling Scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ProxyNonconvexRobustnessCertificate(ProxyRobustnessCertificate):\n",
    "    def __init__(self, ci=0.999, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.ci = ci\n",
    "    \n",
    "    def adjust_robust_marginal_guarantee(self):\n",
    "        adjusted_guarantee = deepcopy(self._marginal_guarantee)\n",
    "        return adjusted_guarantee\n",
    "    \n",
    "    def tune_proxy_certificate(self, logits, dim=2, verbose=True):\n",
    "        certified, y_top_certified = self.hard_certify(logits, r_coef= 1 if self.proxy_setup == \"vanilla\" else 2, dim=dim)\n",
    "        certified = certified.to(device)\n",
    "        y_top_certified = y_top_certified.to(device)\n",
    "\n",
    "        lambd_min = 0.0\n",
    "        lambd_max = 1.0\n",
    "        while lambd_max - lambd_min > 1e-7:\n",
    "            lambd_mid = (lambd_min + lambd_max) / 2\n",
    "            certified_estim, y_top_estim = self.proxy_certificate(logits, lambd_mid, y_top=y_top_certified, vectorized=True)\n",
    "            \n",
    "            robust_acc = ((certified_estim == certified.reshape(-1, 1)).float().mean(0).mean().item())\n",
    "            \n",
    "            if self.risk_setup == \"marginal\":\n",
    "                # risk = ((~certified) & (certified_estim)) | (y_top_estim != y_top_certified)\n",
    "                risk = ((~certified.reshape(-1, 1)) & (certified_estim)) # | (y_top_estim != y_top_certified) TODO: RETURN\n",
    "                risk_votes = risk.float().sum(-1)\n",
    "                upper_vote_proba = clopper_pearson_upper(risk_votes.cpu(), logits.shape[1], alpha=1 - self.ci)\n",
    "                threat_proba = norm.cdf(norm.ppf(upper_vote_proba, scale=self.smoothing_sigma) + self.r, scale=self.smoothing_sigma)\n",
    "                \n",
    "                # risk = (risk.float().sum().item() + 1) / (logits.shape[0] + 1)\n",
    "                risk = (threat_proba.sum().item() + 1.0) / (logits.shape[0] + 1)\n",
    "            else:\n",
    "                raise NotImplementedError(\"Non-convex risk setup only implemented for marginal risk.\")\n",
    "\n",
    "            if self._acceptable_risk is None and risk <= (1 - self.marginal_guarantee):\n",
    "                self._acceptable_risk = risk\n",
    "                self._acceptable_lambd = lambd_mid\n",
    "            elif self._acceptable_risk is not None and risk >= self._acceptable_risk and risk <= (1 - self.marginal_guarantee):\n",
    "                self._acceptable_risk = risk\n",
    "                self._acceptable_lambd = lambd_mid\n",
    "\n",
    "            if risk > (1 - self.marginal_guarantee):\n",
    "                lambd_min = lambd_mid\n",
    "            else:\n",
    "                lambd_max = lambd_mid\n",
    "\n",
    "            if verbose:    \n",
    "                print(f\"lambda: {lambd_mid}, robust accuracy: {robust_acc:.4f}, risk: {risk:.4f}\")\n",
    "        if verbose:\n",
    "            if self._acceptable_lambd is not None:\n",
    "                print(f\"Found acceptable lambda: {self._acceptable_lambd} with risk: {self._acceptable_risk}\")\n",
    "            else: \n",
    "                print(\"No acceptable lambda found.\")\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      " nominal marginal guarantee: 0.9, adjusted robust marginal guarantee: 0.9\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.5020, proxy label accuracy: 1.0000, conditional validity: 0.9434 , marginal validity: 0.9434\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:14<00:14, 14.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.4717, proxy label accuracy: 0.9492, conditional validity: 0.9473 , marginal validity: 0.9473\n",
      "Certified ratio (proxy, clean): 0.5088, proxy label accuracy: 1.0000, conditional validity: 0.9561 , marginal validity: 0.9561\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:29<00:00, 14.93s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.4512, proxy label accuracy: 0.9453, conditional validity: 0.9551 , marginal validity: 0.9551\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.9, adjusted robust marginal guarantee: 0.9\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.3789, proxy label accuracy: 1.0000, conditional validity: 0.9756 , marginal validity: 0.9756\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:14<00:14, 14.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.3389, proxy label accuracy: 0.9404, conditional validity: 0.9844 , marginal validity: 0.9844\n",
      "Certified ratio (proxy, clean): 0.3506, proxy label accuracy: 1.0000, conditional validity: 0.9766 , marginal validity: 0.9766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:29<00:00, 14.78s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.3174, proxy label accuracy: 0.9395, conditional validity: 0.9814 , marginal validity: 0.9814\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.95, adjusted robust marginal guarantee: 0.95\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.3682, proxy label accuracy: 1.0000, conditional validity: 0.9814 , marginal validity: 0.9814\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:14<00:14, 14.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.3398, proxy label accuracy: 0.9404, conditional validity: 0.9746 , marginal validity: 0.9746\n",
      "Certified ratio (proxy, clean): 0.3975, proxy label accuracy: 1.0000, conditional validity: 0.9775 , marginal validity: 0.9775\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:29<00:00, 14.87s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.3408, proxy label accuracy: 0.9492, conditional validity: 0.9814 , marginal validity: 0.9814\n",
      "\n",
      "\n",
      " nominal marginal guarantee: 0.95, adjusted robust marginal guarantee: 0.95\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, clean): 0.2715, proxy label accuracy: 1.0000, conditional validity: 0.9902 , marginal validity: 0.9902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:14<00:14, 14.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.2109, proxy label accuracy: 0.9375, conditional validity: 0.9893 , marginal validity: 0.9893\n",
      "Certified ratio (proxy, clean): 0.2793, proxy label accuracy: 1.0000, conditional validity: 0.9902 , marginal validity: 0.9902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:29<00:00, 14.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Certified ratio (proxy, perturbed): 0.2295, proxy label accuracy: 0.9502, conditional validity: 0.9912 , marginal validity: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "guarantee_range = [0.6, 0.7, 0.8, 0.9, 0.95]\n",
    "\n",
    "results = []\n",
    "\n",
    "for marginal_coverage in guarantee_range:\n",
    "    for setup in [\"vanilla\", \"robust\"]:\n",
    "        proxy_certificate = ProxyNonconvexRobustnessCertificate(\n",
    "            r=r,\n",
    "            smoothing_sigma=smoothing_sigma,\n",
    "            marginal_guarantee=marginal_coverage,\n",
    "            certificate_setup={\"confidence\": confidence, \"n_0\": n_0}, \n",
    "            proxy_certificate=single_sample_proxy_certificate,\n",
    "            proxy_setup=setup, risk_setup=\"marginal\"\n",
    "        )\n",
    "\n",
    "        print(f\"\\n\\n nominal marginal guarantee: {marginal_coverage}, adjusted robust marginal guarantee: {proxy_certificate.marginal_guarantee}\")\n",
    "        print(\"\\n\\n\")\n",
    "\n",
    "        for i in tqdm(range(10)):\n",
    "            cal_mask = get_cal_mask(smooth_prediction.logits, fraction=calibration_budget)\n",
    "            eval_mask = ~cal_mask\n",
    "\n",
    "            proxy_certificate.tune_proxy_certificate(smooth_prediction.logits[cal_mask], dim=2, verbose=False)\n",
    "            eval_y_top = estimate_y_top(smooth_prediction.logits[eval_mask], n_0=n_0, dim=2)\n",
    "            proxy_prediction_clean, proxy_labels = proxy_certificate.certify_proxy(smooth_prediction.logits[eval_mask], dim=2, y_top=eval_y_top)\n",
    "            eval_true_labels, hard_certificate_labels = proxy_certificate.hard_certify(smooth_prediction.logits[eval_mask][:, n_0:, :], dim=2, y_top=eval_y_top)\n",
    "\n",
    "            certified_ratio_clean = eval_true_labels.float().mean().item()\n",
    "            certified_ratio_clean_proxy = proxy_prediction_clean.float().mean().item()\n",
    "            proxy_label_acc = (proxy_labels == hard_certificate_labels).float().mean().item()\n",
    "\n",
    "            conditional_validity = 1 - ((proxy_prediction_clean == True)[eval_true_labels == False].sum().item() / eval_mask.sum().item())\n",
    "            marginal_validity = 1 - (((proxy_prediction_clean == True) & (eval_true_labels == False)).sum().item() / eval_mask.sum().item())\n",
    "            print(f\"Certified ratio (proxy, clean): {certified_ratio_clean_proxy:.4f}, proxy label accuracy: {proxy_label_acc:.4f}, conditional validity: {conditional_validity:.4f} , marginal validity: {marginal_validity:.4f}\")\n",
    "\n",
    "            eval_y_top_pert = estimate_y_top(smooth_prediction_pert.logits[eval_mask], n_0=n_0, dim=2)\n",
    "            proxy_prediction_pert, proxy_labels_pert = proxy_certificate.certify_proxy(smooth_prediction_pert.logits[eval_mask], y_top=eval_y_top_pert, dim=2)\n",
    "            eval_true_labels_pert, hard_certificate_labels_pert = proxy_certificate.hard_certify(smooth_prediction_pert.logits[eval_mask][:, n_0:, :], dim=2)\n",
    "\n",
    "            certified_ratio_pert = eval_true_labels_pert.float().mean().item()\n",
    "            certified_ratio_pert_proxy = proxy_prediction_pert.float().mean().item()\n",
    "            proxy_label_acc_pert = (proxy_labels_pert == hard_certificate_labels_pert).float().mean().item()\n",
    "            conditional_validity_pert = 1 - ((proxy_prediction_pert == True)[eval_true_labels_pert == False].sum().item() / eval_mask.sum().item())\n",
    "            marginal_validity_pert = 1 - (((proxy_prediction_pert == True) & (eval_true_labels_pert == False)).sum().item() / eval_mask.sum().item())\n",
    "            print(f\"Certified ratio (proxy, perturbed): {certified_ratio_pert_proxy:.4f}, proxy label accuracy: {proxy_label_acc_pert:.4f}, conditional validity: {conditional_validity_pert:.4f} , marginal validity: {marginal_validity_pert:.4f}\")\n",
    "            \n",
    "            results.append({\n",
    "                \"setup\": setup,\n",
    "                \"risk_type\": \"marginal\",\n",
    "                \"trial\": i,\n",
    "                \"r\": r,\n",
    "                \"smoothing_sigma\": smoothing_sigma,\n",
    "                \"n_samples\": n_samples,\n",
    "                \n",
    "                \"nominal_coverage\": marginal_coverage,\n",
    "                \"robust_nominal_coverage\": proxy_certificate.marginal_guarantee,\n",
    "                \n",
    "                \"certified_ratio_clean_proxy\": certified_ratio_clean_proxy,\n",
    "                \"certified_ratio_pert_proxy\": certified_ratio_pert_proxy,\n",
    "\n",
    "                \"certified_ratio_clean\": certified_ratio_clean,\n",
    "                \"certified_ratio_pert\": certified_ratio_pert,\n",
    "\n",
    "                # \"proxy_label_acc\": proxy_label_acc,\n",
    "                \"conditional_validity\": conditional_validity,\n",
    "                \"marginal_validity\": marginal_validity,\n",
    "                \"proxy_label_acc_pert\": proxy_label_acc_pert,\n",
    "                \"conditional_validity_pert\": conditional_validity_pert,\n",
    "                \"marginal_validity_pert\": marginal_validity_pert,\n",
    "                \"acceptable_risk\": proxy_certificate._acceptable_risk,\n",
    "                \"acceptable_lambd\": proxy_certificate._acceptable_lambd,\n",
    "            })\n",
    "results_non_convex_df = pd.DataFrame(results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>trial</th>\n",
       "      <th>r</th>\n",
       "      <th>smoothing_sigma</th>\n",
       "      <th>n_samples</th>\n",
       "      <th>robust_nominal_coverage</th>\n",
       "      <th>certified_ratio_clean_proxy</th>\n",
       "      <th>certified_ratio_pert_proxy</th>\n",
       "      <th>certified_ratio_clean</th>\n",
       "      <th>certified_ratio_pert</th>\n",
       "      <th>conditional_validity</th>\n",
       "      <th>marginal_validity</th>\n",
       "      <th>proxy_label_acc_pert</th>\n",
       "      <th>conditional_validity_pert</th>\n",
       "      <th>marginal_validity_pert</th>\n",
       "      <th>acceptable_risk</th>\n",
       "      <th>acceptable_lambd</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nominal_coverage</th>\n",
       "      <th>setup</th>\n",
       "      <th>risk_type</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.90</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.90</td>\n",
       "      <td>0.364746</td>\n",
       "      <td>0.328125</td>\n",
       "      <td>0.638184</td>\n",
       "      <td>0.591309</td>\n",
       "      <td>0.976074</td>\n",
       "      <td>0.976074</td>\n",
       "      <td>0.939941</td>\n",
       "      <td>0.982910</td>\n",
       "      <td>0.982910</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.812341</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.90</td>\n",
       "      <td>0.505371</td>\n",
       "      <td>0.461426</td>\n",
       "      <td>0.646484</td>\n",
       "      <td>0.602539</td>\n",
       "      <td>0.949707</td>\n",
       "      <td>0.949707</td>\n",
       "      <td>0.947266</td>\n",
       "      <td>0.951172</td>\n",
       "      <td>0.951172</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.660024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.95</th>\n",
       "      <th>robust</th>\n",
       "      <th>marginal</th>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.95</td>\n",
       "      <td>0.275391</td>\n",
       "      <td>0.220215</td>\n",
       "      <td>0.643555</td>\n",
       "      <td>0.594727</td>\n",
       "      <td>0.990234</td>\n",
       "      <td>0.990234</td>\n",
       "      <td>0.943848</td>\n",
       "      <td>0.990234</td>\n",
       "      <td>0.990234</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.909685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vanilla</th>\n",
       "      <th>marginal</th>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>0.95</td>\n",
       "      <td>0.382812</td>\n",
       "      <td>0.340332</td>\n",
       "      <td>0.637695</td>\n",
       "      <td>0.592773</td>\n",
       "      <td>0.979492</td>\n",
       "      <td>0.979492</td>\n",
       "      <td>0.944824</td>\n",
       "      <td>0.978027</td>\n",
       "      <td>0.978027</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.797353</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    trial     r  smoothing_sigma  n_samples  \\\n",
       "nominal_coverage setup   risk_type                                            \n",
       "0.90             robust  marginal     0.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     0.5  0.25              0.5    10000.0   \n",
       "0.95             robust  marginal     0.5  0.25              0.5    10000.0   \n",
       "                 vanilla marginal     0.5  0.25              0.5    10000.0   \n",
       "\n",
       "                                    robust_nominal_coverage  \\\n",
       "nominal_coverage setup   risk_type                            \n",
       "0.90             robust  marginal                      0.90   \n",
       "                 vanilla marginal                      0.90   \n",
       "0.95             robust  marginal                      0.95   \n",
       "                 vanilla marginal                      0.95   \n",
       "\n",
       "                                    certified_ratio_clean_proxy  \\\n",
       "nominal_coverage setup   risk_type                                \n",
       "0.90             robust  marginal                      0.364746   \n",
       "                 vanilla marginal                      0.505371   \n",
       "0.95             robust  marginal                      0.275391   \n",
       "                 vanilla marginal                      0.382812   \n",
       "\n",
       "                                    certified_ratio_pert_proxy  \\\n",
       "nominal_coverage setup   risk_type                               \n",
       "0.90             robust  marginal                     0.328125   \n",
       "                 vanilla marginal                     0.461426   \n",
       "0.95             robust  marginal                     0.220215   \n",
       "                 vanilla marginal                     0.340332   \n",
       "\n",
       "                                    certified_ratio_clean  \\\n",
       "nominal_coverage setup   risk_type                          \n",
       "0.90             robust  marginal                0.638184   \n",
       "                 vanilla marginal                0.646484   \n",
       "0.95             robust  marginal                0.643555   \n",
       "                 vanilla marginal                0.637695   \n",
       "\n",
       "                                    certified_ratio_pert  \\\n",
       "nominal_coverage setup   risk_type                         \n",
       "0.90             robust  marginal               0.591309   \n",
       "                 vanilla marginal               0.602539   \n",
       "0.95             robust  marginal               0.594727   \n",
       "                 vanilla marginal               0.592773   \n",
       "\n",
       "                                    conditional_validity  marginal_validity  \\\n",
       "nominal_coverage setup   risk_type                                            \n",
       "0.90             robust  marginal               0.976074           0.976074   \n",
       "                 vanilla marginal               0.949707           0.949707   \n",
       "0.95             robust  marginal               0.990234           0.990234   \n",
       "                 vanilla marginal               0.979492           0.979492   \n",
       "\n",
       "                                    proxy_label_acc_pert  \\\n",
       "nominal_coverage setup   risk_type                         \n",
       "0.90             robust  marginal               0.939941   \n",
       "                 vanilla marginal               0.947266   \n",
       "0.95             robust  marginal               0.943848   \n",
       "                 vanilla marginal               0.944824   \n",
       "\n",
       "                                    conditional_validity_pert  \\\n",
       "nominal_coverage setup   risk_type                              \n",
       "0.90             robust  marginal                    0.982910   \n",
       "                 vanilla marginal                    0.951172   \n",
       "0.95             robust  marginal                    0.990234   \n",
       "                 vanilla marginal                    0.978027   \n",
       "\n",
       "                                    marginal_validity_pert  acceptable_risk  \\\n",
       "nominal_coverage setup   risk_type                                            \n",
       "0.90             robust  marginal                 0.982910             0.10   \n",
       "                 vanilla marginal                 0.951172             0.10   \n",
       "0.95             robust  marginal                 0.990234             0.05   \n",
       "                 vanilla marginal                 0.978027             0.05   \n",
       "\n",
       "                                    acceptable_lambd  \n",
       "nominal_coverage setup   risk_type                    \n",
       "0.90             robust  marginal           0.812341  \n",
       "                 vanilla marginal           0.660024  \n",
       "0.95             robust  marginal           0.909685  \n",
       "                 vanilla marginal           0.797353  "
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_non_convex_df.groupby([\"nominal_coverage\", \"setup\", \"risk_type\"]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_non_convex_df.to_csv(output_dir/f\"beta_marginal_binary_certificate_proxy_r{r}_sigma{smoothing_sigma}_n{n_samples}.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
