{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Differences\n",
    "\n",
    "The point of this notebook is to evaluate inhowfar EC is a suitable measure for finding differences between models. Are differences between DNN models large enough to support claims about their similarity to humans?\n",
    "\n",
    "We will do this by first loading the Geirhos data, filtering for only those conditions that they include in their analysis.\n",
    "\n",
    "Then, we will loop over all models, and for every model, loop over all conditions. We then bootstrap and aggregate ECs just like we did for the main figure 4.\n",
    "\n",
    "This is done in `bootstrap_models.py` and we plot the results here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "font = {\"size\": 15}\n",
    "\n",
    "matplotlib.rc(\"font\", **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loading our bootstrap-results\n",
    "method = \"lower\"\n",
    "standard_df = pd.read_parquet(\n",
    "    f\"data/brainscore_bootstrapped_ecs_1000_{method}.parquet\",\n",
    "    engine=\"pyarrow\",\n",
    ")\n",
    "display(standard_df)\n",
    "\n",
    "# Geirhos et al take the average EC by first averaging within each experiment, then averaging across them.\n",
    "# (how you average within each experiment doesn't matter, because first conditions then humans = first humans then conditions = all at once)\n",
    "\n",
    "# take the average within each of the 12 experiments\n",
    "exp_mean_df = standard_df.groupby(\n",
    "    [\"bootstrap_id\", \"experiment\", \"model\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "\n",
    "# take the average across the experiments\n",
    "mean_df = exp_mean_df.groupby(\n",
    "    [\"bootstrap_id\", \"model\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "\n",
    "# find the top-k models\n",
    "topk_models = (\n",
    "    mean_df.groupby([\"model\"], observed=True)[\"model-human-ec\"]\n",
    "    .mean()\n",
    "    .reset_index()\n",
    "    .nlargest(n=30, columns=[\"model-human-ec\"])[\"model\"]\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# retain only the top-k models\n",
    "topk_df = mean_df[mean_df[\"model\"].isin(topk_models)].reset_index()\n",
    "standard_df = standard_df[standard_df[\"model\"].isin(topk_models)].reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now do the same aggregation for humans\n",
    "human_df = pd.read_parquet(\n",
    "    pjoin(\"../geirhos_analysis/data\", f\"bootstrapped_human_ecs_standard_10000.parquet\"),\n",
    "    engine=\"pyarrow\",\n",
    ")\n",
    "\n",
    "# take the average within each of the 12 experiments\n",
    "human_mean_df = human_df.groupby(\n",
    "    [\"bootstrap_id\", \"experiment\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "human_mean_df[\"name\"] = \"Humans\"\n",
    "\n",
    "# take the average across the experiments\n",
    "human_final_df = human_mean_df.groupby(\n",
    "    [\"bootstrap_id\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "human_final_df[\"name\"] = \"Humans\"\n",
    "print(\"Mean human-human EC:\", human_final_df[\"human-human-ec\"].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_df(df, hdf, ylim_min, ylim_max, name, show_original=True):\n",
    "    order = (\n",
    "        df.groupby(\"model\", observed=True)[\"model-human-ec\"]\n",
    "        .mean()\n",
    "        .sort_values(ascending=False)\n",
    "        .index\n",
    "    )\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
    "    plt.grid(axis=\"y\")\n",
    "\n",
    "    # plot humans\n",
    "    sns.pointplot(\n",
    "        data=hdf,\n",
    "        estimator=np.mean,\n",
    "        errorbar=(\"pi\", 95),\n",
    "        capsize=0.4,\n",
    "        x=\"name\",\n",
    "        y=\"human-human-ec\",\n",
    "        legend=False,\n",
    "        color=\"maroon\",\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    # plot CIs\n",
    "    sns.pointplot(\n",
    "        data=df,\n",
    "        palette=\"mako\",\n",
    "        estimator=np.mean,\n",
    "        errorbar=(\"pi\", 95),\n",
    "        x=\"model\",\n",
    "        y=\"model-human-ec\",\n",
    "        hue=\"model\",\n",
    "        legend=False,\n",
    "        linestyle=\"none\",\n",
    "        capsize=0.4,\n",
    "        order=order,\n",
    "        hue_order=order,\n",
    "    )\n",
    "\n",
    "    if show_original:\n",
    "        # plot Geirhos-datapoints for models\n",
    "        sns.pointplot(\n",
    "            data=topk_df[topk_df[\"bootstrap_id\"] == 0],\n",
    "            x=\"model\",\n",
    "            y=\"model-human-ec\",\n",
    "            estimator=np.mean,\n",
    "            errorbar=None,\n",
    "            markers=\"x\",\n",
    "            color=\"blue\",\n",
    "            legend=False,\n",
    "            linestyle=\"none\",\n",
    "            order=order,\n",
    "        )\n",
    "\n",
    "        # plot Geirhos-datapoints for humans\n",
    "        plt_df = human_final_df.copy()\n",
    "        plt_df = pd.concat([plt_df, plt_df], ignore_index=True)\n",
    "        sns.pointplot(\n",
    "            data=plt_df[plt_df[\"bootstrap_id\"] == 0],\n",
    "            x=\"name\",\n",
    "            y=\"human-human-ec\",\n",
    "            estimator=np.mean,\n",
    "            errorbar=None,\n",
    "            markers=\"x\",\n",
    "            color=\"blue\",\n",
    "            legend=False,\n",
    "            linestyle=\"none\",\n",
    "        )\n",
    "\n",
    "    ax.set_ylim(ylim_min, ylim_max)\n",
    "    ax.set_xlabel(\"Models\")\n",
    "\n",
    "    # use either top for paper-version or bottom for debugging\n",
    "    ax.set_xticklabels([])\n",
    "    # ax.tick_params(axis=\"x\", labelrotation=90)\n",
    "\n",
    "    ax.set_ylabel(\"Error Consistency to Humans [Kappa]\")\n",
    "    sns.despine()\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\n",
    "        pjoin(\"figures\", f\"brainscore_model_comparison_{name}_{method}.pdf\"),\n",
    "        bbox_inches=\"tight\",\n",
    "    )\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregating like Geirhos et al\n",
    "\n",
    "If one defines Geirhos Error Consistency as the mean EC over all conditions and bootstraps all the way through, i.e. obtaining a final mean for every bootstrap, and then reports a confidence interval over the means, the picture already looks a bit dubious because for the first 15 models the CIs overlap, and for a huge sequence of models in the middle, the CIs overlap as well, as shown here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df(topk_df, human_final_df, 0.2, 0.45, \"mean\", show_original=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Without hiding the variance\n",
    "\n",
    "But what this aggregation hides is that these means themselves were obtained by averaging over values with a very wide spread, so the true variance in the data looks more like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df(standard_df, human_mean_df, -0.1, 0.8, \"all\", show_original=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting without humans\n",
    "\n",
    "Because this makes the CIs seem smaller because of y-axis scaling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_df_wo_humans(df, ylim_min, ylim_max, name):\n",
    "    order = (\n",
    "        df.groupby(\"model\", observed=True)[\"model-human-ec\"]\n",
    "        .mean()\n",
    "        .sort_values(ascending=False)\n",
    "        .index\n",
    "    )\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
    "    plt.grid(axis=\"y\")\n",
    "\n",
    "    # plot CIs\n",
    "    sns.pointplot(\n",
    "        data=df,\n",
    "        palette=\"mako\",\n",
    "        estimator=np.mean,\n",
    "        errorbar=(\"pi\", 95),\n",
    "        x=\"model\",\n",
    "        y=\"model-human-ec\",\n",
    "        hue=\"model\",\n",
    "        legend=False,\n",
    "        linestyle=\"none\",\n",
    "        capsize=0.4,\n",
    "        order=order,\n",
    "        hue_order=order,\n",
    "    )\n",
    "\n",
    "    ax.set_ylim(ylim_min, ylim_max)\n",
    "    ax.set_xlabel(\"Models\")\n",
    "\n",
    "    # use either top for paper-version or bottom for debugging\n",
    "    ax.set_xticklabels([])\n",
    "    # ax.tick_params(axis=\"x\", labelrotation=90)\n",
    "\n",
    "    ax.set_ylabel(\"Error Consistency to Humans [Kappa]\")\n",
    "    sns.despine()\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\n",
    "        pjoin(\"figures\", f\"brainscore_model_comparison_{name}_{method}.pdf\"),\n",
    "        bbox_inches=\"tight\",\n",
    "    )\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df_wo_humans(topk_df, 0.2, 0.35, \"meanwohumans\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
