{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test-time recalibration of conformal predictors\n",
    "This notebook contains the code to generate the figures in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "import plotly\n",
    "import plotly.express as px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_yaxes_matches(fig: plotly.graph_objs._figure.Figure, num_facets: int):\n",
    "    r\"\"\"update yaxes matches for independent facet customization\"\"\"\n",
    "    fig.layout['yaxis'].matches = 'y1'\n",
    "    for i in range(2,num_facets+1):\n",
    "        fig.layout[f'yaxis{i}'].matches = f'y{i}'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## APS under distribution shift\n",
    "### Figure 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "aps_imagenet_data_file = \"../experiments/outputs/natural_datasets_aps_bars.csv\"\n",
    "aps_breeds_data_file = \"../experiments/outputs/breeds_aps_bars.csv\"\n",
    "\n",
    "# parameters\n",
    "arch = 'resnet50'\n",
    "alpha = 0.1\n",
    "\n",
    "use_classes_of_dataset = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "nat_bars_df = pd.read_csv(aps_imagenet_data_file)\n",
    "bred_bars_df = pd.read_csv(aps_breeds_data_file)\n",
    "\n",
    "if use_classes_of_dataset:\n",
    "    nat_bars_df = nat_bars_df[nat_bars_df.on_classes_of == use_classes_of_dataset]\n",
    "else:\n",
    "    nat_bars_df = nat_bars_df[nat_bars_df.on_classes_of.isna()]\n",
    "\n",
    "nat_bars_df = nat_bars_df.query(\"Classifier == @arch\")\n",
    "\n",
    "nat_bars_df[\"shift_type\"] = \"ImageNet\"\n",
    "bred_bars_df[\"shift_type\"] = \"BREEDS\"\n",
    "\n",
    "aps_bars_df = pd.concat([nat_bars_df, bred_bars_df], ignore_index=True)\n",
    "\n",
    "aps_bars_df = aps_bars_df.drop(columns=[\"Classifier\"])\n",
    "aps_bars_df = aps_bars_df.query(\"alpha == @alpha and Dataset != 'nonliving26'\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.bar(\n",
    "    aps_bars_df,\n",
    "    x='Method', y='predicted_coverage', \n",
    "    barmode='group',\n",
    "    # facet_row='shift_type', \n",
    "    facet_row_spacing=0.15,\n",
    "    facet_col='Dataset', facet_col_wrap=3, facet_col_spacing=0.08,\n",
    "    category_orders={\n",
    "        \"Dataset\": [\"ImageNetV2\", \"ImageNet-Sketch\", \"ImageNet-R\", \"entity13\", \"entity30\", \"living17\"]\n",
    "    },\n",
    "    # range_y=[0.5, 0.92],\n",
    "    title=\"APS coverage under distribution shift w/ recalibration\",\n",
    "    template=\"simple_white\"\n",
    ")\n",
    "\n",
    "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
    "\n",
    "\n",
    "fig.update_yaxes(matches=None, showticklabels=True)\n",
    "# fig.update_yaxes(showticklabels=True)\n",
    "# update_yaxes_matches(fig, num_facets=6)\n",
    "\n",
    "fig.update_yaxes(range=[0.85, 0.91], selector=3)\n",
    "fig.update_yaxes(range=[0.60, 0.93], selector=4)\n",
    "fig.update_yaxes(range=[0.46, 0.95], selector=5)\n",
    "fig.update_yaxes(range=[0.73, 0.93], selector=0)\n",
    "fig.update_yaxes(range=[0.67, 0.93], selector=1)\n",
    "fig.update_yaxes(range=[0.65, 0.93], selector=2)\n",
    "# fig.layout['yaxis'].update(range=[0.855, 0.9], matches=None)\n",
    "# fig.update_traces(textposition=\"bottom right\")\n",
    "\n",
    "fig.add_hline(\n",
    "    y=0.9, line_dash=\"dot\",\n",
    "    # annotation_text=r\"$1-\\alpha$\",\n",
    "    # annotation_position=\"top left\"\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Figure 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "aps_imagenet_data_file = \"../experiments/outputs/natural_datasets_aps_bars.csv\"\n",
    "\n",
    "# parameters\n",
    "arch = 'resnet50'\n",
    "\n",
    "use_classes_of_dataset = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "nat_df = pd.read_csv(aps_imagenet_data_file)\n",
    "\n",
    "if use_classes_of_dataset:\n",
    "    nat_df = nat_df[nat_df.on_classes_of == use_classes_of_dataset]\n",
    "else:\n",
    "    nat_df = nat_df[nat_df.on_classes_of.isna()]\n",
    "\n",
    "aps_df = nat_df.query(\"Classifier == @arch\").copy()\n",
    "\n",
    "aps_df['1_alpha'] = 1.0 - aps_df['alpha']\n",
    "\n",
    "yx_sub_df = aps_df[aps_df.Method == 'QTC'].copy()\n",
    "yx_sub_df['Method'] = r\"$y=x$\"\n",
    "yx_sub_df['predicted_coverage'] = yx_sub_df['1_alpha']\n",
    "\n",
    "orig_sub_df = aps_df[aps_df.Method == 'QTC'].copy()\n",
    "orig_sub_df['Method'] = 'original'\n",
    "orig_sub_df['predicted_coverage'] = orig_sub_df['original_coverage']\n",
    "\n",
    "aps_df = pd.concat([yx_sub_df, orig_sub_df, aps_df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.line(\n",
    "    aps_df,\n",
    "    x='1_alpha', y='predicted_coverage', color='Method', markers=True, facet_col='Dataset', facet_col_spacing=0.08,\n",
    "    category_orders={\n",
    "        \"Dataset\": [\"ImageNetV2\", \"ImageNet-Sketch\", \"ImageNet-R\"]\n",
    "    },\n",
    "    title=\"APS coverage under distribution shift w/ and w/o recalibration\",\n",
    "    labels = {\n",
    "        '1_alpha': r'$1-\\alpha$',\n",
    "        'predicted_coverage': 'achieved coverage'\n",
    "    }\n",
    ")\n",
    "\n",
    "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
    "\n",
    "fig.update_yaxes(matches=None, showticklabels=True)\n",
    "\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RAPS under distribution shift\n",
    "### Figure 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "raps_imagenet_data_file = \"../experiments/outputs/natural_datasets_raps_vs_alpha.csv\"\n",
    "\n",
    "# parameters\n",
    "arch = 'resnet50'\n",
    "\n",
    "kreg = 2\n",
    "lamda = 0.2\n",
    "\n",
    "use_classes_of_dataset = None\n",
    "use_platt_scaling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "nat_df = pd.read_csv(raps_imagenet_data_file)\n",
    "\n",
    "if use_classes_of_dataset:\n",
    "    nat_df = nat_df[nat_df.on_classes_of == use_classes_of_dataset]\n",
    "else:\n",
    "    nat_df = nat_df[nat_df.on_classes_of.isna()]\n",
    "\n",
    "raps_df = nat_df.query(\"Classifier == @arch & use_platt_scaling == @use_platt_scaling\")\n",
    "\n",
    "raps_df = raps_df.query(\"raps_kreg == @kreg & raps_lambda == @lamda\")\n",
    "\n",
    "raps_df['1_alpha'] = 1.0 - raps_df['alpha']\n",
    "\n",
    "\n",
    "yx_sub_df = raps_df[raps_df.Method == 'QTC'].copy()\n",
    "yx_sub_df['Method'] = r\"$y=x$\"\n",
    "yx_sub_df['predicted_coverage'] = yx_sub_df['1_alpha']\n",
    "\n",
    "orig_sub_df = raps_df[raps_df.Method == 'QTC'].copy()\n",
    "orig_sub_df['Method'] = 'original'\n",
    "orig_sub_df['predicted_coverage'] = orig_sub_df['original_coverage']\n",
    "\n",
    "raps_df = pd.concat([yx_sub_df, orig_sub_df, raps_df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.line(\n",
    "    raps_df,\n",
    "    x='1_alpha', y='predicted_coverage', color='Method', markers=True, facet_col='Dataset', facet_col_spacing=0.08,\n",
    "    category_orders={\n",
    "        \"Dataset\": [\"ImageNetV2\", \"ImageNet-Sketch\", \"ImageNet-R\"]\n",
    "    },\n",
    "    title=\"RAPS coverage under distribution shift w/ and w/o recalibration\",\n",
    "    labels = {\n",
    "        '1_alpha': r'$1-\\alpha$',\n",
    "        'predicted_coverage': 'achieved coverage'\n",
    "    }\n",
    ")\n",
    "\n",
    "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
    "\n",
    "fig.update_yaxes(matches=None, showticklabels=True)\n",
    "\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
  },
  "kernelspec": {
   "display_name": "Python 3.7.11 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
