{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d465f696",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bdfca4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "import matplotlib.pyplot as plt\n",
    "from experiments.summarize import main as summarize_main\n",
    "from pathlib import Path\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "451eb471",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTS_DIR = Path(\"results/iclr\")\n",
    "DATA = {}\n",
    "KEYS = None\n",
    "for method_dir in RESULTS_DIR.iterdir():\n",
    "    method_name = str(method_dir).split(\"/\")[-1]\n",
    "    print(method_name)\n",
    "    n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n",
    "    for n_edit_folder in n_edit_folders:\n",
    "        n_edits = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0]\n",
    "        try:\n",
    "            res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n",
    "\n",
    "            DATA[method_name] = DATA.get(method_name, {})\n",
    "            DATA[method_name][n_edits] = res\n",
    "            if KEYS is None:\n",
    "                KEYS = list(res.keys())\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "print({k: list(v.keys()) for k, v in DATA.items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b9f0860",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 15\n",
    "BIGGER_SIZE = 16\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8b41acc",
   "metadata": {},
   "outputs": [],
   "source": [
    "TITLES = {\n",
    "    \"post_score\": \"Score (S)\",\n",
    "    \"post_rewrite_success\": \"Efficacy Succ. (ES)\",\n",
    "    \"post_paraphrase_success\": \"Generalization Succ. (PS)\",\n",
    "    \"post_neighborhood_success\": \"Specificity Succ. (NS)\",\n",
    "    \"post_rewrite_acc\": \"Efficacy Acc (EA)\",\n",
    "    \"post_paraphrase_acc\": \"Generalization Acc. (PA)\",\n",
    "    \"post_neighborhood_acc\": \"Specificity Acc. (NA)\",\n",
    "    \"post_reference_score\": \"Consistency (RS)\",\n",
    "}\n",
    "\n",
    "SHOW_KEYS = list(TITLES.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1d443f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "SHOW_KEYS = KEYS\n",
    "SHOW_KEYS.pop(SHOW_KEYS.index(\"run_dir\"))\n",
    "TITLES = {k: k for k in SHOW_KEYS}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49efeea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 4\n",
    "h = math.ceil(len(KEYS) / w)\n",
    "plt.figure(figsize=(w * 3.5, h * 2.5))\n",
    "\n",
    "assert all(k in KEYS for k in SHOW_KEYS)\n",
    "for i, key in enumerate(SHOW_KEYS):\n",
    "    plt.subplot(h, w, i + 1)\n",
    "    for method, results in sorted([(k, v) for k, v in DATA.items() if \"_fix\" not in k]):\n",
    "        try:\n",
    "            n_edits = list(map(int, results.keys()))\n",
    "            values = [\n",
    "                f[0] if (type(f := results[str(n)][key]) is tuple) else f\n",
    "                for n in n_edits\n",
    "            ]\n",
    "            plt.plot(n_edits, values, marker=\"o\", markersize=4, label=method)\n",
    "            plt.xlabel(\"# Edits\")\n",
    "            # plt.ylabel(\"metric value\")\n",
    "            plt.title(TITLES[key])\n",
    "            plt.legend()\n",
    "        except:\n",
    "            pass\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"tmp.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae8e7ea4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
