{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from scipy.stats import norm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from qrcp.helpers.storage import load_smooth_prediction\n",
    "from qrcp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor\n",
    "from qrcp.robust.confidence import bernstein_bound, dkw_cdf\n",
    "from qrcp.robust.confidence import clopper_pearson_lower\n",
    "from qrcp.robust.bounds import mean_bounds_l2, CDF_bounds_l2\n",
    "\n",
    "from qrcp.cp.core import ConformalClassifier as CP\n",
    "from qrcp.cp.scores import APSScore, TPSScore\n",
    "\n",
    "from qrcp.methods.robust_cp import RobustCP, VanillaSmoothCP\n",
    "from qrcp.methods.cas import CAS\n",
    "from qrcp.methods.vote import VoteCP\n",
    "from qrcp.methods.binary import QRCPThresholds\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import logging\n",
    "logging.basicConfig(filename='std.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')\n",
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.DEBUG)\n",
    "logger.propagate = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "#region primary configs of the experiment\n",
    "\n",
    "result_folder = \"../../../output-results\"\n",
    "\n",
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.25\n",
    "n_classes=10\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.25\n",
    "n_samples = 10000\n",
    "n_trial_samples = 10\n",
    "\n",
    "score_method = \"APS\"\n",
    "calibration_budget = 0.1\n",
    "n_iterations = 100\n",
    "confidence = 0.99\n",
    "\n",
    "#endregion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cifar10 dataset with 2048 datapoints and 10 samples: Score method: APS\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3388557/2835240085.py:67: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  vanilla_results[vanilla_results[\"coverage_guarantee\"] == 0.9].mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "iteration             49.500000\n",
       "coverage_guarantee     0.900000\n",
       "r                      0.000000\n",
       "smoothing_sigma        0.250000\n",
       "model_sigma            0.250000\n",
       "threshold              0.210366\n",
       "empirical_coverage     0.901123\n",
       "average_set_size       1.505206\n",
       "calibration_budget     0.100000\n",
       "dtype: float64"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coverage_range = [0.9, 0.95]\n",
    "r_range = [0.25]\n",
    "\n",
    "\n",
    "\n",
    "#region loding smooth logit predictions\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "smooth_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "    model_sigma=model_sigma,\n",
    "    n_datapoints=n_datapoints,\n",
    "    smoothing_sigma=smoothing_sigma,\n",
    "    n_samples=n_samples)\n",
    "n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "#endregion\n",
    "\n",
    "#region defining basic setup for conformal evaluation\n",
    "\n",
    "score_pipeline = [\n",
    "    TPSScore(softmax=True) if score_method == \"TPS\" else APSScore(softmax=True)] # defining the score function\n",
    "cp = CP(score_pipeline=score_pipeline, coverage_guarantee=0.9) # the guarantee can vary later by cp.coverage_guarantee\n",
    "smooth_scores = get_smooth_scores(smooth_prediction.logits, cp, mean=False)\n",
    "smooth_scores = smooth_scores[:, :, :n_trial_samples]\n",
    "y_true_mask = F.one_hot(smooth_prediction.y_true, num_classes=10).bool().to(device)\n",
    "mean_scores = smooth_scores.mean(dim=-1)\n",
    "#endregion\n",
    "print(f\"Loading {dataset_name} dataset with {n_datapoints} datapoints and {n_trial_samples} samples: Score method: {score_method}\")\n",
    "\n",
    "cal_mask = get_cal_mask(smooth_scores.mean(dim=-1), calibration_budget)\n",
    "n_dcal = cal_mask.sum().item()\n",
    "\n",
    "vanilla_cp = VanillaSmoothCP(nominal_coverage=0.9)\n",
    "vanilla_results = []\n",
    "\n",
    "vanilla_cp.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "\n",
    "for coverage_guarantee in coverage_range:\n",
    "    r = 0\n",
    "    vanilla_cp.set_nominal_coverage(coverage_guarantee)\n",
    "\n",
    "    for iter_i in range(n_iterations):\n",
    "        cal_mask = get_cal_mask(smooth_scores.mean(dim=-1), calibration_budget)\n",
    "        eval_mask = ~cal_mask\n",
    "        threshold = vanilla_cp.pre_compute_calibrate(cal_mask)\n",
    "        pred_set = vanilla_cp.pre_compute_predict(eval_mask)\n",
    "\n",
    "        empirical_coverage = vanilla_cp.internal_cp.coverage(pred_set, y_true_mask[eval_mask])\n",
    "        average_set_size = pred_set.sum(dim=1).float().mean().item()\n",
    "\n",
    "        vanilla_results.append({\n",
    "            \"method\": \"vanilla\", \n",
    "            \"iteration\": iter_i,\n",
    "            \"coverage_guarantee\": coverage_guarantee,\n",
    "            \"r\": r,\n",
    "            \"smoothing_sigma\": smoothing_sigma,\n",
    "            \"model_sigma\": model_sigma,\n",
    "            \"threshold\": threshold,\n",
    "            \"empirical_coverage\": empirical_coverage,\n",
    "            \"average_set_size\": average_set_size,\n",
    "            \"score_method\": score_method,\n",
    "            \"dataset_name\": dataset_name,\n",
    "            \"calibration_budget\": calibration_budget,\n",
    "        })\n",
    "\n",
    "vanilla_results = pd.DataFrame(vanilla_results)\n",
    "# vanilla_results.to_csv(f\"{result_folder}/vanilla_results-{dataset_name}-smooth{smoothing_sigma}-model{model_sigma}-{score_method}-nsamples{n_trial_samples}.csv\", index=False)\n",
    "vanilla_results[vanilla_results[\"coverage_guarantee\"] == 0.9].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2048, 10, 10])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smooth_scores.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CAS pre-computed\n",
      "Vote pre-computed\n",
      "Running for r=0.25, coverage=0.9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 472.04it/s]\n"
     ]
    }
   ],
   "source": [
    "r = 0.25\n",
    "\n",
    "\n",
    "cas_results = []\n",
    "vote_results = []\n",
    "\n",
    "cas_cp = CAS(nominal_coverage=0.9, r=r, smoothing_sigma=smoothing_sigma, confidence_level=confidence, n_dcal=n_dcal, n_classes=n_classes, \n",
    "                        error_correction=False)\n",
    "cas_cp.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "\n",
    "print(\"CAS pre-computed\")\n",
    "\n",
    "vote_cp = VoteCP(nominal_coverage=0.9, smoothing_sigma=smoothing_sigma, n_dcal=n_dcal, n_classes=n_classes,\n",
    "                    r=r, confidence_level=confidence,\n",
    "                    error_correction=False,\n",
    "                    p_base=0.5)\n",
    "\n",
    "# vote_cp.pre_compute(smooth_scores, smooth_prediction.y_true)\n",
    "print(\"Vote pre-computed\")\n",
    "\n",
    "# here goes a for\n",
    "coverage_guarantee = 0.9\n",
    "\n",
    "print(f\"Running for r={r}, coverage={coverage_guarantee}\")\n",
    "cas_cp.set_nominal_coverage(coverage_guarantee)\n",
    "vote_cp.set_nominal_coverage(coverage_guarantee)\n",
    "\n",
    "# here goes a for\n",
    "for iter_i in tqdm(range(100)):\n",
    "    cal_mask = get_cal_mask(smooth_scores.mean(dim=-1), calibration_budget)\n",
    "    eval_mask = ~cal_mask\n",
    "\n",
    "    # evaluating cas\n",
    "    threshold_cas = cas_cp.pre_compute_calibrate(cal_mask)\n",
    "    pred_set_cas = cas_cp.pre_compute_predict(eval_mask)\n",
    "\n",
    "    empirical_coverage_cas = cas_cp.internal_cp.coverage(pred_set_cas, y_true_mask[eval_mask])\n",
    "    average_set_size_cas = pred_set_cas.sum(dim=1).float().mean().item()\n",
    "\n",
    "    cas_results.append({\n",
    "        \"method\": \"cas\",\n",
    "        \"coverage_guarantee\": coverage_guarantee,\n",
    "        \"iteration\": iter_i,\n",
    "        \"r\": r,\n",
    "        \"smoothing_sigma\": smoothing_sigma,\n",
    "        \"model_sigma\": model_sigma,\n",
    "        \"threshold\": threshold_cas,\n",
    "        \"empirical_coverage\": empirical_coverage_cas,\n",
    "        \"average_set_size\": average_set_size_cas,\n",
    "        \"score_method\": score_method,\n",
    "        \"confidence_level\": confidence,\n",
    "        \"dataset_name\": dataset_name,\n",
    "        \"calibration_budget\": calibration_budget,\n",
    "    })\n",
    "\n",
    "    # evaluating vote\n",
    "    threshold_vote = vote_cp.calibrate_from_scores(smooth_scores[cal_mask], smooth_prediction.y_true[cal_mask])\n",
    "    pred_set_vote = vote_cp.predict_from_scores(smooth_scores[eval_mask])\n",
    "\n",
    "    empirical_coverage_vote = vote_cp.internal_cp.coverage(pred_set_vote, y_true_mask[eval_mask])\n",
    "    average_set_size_vote = pred_set_vote.sum(dim=1).float().mean().item()\n",
    "\n",
    "    vote_results.append({\n",
    "        \"method\": \"vote\",\n",
    "        \"coverage_guarantee\": coverage_guarantee,\n",
    "        \"iteration\": iter_i,\n",
    "        \"r\": r,\n",
    "        \"smoothing_sigma\": smoothing_sigma,\n",
    "        \"model_sigma\": model_sigma,\n",
    "        \"threshold\": threshold_vote,\n",
    "        \"empirical_coverage\": empirical_coverage_vote,\n",
    "        \"average_set_size\": average_set_size_vote,\n",
    "        \"score_method\": score_method,\n",
    "        \"confidence_level\": confidence,\n",
    "        \"dataset_name\": dataset_name,\n",
    "        \"calibration_budget\": calibration_budget,\n",
    "    })\n",
    "\n",
    "cas_results = pd.DataFrame(cas_results)\n",
    "vote_results = pd.DataFrame(vote_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3388557/2375947796.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  cas_results.mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "coverage_guarantee     0.900000\n",
       "iteration             49.500000\n",
       "r                      0.250000\n",
       "smoothing_sigma        0.250000\n",
       "model_sigma            0.250000\n",
       "threshold              0.046067\n",
       "empirical_coverage     0.969740\n",
       "average_set_size       2.546291\n",
       "confidence_level       0.990000\n",
       "calibration_budget     0.100000\n",
       "dtype: float64"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cas_results.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3388557/1839281949.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  vote_results.mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "coverage_guarantee     0.900000\n",
       "iteration             49.500000\n",
       "r                      0.250000\n",
       "smoothing_sigma        0.250000\n",
       "model_sigma            0.250000\n",
       "threshold              0.120033\n",
       "empirical_coverage     0.961057\n",
       "average_set_size       2.302207\n",
       "confidence_level       0.990000\n",
       "calibration_budget     0.100000\n",
       "dtype: float64"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vote_results.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
