{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from scipy.stats import norm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from bin_cp.helpers.storage import load_smooth_prediction\n",
    "from bin_cp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor\n",
    "from bin_cp.robust.confidence import bernstein_bound, dkw_cdf\n",
    "from bin_cp.robust.confidence import clopper_pearson_lower\n",
    "from bin_cp.robust.bounds import mean_bounds_l2, CDF_bounds_l2\n",
    "\n",
    "from bin_cp.cp.core import ConformalClassifier as CP\n",
    "from bin_cp.cp.scores import APSScore, TPSScore\n",
    "\n",
    "from bin_cp.methods.robust_cp import RobustCP, VanillaSmoothCP\n",
    "from bin_cp.methods.cas import CAS\n",
    "from bin_cp.methods.bin import BinCP\n",
    "from bin_cp.methods.binary import BinCPThresholds\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import logging\n",
    "logging.basicConfig(filename='std.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')\n",
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.DEBUG)\n",
    "logger.propagate = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "#region primary configs of the experiment\n",
    "\n",
    "result_folder = \"../../../output-results\"\n",
    "\n",
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.25\n",
    "n_classes=10\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.25\n",
    "n_samples = 10000\n",
    "n_trial_samples = 10\n",
    "\n",
    "score_method = \"APS\"\n",
    "calibration_budget = 0.1\n",
    "n_iterations = 100\n",
    "confidence = 0.99\n",
    "\n",
    "#endregion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cifar10 dataset with 2048 datapoints and 10 samples: Score method: APS\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3388557/2835240085.py:67: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  vanilla_results[vanilla_results[\"coverage_guarantee\"] == 0.9].mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "iteration             49.500000\n",
       "coverage_guarantee     0.900000\n",
       "r                      0.000000\n",
       "smoothing_sigma        0.250000\n",
       "model_sigma            0.250000\n",
       "threshold              0.210366\n",
       "empirical_coverage     0.901123\n",
       "average_set_size       1.505206\n",
       "calibration_budget     0.100000\n",
       "dtype: float64"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coverage_range = [0.9, 0.95]\n",
    "r_range = [0.25]\n",
    "\n",
    "\n",
    "\n",
    "#region loding smooth logit predictions\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "smooth_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "    model_sigma=model_sigma,\n",
    "    n_datapoints=n_datapoints,\n",
    "    smoothing_sigma=smoothing_sigma,\n",
    "    n_samples=n_samples)\n",
    "n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "#endregion\n",
    "\n",
    "#region defining basic setup for conformal evaluation\n",
    "\n",
    "score_pipeline = [\n",
    "    TPSScore(softmax=True) if score_method == \"TPS\" else APSScore(softmax=True)] # defining the score function\n",
    "cp = CP(score_pipeline=score_pipeline, coverage_guarantee=0.9) # the guarantee can vary later by cp.coverage_guarantee\n",
    "smooth_scores = get_smooth_scores(smooth_prediction.logits, cp, mean=False)\n",
    "smooth_scores = smooth_scores[:, :, :n_trial_samples]\n",
    "y_true_mask = F.one_hot(smooth_prediction.y_true, num_classes=10).bool().to(device)\n",
    "mean_scores = smooth_scores.mean(dim=-1)\n",
    "#endregion\n",
    "print(f\"Loading {dataset_name} dataset with {n_datapoints} datapoints and {n_trial_samples} samples: Score method: {score_method}\")\n",
    "\n",
    "cal_mask = get_cal_mask(smooth_scores.mean(dim=-1), calibration_budget)\n",
    "n_dcal = cal_mask.sum().item()\n",
    "\n",
    "vanilla_cp = VanillaSmoothCP(nominal_coverage=0.9)\n",
    "vanilla_results = []\n",
    "\n",
    "vanilla_cp.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "\n",
    "for coverage_guarantee in coverage_range:\n",
    "    r = 0\n",
    "    vanilla_cp.set_nominal_coverage(coverage_guarantee)\n",
    "\n",
    "    for iter_i in range(n_iterations):\n",
    "        cal_mask = get_cal_mask(smooth_scores.mean(dim=-1), calibration_budget)\n",
    "        eval_mask = ~cal_mask\n",
    "        threshold = vanilla_cp.pre_compute_calibrate(cal_mask)\n",
    "        pred_set = vanilla_cp.pre_compute_predict(eval_mask)\n",
    "\n",
    "        empirical_coverage = vanilla_cp.internal_cp.coverage(pred_set, y_true_mask[eval_mask])\n",
    "        average_set_size = pred_set.sum(dim=1).float().mean().item()\n",
    "\n",
    "        vanilla_results.append({\n",
    "            \"method\": \"vanilla\", \n",
    "            \"iteration\": iter_i,\n",
    "            \"coverage_guarantee\": coverage_guarantee,\n",
    "            \"r\": r,\n",
    "            \"smoothing_sigma\": smoothing_sigma,\n",
    "            \"model_sigma\": model_sigma,\n",
    "            \"threshold\": threshold,\n",
    "            \"empirical_coverage\": empirical_coverage,\n",
    "            \"average_set_size\": average_set_size,\n",
    "            \"score_method\": score_method,\n",
    "            \"dataset_name\": dataset_name,\n",
    "            \"calibration_budget\": calibration_budget,\n",
    "        })\n",
    "\n",
    "vanilla_results = pd.DataFrame(vanilla_results)\n",
    "# vanilla_results.to_csv(f\"{result_folder}/vanilla_results-{dataset_name}-smooth{smoothing_sigma}-model{model_sigma}-{score_method}-nsamples{n_trial_samples}.csv\", index=False)\n",
    "vanilla_results[vanilla_results[\"coverage_guarantee\"] == 0.9].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2048, 10, 10])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smooth_scores.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CAS pre-computed\n",
      "bin pre-computed\n",
      "Running for r=0.25, coverage=0.9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 472.04it/s]\n"
     ]
    }
   ],
   "source": [
    "r = 0.25\n",
    "\n",
    "\n",
    "cas_results = []\n",
    "bin_results = []\n",
    "\n",
    "cas_cp = CAS(nominal_coverage=0.9, r=r, smoothing_sigma=smoothing_sigma, confidence_level=confidence, n_dcal=n_dcal, n_classes=n_classes, \n",
    "                        error_correction=False)\n",
    "cas_cp.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "\n",
    "print(\"CAS pre-computed\")\n",
    "\n",
    "bin_cp = BinCP(nominal_coverage=0.9, smoothing_sigma=smoothing_sigma, n_dcal=n_dcal, n_classes=n_classes,\n",
    "                    r=r, confidence_level=confidence,\n",
    "                    error_correction=False,\n",
    "                    p_base=0.5)\n",
    "\n",
    "# bin_cp.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "print(\"bin pre-computed\")\n",
    "\n",
    "# here goes a for\n",
    "coverage_guarantee = 0.9\n",
    "\n",
    "print(f\"Running for r={r}, coverage={coverage_guarantee}\")\n",
    "cas_cp.set_nominal_coverage(coverage_guarantee)\n",
    "bin_cp.set_nominal_coverage(coverage_guarantee)\n",
    "\n",
    "# here goes a for\n",
    "for iter_i in tqdm(range(100)):\n",
    "    cal_mask = get_cal_mask(smooth_scores.mean(dim=-1), calibration_budget)\n",
    "    eval_mask = ~cal_mask\n",
    "\n",
    "    # evaluating cas\n",
    "    threshold_cas = cas_cp.pre_compute_calibrate(cal_mask)\n",
    "    pred_set_cas = cas_cp.pre_compute_predict(eval_mask)\n",
    "\n",
    "    empirical_coverage_cas = cas_cp.internal_cp.coverage(pred_set_cas, y_true_mask[eval_mask])\n",
    "    average_set_size_cas = pred_set_cas.sum(dim=1).float().mean().item()\n",
    "\n",
    "    cas_results.append({\n",
    "        \"method\": \"cas\",\n",
    "        \"coverage_guarantee\": coverage_guarantee,\n",
    "        \"iteration\": iter_i,\n",
    "        \"r\": r,\n",
    "        \"smoothing_sigma\": smoothing_sigma,\n",
    "        \"model_sigma\": model_sigma,\n",
    "        \"threshold\": threshold_cas,\n",
    "        \"empirical_coverage\": empirical_coverage_cas,\n",
    "        \"average_set_size\": average_set_size_cas,\n",
    "        \"score_method\": score_method,\n",
    "        \"confidence_level\": confidence,\n",
    "        \"dataset_name\": dataset_name,\n",
    "        \"calibration_budget\": calibration_budget,\n",
    "    })\n",
    "\n",
    "    # evaluating bin\n",
    "    threshold_bin = bin_cp.calibrate_from_scores(smooth_scores[cal_mask], smooth_prediction.y_true[cal_mask])\n",
    "    pred_set_bin = bin_cp.predict_from_scores(smooth_scores[eval_mask])\n",
    "\n",
    "    empirical_coverage_bin = bin_cp.internal_cp.coverage(pred_set_bin, y_true_mask[eval_mask])\n",
    "    average_set_size_bin = pred_set_bin.sum(dim=1).float().mean().item()\n",
    "\n",
    "    bin_results.append({\n",
    "        \"method\": \"bin\",\n",
    "        \"coverage_guarantee\": coverage_guarantee,\n",
    "        \"iteration\": iter_i,\n",
    "        \"r\": r,\n",
    "        \"smoothing_sigma\": smoothing_sigma,\n",
    "        \"model_sigma\": model_sigma,\n",
    "        \"threshold\": threshold_bin,\n",
    "        \"empirical_coverage\": empirical_coverage_bin,\n",
    "        \"average_set_size\": average_set_size_bin,\n",
    "        \"score_method\": score_method,\n",
    "        \"confidence_level\": confidence,\n",
    "        \"dataset_name\": dataset_name,\n",
    "        \"calibration_budget\": calibration_budget,\n",
    "    })\n",
    "\n",
    "cas_results = pd.DataFrame(cas_results)\n",
    "bin_results = pd.DataFrame(bin_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3388557/2375947796.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  cas_results.mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "coverage_guarantee     0.900000\n",
       "iteration             49.500000\n",
       "r                      0.250000\n",
       "smoothing_sigma        0.250000\n",
       "model_sigma            0.250000\n",
       "threshold              0.046067\n",
       "empirical_coverage     0.969740\n",
       "average_set_size       2.546291\n",
       "confidence_level       0.990000\n",
       "calibration_budget     0.100000\n",
       "dtype: float64"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cas_results.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3388557/1839281949.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  bin_results.mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "coverage_guarantee     0.900000\n",
       "iteration             49.500000\n",
       "r                      0.250000\n",
       "smoothing_sigma        0.250000\n",
       "model_sigma            0.250000\n",
       "threshold              0.120033\n",
       "empirical_coverage     0.961057\n",
       "average_set_size       2.302207\n",
       "confidence_level       0.990000\n",
       "calibration_budget     0.100000\n",
       "dtype: float64"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bin_results.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
