{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from scipy.stats import norm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from qrcp.helpers.storage import load_smooth_prediction\n",
    "from qrcp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor\n",
    "from qrcp.robust.confidence import bernstein_bound, dkw_cdf\n",
    "from qrcp.robust.confidence import clopper_pearson_lower\n",
    "from qrcp.robust.bounds import mean_bounds_l2, CDF_bounds_l2\n",
    "\n",
    "from qrcp.cp.core import ConformalClassifier as CP\n",
    "from qrcp.cp.scores import APSScore, TPSScore\n",
    "\n",
    "from qrcp.methods.robust_cp import RobustCP, VanillaSmoothCP\n",
    "from qrcp.methods.cas import CAS\n",
    "from qrcp.methods.vote import VoteCP\n",
    "from qrcp.methods.binary import QRCPThresholds\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import logging\n",
    "logging.basicConfig(filename='std.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')\n",
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.DEBUG)\n",
    "logger.propagate = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#region primary configs of the experiment\n",
    "\n",
    "result_folder = \"../../../output-results\"\n",
    "\n",
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.25\n",
    "n_classes=10\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.25\n",
    "n_samples = 10000\n",
    "n_trial_samples = 2000\n",
    "\n",
    "score_method = \"TPS\"\n",
    "calibration_budget = 0.1\n",
    "n_iterations = 100\n",
    "confidence = 0.99\n",
    "\n",
    "r=0.0\n",
    "#endregion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "smooth = False\n",
    "smooth_worse = True\n",
    "coverage_range = [0.9]\n",
    "r_range = [r]\n",
    "\n",
    "#region loding smooth logit predictions\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "if not smooth:\n",
    "    vanilla_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "        model_sigma=0.0,\n",
    "        n_datapoints=n_datapoints,\n",
    "        smoothing_sigma=0.0,\n",
    "        n_samples=5)\n",
    "\n",
    "    n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "\n",
    "    score_pipeline = [\n",
    "        TPSScore(softmax=True) if score_method == \"TPS\" else APSScore(softmax=True)] # defining the score function\n",
    "    cp = CP(score_pipeline=score_pipeline, coverage_guarantee=0.9) # the guarantee can vary later by cp.coverage_guarantee\n",
    "    vanilla_scores = get_smooth_scores(vanilla_prediction.logits, cp, mean=False).to(device)[:, :, 0]\n",
    "\n",
    "    pert_logits = {}\n",
    "    pert_scores = {}\n",
    "\n",
    "    r_range = [0.0, 0.12, 0.25, 0.5]\n",
    "    for r in r_range:\n",
    "        pert_logits[r] = load_smooth_prediction(dataset_name=dataset_name,\n",
    "            model_sigma=0.0,\n",
    "            n_datapoints=n_datapoints,\n",
    "            smoothing_sigma=0.0,\n",
    "            n_samples=5, r=r)\n",
    "        pert_scores[r] = get_smooth_scores(pert_logits[r].logits, cp, mean=False).to(device)[:, :, 0]\n",
    "    #endregion\n",
    "\n",
    "if smooth:\n",
    "    vanilla_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "        model_sigma=0.25,\n",
    "        n_datapoints=n_datapoints,\n",
    "        smoothing_sigma=0.25,\n",
    "        n_samples=2000)\n",
    "\n",
    "    n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "\n",
    "    score_pipeline = [\n",
    "        TPSScore(softmax=True) if score_method == \"TPS\" else APSScore(softmax=True)] # defining the score function\n",
    "    cp = CP(score_pipeline=score_pipeline, coverage_guarantee=0.9) # the guarantee can vary later by cp.coverage_guarantee\n",
    "    vanilla_scores = get_smooth_scores(vanilla_prediction.logits, cp, mean=False).to(device).mean(dim=-1)\n",
    "\n",
    "    pert_logits = {}\n",
    "    pert_scores = {}\n",
    "\n",
    "    r_range = [0.0, 0.12, 0.25, 0.5]\n",
    "    for r in r_range:\n",
    "        pert_logits[r] = load_smooth_prediction(dataset_name=dataset_name,\n",
    "            model_sigma=0.25,\n",
    "            n_datapoints=n_datapoints,\n",
    "            smoothing_sigma=0.25,\n",
    "            n_samples=2000, r=r)\n",
    "        pert_scores[r] = get_smooth_scores(pert_logits[r].logits, cp, mean=False).to(device).mean(dim=-1)\n",
    "    #endregion\n",
    "\n",
    "    y_true_mask = F.one_hot(vanilla_prediction.y_true, n_classes).bool().to(device)\n",
    "\n",
    "if smooth_worse:\n",
    "    vanilla_prediction = load_smooth_prediction(dataset_name=dataset_name,\n",
    "        model_sigma=0,\n",
    "        n_datapoints=n_datapoints,\n",
    "        smoothing_sigma=0.12,\n",
    "        n_samples=2000)\n",
    "\n",
    "    n_classes = 10 if dataset_name == \"cifar10\" else None\n",
    "\n",
    "    score_pipeline = [\n",
    "        TPSScore(softmax=True) if score_method == \"TPS\" else APSScore(softmax=True)] # defining the score function\n",
    "    cp = CP(score_pipeline=score_pipeline, coverage_guarantee=0.9) # the guarantee can vary later by cp.coverage_guarantee\n",
    "    vanilla_scores = get_smooth_scores(vanilla_prediction.logits, cp, mean=False).to(device).mean(dim=-1)\n",
    "\n",
    "    pert_logits = {}\n",
    "    pert_scores = {}\n",
    "\n",
    "    r_range = [0.0, 0.12, 0.25, 0.5]\n",
    "    for r in r_range:\n",
    "        pert_logits[r] = load_smooth_prediction(dataset_name=dataset_name,\n",
    "            model_sigma=0,\n",
    "            n_datapoints=n_datapoints,\n",
    "            smoothing_sigma=0.12,\n",
    "            n_samples=2000, r=r)\n",
    "        pert_scores[r] = get_smooth_scores(pert_logits[r].logits, cp, mean=False).to(device).mean(dim=-1)\n",
    "    #endregion\n",
    "\n",
    "    y_true_mask = F.one_hot(vanilla_prediction.y_true, n_classes).bool().to(device)\n",
    "\n",
    "\n",
    "y_true_mask = F.one_hot(vanilla_prediction.y_true, n_classes).bool().to(device)\n",
    "\n",
    "# smooth_prediction_pert = load_smooth_prediction(dataset_name=dataset_name,\n",
    "#     model_sigma=model_sigma,\n",
    "#     n_datapoints=n_datapoints,\n",
    "#     smoothing_sigma=smoothing_sigma,\n",
    "#     n_samples=n_samples, r=r)\n",
    "# #endregion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "# for iter_i in range(n_iterations):\n",
    "for iter_i  in range(n_iterations):\n",
    "    cal_mask = get_cal_mask(vanilla_scores, calibration_budget)\n",
    "    eval_mask = ~cal_mask\n",
    "\n",
    "    cp.change_coverage_guarantee(0.9)\n",
    "    threshold = cp.calibrate_from_scores(vanilla_scores[cal_mask], y_true_mask[cal_mask])\n",
    "    pred_set = cp.predict_from_scores(vanilla_scores[eval_mask])\n",
    "\n",
    "    results.append({\n",
    "        \"score_method\": score_method,\n",
    "        \"r\": 0,\n",
    "        \"threshold\": threshold,\n",
    "        \"coverage_guarantee\": 0.9,\n",
    "        \"coverage\": cp.coverage(pred_set, y_true_mask[eval_mask]),\n",
    "        \"average_set_size\": cp.average_set_size(pred_set)\n",
    "    })\n",
    "\n",
    "    empirical_coverage = cp.coverage(pred_set, y_true_mask[eval_mask])\n",
    "    average_set_size = cp.average_set_size(pred_set)\n",
    "\n",
    "    for r in r_range:\n",
    "        pert_pred_set = cp.predict_from_scores(pert_scores[r][eval_mask]) \n",
    "        pert_empirical_coverage = cp.coverage(pert_pred_set, y_true_mask[eval_mask])\n",
    "        pert_average_set_size = cp.average_set_size(pert_pred_set)\n",
    "\n",
    "        results.append({\n",
    "            \"score_method\": score_method,\n",
    "            \"r\": r,\n",
    "            \"threshold\": threshold,\n",
    "            \"coverage_guarantee\": 0.9,\n",
    "            \"coverage\": pert_empirical_coverage,\n",
    "            \"average_set_size\": pert_average_set_size\n",
    "        })\n",
    "\n",
    "results_df = pd.DataFrame(results)\n",
    "if smooth_worse:\n",
    "    results_df.to_csv(f\"{result_folder}/comparison_vanilla_smooth_worse_results_{dataset_name}_{score_method}.csv\", index=False)\n",
    "elif not smooth:\n",
    "    results_df.to_csv(f\"{result_folder}/comparison_vanilla_results_{dataset_name}_{score_method}.csv\", index=False)\n",
    "elif smooth:\n",
    "    results_df.to_csv(f\"{result_folder}/comparison_vanilla_smooth_results_{dataset_name}_{score_method}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3764261/3588143333.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  results_df[results_df[\"r\"] == 0].mean()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "r                     0.000000\n",
       "threshold             0.428225\n",
       "coverage_guarantee    0.900000\n",
       "coverage              0.901329\n",
       "average_set_size      1.013004\n",
       "dtype: float64"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df[results_df[\"r\"] == 0].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "# for iter_i in range(n_iterations):\n",
    "for iter_i  in range(n_iterations):\n",
    "    cal_mask = get_cal_mask(vanilla_scores, calibration_budget)\n",
    "    eval_mask = ~cal_mask\n",
    "\n",
    "    cp.change_coverage_guarantee(0.9)\n",
    "    threshold = cp.calibrate_from_scores(vanilla_scores[cal_mask], y_true_mask[cal_mask])\n",
    "    pred_set = cp.predict_from_scores(vanilla_scores[eval_mask])\n",
    "\n",
    "    results.append({\n",
    "        \"score_method\": score_method,\n",
    "        \"r\": 0,\n",
    "        \"threshold\": threshold,\n",
    "        \"coverage_guarantee\": 0.9,\n",
    "        \"coverage\": cp.coverage(pred_set, y_true_mask[eval_mask]),\n",
    "        \"average_set_size\": cp.average_set_size(pred_set)\n",
    "    })\n",
    "\n",
    "    empirical_coverage = cp.coverage(pred_set, y_true_mask[eval_mask])\n",
    "    average_set_size = cp.average_set_size(pred_set)\n",
    "\n",
    "    for r in r_range:\n",
    "        pert_pred_set = cp.predict_from_scores(pert_scores[r][eval_mask]) \n",
    "        pert_empirical_coverage = cp.coverage(pert_pred_set, y_true_mask[eval_mask])\n",
    "        pert_average_set_size = cp.average_set_size(pert_pred_set)\n",
    "\n",
    "        results.append({\n",
    "            \"score_method\": score_method,\n",
    "            \"r\": r,\n",
    "            \"threshold\": threshold,\n",
    "            \"coverage_guarantee\": 0.9,\n",
    "            \"coverage\": pert_empirical_coverage,\n",
    "            \"average_set_size\": pert_average_set_size\n",
    "        })\n",
    "\n",
    "results_df = pd.DataFrame(results)\n",
    "if not smooth:\n",
    "    results_df.to_csv(f\"{result_folder}/comparison_vanilla_results_{dataset_name}_{score_method}.csv\", index=False)\n",
    "if smooth:\n",
    "    results_df.to_csv(f\"{result_folder}/comparison_vanilla_smooth_results_{dataset_name}_{score_method}.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.-1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
