{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing the Brain-Score data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# making sure that updates to imported files are immediately available without restarting the kernel\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sys.path.append(os.path.abspath(\"..\"))\n",
    "from utils import (\n",
    "    simulate_trials_from_copy_model,\n",
    "    fast_cohen,\n",
    "    calc_accuracy_bounds_from_kappa,\n",
    "    filter_df,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df = pd.read_csv(\"data/benchmark_scores.csv\")\n",
    "display(all_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# focusing on the error consistency columns and dropping all rows where we don't have values\n",
    "mean_ec_col = \"Geirhos2021-error_consistency\"\n",
    "relevant_cols = [\"model_name\"] + [\n",
    "    col for col in all_df.columns if \"error_consistency\" in col\n",
    "]\n",
    "df = all_df[relevant_cols].copy()\n",
    "\n",
    "# ceiling scores for each metric, taken from Brain-Score tests\n",
    "ceilings = {\n",
    "    \"colour\": 0.41543,\n",
    "    \"contrast\": 0.43703,\n",
    "    \"cueconflict\": 0.33105,\n",
    "    \"edge\": 0.31844,\n",
    "    \"eidolonI\": 0.38634,\n",
    "    \"eidolonII\": 0.45402,\n",
    "    \"eidolonIII\": 0.45953,\n",
    "    \"falsecolour\": 0.44405,\n",
    "    \"highpass\": 0.44014,\n",
    "    \"lowpass\": 0.46888,\n",
    "    \"phasescrambling\": 0.44667,\n",
    "    \"powerequalisation\": 0.51063,\n",
    "    \"rotation\": 0.43851,\n",
    "    \"silhouette\": 0.47571,\n",
    "    \"sketch\": 0.36962,\n",
    "    \"stylized\": 0.50058,\n",
    "    \"uniformnoise\": 0.43406,\n",
    "}\n",
    "\n",
    "# filter and clean the df\n",
    "for col in relevant_cols:\n",
    "\n",
    "    if col == \"model_name\":\n",
    "        continue\n",
    "\n",
    "    # coerce to float and drop rows with NaNs anywhere\n",
    "    df[col] = pd.to_numeric(df[col], errors=\"coerce\")\n",
    "    df = df[~df[col].isna()]\n",
    "\n",
    "    if col in [mean_ec_col]:\n",
    "        continue\n",
    "\n",
    "    # we're now dealing with one of the constituent EC metrics, and remove the ceiling\n",
    "    df[col] = (\n",
    "        df[col] * ceilings[col.split(\"Geirhos2021\")[1].split(\"-error_consistency\")[0]]\n",
    "    )\n",
    "\n",
    "# checking if the Geirhoss2021-error_consistency column is the average over the other ones\n",
    "ec_cols = [col for col in relevant_cols if \"error_consistency\" in col]\n",
    "ec_cols.remove(mean_ec_col)\n",
    "\n",
    "# making the names consistent with Geirhos et al\n",
    "col_name_map = {\n",
    "    ec_col: ec_col.split(\"eirhos2021\")[1].split(\"-error_consistency\")[0]\n",
    "    for ec_col in ec_cols\n",
    "}\n",
    "df.rename(columns=col_name_map, inplace=True)\n",
    "\n",
    "col_name_map2 = {\n",
    "    \"cueconflict\": \"cue-conflict\",\n",
    "    \"falsecolour\": \"false-colour\",\n",
    "    \"phasescrambling\": \"phase-scrambling\",\n",
    "    \"powerequalisation\": \"power-equalisation\",\n",
    "    \"uniformnoise\": \"uniform-noise\",\n",
    "    \"lowpass\": \"low-pass\",\n",
    "    \"highpass\": \"high-pass\",\n",
    "}\n",
    "df.rename(columns=col_name_map2, inplace=True)\n",
    "\n",
    "ec_cols = df.columns.copy().tolist()\n",
    "ec_cols.remove(\"model_name\")\n",
    "ec_cols.remove(mean_ec_col)\n",
    "\n",
    "df[\"raw_ec\"] = df[ec_cols].mean(axis=1)\n",
    "df[\"delta\"] = (df[mean_ec_col] * 0.42899 - df[\"raw_ec\"]).abs()\n",
    "\n",
    "print(\n",
    "    \"Mean Delta between real and expected final value:\", df[\"delta\"].mean()\n",
    ")  # 0.0017, good enough\n",
    "print(\"There are\", len(df), \"models.\")\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's start by plotting the raw ECs, without ceiling shenanigans\n",
    "order = df.groupby(\"model_name\")[\"raw_ec\"].mean().sort_values(ascending=False).index\n",
    "fig, ax = plt.subplots(1, 1, figsize=(25, 5))\n",
    "plt.grid(axis=\"y\")\n",
    "sns.pointplot(\n",
    "    data=df,\n",
    "    y=\"raw_ec\",\n",
    "    x=\"model_name\",\n",
    "    hue=\"model_name\",\n",
    "    legend=False,\n",
    "    linestyle=\"none\",\n",
    "    order=order,\n",
    "    palette=\"mako\",\n",
    "    hue_order=order,\n",
    "    ax=ax,\n",
    ")\n",
    "ax.set_xticks([], [])\n",
    "ax.set_xlabel(\"Model\")\n",
    "ax.set_ylabel(\"Error Consistency to Humans [kappa]\")\n",
    "ax.set_ylim(0, 0.8)\n",
    "ax.tick_params(axis=\"x\", labelrotation=90)\n",
    "sns.despine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_accuracy(acc, kappa, method=\"lower\", eps=0.001):\n",
    "\n",
    "    def clip(val):\n",
    "        return min(1 - eps, max(0 + eps, val))\n",
    "\n",
    "    # get bounds\n",
    "    lower, upper = calc_accuracy_bounds_from_kappa(acc, kappa)\n",
    "\n",
    "    assert 0 <= lower <= upper, \"Bounds were not sensible!\"\n",
    "\n",
    "    if method == \"lower\":\n",
    "        return clip(lower + eps)\n",
    "    elif method == \"upper\":\n",
    "        return clip(upper - eps)\n",
    "    elif method == \"middle\":\n",
    "        return lower + (upper - lower) / 2\n",
    "    else:\n",
    "        raise RuntimeError(f\"Method not known! {method}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Next, we load the Geirhos data and for every model, generate some possible trial data\n",
    "geirhos_df = pd.read_parquet(\n",
    "    \"../geirhos_analysis/data/geirhos_raw_data.parquet\", engine=\"pyarrow\"\n",
    ")\n",
    "\n",
    "# keep only humans\n",
    "geirhos_df = geirhos_df[geirhos_df[\"subj\"].str.contains(\"subject-\")]\n",
    "\n",
    "# keep only relevant conditions\n",
    "geirhos_df = filter_df(geirhos_df)\n",
    "\n",
    "# bootstrap consistency values, this takes about half an hour\n",
    "method = \"lower\"\n",
    "n_bootstraps = 1000\n",
    "res_dfs = []\n",
    "excluded_models = []\n",
    "for exp, exp_df in geirhos_df.groupby(\"experiment\", observed=True):\n",
    "    for con, con_df in exp_df.groupby(\"condition\", observed=True):\n",
    "\n",
    "        n_trials = len(con_df[con_df[\"subj\"] == \"subject-01\"])\n",
    "        n_subjects = len(con_df[\"subj\"].unique())\n",
    "        avg_human_acc = con_df[\"correct\"].mean()  # taking the mean over all humans\n",
    "\n",
    "        # loop over all Brain-Score models\n",
    "        for idx, row in df.iterrows():\n",
    "            name = row[\"model_name\"]\n",
    "            gt_error_consistency = row[exp]\n",
    "\n",
    "            # This only happens for vonenet on some conditions\n",
    "            if gt_error_consistency < 0:\n",
    "                print(f\"Excluding model {name} because {exp} {con} had negative EC.\")\n",
    "                excluded_models.append(name)\n",
    "                continue\n",
    "\n",
    "            model_acc = estimate_accuracy(avg_human_acc, gt_error_consistency, method)\n",
    "\n",
    "            model_human_ecs = [gt_error_consistency]\n",
    "            for i in range(1, n_bootstraps):\n",
    "                ecs = [\n",
    "                    fast_cohen(\n",
    "                        *simulate_trials_from_copy_model(\n",
    "                            gt_error_consistency,\n",
    "                            avg_human_acc,\n",
    "                            model_acc,\n",
    "                            n_trials,\n",
    "                        )\n",
    "                    )\n",
    "                    for j in range(n_subjects)\n",
    "                ]\n",
    "\n",
    "                model_human_ecs.append(np.mean(ecs))\n",
    "\n",
    "            # make a df of bootstrapped model-human error consistencies\n",
    "            res_dfs.append(\n",
    "                pd.DataFrame(\n",
    "                    {\n",
    "                        \"model\": row[\"model_name\"],\n",
    "                        \"bootstrap_id\": np.arange(n_bootstraps),\n",
    "                        \"experiment\": exp,\n",
    "                        \"condition\": con,\n",
    "                        \"model-human-ec\": model_human_ecs,\n",
    "                    }\n",
    "                )\n",
    "            )\n",
    "\n",
    "# join all dfs\n",
    "res_df = pd.concat(res_dfs)\n",
    "display(res_df)\n",
    "res_df.to_parquet(\n",
    "    f\"data/brainscore_bootstrapped_ecs_{n_bootstraps}_{method}.parquet\",\n",
    "    engine=\"pyarrow\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
