{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 1\n",
    "\n",
    "Figure 1 shows on the left how old results looked like with point estimates, and on the right how the new results will look like with CIs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "\n",
    "font = {\"size\": 15}\n",
    "matplotlib.rc(\"font\", **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def axvline_with_caps(ax, x, ymin, ymax, cap_width=0.1, color=\"black\", **line_kwargs):\n",
    "    \"\"\"\n",
    "    Draws a vertical line with caps at the specified y-limits (in data coordinates).\n",
    "\n",
    "    Parameters:\n",
    "    - ax: matplotlib Axes object\n",
    "    - x: x position of the vertical line (in data coordinates)\n",
    "    - ymin: lower y-coordinate (in data coordinates)\n",
    "    - ymax: upper y-coordinate (in data coordinates)\n",
    "    - cap_width: width of the caps (in data coordinates)\n",
    "    - **line_kwargs: additional keyword arguments passed to axvline and cap lines\n",
    "    \"\"\"\n",
    "    # Get current axis limits\n",
    "    y_axis_min, y_axis_max = ax.get_ylim()\n",
    "\n",
    "    # Normalize ymin and ymax for axvline\n",
    "    norm_ymin = (ymin - y_axis_min) / (y_axis_max - y_axis_min)\n",
    "    norm_ymax = (ymax - y_axis_min) / (y_axis_max - y_axis_min)\n",
    "\n",
    "    # Draw vertical line\n",
    "    ax.axvline(x=x, ymin=norm_ymin, ymax=norm_ymax, color=color, **line_kwargs)\n",
    "\n",
    "    # Draw horizontal caps at ends\n",
    "    ax.plot(\n",
    "        [x - cap_width / 2, x + cap_width / 2], [ymin, ymin], color=color, **line_kwargs\n",
    "    )\n",
    "    ax.plot(\n",
    "        [x - cap_width / 2, x + cap_width / 2], [ymax, ymax], color=color, **line_kwargs\n",
    "    )\n",
    "\n",
    "\n",
    "def axhline_with_caps(ax, y, xmin, xmax, cap_height=0.05, color=\"black\", **line_kwargs):\n",
    "    \"\"\"\n",
    "    Draws a horizontal line with vertical caps at the specified x-limits (in data coordinates).\n",
    "\n",
    "    Parameters:\n",
    "    - ax: matplotlib Axes object\n",
    "    - y: y position of the horizontal line (in data coordinates)\n",
    "    - xmin: lower x-coordinate (in data coordinates)\n",
    "    - xmax: upper x-coordinate (in data coordinates)\n",
    "    - cap_height: height of the caps (in data coordinates)\n",
    "    - **line_kwargs: additional keyword arguments passed to axhline and cap lines\n",
    "    \"\"\"\n",
    "    # Get current axis limits\n",
    "    x_axis_min, x_axis_max = ax.get_xlim()\n",
    "\n",
    "    # Normalize xmin and xmax for axhline\n",
    "    norm_xmin = (xmin - x_axis_min) / (x_axis_max - x_axis_min)\n",
    "    norm_xmax = (xmax - x_axis_min) / (x_axis_max - x_axis_min)\n",
    "\n",
    "    # Draw horizontal line\n",
    "    ax.axhline(y=y, xmin=norm_xmin, xmax=norm_xmax, color=color, **line_kwargs)\n",
    "\n",
    "    # Draw vertical caps at ends\n",
    "    ax.plot(\n",
    "        [xmin, xmin],\n",
    "        [y - cap_height / 2, y + cap_height / 2],\n",
    "        color=color,\n",
    "        **line_kwargs,\n",
    "    )\n",
    "    ax.plot(\n",
    "        [xmax, xmax],\n",
    "        [y - cap_height / 2, y + cap_height / 2],\n",
    "        color=color,\n",
    "        **line_kwargs,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data for old plot\n",
    "np.random.seed(1)\n",
    "df = pd.DataFrame(\n",
    "    {\n",
    "        \"model\": [\"humans\"] * 100 + [\"A\"] * 100 + [\"B\"] * 100 + [\"C\"] * 100,\n",
    "        \"EC\": np.random.normal(0.42, 0.03, 100).tolist()\n",
    "        + np.random.normal(0.23, 0.04, 100).tolist()\n",
    "        + np.random.normal(0.2, 0.05, 100).tolist()\n",
    "        + np.random.normal(0.18, 0.05, 100).tolist(),\n",
    "    }\n",
    ")\n",
    "\n",
    "\n",
    "def plot():\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)\n",
    "    an_color = \"grey\"\n",
    "\n",
    "    # original plot on the left\n",
    "    ax = axes[0]\n",
    "    sns.pointplot(data=df, x=\"model\", y=\"EC\", linestyle=\"none\", errorbar=None, ax=ax)\n",
    "    sns.despine()\n",
    "    ax.grid(axis=\"y\")  # is ignored by xkcd style anyway\n",
    "    ax.set_ylim(0, 0.55)\n",
    "    ax.set_ylabel(\"EC [kappa]\")\n",
    "    ax.set_xlabel(\"Classifier\")\n",
    "    ax.set_title(\"Previous Work\")\n",
    "\n",
    "    # add annotations\n",
    "    axvline_with_caps(ax, x=-0.15, ymin=0.23, ymax=0.42, color=an_color)\n",
    "    ax.annotate(\"Human-Machine\\nGap\", (-0.1, 0.3), fontsize=12, color=an_color)\n",
    "\n",
    "    axhline_with_caps(ax, y=0.11, xmin=1, xmax=3, color=an_color)\n",
    "    ax.annotate(\"Model Ranking\", (1.5, 0.05), fontsize=12, color=an_color)\n",
    "\n",
    "    # new plot on the left\n",
    "    ax = axes[1]\n",
    "    sns.pointplot(\n",
    "        data=df,\n",
    "        x=\"model\",\n",
    "        y=\"EC\",\n",
    "        errorbar=(\"pi\", 95),\n",
    "        capsize=0.3,\n",
    "        linestyle=\"none\",\n",
    "        ax=ax,\n",
    "    )\n",
    "    sns.despine()\n",
    "    ax.grid(axis=\"y\")\n",
    "    ax.set_xlabel(\"Classifier\")\n",
    "    ax.set_title(\"Our Work\")\n",
    "\n",
    "    # add annotations\n",
    "    axvline_with_caps(ax, x=-0.25, ymin=0.3, ymax=0.35, color=an_color)\n",
    "    ax.annotate(\"Gap shrinks\", (-0.19, 0.31), fontsize=12, color=an_color)\n",
    "\n",
    "    axhline_with_caps(ax, y=0.35, xmin=1, xmax=3, color=an_color)\n",
    "    ax.annotate(\n",
    "        \"differences are insignificant\", (1.0, 0.4), fontsize=12, color=an_color\n",
    "    )\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"figures/demo_figure_v2.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "with plt.xkcd():\n",
    "    plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
