{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57cf3926-086c-4b94-bc58-f246766fc667",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "from router_utils import train_weighter\n",
    "from pac_utils import pac_labeling, zero_one_loss\n",
    "from plotting_utils import pac_router_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94315487-6493-421c-9245-ad6b2fcf1cab",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff5df671-145b-4daa-b7c1-863a98fdcdea",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"bias_source_features.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbdab28c-56be-4880-b4aa-f7dbaa279f52",
   "metadata": {},
   "source": [
    "### Train PAC router"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a9369c-d111-4d25-8443-d28c49ac73fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "costless = True\n",
    "if costless:\n",
    "    costs = np.array([0,0])\n",
    "else:\n",
    "    costs = np.array([0.075, 0.25]) # real ratio of claude sonnet and gpt-4o\n",
    "expert_cost = 1\n",
    "results = train_weighter(df, calibration_frac=0.1, costs=costs, expert_cost=expert_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42c20340-5ba1-4628-a87a-5fb2ab0b6803",
   "metadata": {},
   "source": [
    "### Set parameters for PAC labeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00658a21-bec2-4edb-9373-4465150881b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.05\n",
    "epsilon = 0.05\n",
    "num_trials = 1000\n",
    "K = 500\n",
    "Y = df[\"Y\"].to_numpy()[~results[\"label_collected\"]]\n",
    "Yhat_routed = results[\"Yhat_routed\"].to_numpy()[~results[\"label_collected\"]]\n",
    "confidence_routed = results[\"confidence_routed\"].to_numpy()[~results[\"label_collected\"]]\n",
    "Yhat_claude = df[\"Yhat_0\"].to_numpy()[~results[\"label_collected\"]]\n",
    "confidence_claude = df[\"confidence_0\"].to_numpy()[~results[\"label_collected\"]]\n",
    "Yhat_gpt = df[\"Yhat_1\"].to_numpy()[~results[\"label_collected\"]]\n",
    "confidence_gpt = df[\"confidence_1\"].to_numpy()[~results[\"label_collected\"]]\n",
    "pi = 1*np.ones(len(Y))\n",
    "cost_sensitive = (np.sum(costs) > 0)\n",
    "if cost_sensitive:\n",
    "    cost_Yhats = costs[results[\"routed_model\"].to_numpy()[~results[\"label_collected\"]]]\n",
    "    cost_claude = costs[0]\n",
    "    cost_gpt = costs[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20459cfa-dee2-4642-81c5-0d2af8cf18ea",
   "metadata": {},
   "source": [
    "### Run PAC labeling with router"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5999fc40-6fea-4abd-8158-474cd22c1051",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_tilde = np.zeros(len(Y))\n",
    "budget_save_router = np.zeros(num_trials)\n",
    "errs_router = np.zeros(num_trials)\n",
    "for i in tqdm(range(num_trials)):\n",
    "    uncertainty_routed = 1 - confidence_routed + 1e-5*np.random.randn(len(Y)) # break ties\n",
    "    Y_tilde, labeled_inds, _ = pac_labeling(Y, Yhat_routed, zero_one_loss, epsilon, alpha, uncertainty_routed, pi, K, asymptotic=False)\n",
    "    if cost_sensitive:\n",
    "        budget_save_router[i] = (len(Y) - np.sum(labeled_inds))*expert_cost - np.sum(cost_Yhats[np.where(1-labeled_inds)])\n",
    "    else:\n",
    "        budget_save_router[i] = (labeled_inds == 0).mean()*100\n",
    "    errs_router[i] = zero_one_loss(Y, Y_tilde)\n",
    "print('Error:', np.quantile(errs_router, 1-alpha), 'Budget save:(', np.mean(budget_save_router), '+/-', np.std(budget_save_router),')')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe713486-327b-444a-8a99-8a702c60fc81",
   "metadata": {},
   "source": [
    "### Run PAC labeling with GPT/Claude individually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae3b56e-dca3-4d8e-a5aa-c47b138e48c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_tilde = np.zeros(len(Y))\n",
    "budget_save_claude = np.zeros(num_trials)\n",
    "errs_claude = np.zeros(num_trials)\n",
    "for i in tqdm(range(num_trials)):\n",
    "    uncertainty_claude = 1 - confidence_claude + 1e-5*np.random.randn(len(Y)) # break ties\n",
    "    Y_tilde, labeled_inds, _ = pac_labeling(Y, Yhat_claude, zero_one_loss, epsilon, alpha, uncertainty_claude, pi, K, asymptotic=False)\n",
    "    if cost_sensitive:\n",
    "        budget_save_claude[i] = (len(Y) - np.sum(labeled_inds))*expert_cost - cost_claude*np.sum(1-labeled_inds)\n",
    "    else:\n",
    "        budget_save_claude[i] = (labeled_inds == 0).mean()*100\n",
    "    errs_claude[i] = zero_one_loss(Y, Y_tilde)\n",
    "print('Error:', np.quantile(errs_claude, 1-alpha), 'Budget save:(', np.mean(budget_save_claude), '+/-', np.std(budget_save_claude),')')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18e50b21-24fd-4305-a1f9-372ff2978ebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_tilde = np.zeros(len(Y))\n",
    "budget_save_gpt = np.zeros(num_trials)\n",
    "errs_gpt = np.zeros(num_trials)\n",
    "for i in tqdm(range(num_trials)):\n",
    "    uncertainty_gpt = 1 - confidence_gpt + 1e-5*np.random.randn(len(Y)) # break ties\n",
    "    Y_tilde, labeled_inds, _ = pac_labeling(Y, Yhat_gpt, zero_one_loss, epsilon, alpha, uncertainty_gpt, pi, K, asymptotic=False)\n",
    "    if cost_sensitive:\n",
    "        budget_save_gpt[i] = (len(Y) - np.sum(labeled_inds))*expert_cost - cost_gpt*np.sum(1-labeled_inds)\n",
    "    else:\n",
    "        budget_save_gpt[i] = (labeled_inds == 0).mean()*100\n",
    "    errs_gpt[i] = zero_one_loss(Y, Y_tilde)\n",
    "print('Error:', np.quantile(errs_gpt, 1-alpha), 'Budget save:(', np.mean(budget_save_gpt), '+/-', np.std(budget_save_gpt),')')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a26be92-70e7-416f-87a3-e46a7f81a539",
   "metadata": {},
   "source": [
    "### Plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed112a1-25f8-42e3-9604-4aea2e650dc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pac_router_plot([errs_router, errs_gpt, errs_claude], [budget_save_router, budget_save_gpt, budget_save_claude], epsilon, \"routed_costs_epsilon_\", num_trials=num_trials, cost_free = (np.sum(costs)==0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b10ac66-d0f6-4511-8aef-1b34d556132c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.axhline(y=epsilon, linestyle='--', color='black', linewidth=1.2, alpha=0.6)\n",
    "plt.plot(range(len(uncertainty_routed)), [np.sum((Yhat_routed != Y)[np.where(uncertainty_routed < u)])/len(Y) for u in np.sort(uncertainty_routed)], color='#2274A5', label='router')\n",
    "plt.plot(range(len(uncertainty_gpt)), [np.sum((Yhat_gpt != Y)[np.where(uncertainty_gpt < u)])/len(Y) for u in np.sort(uncertainty_gpt)], color=\"#00CC66\", label='GPT')\n",
    "plt.plot(range(len(uncertainty_claude)), [np.sum((Yhat_claude != Y)[np.where(uncertainty_claude < u)])/len(Y) for u in np.sort(uncertainty_claude)], color=\"#F75C03\", label='Claude')\n",
    "plt.xlabel(\"$u$\", fontsize=16)\n",
    "plt.ylabel(\"$L^u$\", fontsize=16)\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "plt.gca().spines['right'].set_visible(False)\n",
    "plt.grid(True, linestyle=':', color='gray', alpha=0.4)\n",
    "plt.legend(frameon=False, fontsize=16, loc='upper left')\n",
    "yticks = plt.yticks()[0]\n",
    "yticks = np.append(yticks, epsilon)\n",
    "plt.yticks(yticks)\n",
    "plt.gca().set_yticklabels(\n",
    "    [r'$\\varepsilon=$' + str(epsilon) if np.isclose(tick, epsilon) else f'{tick:.2f}' for tick in yticks],\n",
    "    fontsize=12\n",
    ")\n",
    "plt.xticks(fontsize=12)\n",
    "plt.ylim([0,0.5])\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"Lu.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e84a435-375b-4f6d-b562-6c46821f40a5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
