{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de369d3c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8af77d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "from experiments.summarize import main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29d68944",
   "metadata": {},
   "outputs": [],
   "source": [
    "COLS = [\n",
    "    \"post_score\",\n",
    "    \"post_rewrite_acc\",\n",
    "    \"post_paraphrase_acc\",\n",
    "    \"post_neighborhood_acc\",\n",
    "]\n",
    "OPTIM = [1, 1, 1, 1]\n",
    "LIM = [20, 50, 50, 20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fba58b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASELINE_NAMES = [\"GPT-2 XL\", \"GPT-2 M\", \"GPT-2 L\", \"GPT-J\"]\n",
    "\n",
    "\n",
    "def execute(RUN_DIR, RUN_DATA, FIRST_N):\n",
    "    data = {}\n",
    "    for k, (d, alt_para) in RUN_DATA.items():\n",
    "        cur = main(\n",
    "            dir_name=RUN_DIR / d, runs=[\"run_000\"], first_n_cases=FIRST_N, abs_path=True\n",
    "        )\n",
    "        assert len(cur) == 1\n",
    "        data[k] = cur[0]\n",
    "\n",
    "    m = []\n",
    "    for k, v in data.items():\n",
    "        m.append(\n",
    "            [k]\n",
    "            + [\n",
    "                v[\n",
    "                    z\n",
    "                    if all(k != z for z in BASELINE_NAMES) or z == \"time\"\n",
    "                    else \"pre_\" + z[len(\"post_\") :]\n",
    "                ]\n",
    "                for z in COLS\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    m_np = np.array([[col[0] for col in row[1:]] for row in m[1:]])\n",
    "    m_amax = np.argmax(m_np, axis=0)\n",
    "    m_amin = np.argmin(m_np, axis=0)\n",
    "\n",
    "    res = []\n",
    "\n",
    "    for i, row in enumerate(m):\n",
    "        lstr = [row[0]]\n",
    "        for j, el in enumerate(row[1:]):\n",
    "            mean, std = np.round(el[0], 1), el[1]\n",
    "            interval = 1.96 * std / np.sqrt(FIRST_N)\n",
    "\n",
    "            mean, interval = str(mean), f\"$\\pm${np.round(interval, 1)}\"\n",
    "            bmark = m_amax if OPTIM[j] == 1 else m_amin\n",
    "            res_str = f\"{mean} ({interval})\" if not np.isnan(std) else f\"{mean}\"\n",
    "            if bmark[j] + 1 == i:\n",
    "                lstr.append(\"\\\\goodmetric{\" + res_str + \"}\")\n",
    "            elif not any(lstr[0] in z for z in BASELINE_NAMES) and (\n",
    "                (OPTIM[j] == 1 and float(mean) < LIM[j])\n",
    "                or (OPTIM[j] == 0 and float(mean) > LIM[j])\n",
    "            ):\n",
    "                lstr.append(\"\\\\badmetric{\" + res_str + \"}\")\n",
    "            else:\n",
    "                lstr.append(res_str)\n",
    "\n",
    "        res.append(\n",
    "            \" & \".join(lstr)\n",
    "            + \"\\\\\\\\\"\n",
    "            + (\"\\\\midrule\" if any(lstr[0] == z for z in BASELINE_NAMES) else \"\")\n",
    "        )\n",
    "\n",
    "    return \"\\n\".join(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f25701f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gap = \"\\n\\\\midrule\\\\midrule\\n\"\n",
    "\n",
    "dir2medium = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-medium\")\n",
    "data2medium = {\n",
    "    \"GPT-2 M\": (\"ROME\", False),\n",
    "    \"FT+L\": (\"FT_L\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2medium = 10000\n",
    "\n",
    "dir2l = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-large\")\n",
    "data2l = {\n",
    "    \"GPT-2 L\": (\"ROME\", False),\n",
    "    \"FT+L\": (\"FT_L\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2l = 10000\n",
    "\n",
    "dir2xl = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-xl\")\n",
    "data2xl = {\n",
    "    \"GPT-2 XL\": (\"FT\", True),\n",
    "    \"FT\": (\"FT\", True),\n",
    "    \"FT+L\": (\"FT_L\", True),\n",
    "    \"KE\": (\"KE\", False),\n",
    "    \"KE-zsRE\": (\"KE_zsRE\", False),\n",
    "    \"MEND\": (\"MEND\", False),\n",
    "    \"MEND-CF\": (\"MEND_zsRE\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2xl = 10000\n",
    "\n",
    "print(\n",
    "    execute(dir2medium, data2medium, first2medium)\n",
    "    + gap\n",
    "    + execute(dir2l, data2l, first2l)\n",
    "    + gap\n",
    "    + execute(dir2xl, data2xl, first2xl)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdef5b45",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
