{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nn_core.common import PROJECT_ROOT\n",
    "\n",
    "results_dir = PROJECT_ROOT / \"results_paper\" / \"attention_opt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "end2end_results = pd.read_csv(results_dir / \"end2end_results.tsv\", sep=\"\\t\")\n",
    "stitching_results = pd.read_csv(results_dir / \"stitching_results.tsv\", sep=\"\\t\")\n",
    "opt_results = pd.read_csv(results_dir / \"opt_results.tsv\", sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "end2end_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stitching_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_name = opt_results.iloc[0][\"encoder_name\"]\n",
    "encoder_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_name = opt_results.iloc[0][\"decoder_name\"]\n",
    "decoder_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "w_orig = torch.load(results_dir / \"w_orig.pt\")\n",
    "w_qkv_opt = torch.load(results_dir / \"w_qkv_opt.pt\")\n",
    "w_classifier_opt = torch.load(results_dir / \"w_classifier_opt.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with (results_dir / \"projection_names_order.json\").open(\"r\") as f:\n",
    "    projection_names_order = json.load(f)\n",
    "projection_names_order"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cool vizs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "\n",
    "from sklearn.metrics import ConfusionMatrixDisplay\n",
    "import seaborn as sns\n",
    "\n",
    "# from tueplots.figsizes import _GOLDEN_RATIO\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 3\n",
    "RATIO = 1  # _GOLDEN_RATIO\n",
    "\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "plt.rcParams.update(bundles.iclr2024())\n",
    "plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "plt.rcParams.update(axes.lines())\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, sharex=True, sharey=False)\n",
    "\n",
    "for ax, w, title in zip(\n",
    "    ax,\n",
    "    [w_orig, w_qkv_opt, w_classifier_opt],\n",
    "    [\"Orig\", \"QKV\", \"Classifier\"],\n",
    "):\n",
    "    ConfusionMatrixDisplay(w.numpy(), display_labels=projection_names_order,).plot(\n",
    "        values_format=\".2f\",\n",
    "        ax=ax,\n",
    "        cmap=plt.cm.Reds,\n",
    "        colorbar=False,\n",
    "        xticks_rotation=30,\n",
    "    )\n",
    "    # ax.set_title(title)\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "\n",
    "\n",
    "filename = PROJECT_ROOT / \"results_paper\" / \"attention_opt\" / \"cm_plot\"\n",
    "filename.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "fig.savefig(f\"{filename}.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o {filename}.pdf {filename}.svg\n",
    "!rm {filename}.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stitching_results.drop(columns=[\"seed\"])[\n",
    "    (stitching_results.encoder_name == encoder_name) & (stitching_results.decoder_name == decoder_name)\n",
    "].groupby([\"encoder_name\", \"decoder_name\", \"runname\"]).agg(\"mean\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
