{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "392dbe38-422d-4fd6-ab02-98e0a6e10c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "from pac_utils import pac_labeling, zero_one_loss, squared_loss\n",
    "from plotting_utils import pac_plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8b8a0834-ba16-4dbe-85db-37257cb2983a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "157ce7f2-6520-4707-9953-4a64cd39c60d",
   "metadata": {},
   "source": [
    "## Demonstration on Bias Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "41ba7f78-1470-45cb-9900-e2888b06db8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('bias.csv')\n",
    "df['Yhat'] = df['Yhat (GPT4o)']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e63b117-3b27-488e-a513-0b2d6bc81e14",
   "metadata": {},
   "source": [
    "### Initial Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e9b513a1-9d50-4caa-92b1-860ddc6f519a",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.05\n",
    "epsilon = 0.05\n",
    "num_trials = 100\n",
    "errs = np.zeros(num_trials)\n",
    "percent_saved = np.zeros(num_trials)\n",
    "pi = np.ones(len(df))\n",
    "K = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "11cb89c9-8472-4664-a42d-fe49f07f734a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "08e64004c897448bb5d3c7a994525e3b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: 0.04104552276138069 Budget save:( 13.684342171085543 +/- 3.1861053779909403 )\n"
     ]
    }
   ],
   "source": [
    "Y = np.array(df['Y'])\n",
    "Yhat = np.array(df['Yhat'])\n",
    "\n",
    "for i in tqdm(range(num_trials)):\n",
    "    uncertainty = 1 - df['confidence'] + 1e-5 * np.random.normal(size=len(df)) # break ties\n",
    "    Y_tilde, labeled_inds, _ = pac_labeling(Y, \n",
    "                                            Yhat, \n",
    "                                            zero_one_loss, \n",
    "                                            epsilon, \n",
    "                                            alpha, \n",
    "                                            uncertainty, \n",
    "                                            pi, \n",
    "                                            K,\n",
    "                                            asymptotic=False)\n",
    "    errs[i] = zero_one_loss(Y,Y_tilde)\n",
    "    percent_saved[i] = np.mean(labeled_inds==0.0)*100\n",
    "print('Error:', np.quantile(errs, 1-alpha), 'Budget save:(', np.mean(percent_saved), '+/-', np.std(percent_saved),')')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7500c5c-45ff-45d7-8743-f5dc53e1b485",
   "metadata": {},
   "source": [
    "### Next, calibrate the predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cc5e3150-9bcc-45ad-936c-3ea21b46082e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters (set without tuning)\n",
    "calibration_frac = 0.15\n",
    "granularity = 3\n",
    "min_size = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "79be53c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Categorized the news sources into 5 clusters:\n",
    "# 1. Very Conservative\n",
    "# 2. Conservative\n",
    "# 3. Liberal\n",
    "# 4. Very Liberal\n",
    "# 5. Center/Unknown\n",
    "\n",
    "\n",
    "lookup_fine = {\n",
    "    \"The Intercept\": -2,\n",
    "    \"Mother Jones\": -2,\n",
    "    \"Salon\": -2,\n",
    "    \"Media Matters\": -2,\n",
    "    \"The Root\": -2,\n",
    "    \"Democracy Now\": -2,\n",
    "    \"ACLU\": -2,\n",
    "    \"ThinkProgress\": -2,\n",
    "    \"Daily Kos\": -2,\n",
    "    \"TruthOut\": -2,\n",
    "    \"CNN - Editorial\": -1,\n",
    "    \"NPR Online News\": -1,\n",
    "    \"CNN (Web News)\": -1,\n",
    "    \"New York Times - News\": -1,\n",
    "    \"The Atlantic\": -1,\n",
    "    \"HuffPost\": -1,\n",
    "    \"Guest Writer - Left\": -1,\n",
    "    \"Vox\": -1,\n",
    "    \"NBC News (Online)\": -1,\n",
    "    \"Daily Beast\": -1,\n",
    "    \"The New Yorker\": -1,\n",
    "    \"FiveThirtyEight\": -1,\n",
    "    \"The Guardian\": -1,\n",
    "    \"The Boston Globe\": -1,\n",
    "    \"Slate\": -1,\n",
    "    \"Vanity Fair\": -1,\n",
    "    \"New Republic\": -1,\n",
    "    \"BuzzFeed News\": -1,\n",
    "    \"Juan Williams\": -1,\n",
    "    \"New York Magazine\": -1,\n",
    "    \"Julian Zelizer\": -1,\n",
    "    \"Chicago Sun-Times\": -1,\n",
    "    \"Center For American Progress\": -1,\n",
    "    \"Washington Post\": -1,\n",
    "    \"RollingStone.com\": -1,\n",
    "    \"Ezra Klein\": -1,\n",
    "    \"Bustle\": -1,\n",
    "    \"Pacific Standard\": -1,\n",
    "    \"Vice\": -1,\n",
    "    \"NBCNews.com\": -1,\n",
    "    \"New York Times - Opinion\": -1,\n",
    "    \"MSNBC\": -1,\n",
    "    \"Grist\": -1,\n",
    "    \"Detroit Free Press\": -1,\n",
    "    \"Esquire\": -1,\n",
    "    \"ProPublica\": -1,\n",
    "    \"The Week - Opinion\": -1,\n",
    "    \"New York Times (Online News)\": -1,\n",
    "    \"The Verge\": -1,\n",
    "    \"The Observer (New York)\": -1,\n",
    "    \"Time Magazine\": -1,\n",
    "    \"BBC News\": 0,\n",
    "    \"Politico\": 0,\n",
    "    \"Christian Science Monitor\": 0,\n",
    "    \"The Hill\": 0,\n",
    "    \"Mediaite\": 0,\n",
    "    \"Thomas Frey\": 0,\n",
    "    \"USA TODAY\": 0,\n",
    "    \"Reuters\": 0,\n",
    "    \"Lifehacker\": 0,\n",
    "    \"Yahoo! The 360\": 0,\n",
    "    \"Associated Press\": 0,\n",
    "    \"Reason\": 0,\n",
    "    \"Wall Street Journal - News\": 0,\n",
    "    \"The Flip Side\": 0,\n",
    "    \"ABC News\": 0,\n",
    "    \"Atlanta Journal-Constitution\": 0,\n",
    "    \"Al Jazeera\": 0,\n",
    "    \"Deadline.com\": 0,\n",
    "    \"Business Insider\": 0,\n",
    "    \"Bloomberg\": 0,\n",
    "    \"Scientific American\": 0,\n",
    "    \"The Week - News\": 0,\n",
    "    \"MarketWatch\": 0,\n",
    "    \"Smithsonian Magazine\": 0,\n",
    "    \"CBS News\": 0,\n",
    "    \"Guest Writer\": 0,\n",
    "    \"Howard Kurtz\": 0,\n",
    "    \"TechCrunch\": 0,\n",
    "    \"Damon Linker\": 0,\n",
    "    \"The Dallas Morning News\": 0,\n",
    "    \"Yahoo! News\": 0,\n",
    "    \"Indy Online\": 0,\n",
    "    \"RealClearPolitics\": 0,\n",
    "    \"Foreign Policy\": 0,\n",
    "    \"Pew Research Center\": 0,\n",
    "    \"Brookings Institution\": 0,\n",
    "    \"CNBC\": 0,\n",
    "    \"The Economist\": 0,\n",
    "    \"Axios\": 0,\n",
    "    \"DAG Blog\": 0,\n",
    "    \"Nieman Lab\": 0,\n",
    "    \"The Texas Tribune\": 0,\n",
    "    \"ABC News (Online)\": 0,\n",
    "    \"Whitehouse.gov\": 0,\n",
    "    \"Bring Me The News\": 0,\n",
    "    \"NBC Today Show\": 0,\n",
    "    \"Bipartisan Policy Center\": 0,\n",
    "    \"PACE\": 0,\n",
    "    \"Washington Times\": 1,\n",
    "    \"National Review\": 1,\n",
    "    \"Fox Online News\": 1,\n",
    "    \"Guest Writer - Right\": 1,\n",
    "    \"Fox News\": 1,\n",
    "    \"Fox News Opinion\": 1,\n",
    "    \"Allysia Finley (Wall Street Journal)\": 1,\n",
    "    \"The Daily Caller\": 1,\n",
    "    \"John Stossel\": 1,\n",
    "    \"David Brooks\": 1,\n",
    "    \"New York Post\": 1,\n",
    "    \"Scott Walker\": 1,\n",
    "    \"The Daily Signal\": 1,\n",
    "    \"Rich Lowry\": 1,\n",
    "    \"HotAir\": 1,\n",
    "    \"The Epoch Times\": 1,\n",
    "    \"Independent Journal Review\": 1,\n",
    "    \"The Dispatch\": 1,\n",
    "    \"City Journal\": 1,\n",
    "    \"John Fund\": 1,\n",
    "    \"Michael Brendan Dougherty\": 1,\n",
    "    \"Fox News (Online)\": 1,\n",
    "    \"Commonwealth Journal\": 1,\n",
    "    \"Jonah Goldberg\": 1,\n",
    "    \"Quillette\": 1,\n",
    "    \"The Libertarian Republic\": 1,\n",
    "    \"American Enterprise Institute\": 1,\n",
    "    \"Daily Mail\": 1,\n",
    "    \"Washington Free Beacon\": 1,\n",
    "    \"Townhall\": 2,\n",
    "    \"CBN\": 2,\n",
    "    \"Newsmax\": 2,\n",
    "    \"TheBlaze.com\": 2,\n",
    "    \"Breitbart News\": 2,\n",
    "    \"Newsmax - Opinion\": 2,\n",
    "    \"Ann Coulter\": 2,\n",
    "    \"American Spectator\": 2,\n",
    "    \"Michelle Malkin\": 2,\n",
    "    \"Victor Hanson\": 2,\n",
    "    \"Newsmax - News\": 2,\n",
    "    \"Charles Krauthammer\": 2,\n",
    "    \"The Daily Wire\": 2,\n",
    "    \"NewsBusters\": 2,\n",
    "    \"Ben Shapiro\": 2,\n",
    "    \"Media Research Center\": 2,\n",
    "    \"Pat Buchanan\": 2,\n",
    "    \"Newsmax (News)\": 2,\n",
    "    \"InfoWars\": 2,\n",
    "    \"ZeroHedge\": 2,\n",
    "    \"The Western Journal\": 2,\n",
    "    \"CNS News\": 2\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d4842b32-2ad2-4a9a-a02c-8f3a63dcb5fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data loading\n",
    "df['confidence_bin'] = pd.cut(df['confidence'], bins=granularity, labels=False)\n",
    "df['source_type'] = df['source'].apply(lookup_fine.get)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "295cae28-d55f-4674-9b56-83730a46c290",
   "metadata": {},
   "outputs": [],
   "source": [
    "calibration_df = df.sample(frac=calibration_frac, replace=False)\n",
    "rest_of_df = df[~df.index.isin(calibration_df.index)]\n",
    "calibration_frac = 1 - len(rest_of_df) / len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e9afdb19-f459-42ff-9bec-90511f981f9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calibrating cluster source_type=-2...\n",
      "2 0.10000000000000009 6\n",
      "Skipping empty bin 1\n",
      "Skipping empty bin 0\n",
      "Calibrating cluster source_type=1...\n",
      "2 -0.05833333333333335 54\n",
      "1 -0.05263157894736836 19\n",
      "0 0.7 2\n",
      "Calibrating cluster source_type=2...\n",
      "2 -0.03888888888888886 18\n",
      "1 -0.5428571428571429 7\n",
      "0 0.7 4\n",
      "Calibrating cluster source_type=-1...\n",
      "2 -0.29047619047619044 63\n",
      "1 -0.20434782608695634 23\n",
      "0 -0.15 6\n",
      "Calibrating cluster source_type=0...\n",
      "2 -0.3749999999999999 58\n",
      "1 -0.21282051282051273 39\n",
      "0 -0.3 1\n"
     ]
    }
   ],
   "source": [
    "cluster_col = 'source_type'\n",
    "for cluster_id in df[cluster_col].unique():\n",
    "    print(f'Calibrating cluster {cluster_col}={cluster_id}...')\n",
    "    cluster_df = calibration_df[calibration_df[cluster_col] == cluster_id]\n",
    "    for bin_id in df['confidence_bin'].unique():\n",
    "        bin_df = cluster_df[cluster_df[\"confidence_bin\"] == bin_id].copy()\n",
    "        \n",
    "        # Skip empty bins\n",
    "        if len(bin_df) == 0:\n",
    "            print(f'Skipping empty bin {bin_id}')\n",
    "            continue\n",
    "            \n",
    "        bin_df['correct'] = (bin_df['Y'] == bin_df['Yhat'])\n",
    "        correction = bin_df['correct'].mean() - bin_df['confidence'].mean()\n",
    "\n",
    "        print(bin_id, correction, len(bin_df))\n",
    "        if len(bin_df) > min_size:\n",
    "            mask = (rest_of_df[cluster_col] == cluster_id) & (rest_of_df['confidence_bin'] == bin_id)\n",
    "            rest_of_df.loc[mask, 'confidence'] = np.clip(rest_of_df.loc[mask, 'confidence'] + correction, 1e-5, 1 - 1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d622f19-8757-4515-b904-bfb0ba798314",
   "metadata": {},
   "source": [
    "### Post-calibration PAC Labeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6a49fad1-4569-4bc0-8dec-93b7bee6eac4",
   "metadata": {},
   "outputs": [],
   "source": [
    "rest_of_df = rest_of_df.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f899c8d-6fbb-431c-9638-5d7aa098c9e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b039bc647ce345ee8ce28a05d9f1c8f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Y = np.array(rest_of_df['Y'])\n",
    "Yhat = np.array(rest_of_df['Yhat'])\n",
    "pi = np.ones(len(rest_of_df)) # 0.1*np.ones(len(rest_of_df))\n",
    "\n",
    "for i in tqdm(range(num_trials)):\n",
    "    uncertainty = 1 - rest_of_df['confidence'] + 1e-5 * np.random.normal(size=len(rest_of_df)) # break ties\n",
    "    Y_tilde, labeled_inds, _ = pac_labeling(Y, \n",
    "                                            Yhat, \n",
    "                                            zero_one_loss, \n",
    "                                            epsilon * (1 / (1 - calibration_frac)),\n",
    "                                            alpha, \n",
    "                                            uncertainty, \n",
    "                                            pi, \n",
    "                                            K,\n",
    "                                            asymptotic=False)\n",
    "    errs[i] = zero_one_loss(Y,Y_tilde)\n",
    "    percent_saved[i] = np.mean(labeled_inds==0.0)*100\n",
    "\n",
    "err = np.quantile(errs, 1-alpha) * (1 - calibration_frac)\n",
    "print('Error:', err, \n",
    "      'Budget save:(', np.mean(percent_saved) * (1 - calibration_frac), \n",
    "      '+/-', np.std(percent_saved) * (1 - calibration_frac),')')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
