{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Differences\n",
    "\n",
    "The point of this notebook is to evaluate inhowfar EC is a suitable measure for finding differences between models. Are differences between DNN models large enough to support claims about their similarity to humans?\n",
    "\n",
    "We will do this by first loading the data by Geirhos et al, filtering for only those conditions that they actually include in their analysis.\n",
    "\n",
    "Then, we will loop over all models, and for every model, loop over all conditions. We then bootstrap and aggregate ECs just like we did for the main figure 4.\n",
    "\n",
    "This is done in `bootstrap_models.py` and we plot the results here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from typing import List\n",
    "import matplotlib\n",
    "\n",
    "font = {\"size\": 15}\n",
    "matplotlib.rc(\"font\", **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_df(fname: str, experiments: List[str]):\n",
    "    \"\"\"\n",
    "    Given a bootstrap-result-df, aggregate the data like Geirhos et al do:\n",
    "    They take the average EC by first averaging within each experiment, then averaging across them.\n",
    "    How you average within each experiment doesn't matter, because\n",
    "    \"first conditions then humans\" == \"first humans then conditions\" == \"all at once\".\n",
    "\n",
    "    :param fname: path to parquet file\n",
    "    :param experiments: list of experiment names to include, or None if all should be included\n",
    "\n",
    "    :return: todo write\n",
    "    \"\"\"\n",
    "\n",
    "    # loading bootstrap-results\n",
    "    standard_df = pd.read_parquet(fname, engine=\"pyarrow\")\n",
    "\n",
    "    # filtering experiments\n",
    "    if experiments is None:\n",
    "        experiments = standard_df[\"experiment\"].unique()\n",
    "    standard_df = standard_df[standard_df[\"experiment\"].isin(experiments)]\n",
    "\n",
    "    # take the average within each of the 17 experiments\n",
    "    exp_mean_df = standard_df.groupby(\n",
    "        [\"bootstrap_id\", \"experiment\", \"model\"], observed=True, as_index=False\n",
    "    ).mean(numeric_only=True)\n",
    "\n",
    "    # take the average across the experiments\n",
    "    mean_df = exp_mean_df.groupby(\n",
    "        [\"bootstrap_id\", \"model\"], observed=True, as_index=False\n",
    "    ).mean(numeric_only=True)\n",
    "\n",
    "    # list for each model what its EC is, just for comparing to the Geirhos data\n",
    "    # print(mean_df.groupby([\"model\"], observed=True)[\"model-human-ec\"].mean())\n",
    "\n",
    "    return mean_df, standard_df, exp_mean_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df of mean model-machine ECs, for all experiments\n",
    "mean_df, standard_df, exp_mean_df = get_mean_df(\n",
    "    pjoin(\"data\", f\"model_wise_bootstrapped_ecs_standard_10000.parquet\"), None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df of mean model-machine ECs, for only the core experiments\n",
    "core_mean_df, _, _ = get_mean_df(\n",
    "    pjoin(\"data\", f\"model_wise_bootstrapped_ecs_standard_10000.parquet\"),\n",
    "    [\n",
    "        \"sketch\",\n",
    "        \"eidolonII\",\n",
    "        \"power-equalisation\",\n",
    "        \"cue-conflict\",\n",
    "        \"colour\",\n",
    "        \"high-pass\",\n",
    "        \"false-colour\",\n",
    "        \"phase-scrambling\",\n",
    "        \"rotation\",\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now do the same aggregation for humans\n",
    "human_df = pd.read_parquet(\n",
    "    pjoin(\"data\", f\"bootstrapped_human_ecs_standard_10000.parquet\"),\n",
    "    engine=\"pyarrow\",\n",
    ")\n",
    "\n",
    "# take the average within each of the 12 experiments\n",
    "human_mean_df = human_df.groupby(\n",
    "    [\"bootstrap_id\", \"experiment\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "human_mean_df[\"name\"] = \"Humans\"\n",
    "\n",
    "# take the average across the experiments\n",
    "human_final_df = human_mean_df.groupby(\n",
    "    [\"bootstrap_id\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "human_final_df[\"name\"] = \"Humans\"\n",
    "print(\"Mean human-human EC:\", human_final_df[\"human-human-ec\"].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_df(\n",
    "    df: pd.DataFrame,\n",
    "    hdf: pd.DataFrame,\n",
    "    ylim_min: float,\n",
    "    ylim_max: float,\n",
    "    name: str,\n",
    "    show_original: bool = True,\n",
    "    show_models: bool = True,\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot the ECs with CIs of models and humans based on bootstrapped data.\n",
    "\n",
    "    :param df: the df with columns bootstrap, model, and final model-human-ec\n",
    "    :param hdf: the same df for humans, so with final human-human-ec\n",
    "    :param ylim_min: y-limit of the plot\n",
    "    :param ylim_max: y-limit of the plot\n",
    "    :param name: only used in the filename of the pdf\n",
    "    :param show_original: whether to show the values of bootstrap 0 as x's\n",
    "    :param show_models: whether to label the x-axis with model names\n",
    "    \"\"\"\n",
    "\n",
    "    # order models by mean of the bootstraps\n",
    "    order = (\n",
    "        df.groupby(\"model\", observed=True)[\"model-human-ec\"]\n",
    "        .mean()\n",
    "        .sort_values(ascending=False)\n",
    "        .index\n",
    "    )\n",
    "\n",
    "    fsize = (12, 5)\n",
    "    fig, ax = plt.subplots(1, 1, figsize=fsize)\n",
    "    plt.grid(axis=\"y\")\n",
    "\n",
    "    # plot humans\n",
    "    sns.pointplot(\n",
    "        data=hdf,\n",
    "        estimator=np.mean,\n",
    "        errorbar=(\"pi\", 95),\n",
    "        capsize=0.4,\n",
    "        x=\"name\",\n",
    "        y=\"human-human-ec\",\n",
    "        legend=False,\n",
    "        color=\"maroon\",\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    # plot CIs\n",
    "    sns.pointplot(\n",
    "        data=df,\n",
    "        palette=\"mako\",\n",
    "        estimator=np.mean,\n",
    "        errorbar=(\"pi\", 95),\n",
    "        x=\"model\",\n",
    "        y=\"model-human-ec\",\n",
    "        hue=\"model\",\n",
    "        legend=False,\n",
    "        linestyle=\"none\",\n",
    "        capsize=0.4,\n",
    "        order=order,\n",
    "        hue_order=order,\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    if show_original:\n",
    "\n",
    "        # plot Geirhos-datapoints for models\n",
    "        sns.pointplot(\n",
    "            data=mean_df[mean_df[\"bootstrap_id\"] == 0],\n",
    "            x=\"model\",\n",
    "            y=\"model-human-ec\",\n",
    "            estimator=np.mean,\n",
    "            errorbar=None,\n",
    "            markers=\"x\",\n",
    "            color=\"blue\",\n",
    "            legend=False,\n",
    "            linestyle=\"none\",\n",
    "            order=order,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "        # plot Geirhos-datapoints for humans\n",
    "        sns.pointplot(\n",
    "            data=human_final_df[human_final_df[\"bootstrap_id\"] == 0],\n",
    "            x=\"name\",\n",
    "            y=\"human-human-ec\",\n",
    "            estimator=np.mean,\n",
    "            errorbar=None,\n",
    "            markers=\"x\",\n",
    "            color=\"blue\",\n",
    "            legend=False,\n",
    "            linestyle=\"none\",\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    ax.set_ylim(ylim_min, ylim_max)\n",
    "    ax.set_yticks([0.1, 0.2, 0.3, 0.4, 0.5])\n",
    "    ax.set_xlabel(\"Models\")\n",
    "\n",
    "    if show_models:\n",
    "        ax.set_xticklabels(\n",
    "            labels=[\"H\"] + [i for i in range(len(order))], fontdict={\"size\": 10}\n",
    "        )\n",
    "        # ax.tick_params(axis=\"x\", labelrotation=90)\n",
    "    else:\n",
    "        ax.set_xticklabels([])\n",
    "\n",
    "    ax.set_ylabel(\"Error Consistency to Humans [Kappa]\")\n",
    "    sns.despine()\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(pjoin(\"figures\", f\"model_comparison_{name}.pdf\"), bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    return order"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregating like Geirhos et al\n",
    "\n",
    "If one defines Geirhos Error Consistency as the mean EC over all conditions and bootstraps all the way through, i.e. obtaining a final mean for every bootstrap, and then reports a confidence interval over the means, the picture already looks a bit dubious because for the first 15 models the CIs overlap, and for a huge sequence of models in the middle, the CIs overlap as well, as shown here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geirhos_mean_order = plot_df(\n",
    "    mean_df, human_final_df, 0.1, 0.5, \"mean\", show_original=False, show_models=True\n",
    ")\n",
    "\n",
    "model_appdx_df = pd.DataFrame(\n",
    "    {\"Index\": np.arange(0, len(geirhos_mean_order)), \"Model Name\": geirhos_mean_order}\n",
    ")\n",
    "\n",
    "\n",
    "def format_func(x: str) -> str:\n",
    "    return x.replace(\"_\", \"\\_\")\n",
    "\n",
    "\n",
    "latex = model_appdx_df.to_latex(\n",
    "    index=False, formatters={\"Model Name\": format_func}, float_format=\"{:.1f}\".format\n",
    ")\n",
    "print(latex)\n",
    "display(model_appdx_df)\n",
    "print(geirhos_mean_order)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Only plotting the core experiments\n",
    "\n",
    "Here we plot the means obtained from filtering experiments, but still plot the reference values we'd be getting by using all experiments. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df(\n",
    "    core_mean_df,\n",
    "    human_final_df,\n",
    "    0,\n",
    "    0.5,\n",
    "    \"mean_core\",\n",
    "    show_original=False,\n",
    "    show_models=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Without hiding the variance\n",
    "\n",
    "But what this aggregation hides is that these means themselves were obtained by averaging over values with a very wide spread, so the true variance in the data looks more like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df(\n",
    "    exp_mean_df, human_mean_df, -0.1, 0.6, \"all\", show_original=False, show_models=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Root Cause\n",
    "\n",
    "By plotting the distributions of EC values obtained by the different experiments, we see where the variance is coming from. We do this for one exemplary model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_df = standard_df[(standard_df[\"model\"] == \"clip\")].reset_index()\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "sns.kdeplot(\n",
    "    data=model_df, x=\"model-human-ec\", hue=\"experiment\", fill=False, legend=True, ax=ax\n",
    ")\n",
    "sns.despine()\n",
    "ax.set_xlabel(\"Error Consistency to Humans [kappa]\")\n",
    "ax.set_xlim(-0.2, 0.8)\n",
    "plt.tight_layout()\n",
    "plt.savefig(pjoin(\"figures\", f\"experiment_distributions.pdf\"), bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now make the same point that we make in the Alignment of Alignment paper: If you want to aggregate over different distributions that are not centered at the same location, you cannot simply take their mean. It would probably be better to z-transform values first.\n",
    "\n",
    "We can also take a look at the distribution for humans, which is noticeably tighter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "sns.kdeplot(\n",
    "    data=human_mean_df,\n",
    "    x=\"human-human-ec\",\n",
    "    hue=\"experiment\",\n",
    "    fill=False,\n",
    "    legend=True,\n",
    "    ax=ax,\n",
    ")\n",
    "sns.despine()\n",
    "ax.set_xlabel(\"Error Consistency to Humans [kappa]\")\n",
    "ax.set_xlim(-0.2, 0.8)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\n",
    "    pjoin(\"figures\", f\"experiment_human_distributions.pdf\"), bbox_inches=\"tight\"\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Checking the distributions\n",
    "\n",
    "Earlier, we had a bug in the plotting code which made it look like the Geirhos value differs from the mean value of the bootstraps. It's fixed now, but in the process I looked at the shape of the distributions of EC values, finding that they are more or less normally distributed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mean_df, standard_df\n",
    "\n",
    "from scipy.stats import norm\n",
    "\n",
    "ncols = 4\n",
    "nrows = int(np.ceil(len(mean_df[\"model\"].unique()) / ncols))\n",
    "fig, axes = plt.subplots(\n",
    "    nrows=nrows,\n",
    "    ncols=ncols,\n",
    "    figsize=(ncols * 3.5, nrows * 3.5),\n",
    "    sharex=True,\n",
    "    sharey=True,\n",
    ")\n",
    "flatax = axes.flatten()\n",
    "for ax, (model, model_df) in zip(flatax, mean_df.groupby(\"model\", observed=True)):\n",
    "\n",
    "    # plot the distribution of original data\n",
    "    sns.kdeplot(data=model_df, x=\"model-human-ec\", color=\"blue\", fill=True, ax=ax)\n",
    "\n",
    "    # plot the best fitting gaussian\n",
    "    mu, sigma = norm.fit(model_df[\"model-human-ec\"].values)\n",
    "    gaussian = np.random.normal(loc=mu, scale=sigma, size=1000)\n",
    "    sns.kdeplot(data=gaussian, color=\"red\", ax=ax)\n",
    "    ax.set_title(model)\n",
    "    ax.axvline(\n",
    "        x=model_df[model_df[\"bootstrap_id\"] == 0][\"model-human-ec\"].values[0],\n",
    "    )\n",
    "    sns.despine()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting CIs for all individual experiments\n",
    "\n",
    "I think it might be cool to be able to say \"look, every individual experiment has CIs like this, only by averaging all of them do you get something with small CIs\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_mean_df = standard_df.groupby(\n",
    "    [\"bootstrap_id\", \"experiment\", \"model\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "\n",
    "for exp, exp_df in exp_mean_df.groupby(\"experiment\", observed=True):\n",
    "\n",
    "    # order models by mean of the bootstraps\n",
    "    order = (\n",
    "        exp_df.groupby(\"model\", observed=True)[\"model-human-ec\"]\n",
    "        .mean()\n",
    "        .sort_values(ascending=False)\n",
    "        .index\n",
    "    )\n",
    "\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
    "    plt.grid(axis=\"y\")\n",
    "\n",
    "    # plot CIs\n",
    "    sns.pointplot(\n",
    "        data=exp_df,\n",
    "        palette=\"mako\",\n",
    "        estimator=np.mean,\n",
    "        errorbar=(\"pi\", 95),\n",
    "        x=\"model\",\n",
    "        y=\"model-human-ec\",\n",
    "        hue=\"model\",\n",
    "        legend=False,\n",
    "        linestyle=\"none\",\n",
    "        capsize=0.4,\n",
    "        order=order,\n",
    "        hue_order=order,\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    # ax.set_ylim(0, 0.5)\n",
    "    ax.set_xlabel(\"Models\")\n",
    "    n_conditions = len(\n",
    "        standard_df[standard_df[\"experiment\"] == exp][\"condition\"].unique()\n",
    "    )\n",
    "    ax.set_title(f\"{exp} ({n_conditions} condition{'s' if n_conditions > 1 else ''})\")\n",
    "    ax.set_xticklabels([])\n",
    "\n",
    "    ax.set_ylabel(\"Error Consistency to Humans [Kappa]\")\n",
    "    sns.despine()\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(pjoin(\"figures\", f\"model_comparison_{exp}.pdf\"), bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
