{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "54d3f1d5",
   "metadata": {},
   "source": [
    "# DR CI for Comparing Abstaining Classifiers on CIFAR-100\n",
    "\n",
    "Using pretrained features from a VGG-16 model.\n",
    "\n",
    "https://github.com/chenyaofo/pytorch-cifar-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1b492c1c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[MLENS] backend: threading\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os.path\n",
    "from collections import OrderedDict\n",
    "\n",
    "from tqdm.notebook import tqdm, trange\n",
    "from joblib import Parallel, delayed\n",
    "import time\n",
    "import itertools as it\n",
    "\n",
    "import comparecast as cc\n",
    "import comparecast_causal as c3\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "20d01798",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets, svm, metrics\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.pipeline import make_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6ca28dfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for images\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms as T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9825393",
   "metadata": {},
   "source": [
    "Configurations for evaluation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1c5a07ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "RANDOM_SEED = 1234567890\n",
    "\n",
    "rng = np.random.default_rng(RANDOM_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ca0002a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHA = 0.05\n",
    "SCORING_RULE = \"brier\"\n",
    "EPSILON = 0.2       # how much to clip large values of estimated abstention probability\n",
    "NOISE_LEVEL = 0.0   # label noise\n",
    "N_TRIALS = 100      # repetitions  \n",
    "N_JOBS = -1         # number of CPU cores for parallel processing\n",
    "MIXED_EST = False   # use mixed estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "34c19e31",
   "metadata": {},
   "outputs": [],
   "source": [
    "PLOTS_DIR = \"./plots/cifar100_vgg16/experiment1{}/{}/eps{:g}\".format(\n",
    "    \"_mixed\" if MIXED_EST else \"\", SCORING_RULE, EPSILON,\n",
    ")\n",
    "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
    "\n",
    "FONT = \"Liberation Serif\"  # for servers & colab; if not use \"DejaVu Serif\"\n",
    "cc.set_theme(style=\"whitegrid\", font=FONT)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd0e3661",
   "metadata": {},
   "source": [
    "## CIFAR-10 features\n",
    "\n",
    "Extracted using:\n",
    "\n",
    "```python\n",
    "import torch\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms as T\n",
    "\n",
    "model = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \"cifar100_vgg16_bn\", pretrained=True).eval()\n",
    "\n",
    "transform = T.Compose([T.ToTensor(), \n",
    "                       T.Normalize(mean=[0.4914, 0.4822, 0.4465], \n",
    "                                   std=[0.2023, 0.1994, 0.2010])])\n",
    "val_data = datasets.CIFAR100(\"./data/cifar100\", download=True, train=False, transform=transform)\n",
    "val_loader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for inputs_b, labels_b in tqdm(val_loader):\n",
    "        features = model.features(inputs_b).squeeze().numpy()\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2dbaa5b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 512) (10000,)\n"
     ]
    }
   ],
   "source": [
    "X_eval = np.load(\"./data/cifar100/cifar100_vgg16_features.npy\")\n",
    "y_eval = np.load(\"./data/cifar100/cifar100_labels.npy\")\n",
    "print(X_eval.shape, y_eval.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f510225",
   "metadata": {},
   "source": [
    "## Classifiers\n",
    "\n",
    "We use the VGG model's features but train the final layer separately and also vary the abstention mechanisms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5a87e51e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original VGG16\n",
      "Accuracy: 0.7106\n",
      "Score: 0.7576894299713142\n"
     ]
    }
   ],
   "source": [
    "# Create a classifier\n",
    "clf = c3.get_learner(\"c\", \"Logistic\", scaler=None, max_iter=100)   # lbfgs\n",
    "\n",
    "# Split eval again into 1:1 train:validation\n",
    "indices = np.arange(len(y_eval))\n",
    "X_train, X_valid, y_train, y_valid, idx_train, idx_valid = train_test_split(\n",
    "    X_eval, y_eval, indices, test_size=0.5, shuffle=True, random_state=101,\n",
    ")\n",
    "\n",
    "# 1. \"ground truth\"\n",
    "preds_vgg = np.load(\"./data/cifar100/cifar100_vgg16_predictions.npy\")  # pre-softmax\n",
    "preds_vgg = F.softmax(torch.tensor(preds_vgg), dim=1).numpy()[idx_valid]\n",
    "score_fn = cc.get_scoring_rule(SCORING_RULE)\n",
    "print(\"Original VGG16\")\n",
    "print(\"Accuracy:\", np.mean(preds_vgg.argmax(-1) == y_valid))\n",
    "print(\"Score:\", score_fn(preds_vgg, y_valid).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c5d5140b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training: Pipeline(steps=[('logisticregression', LogisticRegression())])\n",
      "elapsed time: 5.75s\n",
      "Retrained VGG\n",
      "Accuracy: 0.701\n",
      "Score: 0.7869779531986764\n",
      "Mean SR: 0.7805081505716034\n"
     ]
    }
   ],
   "source": [
    "# 2. trained final layer\n",
    "print(\"training:\", clf)\n",
    "t0 = time.time()\n",
    "clf.fit(X_train, y_train)\n",
    "t1 = time.time()\n",
    "print(\"elapsed time: {:.2f}s\".format(t1 - t0))\n",
    "\n",
    "# Predict the value of the digit on the test subset\n",
    "predictions = clf.predict_proba(X_valid)\n",
    "print(\"Retrained VGG\")\n",
    "print(\"Accuracy:\", np.mean(predictions.argmax(-1) == y_valid))\n",
    "print(\"Score:\", score_fn(predictions, y_valid).mean())\n",
    "print(\"Mean SR:\", predictions.max(-1).mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76171b99",
   "metadata": {},
   "source": [
    "## Abstaining classifier comparisons on the \"test\" set\n",
    "\n",
    "We compare the same classifiers that use different abstention rules. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4cf5d24a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_preds_cifar(\n",
    "    X, y, predictions,\n",
    "    confidence_metric=\"sr\", stochastic=True, threshold=0.7, eps=EPSILON, plot=False,\n",
    "):\n",
    "    _, abstentions, confidence, pred_labels = c3.predict_or_abstain(\n",
    "        predictions,\n",
    "        confidence_metric=confidence_metric,\n",
    "        stochastic=stochastic,\n",
    "        threshold=threshold,\n",
    "        eps=eps,\n",
    "    )\n",
    "    scores = c3.compute_scores(predictions, abstentions, y, \n",
    "                               scoring_rule=SCORING_RULE, compute_se=True)\n",
    "\n",
    "    if plot:\n",
    "        print(f\"[{confidence_metric}, stochastic: {stochastic}, eps: {eps}]\")\n",
    "        for name, (m, se) in scores.items():\n",
    "            print(\"{}: {:.3f} +/- {:.3f}\".format(name, m, se))\n",
    "        print(\"-\" * 40)\n",
    "\n",
    "    return predictions, abstentions, confidence, pred_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "197d084c",
   "metadata": {},
   "outputs": [],
   "source": [
    "scenarios = [\n",
    "    # cf(A)=cf(B), deterministic\n",
    "    ((\"SRDet0.8\", dict(predictions=predictions, \n",
    "               confidence_metric=\"sr\", \n",
    "               stochastic=False, \n",
    "               threshold=0.8,\n",
    "               eps=EPSILON)),\n",
    "     (\"SRDet0.5\", dict(predictions=predictions, \n",
    "               confidence_metric=\"sr\", \n",
    "               stochastic=False, \n",
    "               threshold=0.5,\n",
    "               eps=EPSILON))),\n",
    "    # cf(A)=cf(B), stochastic\n",
    "    ((\"SRStoc\", dict(predictions=predictions, \n",
    "               confidence_metric=\"sr\", \n",
    "               stochastic=True, \n",
    "               eps=EPSILON)),\n",
    "     (\"GiniStoc\", dict(predictions=predictions, \n",
    "               confidence_metric=\"impurity\", \n",
    "               stochastic=True, \n",
    "               eps=EPSILON))),\n",
    "    # cf(A)>cf(B), stochastic\n",
    "    ((\"VGGSRStoc\", dict(predictions=preds_vgg, \n",
    "               confidence_metric=\"sr\", \n",
    "               stochastic=True, \n",
    "               eps=EPSILON)),\n",
    "     (\"LinearSRStoc\", dict(predictions=predictions, \n",
    "               confidence_metric=\"sr\", \n",
    "               stochastic=True, \n",
    "               eps=EPSILON))),    \n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a6455995",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[sr, stochastic: True, eps: 0.2]\n",
      "selective_score: 0.822 +/- 0.005\n",
      "coverage: 0.677 +/- 0.007\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n",
      "[impurity, stochastic: True, eps: 0.2]\n",
      "selective_score: 0.834 +/- 0.005\n",
      "coverage: 0.620 +/- 0.007\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "clfA, clfB = \"Logistic_SR\", \"Logistic_Gini\"\n",
    "preds_AB = {\n",
    "    clfA: compute_preds_cifar(X_valid, y_valid, predictions,\n",
    "                              confidence_metric=\"sr\", stochastic=True, #threshold=0.8,\n",
    "                              eps=EPSILON, plot=True),\n",
    "    clfB: compute_preds_cifar(X_valid, y_valid, predictions,\n",
    "                              confidence_metric=\"impurity\", stochastic=True, #threshold=0.5, \n",
    "                              eps=EPSILON, plot=True),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ff750e00",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SRDet0.8/SRDet0.5\n",
      "[sr, stochastic: False, eps: 0.2]\n",
      "selective_score: 0.899 +/- 0.005\n",
      "coverage: 0.620 +/- 0.007\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n",
      "[sr, stochastic: False, eps: 0.2]\n",
      "selective_score: 0.839 +/- 0.005\n",
      "coverage: 0.819 +/- 0.005\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n",
      "95% CI for CF score difference [estimator: dr, learner: {'model': 'Linear', 'scaler': None}]\n",
      "pi [split=0] evaluation: accuracy 0.83606\n",
      "mu0 [split=0] evaluation: MSE 0.07894\n",
      "pi_b [split=0] evaluation: accuracy 0.84763\n",
      "mu0_b [split=0] evaluation: MSE 0.08153\n",
      "pi [split=1] evaluation: accuracy 0.81588\n",
      "mu0 [split=1] evaluation: MSE 0.07613\n",
      "pi_b [split=1] evaluation: accuracy 0.84356\n",
      "mu0_b [split=1] evaluation: MSE 0.08051\n",
      "Target CF score difference: 0.00000\n",
      "Cross-fit estimate: 0.00672\n",
      "Asymptotic CI: (-0.00455, 0.01799)\n",
      "CI width: 0.02255\n",
      "Contains oracle CF score: True\n",
      "Rejection (for H0: diff=0): False\n",
      "----------------------------------------\n",
      "SRStoc/GiniStoc\n",
      "[sr, stochastic: True, eps: 0.2]\n",
      "selective_score: 0.824 +/- 0.005\n",
      "coverage: 0.705 +/- 0.006\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n",
      "[impurity, stochastic: True, eps: 0.2]\n",
      "selective_score: 0.838 +/- 0.005\n",
      "coverage: 0.632 +/- 0.007\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n",
      "95% CI for CF score difference [estimator: dr, learner: {'model': 'Linear', 'scaler': None}]\n",
      "pi [split=0] evaluation: accuracy 0.69461\n",
      "mu0 [split=0] evaluation: MSE 0.07875\n",
      "pi_b [split=0] evaluation: accuracy 0.64311\n",
      "mu0_b [split=0] evaluation: MSE 0.08135\n",
      "pi [split=1] evaluation: accuracy 0.67415\n",
      "mu0 [split=1] evaluation: MSE 0.07504\n",
      "pi_b [split=1] evaluation: accuracy 0.66453\n",
      "mu0_b [split=1] evaluation: MSE 0.07966\n",
      "Target CF score difference: 0.00000\n",
      "Cross-fit estimate: 0.00434\n",
      "Asymptotic CI: (-0.00627, 0.01494)\n",
      "CI width: 0.02121\n",
      "Contains oracle CF score: True\n",
      "Rejection (for H0: diff=0): False\n",
      "----------------------------------------\n",
      "VGGSRStoc/LinearSRStoc\n",
      "[sr, stochastic: True, eps: 0.2]\n",
      "selective_score: 0.770 +/- 0.006\n",
      "coverage: 0.776 +/- 0.006\n",
      "oracle_cf_score: 0.758 +/- 0.005\n",
      "----------------------------------------\n",
      "[sr, stochastic: True, eps: 0.2]\n",
      "selective_score: 0.818 +/- 0.005\n",
      "coverage: 0.695 +/- 0.007\n",
      "oracle_cf_score: 0.787 +/- 0.004\n",
      "----------------------------------------\n",
      "95% CI for CF score difference [estimator: dr, learner: {'model': 'Linear', 'scaler': None}]\n",
      "pi [split=0] evaluation: accuracy 0.75716\n",
      "mu0 [split=0] evaluation: MSE 0.12975\n",
      "pi_b [split=0] evaluation: accuracy 0.68301\n",
      "mu0_b [split=0] evaluation: MSE 0.08209\n",
      "pi [split=1] evaluation: accuracy 0.75887\n",
      "mu0 [split=1] evaluation: MSE 0.13588\n",
      "pi_b [split=1] evaluation: accuracy 0.68543\n",
      "mu0_b [split=1] evaluation: MSE 0.08130\n",
      "Target CF score difference: -0.02929\n",
      "Cross-fit estimate: -0.04015\n",
      "Asymptotic CI: (-0.05123, -0.02907)\n",
      "CI width: 0.02216\n",
      "Contains oracle CF score: True\n",
      "Rejection (for H0: diff=0): True\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "learners = {\n",
    "    \"Linear\": dict(model=\"Linear\", scaler=None),\n",
    "}\n",
    "\n",
    "drcis = OrderedDict()\n",
    "summary = OrderedDict()\n",
    "for learner_name, learner_config in learners.items():\n",
    "    for (clfA, configA), (clfB, configB) in scenarios:\n",
    "        scenario = \"/\".join([clfA, clfB])\n",
    "        print(scenario)\n",
    "        preds_AB = {\n",
    "            clfA: compute_preds_cifar(X_valid, y_valid, plot=True, **configA),\n",
    "            clfB: compute_preds_cifar(X_valid, y_valid, plot=True, **configB),\n",
    "        }\n",
    "        drcis[learner_name], summary[learner_name] = c3.run_experiment(\n",
    "            X=X_valid, \n",
    "            y=y_valid,\n",
    "            predictions=preds_AB[clfA][0],\n",
    "            abstentions=preds_AB[clfA][1],\n",
    "            predictions_b=preds_AB[clfB][0],\n",
    "            abstentions_b=preds_AB[clfB][1],\n",
    "            estimator=\"dr\", \n",
    "            learner=learner_config, \n",
    "            mixed_estimation=False, \n",
    "            mixed_coef=None,\n",
    "            alpha=ALPHA, \n",
    "            scoring_rule=SCORING_RULE, \n",
    "            clip_pi=EPSILON, \n",
    "            rng=rng, \n",
    "            verbose=True,\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9 (c3)",
   "language": "python",
   "name": "c3_py39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
