{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54bc22ca-6f62-43bc-ab4d-f614ee9239af",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "from pac_utils import pac_labeling, zero_one_loss, squared_loss\n",
    "from plotting_utils import pac_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7df5a45-eea9-4d84-9f6b-1c653d742fc0",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63bef1f0-6164-42b9-9457-26003f4c839a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'alphafold' # one of 'stance', 'misinfo', 'bias', 'imagenet', 'imagenetv2', 'sentiment', 'alphafold'\n",
    "data = pd.read_csv(dataset + '.csv')\n",
    "Y = data[\"Y\"].to_numpy()\n",
    "if dataset in ['stance', 'misinfo', 'bias', 'sentiment']:\n",
    "    Yhat = data[\"Yhat (GPT4o)\"].to_numpy()\n",
    "elif dataset in ['imagenet', 'imagenetv2', 'alphafold']:\n",
    "    Yhat = data[\"Yhat\"].to_numpy()\n",
    "confidence = data[\"confidence\"].to_numpy()\n",
    "n = len(Y)\n",
    "\n",
    "if dataset in ['stance', 'misinfo', 'bias', 'imagenet', 'imagenetv2']:\n",
    "    loss = zero_one_loss\n",
    "elif dataset in ['sentiment', 'alphafold']:\n",
    "    loss = squared_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7d86e7d-601f-4753-ae99-7890e7e5192b",
   "metadata": {},
   "source": [
    "### Set parameters for PAC labeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3078b2ae-7d80-44b8-94d8-70221aaf562a",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.05\n",
    "epsilon = 0.05\n",
    "num_trials = 1000\n",
    "errs = np.zeros(num_trials)\n",
    "percent_saved = np.zeros(num_trials)\n",
    "pi = np.ones(len(Y))\n",
    "if dataset in ['stance', 'misinfo', 'bias']:\n",
    "    asymptotic = False\n",
    "    K = 500\n",
    "elif dataset in ['imagenet', 'imagenetv2']:\n",
    "    asymptotic = False\n",
    "    K = n // 10\n",
    "elif dataset in ['sentiment', 'alphafold']:\n",
    "    asymptotic = True\n",
    "    K = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5301373-9a3b-47ef-8b2e-1bd53cceba96",
   "metadata": {},
   "source": [
    "### Run PAC labeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c991358-7bb9-4170-8e19-51bf7d6b1e2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in tqdm(range(num_trials)):\n",
    "    uncertainty = 1 - confidence + 1e-5*np.random.normal(size=n) # break ties\n",
    "    Y_tilde, labeled_inds, _ = pac_labeling(Y, Yhat, loss, epsilon, alpha, uncertainty, pi, K, asymptotic=asymptotic)\n",
    "    errs[i] = loss(Y,Y_tilde)\n",
    "    percent_saved[i] = np.mean(labeled_inds==0.0)*100\n",
    "print('Error:', np.quantile(errs, 1-alpha), 'Budget save:(', np.mean(percent_saved), '+/-', np.std(percent_saved),')')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "373fbb3a-b6ba-4034-a6ed-cf9c783a79ca",
   "metadata": {},
   "source": [
    "### Plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5075cef7-8c00-4b74-b509-22161932a63c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pac_plot(errs, percent_saved, epsilon, Y, Yhat, confidence, loss, dataset, num_trials, xlim=[0,0.4], plot_naive=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f4e61a0-f9cd-4d2e-9497-f00ffc7b8544",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
