{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.datasets as dset\n",
    "from torchvision.datasets import CIFAR10\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet18\n",
    "from transformers import set_seed\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import pandas as pd\n",
    "from itertools import product\n",
    "from sklearn.metrics import accuracy_score, balanced_accuracy_score\n",
    "\n",
    "from torch.utils.data import Subset\n",
    "\n",
    "from torchcp.classification import Metrics\n",
    "from torchcp.classification.predictor import SplitPredictor\n",
    "from torchcp.classification.score import THR, APS, SAPS, RAPS, Margin, TOPK\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize((224, 224)),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = resnet18(pretrained=True)\n",
    "num_ftrs = model.fc.in_features\n",
    "\n",
    "model.fc = nn.Linear(num_ftrs, 10)\n",
    "model.load_state_dict(torch.load(\"finetuned_models/clf_cifar10h_dbg.pth\",map_location=torch.device('cpu')))\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "    \n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "cal_data, test_data = torch.utils.data.random_split(dataset, [10000, 40000])\n",
    "\n",
    "cal_data = Subset(cal_data, range(0,100))\n",
    "test_data = Subset(test_data, range(500,600))\n",
    "\n",
    "cal_data_loader = DataLoader(cal_data,batch_size=64)\n",
    "test_data_loader = DataLoader(test_data,batch_size=64)\n",
    "\n",
    "# Extract logits and labels\n",
    "cal_logits = torch.stack([sample[0] for sample in cal_data])\n",
    "cal_labels = torch.stack([torch.tensor(sample[1]) for sample in cal_data])\n",
    "test_logits = torch.stack([sample[0] for sample in test_data])\n",
    "test_labels = torch.stack([torch.tensor(sample[1]) for sample in test_data])\n",
    "\n",
    "#######################################\n",
    "# A standard process of conformal prediction\n",
    "#######################################\n",
    "scoring_methods = [THR(), APS(), RAPS(penalty=0)]\n",
    "alphas = [0.05,0.1,0.2]\n",
    "for score_function, alpha in product(scoring_methods, alphas):\n",
    "    predictor = SplitPredictor(score_function,model)\n",
    "    predictor.calibrate(cal_data_loader, alpha)\n",
    "\n",
    "    predictions_sets_list = []\n",
    "    predictions_list = []\n",
    "    labels_list = []\n",
    "    logits_list = []\n",
    "    feature_list = []\n",
    "\n",
    "    # Evaluate in inference mode\n",
    "    predictor._model.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch in test_data_loader:\n",
    "            # Move batch to device and get predictions\n",
    "            inputs = batch[0]\n",
    "            labels = batch[1]\n",
    "\n",
    "            # Get predictions as bool tensor (N x C)\n",
    "            batch_predictions = predictor.predict(inputs)\n",
    "\n",
    "            logits = model(inputs)\n",
    "\n",
    "            predicted_label = logits.argmax(axis=1)\n",
    "            # Accumulate predictions and labels\n",
    "            predictions_sets_list.append(batch_predictions)\n",
    "            predictions_list.append(predicted_label)\n",
    "            labels_list.append(labels)\n",
    "            logits_list.append(logits)\n",
    "            feature_list.append(inputs)\n",
    "\n",
    "        # Concatenate all batches\n",
    "        val_prediction_sets = torch.cat(predictions_sets_list, dim=0)  # (N_val x C)\n",
    "        val_predictions = torch.cat(predictions_list, dim=0)\n",
    "        val_labels = torch.cat(labels_list, dim=0)  # (N_val,)\n",
    "        val_logits = torch.cat(logits_list, dim=0)\n",
    "        val_features = torch.cat(feature_list, dim=0)\n",
    "\n",
    "        y_pred = val_predictions.detach().cpu().numpy()\n",
    "        y_true = val_labels.detach().cpu().numpy()\n",
    "        # Compute evaluation metrics\n",
    "        metric = Metrics()\n",
    "\n",
    "        metrics = {\n",
    "            \"coverage_rate\": metric(\"coverage_rate\")(\n",
    "                prediction_sets=val_prediction_sets, labels=val_labels\n",
    "            ),\n",
    "            \"average_size\": metric(\"average_size\")(\n",
    "                prediction_sets=val_prediction_sets, labels=val_labels\n",
    "            ),\n",
    "            \"cov_gap\": metric(\"CovGap\")(\n",
    "                prediction_sets=val_prediction_sets,\n",
    "                labels=val_labels,\n",
    "                alpha=alpha,\n",
    "                num_classes=num_classes,\n",
    "            ),\n",
    "            \"vio_classes\": metric(\"VioClasses\")(\n",
    "                prediction_sets=val_prediction_sets,\n",
    "                labels=val_labels,\n",
    "                alpha=alpha,\n",
    "                num_classes=num_classes,\n",
    "            ),\n",
    "\n",
    "            \"sscv\": metric(\"SSCV\")(\n",
    "                prediction_sets=val_prediction_sets,\n",
    "                labels=val_labels,\n",
    "                alpha=alpha,\n",
    "            ),\n",
    "            # \"wsc\": metric(\"WSC\")(\n",
    "            #     val_features,\n",
    "            #     prediction_sets=val_prediction_sets,\n",
    "            #     labels=val_labels,\n",
    "            # ),\n",
    "            \"acc\": accuracy_score(y_true, y_pred),\n",
    "            \"bacc\": balanced_accuracy_score(y_true, y_pred),\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
