{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bootstrapping Confidence Invervals\n",
    "\n",
    "This notebook demonstrates how the size of confidence intervals for EC data depends on the number of sampled trials, and on the accuracy delta between the two classifiers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# making sure that updates to imported files are immediately available without restarting the kernel\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib\n",
    "\n",
    "# not using this here because finer control needed\n",
    "# font = {\"size\": 15}\n",
    "# matplotlib.rc(\"font\", **font)\n",
    "\n",
    "sys.path.append(os.path.abspath(\"..\"))\n",
    "from utils import fast_cohen, simulate_trials_from_copy_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_ci_data():\n",
    "    # collecting a bunch of data\n",
    "    gt_ecs = []\n",
    "    std_ecs = []\n",
    "    accs = []\n",
    "    trials = []\n",
    "    its = []\n",
    "    for kappa in [0, 0.25, 0.5, 0.75, 0.95, 0.999]:\n",
    "        for acc in [0.05, 0.25, 0.5, 0.75, 0.95]:\n",
    "            for n_trials in [100, 200, 400, 800, 1600]:\n",
    "                for i in range(5000):\n",
    "                    trials1, trials2 = simulate_trials_from_copy_model(\n",
    "                        kappa, acc, acc, n_trials\n",
    "                    )\n",
    "\n",
    "                    if (\n",
    "                        np.all(trials1)\n",
    "                        or np.all(trials2)\n",
    "                        or not np.any(trials1)\n",
    "                        or not np.any(trials2)\n",
    "                    ):\n",
    "                        std_ecs.append(np.nan)\n",
    "                    else:\n",
    "                        std_ecs.append(fast_cohen(trials1, trials2))\n",
    "\n",
    "                    gt_ecs.append(kappa)\n",
    "                    accs.append(acc)\n",
    "                    trials.append(n_trials)\n",
    "                    its.append(i)\n",
    "\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"True EC\": gt_ecs,\n",
    "            \"Accuracy\": accs,\n",
    "            \"Trials\": trials,\n",
    "            \"Iteration\": its,\n",
    "            \"Empirical EC\": std_ecs,\n",
    "        }\n",
    "    )\n",
    "    df = df[~df[\"Empirical EC\"].isnull()]\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_data = False\n",
    "if generate_data:\n",
    "    os.makedirs(\"data\", exist_ok=True)\n",
    "    df = generate_ci_data()\n",
    "    df.to_csv(\"data/ecdf_equal.csv\", na_rep=\"NULL\", index=False)\n",
    "else:\n",
    "    df = pd.read_csv(\"data/ecdf_equal.csv\")\n",
    "\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # extracting the number of trials from the data by Geirhos et al\n",
    "# geirhos_raw_df = pd.read_parquet(\n",
    "#     \"../geirhos_analysis/data/geirhos_raw_data.parquet\", engine=\"pyarrow\"\n",
    "# )\n",
    "# human_df = geirhos_raw_df[geirhos_raw_df[\"subject_type\"] == \"human\"]\n",
    "# human_df.groupby([\"experiment\", \"condition\", \"subj\"], observed=True)[\"category\"].count().unique()\n",
    "\n",
    "# # Wiles at al:\n",
    "# words = [3727, 2729, 3901, 3304, 2792, 3298, 2549, 3472]\n",
    "# print(np.mean(words))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geirhos = \"Geirhos et al. 2021\"\n",
    "li = \"Li et al. 2025\"\n",
    "wiles = \"Wiles et al. 2024\"\n",
    "ollikka = \"Ollikka et al. 2024\"\n",
    "\n",
    "literature_sizes = {\n",
    "    geirhos: [160, 320, 560, 640, 800, 1_280],\n",
    "    li: [600],  # seems like the fairer comparison\n",
    "    wiles: [3_222],\n",
    "    ollikka: [147],\n",
    "}\n",
    "\n",
    "# specifying markers and colors\n",
    "symbols = {\n",
    "    geirhos: \"v\",  # triangle down\n",
    "    li: \"^\",  # triangle up\n",
    "    ollikka: \"o\",  # circle,\n",
    "    wiles: \"p\",  # pentagon\n",
    "}\n",
    "cmap = sns.color_palette(\"mako\", n_colors=len(literature_sizes.keys()), as_cmap=False)\n",
    "colors = {k: cmap[i] for i, k in enumerate(literature_sizes.keys())}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def put_markers_on_x_axis(ax, annotations):\n",
    "    handles = []\n",
    "    labels = []\n",
    "\n",
    "    # Optionally get current y-limits to place markers below\n",
    "    ymin, _ = ax.get_ylim()\n",
    "    y_marker = ymin  # - 1  # adjust based on your data\n",
    "\n",
    "    for label, xvals in annotations.items():\n",
    "        sym = symbols.get(label, \"o\")  # default to circle if not found\n",
    "        color = colors.get(label, \"black\")\n",
    "        handle = ax.plot(\n",
    "            xvals,\n",
    "            [y_marker] * len(xvals),\n",
    "            sym,\n",
    "            color=color,\n",
    "            label=label,\n",
    "            zorder=10,\n",
    "            clip_on=False,\n",
    "        )[0]\n",
    "        handles.append(handle)\n",
    "        labels.append(label)\n",
    "\n",
    "    # Add the new legend for the symbols\n",
    "    legend2 = ax.legend(\n",
    "        handles,\n",
    "        labels,\n",
    "        loc=\"upper center\",  # place it above the anchor point\n",
    "        bbox_to_anchor=(0.5, -0.15),  # center-bottom, below the axes\n",
    "        ncol=len(handles),  # horizontal layout with one entry per column\n",
    "        # frameon=False,\n",
    "        title=\"Reference Works\",\n",
    "    )\n",
    "    ax.add_artist(legend2)  # Keep the original legend\n",
    "\n",
    "\n",
    "# plotting this as pointplots with 95% PIs\n",
    "def plot_cis(df, save=False):\n",
    "\n",
    "    for gt_ec in df[\"True EC\"].unique():\n",
    "        pdf = df[df[\"True EC\"] == gt_ec]\n",
    "        pdf.loc[:, \"Trials\"] = pd.to_numeric(pdf[\"Trials\"], errors=\"coerce\")\n",
    "        fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
    "        # ax.set_title(f\"True EC: {gt_ec}\")\n",
    "        ax.grid(axis=\"y\")\n",
    "        sns.pointplot(\n",
    "            data=pdf,\n",
    "            errorbar=(\"pi\", 95),\n",
    "            capsize=0.1,\n",
    "            x=\"Trials\",\n",
    "            y=\"Empirical EC\",\n",
    "            hue=\"Accuracy\",\n",
    "            dodge=0.4,\n",
    "            linestyle=\"none\",\n",
    "            legend=True,\n",
    "            log_scale=(True, False),\n",
    "            native_scale=True,\n",
    "            ax=ax,\n",
    "        )\n",
    "        ax.set_ylim(\n",
    "            -0.2, 1.05\n",
    "        )  # using -0.2 for the paper figure only, better value here is -0.4\n",
    "        sns.despine()\n",
    "\n",
    "        # manually setting x-tick labels\n",
    "        old_ticks = pdf[\"Trials\"].unique()\n",
    "        old_ticks.sort()\n",
    "        ax.set_xticks(old_ticks, old_ticks)\n",
    "        ax.tick_params(axis=\"x\", labelsize=13)\n",
    "\n",
    "        # manually making sure that the legend stays there, because I will add a new one soon\n",
    "        original_legend = ax.get_legend()\n",
    "        ax.add_artist(original_legend)\n",
    "\n",
    "        ax.set_ylabel(\"Empirical EC\", fontsize=15)\n",
    "        ax.set_xlabel(\"Number of Trials\", fontsize=15)\n",
    "\n",
    "        # annotate x axis with literature results\n",
    "        put_markers_on_x_axis(ax, literature_sizes)\n",
    "        plt.tight_layout()\n",
    "        fig.subplots_adjust(bottom=0.3)\n",
    "        if save:\n",
    "            if not os.path.exists(\"figures\"):\n",
    "                os.makedirs(\"figures\")\n",
    "\n",
    "            plt.savefig(f\"figures/pointplot_{gt_ec}_standard.pdf\")\n",
    "        plt.show()\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make different versions of figure 3, relating CI size to number of trials\n",
    "\n",
    "# NOTE: the legend with references doesn't show up here, it's only in the pdf\n",
    "plot_cis(df, save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plotting ground-truth EC against CI size, for different number of trials as point plot just to understand the data\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 5))\n",
    "ax.grid(axis=\"y\")\n",
    "sns.pointplot(\n",
    "    data=df,\n",
    "    x=\"True EC\",\n",
    "    y=\"Empirical EC\",\n",
    "    hue=\"Trials\",\n",
    "    errorbar=(\"pi\", 95),\n",
    "    capsize=0.2,\n",
    "    dodge=0.4,\n",
    "    linestyle=\"none\",\n",
    "    ax=ax,\n",
    ")\n",
    "sns.despine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plotting ground-truth EC against CI size, for different number of trials\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 5))\n",
    "ax.grid(axis=\"y\")\n",
    "sns.lineplot(\n",
    "    data=df,\n",
    "    x=\"True EC\",\n",
    "    y=\"Empirical EC\",\n",
    "    hue=\"Trials\",\n",
    "    errorbar=(\"pi\", 95),\n",
    "    # linestyle='none',\n",
    "    ax=ax,\n",
    ")\n",
    "ax.set_xlim(0.0, 0.95)\n",
    "ax.set_ylim(-0.2, 1.0)\n",
    "sns.despine()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"figures/ci_size_plot.pdf\")\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {\"size\": 15}\n",
    "matplotlib.rc(\"font\", **font)\n",
    "\n",
    "# normalized version of this figure\n",
    "norm_df = df.copy()\n",
    "norm_df[\"delta\"] = norm_df[\"Empirical EC\"] - norm_df[\"True EC\"]\n",
    "\n",
    "# plotting ground-truth EC against CI size, for different number of trials\n",
    "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
    "ax.grid(axis=\"y\")\n",
    "palette_name = \"crest\"\n",
    "sns.lineplot(\n",
    "    data=norm_df,\n",
    "    x=\"True EC\",\n",
    "    y=\"delta\",\n",
    "    hue=\"Trials\",\n",
    "    palette=palette_name,\n",
    "    errorbar=(\"pi\", 95),\n",
    "    linestyle=\"none\",\n",
    "    legend=False,\n",
    "    ax=ax,\n",
    ")\n",
    "\n",
    "# Manually create legend handles\n",
    "import matplotlib.lines as mlines\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "hue_levels = norm_df[\"Trials\"].unique()\n",
    "n_levels = len(hue_levels)\n",
    "colormap = cm.get_cmap(palette_name, n_levels)  # Or use the one you passed explicitly\n",
    "palette = [colormap(i) for i in range(n_levels)]\n",
    "handles = [\n",
    "    mlines.Line2D([], [], color=palette[i], marker=\"o\", linestyle=\"None\", label=hue)\n",
    "    for i, hue in enumerate(hue_levels)\n",
    "]\n",
    "\n",
    "# Add the custom legend\n",
    "ax.legend(handles=handles, title=\"Trials\")\n",
    "\n",
    "ax.set_ylabel(\"Delta between empirical and true EC\")\n",
    "ax.set_xlim(0.0, 1)\n",
    "ax.set_ylim(-0.4, 0.4)\n",
    "yticks = np.arange(-0.4, 0.41, 0.2)  # Every second value in range\n",
    "ax.set_yticks(yticks)\n",
    "sns.despine()\n",
    "# Fix confidence interval alpha\n",
    "for collection in ax.collections:\n",
    "    collection.set_alpha(0.8)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"figures/ci_size_plot_centered.pdf\")\n",
    "plt.show()\n",
    "plt.close()"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
