{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from scipy.stats import norm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pathlib\n",
    "import os\n",
    "\n",
    "\n",
    "from bin_cp.helpers.storage import load_smooth_prediction\n",
    "from bin_cp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor\n",
    "from bin_cp.robust.confidence import bernstein_bound, dkw_cdf\n",
    "from bin_cp.robust.confidence import clopper_pearson_lower\n",
    "from bin_cp.robust.bounds import mean_bounds_l2, CDF_bounds_l2\n",
    "\n",
    "from bin_cp.cp.core import ConformalClassifier as CP\n",
    "from bin_cp.cp.scores import APSScore, TPSScore, LogitScore\n",
    "\n",
    "from bin_cp.methods.robust_cp import RobustCP, VanillaSmoothCP\n",
    "from bin_cp.methods.cas import CAS\n",
    "from bin_cp.methods.bincp import BinCP\n",
    "from bin_cp.methods.rcp_one import RCP1\n",
    "# from qrcp.methods.binary import QRCPThresholds\n",
    "import time\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [],
   "source": [
    "#region primary configs of the experiment\n",
    "\n",
    "output_dir = \"../../results/\"\n",
    "\n",
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.25\n",
    "n_classes=10\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.25\n",
    "n_samples = 10000\n",
    "n_trial_samples = 10000\n",
    "\n",
    "score_method = \"TPS\"\n",
    "calibration_budget = 0.1\n",
    "n_iterations = 100\n",
    "\n",
    "confidence = 0.999\n",
    "coverage_range = [0.85, 0.9, 0.95]\n",
    "# coverage_range = [0.9]\n",
    "r_range = [0.06, 0.12, 0.18, 0.25, 0.37, 0.5, 0.75]\n",
    "# r_range = [0.12,  0.25, 0.5, ]\n",
    "\n",
    "#endregion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = pathlib.Path(output_dir)/dataset_name\n",
    "output_dir.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = pathlib.Path(\"./Robust_CP/models/\")\n",
    "dataset_dir = pathlib.Path(\"./datasets/\")\n",
    "logits_dir = pathlib.Path(\"./Robust_CP/logits/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cifar10 dataset with 2048 datapoints and 10000 samples: Score method: TPS\n"
     ]
    }
   ],
   "source": [
    "#region loding smooth logit predictions\n",
    "if n_samples < n_trial_samples:\n",
    "    print(f\"Number of trial samples is set to {n_trial_samples} as it is smaller than the number of samples.\")\n",
    "    \n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "try:\n",
    "    smooth_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "        model_sigma=model_sigma,\n",
    "        n_datapoints=n_datapoints,\n",
    "        smoothing_sigma=smoothing_sigma,\n",
    "        n_samples=n_samples,\n",
    "        models_dir=models_dir,\n",
    "        dataset_dir=dataset_dir,\n",
    "        logits_dir=logits_dir,)\n",
    "    n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "except FileNotFoundError as e:\n",
    "    print(\"Smooth predictions not found, you can generate them using bin/smooth_logits_clean.py\")\n",
    "    print(\"Full description of the error: \", e)\n",
    "#endregion\n",
    "\n",
    "# region computing the conformal scores.\n",
    "score_pipeline = [\n",
    "    TPSScore(softmax=True) if score_method == \"TPS\" else APSScore(softmax=True)] # defining the score function\n",
    "cp = CP(score_pipeline=score_pipeline, coverage_guarantee=0.9) # the guarantee can vary later by cp.coverage_guarantee\n",
    "smooth_scores = get_smooth_scores(smooth_prediction.logits, cp, mean=False)\n",
    "smooth_scores = smooth_scores[:, :, :n_trial_samples]\n",
    "y_true_mask = F.one_hot(smooth_prediction.y_true, num_classes=10).bool().to(device)\n",
    "print(f\"Loading {dataset_name} dataset with {n_datapoints} datapoints and {n_samples} samples: Score method: {score_method}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 337,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = 0.12\n",
    "coverage_guarantee = 0.9\n",
    "internal_sample_rate = 150\n",
    "\n",
    "smooth_scores_cal = smooth_scores[:, :, :internal_sample_rate]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 338,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running the experiment for r=0.12 and coverage=0.9\n"
     ]
    }
   ],
   "source": [
    "cal_mask = get_cal_mask(smooth_scores, calibration_budget)\n",
    "n_dcal = cal_mask.sum().item()\n",
    "\n",
    "cp_methods = {\n",
    "    \"CAS\": CAS(smoothing_sigma=smoothing_sigma, confidence_level=confidence, n_dcal=n_dcal, n_classes=n_classes, nominal_coverage=coverage_guarantee, r=r),\n",
    "    \"BinCP\": BinCP(smoothing_sigma=smoothing_sigma, confidence_level=confidence, n_dcal=n_dcal, n_classes=n_classes,\n",
    "                p_base=0.8, scheme=\"guass\", nominal_coverage=coverage_guarantee, r=r),\n",
    "    \"RCP1\": RCP1(smoothing_sigma=smoothing_sigma, n_dcal=n_dcal, n_classes=n_classes, nominal_coverage=coverage_guarantee, r=r),\n",
    "}\n",
    "\n",
    "# for method_name, method in cp_methods.items():\n",
    "    # print(f\"Pre-computing {method_name} for r={r} and coverage={coverage_guarantee}\")\n",
    "    # method.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "\n",
    "print(f\"Running the experiment for r={r} and coverage={coverage_guarantee}\")\n",
    "\n",
    "cal_mask = get_cal_mask(smooth_scores, calibration_budget)\n",
    "eval_mask = ~cal_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 339,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:12<00:00, 81.51it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([1000, 1844, 10])"
      ]
     },
     "execution_count": 339,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "method_name = \"BinCP\"\n",
    "method = cp_methods[method_name]\n",
    "\n",
    "threshold = method.calibrate_from_scores(smooth_scores_cal[cal_mask], smooth_prediction.y_true[cal_mask])\n",
    "\n",
    "random_iterations_max = n_trial_samples // internal_sample_rate\n",
    "pred_set_groups = []\n",
    "for i in tqdm(range(1000)):\n",
    "    smooth_scores_eval = smooth_scores[:, :, torch.randperm(n_trial_samples)[:internal_sample_rate]]\n",
    "    pred_set = method.predict_from_scores(smooth_scores_eval[eval_mask], return_scores=False)\n",
    "    covered = (pred_set[torch.arange(pred_set.shape[0]), smooth_prediction.y_true[eval_mask]])\n",
    "    pred_set_groups.append(pred_set)\n",
    "pred_set_groups = torch.stack(pred_set_groups, dim=0)\n",
    "pred_set_groups.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 340,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1813, 0.0000, 0.4708,  ..., 0.3611, 0.4903, 0.0706],\n",
       "        [0.4046, 0.0000, 0.3964,  ..., 0.4841, 0.0000, 0.4467],\n",
       "        [0.1435, 0.0000, 0.0000,  ..., 0.4955, 0.0000, 0.0000],\n",
       "        ...,\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 340,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sorted_indices = pred_set_groups.float().sum(-1).var(0).sort(0, descending=True).indices\n",
    "bincp_sorted_pred_set_groups = pred_set_groups[:, sorted_indices, :]\n",
    "bincp_sorted_pred_set_groups.float().std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 341,
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_cp_set_size_probs = F.one_hot(pred_set_groups.sum(-1), num_classes=n_classes + 1).float().mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 342,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:00<00:00, 7940.75it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([1000, 1844, 10])"
      ]
     },
     "execution_count": 342,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "method_name = \"RCP1\"\n",
    "method = cp_methods[method_name]\n",
    "\n",
    "threshold = method.calibrate_from_scores(smooth_scores_cal[cal_mask], smooth_prediction.y_true[cal_mask])\n",
    "random_iterations_max = n_trial_samples // internal_sample_rate\n",
    "pred_set_groups = []\n",
    "for i in tqdm(range(1000)):\n",
    "    smooth_scores_eval = smooth_scores[:, :, i]\n",
    "    pred_set = method.predict_from_scores(smooth_scores_eval[eval_mask], return_scores=False)\n",
    "    covered = (pred_set[torch.arange(pred_set.shape[0]), smooth_prediction.y_true[eval_mask]])\n",
    "    # print(f\"Coverage: {covered.float().mean()}\")\n",
    "    pred_set_groups.append(pred_set)\n",
    "pred_set_groups = torch.stack(pred_set_groups, dim=0)\n",
    "pred_set_groups.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 343,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_set_mean_groups = []\n",
    "\n",
    "for i in range(1000):\n",
    "    smooth_scores_eval = smooth_scores[:, :, torch.randperm(n_trial_samples)[:internal_sample_rate]].mean(-1)\n",
    "    pred_set_mean = method.predict_from_scores(smooth_scores_eval[eval_mask], return_scores=False)\n",
    "    covered = (pred_set_mean[torch.arange(pred_set.shape[0]), smooth_prediction.y_true[eval_mask]])\n",
    "    pred_set_mean_groups.append(pred_set_mean)\n",
    "pred_set_mean_groups = torch.stack(pred_set_mean_groups, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 344,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([3.1432, 3.1367, 3.1198, 3.1274, 3.1242, 3.1188, 3.1410, 3.1253, 3.1269,\n",
       "        3.1405, 3.1426, 3.1345, 3.1258, 3.1323, 3.1405, 3.1410, 3.1307, 3.1383,\n",
       "        3.1264, 3.1388, 3.1302, 3.1264, 3.1388, 3.1302, 3.1329, 3.1220, 3.1253,\n",
       "        3.1274, 3.1464, 3.1356, 3.1372, 3.1318, 3.1247, 3.1166, 3.1361, 3.1247,\n",
       "        3.1415, 3.1361, 3.1453, 3.1388, 3.1280, 3.1405, 3.1350, 3.1405, 3.1312,\n",
       "        3.1264, 3.1415, 3.1209, 3.1264, 3.1258, 3.1209, 3.1302, 3.1215, 3.1410,\n",
       "        3.1274, 3.1296, 3.1269, 3.1274, 3.1345, 3.1302, 3.1117, 3.1264, 3.1280,\n",
       "        3.1350, 3.1334, 3.1383, 3.1350, 3.1269, 3.1253, 3.1377, 3.1350, 3.1361,\n",
       "        3.1269, 3.1329, 3.1280, 3.1329, 3.1247, 3.1356, 3.1356, 3.1437, 3.1410,\n",
       "        3.1226, 3.1307, 3.1383, 3.1280, 3.1318, 3.1204, 3.1264, 3.1394, 3.1350,\n",
       "        3.1318, 3.1367, 3.1339, 3.1345, 3.1264, 3.1377, 3.1302, 3.1323, 3.1318,\n",
       "        3.1361, 3.1307, 3.1339, 3.1296, 3.1215, 3.1302, 3.1285, 3.1377, 3.1285,\n",
       "        3.1318, 3.1291, 3.1312, 3.1415, 3.1323, 3.1410, 3.1350, 3.1377, 3.1323,\n",
       "        3.1215, 3.1198, 3.1345, 3.1312, 3.1285, 3.1166, 3.1302, 3.1388, 3.1345,\n",
       "        3.1405, 3.1339, 3.1345, 3.1274, 3.1323, 3.1285, 3.1361, 3.1204, 3.1394,\n",
       "        3.1236, 3.1399, 3.1345, 3.1405, 3.1377, 3.1274, 3.1285, 3.1285, 3.1367,\n",
       "        3.1356, 3.1329, 3.1356, 3.1329, 3.1226, 3.1367, 3.1318, 3.1226, 3.1345,\n",
       "        3.1350, 3.1361, 3.1405, 3.1253, 3.1318, 3.1415, 3.1432, 3.1247, 3.1253,\n",
       "        3.1231, 3.1329, 3.1247, 3.1258, 3.1318, 3.1307, 3.1345, 3.1307, 3.1405,\n",
       "        3.1285, 3.1264, 3.1377, 3.1421, 3.1323, 3.1361, 3.1253, 3.1253, 3.1280,\n",
       "        3.1258, 3.1258, 3.1361, 3.1307, 3.1269, 3.1377, 3.1372, 3.1302, 3.1394,\n",
       "        3.1345, 3.1372, 3.1383, 3.1258, 3.1269, 3.1323, 3.1209, 3.1323, 3.1274,\n",
       "        3.1345, 3.1372, 3.1280, 3.1388, 3.1269, 3.1231, 3.1361, 3.1361, 3.1491,\n",
       "        3.1329, 3.1226, 3.1231, 3.1339, 3.1291, 3.1529, 3.1410, 3.1236, 3.1307,\n",
       "        3.1312, 3.1231, 3.1329, 3.1415, 3.1388, 3.1361, 3.1226, 3.1296, 3.1312,\n",
       "        3.1296, 3.1236, 3.1350, 3.1323, 3.1388, 3.1367, 3.1296, 3.1307, 3.1258,\n",
       "        3.1285, 3.1274, 3.1291, 3.1302, 3.1323, 3.1285, 3.1361, 3.1302, 3.1367,\n",
       "        3.1242, 3.1264, 3.1264, 3.1226, 3.1291, 3.1285, 3.1296, 3.1215, 3.1307,\n",
       "        3.1242, 3.1323, 3.1220, 3.1318, 3.1367, 3.1432, 3.1377, 3.1415, 3.1236,\n",
       "        3.1350, 3.1253, 3.1285, 3.1329, 3.1329, 3.1285, 3.1345, 3.1242, 3.1323,\n",
       "        3.1318, 3.1258, 3.1209, 3.1367, 3.1339, 3.1356, 3.1334, 3.1356, 3.1383,\n",
       "        3.1323, 3.1307, 3.1302, 3.1345, 3.1231, 3.1443, 3.1307, 3.1285, 3.1302,\n",
       "        3.1247, 3.1280, 3.1329, 3.1258, 3.1329, 3.1388, 3.1361, 3.1399, 3.1367,\n",
       "        3.1394, 3.1312, 3.1285, 3.1350, 3.1307, 3.1339, 3.1312, 3.1323, 3.1453,\n",
       "        3.1470, 3.1253, 3.1302, 3.1291, 3.1274, 3.1339, 3.1426, 3.1388, 3.1356,\n",
       "        3.1405, 3.1302, 3.1312, 3.1323, 3.1345, 3.1247, 3.1388, 3.1459, 3.1318,\n",
       "        3.1377, 3.1367, 3.1274, 3.1312, 3.1253, 3.1394, 3.1318, 3.1302, 3.1307,\n",
       "        3.1312, 3.1274, 3.1350, 3.1318, 3.1247, 3.1296, 3.1264, 3.1426, 3.1312,\n",
       "        3.1280, 3.1285, 3.1329, 3.1318, 3.1405, 3.1350, 3.1209, 3.1247, 3.1410,\n",
       "        3.1361, 3.1334, 3.1182, 3.1280, 3.1361, 3.1345, 3.1269, 3.1318, 3.1285,\n",
       "        3.1339, 3.1356, 3.1302, 3.1318, 3.1307, 3.1372, 3.1247, 3.1242, 3.1437,\n",
       "        3.1274, 3.1269, 3.1318, 3.1383, 3.1231, 3.1377, 3.1345, 3.1383, 3.1448,\n",
       "        3.1302, 3.1318, 3.1285, 3.1285, 3.1426, 3.1350, 3.1269, 3.1291, 3.1226,\n",
       "        3.1247, 3.1388, 3.1198, 3.1339, 3.1264, 3.1453, 3.1426, 3.1269, 3.1269,\n",
       "        3.1215, 3.1405, 3.1285, 3.1377, 3.1323, 3.1513, 3.1405, 3.1177, 3.1405,\n",
       "        3.1339, 3.1367, 3.1443, 3.1312, 3.1296, 3.1318, 3.1274, 3.1291, 3.1296,\n",
       "        3.1345, 3.1339, 3.1323, 3.1383, 3.1264, 3.1258, 3.1372, 3.1302, 3.1421,\n",
       "        3.1258, 3.1307, 3.1518, 3.1280, 3.1285, 3.1350, 3.1193, 3.1296, 3.1318,\n",
       "        3.1182, 3.1312, 3.1269, 3.1247, 3.1285, 3.1312, 3.1296, 3.1204, 3.1285,\n",
       "        3.1361, 3.1285, 3.1394, 3.1388, 3.1350, 3.1334, 3.1291, 3.1258, 3.1171,\n",
       "        3.1437, 3.1307, 3.1367, 3.1296, 3.1426, 3.1350, 3.1329, 3.1361, 3.1334,\n",
       "        3.1383, 3.1269, 3.1372, 3.1258, 3.1405, 3.1367, 3.1198, 3.1339, 3.1323,\n",
       "        3.1383, 3.1280, 3.1334, 3.1312, 3.1497, 3.1226, 3.1329, 3.1323, 3.1318,\n",
       "        3.1307, 3.1361, 3.1307, 3.1372, 3.1329, 3.1302, 3.1302, 3.1432, 3.1345,\n",
       "        3.1285, 3.1307, 3.1253, 3.1339, 3.1318, 3.1415, 3.1296, 3.1372, 3.1367,\n",
       "        3.1405, 3.1421, 3.1291, 3.1247, 3.1470, 3.1242, 3.1329, 3.1356, 3.1258,\n",
       "        3.1356, 3.1345, 3.1274, 3.1383, 3.1361, 3.1339, 3.1312, 3.1274, 3.1291,\n",
       "        3.1323, 3.1345, 3.1166, 3.1405, 3.1372, 3.1285, 3.1291, 3.1394, 3.1421,\n",
       "        3.1226, 3.1339, 3.1264, 3.1302, 3.1426, 3.1350, 3.1410, 3.1280, 3.1253,\n",
       "        3.1399, 3.1405, 3.1383, 3.1318, 3.1443, 3.1253, 3.1388, 3.1226, 3.1415,\n",
       "        3.1296, 3.1296, 3.1307, 3.1323, 3.1318, 3.1350, 3.1323, 3.1345, 3.1318,\n",
       "        3.1350, 3.1242, 3.1236, 3.1307, 3.1307, 3.1285, 3.1296, 3.1198, 3.1323,\n",
       "        3.1307, 3.1329, 3.1350, 3.1220, 3.1372, 3.1198, 3.1399, 3.1171, 3.1334,\n",
       "        3.1334, 3.1302, 3.1318, 3.1350, 3.1415, 3.1258, 3.1421, 3.1421, 3.1215,\n",
       "        3.1399, 3.1296, 3.1350, 3.1356, 3.1253, 3.1383, 3.1323, 3.1312, 3.1182,\n",
       "        3.1383, 3.1399, 3.1253, 3.1415, 3.1312, 3.1513, 3.1329, 3.1291, 3.1220,\n",
       "        3.1339, 3.1242, 3.1253, 3.1166, 3.1296, 3.1318, 3.1334, 3.1253, 3.1350,\n",
       "        3.1339, 3.1280, 3.1323, 3.1437, 3.1264, 3.1312, 3.1258, 3.1318, 3.1302,\n",
       "        3.1367, 3.1274, 3.1399, 3.1356, 3.1302, 3.1242, 3.1264, 3.1415, 3.1350,\n",
       "        3.1334, 3.1231, 3.1302, 3.1296, 3.1209, 3.1475, 3.1269, 3.1312, 3.1377,\n",
       "        3.1464, 3.1405, 3.1215, 3.1339, 3.1231, 3.1415, 3.1318, 3.1296, 3.1177,\n",
       "        3.1285, 3.1377, 3.1405, 3.1350, 3.1258, 3.1350, 3.1323, 3.1334, 3.1339,\n",
       "        3.1231, 3.1296, 3.1264, 3.1253, 3.1291, 3.1334, 3.1274, 3.1258, 3.1339,\n",
       "        3.1394, 3.1464, 3.1307, 3.1307, 3.1220, 3.1285, 3.1399, 3.1367, 3.1318,\n",
       "        3.1329, 3.1345, 3.1280, 3.1345, 3.1312, 3.1269, 3.1367, 3.1258, 3.1388,\n",
       "        3.1307, 3.1307, 3.1334, 3.1361, 3.1269, 3.1302, 3.1280, 3.1372, 3.1204,\n",
       "        3.1415, 3.1302, 3.1264, 3.1307, 3.1253, 3.1361, 3.1280, 3.1329, 3.1345,\n",
       "        3.1329, 3.1329, 3.1242, 3.1312, 3.1334, 3.1329, 3.1258, 3.1367, 3.1220,\n",
       "        3.1334, 3.1247, 3.1312, 3.1258, 3.1291, 3.1269, 3.1329, 3.1383, 3.1253,\n",
       "        3.1302, 3.1399, 3.1247, 3.1247, 3.1291, 3.1356, 3.1448, 3.1307, 3.1307,\n",
       "        3.1312, 3.1280, 3.1296, 3.1264, 3.1339, 3.1302, 3.1339, 3.1388, 3.1421,\n",
       "        3.1312, 3.1329, 3.1426, 3.1334, 3.1372, 3.1410, 3.1339, 3.1307, 3.1377,\n",
       "        3.1269, 3.1356, 3.1372, 3.1302, 3.1198, 3.1383, 3.1372, 3.1302, 3.1312,\n",
       "        3.1280, 3.1334, 3.1372, 3.1296, 3.1247, 3.1372, 3.1236, 3.1226, 3.1307,\n",
       "        3.1339, 3.1274, 3.1339, 3.1388, 3.1361, 3.1383, 3.1264, 3.1323, 3.1339,\n",
       "        3.1459, 3.1432, 3.1193, 3.1372, 3.1377, 3.1280, 3.1280, 3.1285, 3.1285,\n",
       "        3.1274, 3.1410, 3.1247, 3.1280, 3.1285, 3.1345, 3.1361, 3.1329, 3.1291,\n",
       "        3.1345, 3.1312, 3.1307, 3.1329, 3.1361, 3.1383, 3.1296, 3.1269, 3.1285,\n",
       "        3.1307, 3.1421, 3.1264, 3.1383, 3.1323, 3.1334, 3.1367, 3.1253, 3.1307,\n",
       "        3.1334, 3.1339, 3.1296, 3.1307, 3.1215, 3.1367, 3.1350, 3.1372, 3.1302,\n",
       "        3.1312, 3.1388, 3.1356, 3.1329, 3.1177, 3.1166, 3.1377, 3.1220, 3.1269,\n",
       "        3.1356, 3.1280, 3.1264, 3.1296, 3.1356, 3.1296, 3.1383, 3.1356, 3.1356,\n",
       "        3.1209, 3.1415, 3.1264, 3.1280, 3.1426, 3.1215, 3.1296, 3.1296, 3.1367,\n",
       "        3.1312, 3.1274, 3.1291, 3.1302, 3.1334, 3.1329, 3.1253, 3.1448, 3.1356,\n",
       "        3.1372, 3.1334, 3.1415, 3.1367, 3.1307, 3.1188, 3.1193, 3.1269, 3.1437,\n",
       "        3.1394, 3.1459, 3.1383, 3.1329, 3.1372, 3.1296, 3.1296, 3.1280, 3.1285,\n",
       "        3.1258, 3.1405, 3.1323, 3.1253, 3.1399, 3.1350, 3.1361, 3.1291, 3.1329,\n",
       "        3.1285, 3.1323, 3.1323, 3.1334, 3.1307, 3.1394, 3.1426, 3.1291, 3.1329,\n",
       "        3.1334, 3.1361, 3.1367, 3.1394, 3.1296, 3.1274, 3.1350, 3.1426, 3.1318,\n",
       "        3.1312, 3.1399, 3.1220, 3.1307, 3.1198, 3.1264, 3.1274, 3.1231, 3.1318,\n",
       "        3.1312, 3.1334, 3.1312, 3.1291, 3.1377, 3.1318, 3.1182, 3.1421, 3.1247,\n",
       "        3.1258, 3.1226, 3.1410, 3.1318, 3.1323, 3.1285, 3.1226, 3.1291, 3.1296,\n",
       "        3.1296, 3.1399, 3.1247, 3.1220, 3.1242, 3.1399, 3.1312, 3.1231, 3.1258,\n",
       "        3.1383, 3.1312, 3.1296, 3.1426, 3.1302, 3.1453, 3.1280, 3.1323, 3.1253,\n",
       "        3.1350, 3.1182, 3.1318, 3.1318, 3.1258, 3.1220, 3.1426, 3.1291, 3.1361,\n",
       "        3.1356, 3.1394, 3.1367, 3.1399, 3.1280, 3.1226, 3.1220, 3.1198, 3.1269,\n",
       "        3.1367, 3.1193, 3.1372, 3.1188, 3.1291, 3.1323, 3.1318, 3.1334, 3.1323,\n",
       "        3.1318, 3.1296, 3.1285, 3.1350, 3.1226, 3.1296, 3.1226, 3.1204, 3.1448,\n",
       "        3.1399, 3.1307, 3.1242, 3.1312, 3.1448, 3.1329, 3.1307, 3.1339, 3.1345,\n",
       "        3.1339, 3.1307, 3.1253, 3.1323, 3.1285, 3.1356, 3.1329, 3.1410, 3.1432,\n",
       "        3.1367, 3.1264, 3.1198, 3.1291, 3.1356, 3.1247, 3.1307, 3.1307, 3.1367,\n",
       "        3.1312], device='cuda:0')"
      ]
     },
     "execution_count": 344,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_set_mean_groups.sum(-1).float().mean(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 345,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4882, 0.3877, 0.4903,  ..., 0.4991, 0.4935, 0.4986],\n",
       "        [0.4952, 0.0000, 0.4781,  ..., 0.4971, 0.0891, 0.4982],\n",
       "        [0.4957, 0.0000, 0.0000,  ..., 0.5002, 0.2805, 0.0447],\n",
       "        ...,\n",
       "        [0.0000, 0.0000, 0.0834,  ..., 0.1401, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0547, 0.0000,  ..., 0.0000, 0.0000, 0.3118],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 345,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sorted_indices = pred_set_groups.float().sum(-1).var(0).sort(0, descending=True).indices\n",
    "rcp1_sorted_pred_set_groups = pred_set_groups[:, sorted_indices, :]\n",
    "rcp1_sorted_pred_set_groups.float().std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 346,
   "metadata": {},
   "outputs": [],
   "source": [
    "rcp1_set_size_probs = F.one_hot(pred_set_groups.sum(-1), num_classes=n_classes + 1).float().mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 347,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0316, 0.2201, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        ...,\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0632],\n",
       "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 347,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sorted_indices = pred_set_mean_groups.float().sum(-1).var(0).sort(0, descending=True).indices\n",
    "rcp1_sorted_pred_set_mean_groups = pred_set_mean_groups[:, sorted_indices, :]\n",
    "rcp1_sorted_pred_set_mean_groups.float().std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 348,
   "metadata": {},
   "outputs": [],
   "source": [
    "rcp1_mean_set_size_probs = F.one_hot(pred_set_mean_groups.sum(-1), num_classes=n_classes + 1).float().mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 349,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:ylabel='Density'>"
      ]
     },
     "execution_count": 349,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAD4CAYAAAD//dEpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAutUlEQVR4nO3de3xU9Z3/8dcn95Abl9wvXMMduQZEQQEVQcWilSJatbbusq662+62uz+3u71s99Juu9ttt7a1Wl1t6+LdSquIogKiIgTkEi7hTkhIyBWSEHKdz++PGWyMCZnAzJzJzOf5eMxjZs45M/NmOMkn53u+3+8RVcUYY4zpTYTTAYwxxvQPVjCMMcZ4xQqGMcYYr1jBMMYY4xUrGMYYY7wS5XQAX0pNTdXhw4c7HcMYY/qNbdu2VatqmjfbhlTBGD58OIWFhU7HMMaYfkNEjnu7rTVJGWOM8YoVDGOMMV6xgmGMMcYrVjCMMcZ4xQqGMcYYr1jBMMYY4xUrGMYYY7xiBcMYh3S47NICpn8JqYF7xgS7j47U8PP1h9lVepqG5nbGZSZx85Rs7r1yOHHRkU7HM+aCrGAYEwBtHS6+/WoRq7acIC0plhsmZZEcH8X243X8YM1+frf5OI/dXcCE7GSnoxrTI78VDBF5ElgCVKrqJM+y54Cxnk0GAqdVdWo3rz0GNAAdQLuqFvgrpzH+1truYuVvC1lfXMX980bxtetGf+po4v1D1Xz9+Z0s/9WHPH5PAVeMGuJgWmN65s9zGE8BizsvUNXbVXWqp0i8BLx8gdcv8GxrxcL0W6rKt35fxPriKv791st4+IZxn2l6mpOfyssPXElmShwrf1PIocoGh9Iac2F+KxiquhGo7W6diAiwHFjlr883Jhg881EJzxWe4KEF+dx5+dAet8seGM/TX5lFbHQkX3mqkDNNbQFMaYx3nOoldRVwSlUP9rBegTdFZJuIrLzQG4nIShEpFJHCqqoqnwc15mKV1DTxb6/t4+oxafztwjG9bp8zMJ7H7plB2elz/PMf9gQgoTF941TBuIMLH13MVdXpwA3AgyJydU8bqupjqlqgqgVpaV5N6W6M36kqD7+8i8gI4Qefv4yICPHqddOHDuLBBfm8/HEZa/dU+DmlMX0T8IIhIlHA54HnetpGVcs895XAK8CswKQzxjde213OB4drePiGcWQPjO/Tax9akM+ErGS+8+oemlrb/ZTQmL5z4gjjOmC/qpZ2t1JEEkQk6fxj4HqgKID5jLkkzW0d/GDNfsZnJXPHrJ7PW/QkJiqCf146kYr6Zh7dcMQPCY25OH4rGCKyCvgQGCsipSJyn2fVCro0R4lItoi87nmaAWwSkZ3AFuA1VX3DXzmN8bXffHiM0rpz/NNN44n0simqq5nDB7Nkcha/2nCYk6fP+TihMRfHb+MwVPWOHpbf282yk8CNnsdHgCn+ymWMPzW3dfDYxiNcNTqVOfmpl/ReD98wjrV7Knjk3UP8+62X+SihMRfP5pIyxodeKDxBdWMrDy7Iv+T3yh00gBUzh/L81hOcqG3yQTpjLo0VDGN8pL3Dxa82HmH60IFcPmKwT97zgQWjiIgQHnnnkE/ez5hLYQXDGB/5465ySuvO8Zfz83GPTb10WSnx3F6Qxysfl1HZ0OyT9zTmYlnBMMYHVJVfrj/MmIxErh2X7tP3/srcEbS5XPzuw+M+fV9j+soKhjE+8M7+SopPNXD/vFFeD9Lz1ojUBK4dl8HvPiqhua3Dp+9tTF9YwTDmEqkqv1h/mJyB8dw8Jdsvn3Hf3BHUnm3llY/L/PL+xnjDCoYxl2jrsTq2Ha9j5dUjiY70z4/U7JGDmZidzBObjqJqV+ozzrCCYcwl+sX6QwxJiGF5QZ7fPkNEuG/uCA5VNrLhgE2yaZxhBcOYS7D3ZD3ri6v4ytwRxMf49xKrSyZnk54UyxObjvr1c4zpiRUMYy7BLzccJjE2irtmD/P7Z8VERXDX7GG8d7CakhobyGcCzwqGMRfpWPVZXtt1ki/OHkpKfHRAPvMLBblECDy7tSQgn2dMZ1YwjLlIj713hKjICO6bMyJgn5mVEs+Csem8sK2Utg5XwD7XGLCCYcxFqaxv5sXCUpbNyCU9OS6gn33HrKFUNbTw9r7KgH6uMVYwjLkIT7x/lHaXi7+4emTAP3v+2DQykmOtWcoEnBUMY/rozLk2ntlcwk2Tsxk2JCHgnx8VGcHtBXlsOFBFmV0rwwSQFQxj+uh/3z9KY0s7988L/NHFectnusd8PLf1hGMZTPixgmFMH5w518aTm45y/YQMJmanOJYjd9AArhqdxguFJ+hw2chvExhWMIzpg6feP0Z9czt/fe1op6OwYmYe5Weaee+gjfw2gWEFwxgv1Te38cSmIyyckMGkHOeOLs67dnw6gwZE80JhqdNRTJiwgmGMl572HF18NQiOLgBioyK5dVoub+6toPZsq9NxTBjwW8EQkSdFpFJEijot+66IlInIDs/txh5eu1hEikXkkIg87K+MxnirobmNX286ynXj04Pi6OK85TNzaetQXt1h054b//PnEcZTwOJulv+3qk713F7vulJEIoGfAzcAE4A7RGSCH3Ma06vHNx7hzLk2vnrtGKejfMq4zGQm56bw3NYTNu258Tu/FQxV3QjUXsRLZwGHVPWIqrYCzwJLfRrOmD44Vd/M4+8dZcnkLC7LDZ6ji/O+UJDH/ooGisrqnY5iQpwT5zAeEpFdniarQd2szwE6dy4v9SzrloisFJFCESmsqrLeIsb3/vutA7S7XPz9onFOR+nW56ZkExsVwfOFNibD+FegC8YvgVHAVKAc+K9LfUNVfUxVC1S1IC0t7VLfzphPOXCqgecLT3D37OEMHTLA6TjdSomPZvGkTF7dUWbX/DZ+FdCCoaqnVLVDVV3A47ibn7oqAzpfuizXs8yYgPvBmv0kxEbxV9fkOx3lgm4vyKO+uZ21eyqcjmJCWEALhohkdXp6K1DUzWZbgdEiMkJEYoAVwOpA5DOmsw8OV/PO/koemJ/PoIQYp+Nc0OyRQ8gdFG9jMoxf+bNb7SrgQ2CsiJSKyH3AD0Vkt4jsAhYAf+PZNltEXgdQ1XbgIWAtsA94XlX3+CunMd3pcCn/+sd95AyM58tzhjsdp1cREcIXZuSx6VA1J2rtanzGP6L89caqekc3i5/oYduTwI2dnr8OfKbLrTGB8tL2UvaW1/PTFVOJi/bvtbp95bYZOfzk7QO8uK2Uv1kYXN1/TWiwkd7GdHG2pZ3/XFvM1LyBfG5KttNxvJY7aABz81N5cVspLpuQ0PiBFQxjuvjVxiNUNrTwrSUTEBGn4/TJ8oI8yk6f44PDNU5HMSHICoYxnZSfOcdjGw+zZHIWM4Z1N0wouC2ckEFKfLSNyTB+YQXDmE5+9EYxLoX/tzg4B+n1Ji46klumZvPGngrONLU5HceEGCsYxnjsKj3Nyx+Xcd/cEeQNDs5Bet74QkEere0uXt1pw5eMb1nBMAZQdXejTU2M4YH5o5yOc0km5aQwISvZmqWMz1nBMAbYdKiaLcdq+eq1o0mKi3Y6ziVbXpBLUVk9e06ecTqKCSFWMEzYU1V+su4g2SlxLJ+Z1/sL+oFbpuUQExlhI7+NT1nBMGFv06Fqth2v4y8X5BMb1T8G6fVm4IAYrp+Ywe93lNHSbhMSGt+wgmHCmqry03UHyUqJY3lBrtNxfGp5QR6nm9p4a+8pp6OYEGEFw4S19w/VUHi8jgfmjwqZo4vz5uSnkp0Sx/PWLGV8xAqGCVuqyk/fPkBmcuicu+gsMkJYNiOX9w5WcfL0OafjmBBgBcOEre0ldWw9Vsf980aG3NHFeV8oyEMVXtpmRxnm0lnBMGHryU3HSI6L4gsFoXd0cV7e4AFcOWoIL9iEhMYHrGCYsFRa18SaonLuuHwoCbF+m+U/KCwvyKOktonNR21CQnNprGCYsPTbzccREe65YrjTUfxu8aRMkuKieG6rjfw2l8YKhgk7bR0uXtpWyrXj0skZGO90HL+Li47ktum5rNldQXVji9NxTD9mBcOEnXf2V1Ld2MrtIdgzqid3zR5Ka4fL5pcyl8QKhgk7z289QXpSLPPGpDkdJWDy05O4YuQQntlcQoed/DYXyQqGCSsVZ5p5t7iSZTNyiYoMr93/7iuGUXb6HBsOVDodxfRTfvuJEZEnRaRSRIo6LfuRiOwXkV0i8oqIDOzhtcdEZLeI7BCRQn9lNOHnpe2luNTdcyjcLJyQQXpSLL/98LjTUUw/5c8/sZ4CFndZ9hYwSVUnAweAf7jA6xeo6lRVLfBTPhNmXC7l+cITzB45mOGpCU7HCbjoyAhWzBrK+gNVlNQ0OR3H9EN+KxiquhGo7bLsTVVt9zzdDITWbG8mqG0+WsPxmqawOtnd1R2z8ogQ4ZktdpRh+s7JRtyvAGt6WKfAmyKyTURWXuhNRGSliBSKSGFVVZXPQ5rQ8dK2MpLiorhhUpbTURyTlRLPwvEZPL/1BM1tNu256RtHCoaI/CPQDjzTwyZzVXU6cAPwoIhc3dN7qepjqlqgqgVpaeHT68X0TUt7B2/uqWDxxEziokNz3ihv3TV7GHVNbawpKnc6iulnAl4wROReYAnwRVXttn+fqpZ57iuBV4BZAQtoQtJ7B6ppaGnnpsnhe3Rx3pWjhjAyNcFOfps+C2jBEJHFwN8Dn1PVbs+6iUiCiCSdfwxcDxR1t60x3nptdzkp8dHMyU91OorjIiKEL84exvaS0xSV2TW/jff82a12FfAhMFZESkXkPuARIAl4y9Nl9lHPttki8rrnpRnAJhHZCWwBXlPVN/yV04S+5rYO3tp7isUTM4kOs7EXPVk2PZe46Ag7yjB94rdpOlX1jm4WP9HDtieBGz2PjwBT/JXLhJ+NB6potOaoT0kZEM2t03J45eMyvnnjeFIGRDsdyfQD9ueWCXmv7S5n0IBorhg1xOkoQeXu2cNpbnPxwjabX8p4xwqGCWnNbR2s23uKxZOyrDmqiwnZyRQMG8RvNx+3iysZr9hPkAlp64urONvawRJrjurWPVcO53hNExsO2hgm0zsrGCakvb67nCEJMVw+YrDTUYLS4omZpCba/FLGO1YwTMhqbXfxbnEl143PCLuZab0VExXBnbPyeLe40uaXMr2ynyITsrYeq6WhuZ3rJmQ4HSWo3Xn5MCJE+N1HdpRhLswKhglZb+09RWxUBHNtsN4FZabEsWhiBs8X2vxS5sKsYJiQpKqs23eKufmpxMeE99xR3rh79nBON7WxeudJp6OYIGYFw4Sk4lMNlNads+YoL80eOZgxGYn85sNj9DDFmzFWMExoWrf3FADXjkt3OEn/ICLcfcVwisrq+fjEaafjmCBlBcOEpLf2VTIlbyDpyXFOR+k3bp2WQ2JslHWxNT2ygmFCTmV9MztPnGbheDu66IvE2Chum57Da7vKqW5scTqOCUJWMEzIeXt/JYCdv7gId18xnNYOF89ttfmlzGdZwTAh5+19p8gZGM/YjCSno/Q7+emJzMkfwjObj9Pe4XI6jgkyXhUMEXlZRG4SESswJqg1t3Xw/qEarhmXjog4Hadfunv2cE6eaf7kSM2Y87wtAL8A7gQOisgPRGSsHzMZc9EKj9Vxrq2D+WPt+u4X67rx6WSnxNnJb/MZXhUMVV2nql8EpgPHgHUi8oGIfFlE7MorJmisL64kJjLCrn1xCaIiI/ji7GFsOlTNocpGp+OYIOJ1E5OIDAHuBf4M+Bj4Ke4C8pZfkhlzEdYfqOLykYMZEOO3i0mGhdtn5hETGcHvNttRhvkTb89hvAK8BwwAblbVz6nqc6r6V0CiPwMa463SuiYOVTYyb4w1R12q1MRYbpqcxUvbSjnb0u50HBMkvD3CeFxVJ6jq91W1HEBEYgFUtcBv6Yzpg/XF7osA2fkL37j7imE0tLTz+x1lTkcxQcLbgvGv3Sz7sLcXiciTIlIpIkWdlg0WkbdE5KDnflAPr/2SZ5uDIvIlL3OaMLbhQBU5A+MZlWYHvb4wLW8g4zKTWLWlxOkoJkhcsGCISKaIzADiRWSaiEz33Objbp7qzVPA4i7LHgbeVtXRwNue510/dzDwHeByYBbwnZ4KizHgvljSB4eqmT82zbrT+oiIcOflQykqq2d36Rmn45gg0NsRxiLgP4Fc4MfAf3lufwt8s7c3V9WNQG2XxUuBpz2PnwZu6eFz31LVWlWtw31ivWvhMeYThcdqOdvawfyxNh2ILy2dmkNcdASrttpRhoELdiVR1aeBp0XkNlV9yUefmXH+PAhQAXQ3f0MO0HluglLPss8QkZXASoChQ4f6KKLpb9YfqCImMoIrrTutT6XER3PTZdms3nGSf7xxPAmx1vssnPXWJHWX5+FwEfnbrrdL/XB1T7x/SZPvq+pjqlqgqgVpaXayM1ytL65k5ohB9gvND+6YlUdjSzt/3GUXVwp3vTVJJXjuE4Gkbm4X45SIZAF47rubf6AMyOv0PNezzJjPOHn6HAdOWXdaf5kxbBCj0xNZtcUmJAx3vTVJ/cpz/88+/MzVwJeAH3juX+1mm7XAv3c60X098A8+zGBCyIYD57vT2vkLfxARVswayr/8cS97T9YzITvZ6UjGId4O3PuhiCSLSLSIvC0iVZ2aqy70ulW4u9+OFZFSEbkPd6FYKCIHges8zxGRAhH5NYCq1gL/Amz13L7nWWbMZ6wvriQ7JY7R6dad1l8+Py2HmKgInrWT32HN23EY16tqPbAE91xS+cDf9fYiVb1DVbNUNVpVc1X1CVWtUdVrVXW0ql53vhCoaqGq/lmn1z6pqvme2//2/Z9mwkFru4v3D9Uwb6zNTutPgxJiuGFSJq98XMa51g6n4xiHeFswzjdd3QS8oKrWKdsEhe0ldTS2tNvo7gC4Y9ZQGprbeW13ee8bm5DkbcH4o4jsB2YAb4tIGtDsv1jGeGfDgSqiIsS60wbA5SMGMzI1wUZ+hzFvpzd/GLgSKFDVNuAs7gF4xjhqQ3EVM4YNIinOZtn3NxHh9pl5bDtex+Eqm/Y8HPXlCnrjgNtF5B5gGe6eS8Y4prK+mb3l9cyz5qiAuXVaDhECL28vdTqKcYC3vaR+i3uKkLnATM/NZqk1jtp4sBrAxl8EUHpyHFePSePl7WV0uC5pzK3ph7wdFlsATPCMzDYmKGw8UEVaUiwTsmxcQCDdNj2Xv1r1MR8ermHu6FSn45gA8rZJqgjI9GcQY/qiw6W8d7CKq0anWnfaAFs4IYOkuChesmapsONtwUgF9orIWhFZff7mz2DGXMjusjPUNbVZc5QD4qIjuXlKNmuKymlobnM6jgkgb5ukvuvPEMb01YbiKkTgqtFWMJywbEYu//dRCWt2V7B8Zl7vLzAhwdtutRtwj/CO9jzeCmz3Yy5jLmjjwSom5w5kcEKM01HC0rS8gYxMTeBFa5YKK972kvpz4EXgV55FOcDv/ZTJmAs609TGxyV1zLMTro4REW6bkcuWo7WcqG1yOo4JEG/PYTwIzAHqAVT1IGBTgxpHbDpUjUux8RcO+9yUbAD+YNfJCBveFowWVW09/0REorjECx8Zc7E2HKgkOS6KKbkDnY4S1vIGD2D60IGs3mEFI1x4WzA2iMg3gXgRWQi8APzBf7GM6Z6qsvFANVeNTiMqsi8TFRh/WDo1h/0VDRw41eB0FBMA3v7EPQxUAbuBvwBeB/7JX6GM6cmBU41U1Ddz9Rg7fxEMbrwsiwjBjjLChLe9pFy4T3I/oKrLVPVxG/VtnLDhgPuKvlfb+IugkJYUy5z8VFbvPIn9Sgh9FywY4vZdEakGioFiz9X2vh2YeMZ82oYDVYzNSCIrJd7pKMbjc1OyKaltYseJ005HMX7W2xHG3+DuHTVTVQer6mDgcmCOiPyN39MZ00lTaztbj9ZZ76ggs2hSJjFREazeac1Soa63gnE3cIeqHj2/QFWPAHcB9/gzmDFdbT5SQ2uHi6ttdHdQSY6L5pqx6fxhZ7nNYBvieisY0apa3XWhqlYBdsUaE1AbiquIj46kYPggp6OYLm6ekk11Ywtbj9U6HcX4UW8Fo/Ui1/VIRMaKyI5Ot3oR+VqXbeaLyJlO29g5E8OGA1VcMWoIcdGRTkcxXSwYl0ZcdARr7HrfIa23yQeniEh9N8sFiLuYD1TVYmAqgIhEAmXAK91s+p6qLrmYzzCh53jNWY7VNPHlOSOcjmK6MSAmigVj01lTVMF3bp5IRIRNOR+KLniEoaqRqprczS1JVX3RJHUtcFhVj/vgvUwI23CgCrDutMFs8aRMKhta2F5S53QU4ydOD5VdAazqYd0VIrJTRNaIyMSe3kBEVopIoYgUVlVV+Selcdw7+ysZkZrAiNQEp6OYHlwzLp2YqAhe313hdBTjJ44VDBGJAT6He5qRrrYDw1R1CvAzLjAzrqo+pqoFqlqQlmZ/fYaiptZ2Pjhcw4KxNt9lMEuKi+bq0WmsKSrHZb2lQpKTRxg3ANtV9VTXFapar6qNnsevA9EiYnNBhKn3D9XQ2u7i2vFWMILdjZdlUn6mmZ2lp52OYvzAyYJxBz00R4lIpngu1Cwis3DnrAlgNhNE3tlfSWJsFDOHD3Y6iunFteMziI4U1hRZs1QocqRgiEgCsBB4udOy+0Xkfs/TZUCRiOwE/gdYYXNXhSdV5d39lVw1OpWYKKdPuZnepMRHMzc/ldd3l9vcUiHI22t6+5SqngWGdFn2aKfHjwCPBDqXCT57y+upqG/mmnHWHNVf3DApi3df2sWek/VMyklxOo7xIfuTzQS1d/ZVIgLz7YR3v3HdhAwiBNbusWapUGMFwwS1t/dXMjl3IGlJsU5HMV4anBDD5SOG8Iadxwg5VjBM0KpubGFn6WmuteaofmfRxAwOVjZyuKrR6SjGh6xgmKC1vrgKVez8RT90/cRMwJqlQo0VDBO03txTQWZyHBOzk52OYvooe2A8U3JTWLvnM8OsTD9mBcMEpabWdjYcqGLRxAw8Q3JMP3P9xEx2njhN+ZlzTkcxPmIFwwSljQeqaGl3sWhSptNRzEVa5GmWetOOMkKGFQwTlN4oqmDQgGhm2ejufis/PZH89EQ7jxFCrGCYoNPa7uLt/ZVcNz6DqEjbRfuzRRMz+OhoLXVnL+p6aybI2E+jCTofHqmhobmdxdYc1e8tnphFh0tZt8+apUKBFQwTdN4oqiAhJpI5+TZBcX83KSeZnIHx1lsqRFjBMEGlw6W8tfcU88el27W7Q4CIsHBCBhsPVnG2pd3pOOYSWcEwQWV7SR3VjS0snmjNUaFi8aRMWttdn1xm1/RfVjBMUFmzu4KYyAjmj7WrJ4aKmcMHMzghxnpLhQArGCZodLiUP+w6yYJxaSTFRTsdx/hIZIRw3fh03tlXSWu7y+k45hJYwTBBY/ORGqoaWlg6NcfpKMbHFk/KpKGlnQ8OVzsdxVwCKxgmaLy6o4zE2CibbDAEXTkqlYSYSGuW6uesYJig0NzWwZrdFSyamGm9o0JQXHQkC8al89beU3S47NKt/ZUVDBMU1hdX0tDSztKp2U5HMX6yaGIm1Y2tbC+pczqKuUhWMExQeHXHSVITY7hy1JDeNzb90vyxacRERtiV+PoxxwqGiBwTkd0iskNECrtZLyLyPyJySER2ich0J3Ia/6tvbuPt/ZUsmZxtc0eFsKS4aOaOTmXtngpUrVmqP3L6p3OBqk5V1YJu1t0AjPbcVgK/DGgyEzBriypobXdZc1QYWDQxg9K6c+w5We90FHMRnC4YF7IU+I26bQYGikiW06GM7724rZRhQwYwNW+g01GMn103PoMIcV9N0fQ/ThYMBd4UkW0isrKb9TnAiU7PSz3LPkVEVopIoYgUVlXZ1AP9zeGqRj46WsvtM/PsynphYEhiLDOHD7bJCPspJwvGXFWdjrvp6UERufpi3kRVH1PVAlUtSEuz6ST6m2e3lBAVISybket0FBMgiyZmUnyqgaPVZ52OYvrIsYKhqmWe+0rgFWBWl03KgLxOz3M9y0yIaGnv4KXtZSyckEF6UpzTcUyAnL/srg3i638cKRgikiAiSecfA9cDRV02Ww3c4+ktNRs4o6rlAY5q/OjNPaeoPdvKillDnY5iAihnYDyX5aRY99p+yKkjjAxgk4jsBLYAr6nqGyJyv4jc79nmdeAIcAh4HHjAmajGX1ZtKSFnYDxX2YWSws6iiRnsOHGaijPNTkcxfRDlxIeq6hFgSjfLH+30WIEHA5nLBM6x6rN8cLiGb1w/hogIO9kdbhZPyuQ/3zzAW3sruPuK4U7HMV4K5m61JoSt2lpCZITwhYK83jc2ISc/PYmRaQm8Yecx+hUrGCbgzra0s+qjEhaOzyAj2U52h6vFEzPZfKSW002tTkcxXrKCYQLu+cIT1De3s3LeSKejGActmphJh0tZt6/S6SjGS1YwTEC1d7h4YtNRCoYNYvrQQU7HMQ6anJtCVkqc9ZbqR6xgmIB6Y08FpXXn+POr7egi3IkIN0zKYuOBKmuW6iesYJiAUVUe33iEEakJXDc+w+k4JgjcNiOH1g4Xq3eedDqK8YIVDBMwW47WsrP0DPfNHUGkdaU1wMTsFMZnJfPitlKnoxgvWMEwAfPLDYcZnBDDbdNt3ijzJ8tm5LKr9AzFFQ1ORzG9sIJhAmLb8VrWF1dx39wRxMfYNbvNn9wyNZuoCOGl7XaUEeysYBi/U1V+tLaY1MQYvjxnuNNxTJAZkhjLNePSeXl7Ge0dLqfjmAuwgmH87v1DNWw+UsuDC/IZEOPIbDQmyC2bkUt1YwsbDtg1bYKZFQzjV6rKj94sJjsljjsvt1lpTfcWjEtnSEIMLxRas1Qws4Jh/Grdvkp2njjNX187mtgoO3dhuhcdGcEt03JYt+8UlfU2g22wsoJh/Katw8UP39jP8CEDuM2uqGd6cdfsYbS7lP/bUuJ0FNMDKxjGb57+4BgHKxv55o3jiY60Xc1c2IjUBOaNSeOZj0pobbeT38HIfoqNX1TWN/OTdQeZPzaNhRNsVLfxzr1XDqeqoYU1RXZxzWBkBcP4xffX7Ke13cV3bp6IiI3qNt6ZNyaNkakJPP7eEdzXUDPBxAqG8bktR2t55eMy/vzqEYxITXA6julHIiKElVePpKisng8O1zgdx3RhBcP4VEt7B9/6fRHZKXE8uCDf6TimH7plWg5pSbE8uuGw01FMF1YwjE/9dN1Bik818C+3TLJBeuaixEVHct/cEbx3sJptx+ucjmM6CXjBEJE8EXlXRPaKyB4R+Wo328wXkTMissNz+3agc5q+215Sx6MbDrO8IJdrbfpycwnunj2MIQkx/GTdAaejmE6cOMJoB76uqhOA2cCDIjKhm+3eU9Wpntv3AhvR9NW51g6+8fxOslLi+daS7v47jfFeQmwU988bxXsHq9lytNbpOMYj4AVDVctVdbvncQOwD8gJdA7jW//xxn6OVJ/lR8smkxQX7XQcEwLumj2M9KRYvr9mn/WYChKOnsMQkeHANOCjblZfISI7RWSNiEy8wHusFJFCESmsqrKJy5zw1t5TPPXBMe69cjhX5qc6HceEiPiYSL5+/Rg+LjnNa7ttXEYwcKxgiEgi8BLwNVWt77J6OzBMVacAPwN+39P7qOpjqlqgqgVpaWl+y2u6d6K2ia8/v4PLclL4hxvHOR3HhJhlM/IYl5nED9bsp7mtw+k4Yc+RgiEi0biLxTOq+nLX9apar6qNnsevA9EiYn+6Bpnmtg4eeGY7Cvz8zuk2uaDxucgI4ds3T6C07hy/WG/dbJ3mRC8pAZ4A9qnqj3vYJtOzHSIyC3dOG8UTZP71tb3sLjvDf31hCkOHDHA6jglRV45K5Zap2Ty6/jCHqxqdjhPWnDjCmAPcDVzTqdvsjSJyv4jc79lmGVAkIjuB/wFWqJ31Cir/91EJv9tcwsqrR3L9xEyn45gQ9483TSAuOoK/f3GXXZXPQQEfWaWqm4ALTi6kqo8AjwQmkemrDw5X8+1Xi5g3Jo2/XzTW6TgmDKQlxfK9pZP42nM7eHTDYR66ZrTTkcKSjfQ2fXKs+iwPPLOd4akJ/OzOaUTZtOUmQJZOzWbJ5Cx+su4gu0pPOx0nLNlPu/Ha6aZW7nt6KwBPfKmAZBtvYQJIRPi3Wy4jNTGWrz23g3Ot1msq0KxgGK80t3Xw578p5ETtOR69awbDhtgstCbwUgZE81/Lp3Ck6iz//Ic9NqAvwKxgmF51uJSvPvsxhcfr+PHtU5g9cojTkUwYm5OfyoMLRvHs1hP85sPjTscJK1YwzAWpKt9ZXcTaPaf41k0TWDI52+lIxvD1hWO5bnw63/vjXt4/VO10nLBhBcP0SFX5jzeK+d3mEv7i6pF8Ze4IpyMZA7gvtPSTFdMYlZbAA89s52j1WacjhQUrGKZHP1l3kEc3HOaLlw/l4Rts2g8TXBJjo/j1PTOJEPjSk1uoONPsdKSQZwXDdOsX6w/x07cPsrwgl39ZOsmuy22C0tAhA3jy3pnUNLbwxV9vprqxxelIIc0KhvkUVeXHbxbzwzeKuXVaDt///GQiIqxYmOA1beggnrx3JmWnz3HXrz+i7myr05FClhUM84kOl/LNV4r4n3cOsWJmHj9aNplIKxamH7h85BB+fc9MjlSf5Y7HN1N+5pzTkUKSFQwDuK+Y9+Az21m1pYQHF4zi+5+/zEZxm35l7uhUnvzSTErrznHLz99nz8kzTkcKOfYbwXC4qpFbfv4+a/dW8E83jefvFo2zcxamX5o7OpUX//IKIkRY/uiHvLP/lNORQooVjDD3h50n+dzPNlHZ0Mz/3juTP7tqpNORjLkk4zKT+f2Dcxg2JIGvPFXId1fvsWlEfMQKRpiqbGjmof/bzl+t+phxWcm89tdXMX9sutOxjPGJjOQ4Xn7gSr48ZzhPfXCMJT97j50nTjsdq9+TUJqLpaCgQAsLC52OEdRON7Xy5PvH+PV7R2jvUB5ckM8DC0YRbecrTIjadLCab7ywk4r6ZpZMzuJvF45hZFqi07GChohsU9UCr7a1ghH62jpcbDlay6s7yli98yTNbS5uuiyLbyway4hUm0TQhL4z59p4fOMRnnz/KC3tLpZOzWZ5QR6zhg8O+27jVjDCVIdLqTnbQmV9C4erGtlX3sD+inoKj9XR2NJOQkwkSyZn85W5IxibmeR0XGMCrqqhhV+sP8TzW09wtrWD3EHx3DwlmytGDmH6sEEkxgb8mnKOs4IRglwupbKhhdK6Jkrrzn1yf6q+mcqGFiobWqhpbMHV6b8zOlIYlZbItKGDmDcmlXlj0omPiXTuH2FMkGhqbWftngpe2lbGh0dq6HApkRHC2Iwk8tMTGZGawPDUAaQlxpGaFENqYiyDBsSE5LgkKxj9XFNrO8UVDewrb2Bv+Rn2lTdQXNFAY0v7p7ZLTYwlMyWW9KQ40pNiSUuK/eR+RKp7p4+JsnMTxlzI2ZZ2tpfUseVoLTtLz3CkqpGy0+fo+qsxQmBwgvvnKzM5lsyUODKS48hMjiN30ABGpSeQmRzX77qk96VghN/xVxBRVcpOn2N/eQP7yuvZV1HP/vIGjtac/WRnTYqNYnxWMrdNzyE/I4m8QfHkDhpA7qB44qLtaMGYS5UQG8VVo9O4anTaJ8ua2zoorTtHdWOL+9bQQnVjK9WNLVQ1tHCqoZndZWeobvz0NCQJMZGMSk9kVFoiYzOTGJeZxLjMZDKSY/tdIemOIwVDRBYDPwUigV+r6g+6rI8FfgPMAGqA21X1WKBz+oKqUt3YSkltEyW1Zzle00RJTRPHa5s4cKqBhuY/HTUMGzKA8ZnJLJ2aw/isJMZnJZM7KD4kdjRj+pO46Ejy0xPJT79wb6rWdheVDc2U1DZxuOoshysbOVzVyOYjNbzycdkn2w0cEP1J8Rif5b4fk5HU75qIA14wRCQS+DmwECgFtorIalXd22mz+4A6Vc0XkRXAfwC3ByKfquJSaHe5cLk+fd+hSnuH0tTaTmNLB2db2mlsaf/k/vxfINUNLVR5/jKpamihuc3V6d8PWclx5A0ewNKp2Z4dKJlxmUkkhOEJN2P6s5ioCM8R/wCuHJX6qXVnmtooPuXueHK+A8rzhSdo8gwiFIHM5DiyB8aTMzCe7IHxZKXEMXBANMnx0STHRZMSH0V8TBTRkUJMZATRn9zEkT8knfgNNQs4pKpHAETkWWAp0LlgLAW+63n8IvCIiIj66YTLtO+9ydmWDjpU6XBd/EeIwKABMaQlxpKaFMP0oYNITYwld1A8w4YMYOjgBGtKMiZMpAyIZtaIwcwaMfiTZS6XcqKu6ZMCcqL2HCdPn2PHidOsKSqnrcP73z+REUKEgIiQlhjL+w9f449/xqc4UTBygBOdnpcCl/e0jaq2i8gZYAjwmWsxishKYKXnaaOIFPs88Z+kdpehs2N+/PA+6jVrELGs/mFZfS8ocx4E5B8+s9jbrMO8/Zx+3waiqo8BjwXis0Sk0NveBE6zrP5hWf2jv2TtLznBP1md6HNZBuR1ep7rWdbtNiISBaTgPvltjDHGIU4UjK3AaBEZISIxwApgdZdtVgNf8jxeBrzjr/MXxhhjvBPwJinPOYmHgLW4u9U+qap7ROR7QKGqrgaeAH4rIoeAWtxFJRgEpOnLRyyrf1hW/+gvWftLTvBD1pAa6W2MMcZ/bN4IY4wxXrGCYYwxxithWTBEJE9E3hWRvSKyR0S+2s0280XkjIjs8Ny+7Vk+ttOyHSJSLyJf86z7roiUdVp3YyCydsq7w7PNhk7LF4tIsYgcEpGHOy0fISIfeZY/5+mA4FjWC702SL/XYyKy27OusNPywSLylogc9NwPcjJrMO6vIvJ3nT6zSEQ6RGSwZ11Q7a89ZQ3G/bWX79U3+6uqht0NyAKmex4nAQeACV22mQ/8sZf3iQQqgGGe598FvuFA1oG4R8oP9TxP75TvMDASiAF2nn8t8DywwvP4UeAvHc7a42uD7Xv1PD4GpHbzvj8EHvY8fhj4D6ezBtv+2mX7m3H3ggzK/fUCWYNuf+0pqy/317A8wlDVclXd7nncAOzDPbq8r64FDqvqcV/m68zLrHcCL6tqiWe7Ss/yT6ZhUdVW4FlgqYgIcA3uaVcAngZucTKrD/9P/J61F0txf58QBN9rF8Gyv3Z2B7DK8zgY99duswbp/tpt1l70aX8Ny4LRmYgMB6YBH3Wz+goR2Skia0RkYjfrV/DZ/5SHRGSXiDzpi+YIL7OOAQaJyHoR2SYi93iWdzcNSw7uaVZOq2p7l+VOZu3ttcH0vQIo8KZn+cpOyzNUtdzzuALICIKs5wXL/np+/QBgMfCSZ1Ew7q89Ze3ttcH0vYKv9ldfHjb1txuQCGwDPt/NumQg0fP4RuBgl/UxuOdpyei0LAP3YXUE8G+4x5gEIusjwGYgAff8MQdx/wJZhnv6+PPb3e3ZNhX3X3Lnl+cBRU5mvdBrg+179azL8dyn4246udrz/HSX96hzOmuw7a+dtrkd+EOn50G3v/aUNRj31wtl9dX+GrZHGCISjbsCP6OqL3ddr6r1qtroefw6EC0inecvvgHYrqqnOr3mlKp2qKoLeBz3Ibbfs+L+i2utqp5V1WpgIzCFnqdhqQEGinvalc7Lncza42uD8HtFVcs895XAK50ynRKRLM/7ZwHeNGP5NatHMO2v53U94gnG/bWnrMG4v/aY1Vf7a1gWDE+b6BPAPlX9cQ/bZHq2Q0Rm4f6uOs9n9Zk2wvNfvMetQFEgsgKvAnNFJMpzOHo57jbObqdhUfefEu/i/osO3NOwvOpk1gu9Nti+VxFJEJEkz/skANd3ytR5WhvHv9dO64Npf0VEUoB5fPr7Ccb9tdusQbq/9pTVd/urrw6V+tMNmIu7TW8XsMNzuxG4H7jfs81DwB7ch2+bgSs7vT4Bd/FI6fK+vwV2e953NZAViKye7f4Ody+ZIuBrnZbfiLtHxWHgHzstHwlsAQ4BLwCxTmbt6bXB+L16vrudntueLt/rEOBt3E1C64DBQbAPBOP+ei/wbDevD8b99TNZg3h/7S6rz/ZXmxrEGGOMV8KyScoYY0zfWcEwxhjjFSsYxhhjvGIFwxhjjFesYBhjjPGKFQxjjDFesYJhjDHGK/8fFMHOA5+l9l4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.kdeplot(pred_set_groups.sum(-1).float().mean(-1).cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 350,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0219, device='cuda:0')"
      ]
     },
     "execution_count": 350,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_set_groups.sum(-1).float().mean(-1).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 351,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_hat, _ = F.one_hot(smooth_scores.argmax(1), n_classes).float().mean(1).max(-1)\n",
    "p_hat = p_hat[eval_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 352,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dict = {\n",
    "    \"BinCP\":  bincp_sorted_pred_set_groups.float().std(0),\n",
    "    \"RCP1\": rcp1_sorted_pred_set_groups.float().std(0),\n",
    "    \"r\": r,\n",
    "    \"sample_rate\": internal_sample_rate,\n",
    "    \"bin_cp_support\": n_trial_samples // internal_sample_rate,\n",
    "    \"rcp1_support\": n_trial_samples,\n",
    "    \"p_hat\": p_hat[sorted_indices],\n",
    "    \"BinCP_set_sizes\": bin_cp_set_size_probs,\n",
    "    \"RCP1_set_sizes\": rcp1_set_size_probs,\n",
    "    \"RCP1_mean_set_sizes\": rcp1_mean_set_size_probs,\n",
    "    \"sorting\": sorted_indices,\n",
    "}\n",
    "torch.save(state_dict, output_dir / f\"noise-effects_{dataset_name}_sigma-{model_sigma}_r-{r}_n_samples_{internal_sample_rate}.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
