{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Rankings implied by different experiments\n",
    "\n",
    "Geirhos et al just take an average, but it would be funny to show that different experiments imply different rankings of models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ranking_df(fname, colname):\n",
    "    # loading the df containing only the EC for the actual empirical data, without bootstrapping\n",
    "    standard_df = pd.read_parquet(pjoin(\"data\", fname), engine=\"pyarrow\")\n",
    "\n",
    "    # take the average within each of the 17 experiments (over conditions)\n",
    "    experiment_df = standard_df.groupby(\n",
    "        [\"experiment\", \"model\", \"bootstrap_id\"], observed=True, as_index=False\n",
    "    ).mean(numeric_only=True)\n",
    "    experiment_df.drop(columns=[\"bootstrap_id\"], inplace=True)\n",
    "\n",
    "    # display(experiment_df)\n",
    "\n",
    "    # what I need next is a df like the following:\n",
    "    # model  exp_1  exp_2  exp_3 ...\n",
    "    #     A    0.4    0.2    0.2\n",
    "    #     B    0.1    0.3    0.2\n",
    "    # then, I can calculate the Spearman correlation between numeric columns\n",
    "\n",
    "    pivoted = experiment_df.pivot(\n",
    "        columns=[\"experiment\"], index=\"model\", values=\"model-human-ec\"\n",
    "    )\n",
    "\n",
    "    # display(pivoted)\n",
    "\n",
    "    return pivoted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate Spearman rank correlation matrix between the different metrics\n",
    "rankings_df = get_ranking_df(\n",
    "    f\"model_wise_bootstrapped_ecs_standard_1.parquet\", \"model-human-ec\"\n",
    ")\n",
    "corr_matrix = rankings_df.corr(method=\"kendall\", numeric_only=True)\n",
    "\n",
    "# Plot heatmap\n",
    "fig, ax = plt.subplots(figsize=(12, 10))\n",
    "sns.heatmap(corr_matrix, annot=True, cmap=\"coolwarm\", vmin=-1, vmax=1, fmt=\".2f\", ax=ax)\n",
    "ax.set_xlabel(\"Experiment\")\n",
    "ax.set_ylabel(\"Experiment\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(pjoin(\"figures\", f\"model_rankings.pdf\"), bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Rankings implied by different bootstraps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loading the bootstrap-results\n",
    "standard_df = pd.read_parquet(\n",
    "    pjoin(\"data\", f\"model_wise_bootstrapped_ecs_standard_10000.parquet\"),\n",
    "    engine=\"pyarrow\",\n",
    ")\n",
    "\n",
    "# They take the average EC by first averaging within each experiment, then averaging across them.\n",
    "# (how you average within each experiment doesn't matter, because first conditions then humans = first humans then conditions = all at once)\n",
    "\n",
    "# take the average within each of the 12 experiments\n",
    "exp_mean_df = standard_df.groupby(\n",
    "    [\"bootstrap_id\", \"experiment\", \"model\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)\n",
    "\n",
    "# take the average across the experiments\n",
    "mean_df = exp_mean_df.groupby(\n",
    "    [\"bootstrap_id\", \"model\"], observed=True, as_index=False\n",
    ").mean(numeric_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr, kendalltau\n",
    "\n",
    "# get canonical ordering of models\n",
    "models = mean_df[mean_df[\"bootstrap_id\"] == 0][\"model\"].unique().tolist()\n",
    "\n",
    "\n",
    "# given a df, get the model-human-ecs for the models\n",
    "def get_sorted_values_from_df(df):\n",
    "    return df.sort_values(by=[\"model\"])[\"model-human-ec\"].values\n",
    "\n",
    "\n",
    "ranking_0 = get_sorted_values_from_df(mean_df[mean_df[\"bootstrap_id\"] == 0])\n",
    "\n",
    "spearman_rs = []\n",
    "spearman_ps = []\n",
    "kendall_ts = []\n",
    "kendall_ps = []\n",
    "ids = []\n",
    "for id, rdf in mean_df[mean_df[\"bootstrap_id\"] > 0].groupby(\"bootstrap_id\"):\n",
    "    ranking = get_sorted_values_from_df(rdf)\n",
    "    r = spearmanr(ranking_0, ranking)\n",
    "    tau = kendalltau(ranking_0, ranking, nan_policy=\"raise\")\n",
    "    ids.append(id)\n",
    "    spearman_rs.append(r.statistic)\n",
    "    spearman_ps.append(r.pvalue)\n",
    "    kendall_ts.append(tau.statistic)\n",
    "    kendall_ps.append(tau.pvalue)\n",
    "\n",
    "res_df = pd.DataFrame(\n",
    "    {\n",
    "        \"id\": ids,\n",
    "        \"r\": spearman_rs,\n",
    "        \"r_p\": spearman_ps,\n",
    "        \"t\": kendall_ts,\n",
    "        \"t_p\": kendall_ps,\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_df(df, method):\n",
    "    if method == \"spearman\":\n",
    "        stat = \"r\"\n",
    "        pval = \"r_p\"\n",
    "    elif method == \"kendall\":\n",
    "        stat = \"t\"\n",
    "        pval = \"t_p\"\n",
    "    else:\n",
    "        raise RuntimeError(\"Method unknown\")\n",
    "\n",
    "    df[\"Significance\"] = df[pval] < 0.05\n",
    "    label_map = {False: \"p > 0.05\", True: \"p < 0.05 (significant)\"}\n",
    "    df[\"Significance\"] = df[\"Significance\"].map(lambda x: label_map[x])\n",
    "    method_str = \"Kendall's Tau\" if method == \"kendall\" else \"Spearman's r\"\n",
    "    print(\n",
    "        f\"{len(df[df[pval] < 0.05]) / len(df) * 100:.2f}% of {method_str} are significant at alpha = 0.05.\"\n",
    "    )\n",
    "    print(f\"The average {method_str} is {df[stat].mean()}\")\n",
    "\n",
    "    return df, stat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"kendall\"\n",
    "plot_df, xval = prepare_df(res_df, method)\n",
    "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
    "sns.kdeplot(\n",
    "    data=plot_df,\n",
    "    x=xval,\n",
    "    hue=\"Significance\",\n",
    "    legend=True,\n",
    "    palette=[\"maroon\", \"blue\"],  #'crest',\n",
    "    fill=True,\n",
    "    ax=ax,\n",
    ")\n",
    "sns.despine()\n",
    "ax.set_xlabel(\"Kendall's tau\" if method == \"kendall\" else \"Spearman's rho\")\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(pjoin(\"figures\", f\"bootstrap_rankings_{method}.pdf\"), bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
