{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "de369d3c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8af77d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "from experiments.summarize import main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "29d68944",
   "metadata": {},
   "outputs": [],
   "source": [
    "COLS = [\n",
    "    \"post_score\",\n",
    "    \"post_rewrite_acc\",\n",
    "    \"post_paraphrase_acc\",\n",
    "    \"post_neighborhood_acc\",\n",
    "]\n",
    "OPTIM = [1, 1, 1, 1]\n",
    "LIM = [20, 50, 50, 20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fba58b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASELINE_NAMES = [\"GPT-2 XL\", \"GPT-2 M\", \"GPT-2 L\", \"GPT-J\"]\n",
    "\n",
    "\n",
    "def execute(RUN_DIR, RUN_DATA, FIRST_N):\n",
    "    data = {}\n",
    "    for k, (d, alt_para) in RUN_DATA.items():\n",
    "        cur = main(\n",
    "            dir_name=RUN_DIR / d, runs=[\"run_000\"], first_n_cases=FIRST_N, abs_path=True\n",
    "        )\n",
    "        assert len(cur) == 1\n",
    "        data[k] = cur[0]\n",
    "\n",
    "    m = []\n",
    "    for k, v in data.items():\n",
    "        m.append(\n",
    "            [k]\n",
    "            + [\n",
    "                v[\n",
    "                    z\n",
    "                    if all(k != z for z in BASELINE_NAMES) or z == \"time\"\n",
    "                    else \"pre_\" + z[len(\"post_\") :]\n",
    "                ]\n",
    "                for z in COLS\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    m_np = np.array([[col[0] for col in row[1:]] for row in m[1:]])\n",
    "    m_amax = np.argmax(m_np, axis=0)\n",
    "    m_amin = np.argmin(m_np, axis=0)\n",
    "\n",
    "    res = []\n",
    "\n",
    "    for i, row in enumerate(m):\n",
    "        lstr = [row[0]]\n",
    "        for j, el in enumerate(row[1:]):\n",
    "            mean, std = np.round(el[0], 1), el[1]\n",
    "            interval = 1.96 * std / np.sqrt(FIRST_N)\n",
    "\n",
    "            mean, interval = str(mean), f\"$\\pm${np.round(interval, 1)}\"\n",
    "            bmark = m_amax if OPTIM[j] == 1 else m_amin\n",
    "            res_str = f\"{mean} ({interval})\" if not np.isnan(std) else f\"{mean}\"\n",
    "            if bmark[j] + 1 == i:\n",
    "                lstr.append(\"\\\\goodmetric{\" + res_str + \"}\")\n",
    "            elif not any(lstr[0] in z for z in BASELINE_NAMES) and (\n",
    "                (OPTIM[j] == 1 and float(mean) < LIM[j])\n",
    "                or (OPTIM[j] == 0 and float(mean) > LIM[j])\n",
    "            ):\n",
    "                lstr.append(\"\\\\badmetric{\" + res_str + \"}\")\n",
    "            else:\n",
    "                lstr.append(res_str)\n",
    "\n",
    "        res.append(\n",
    "            \" & \".join(lstr)\n",
    "            + \"\\\\\\\\\"\n",
    "            + (\"\\\\midrule\" if any(lstr[0] == z for z in BASELINE_NAMES) else \"\")\n",
    "        )\n",
    "\n",
    "    return \"\\n\".join(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "010a675f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'num_cases': 10000,\n",
      " 'post_neighborhood_acc': (24.12, 24.73),\n",
      " 'post_paraphrase_acc': (64.75, 32.87),\n",
      " 'post_rewrite_acc': (69.58, 30.15),\n",
      " 'post_score': (42.1, nan),\n",
      " 'run_dir': '/share/projects/rewriting-knowledge/OFFICIAL_DATA_MROME/zsre/gpt-j/FT/run_000',\n",
      " 'time': (1922.1313865184784, 0.0)}\n",
      "{'num_cases': 10000,\n",
      " 'post_neighborhood_acc': (22.39, 24.07),\n",
      " 'post_paraphrase_acc': (18.56, 25.62),\n",
      " 'post_rewrite_acc': (19.37, 26.16),\n",
      " 'post_score': (19.98, nan),\n",
      " 'run_dir': '/share/projects/rewriting-knowledge/OFFICIAL_DATA_MROME/zsre/gpt-j/MEND/run_000',\n",
      " 'time': (202.65541529655457, 0.0)}\n",
      "{'num_cases': 10000,\n",
      " 'post_neighborhood_acc': (0.93, 4.25),\n",
      " 'post_paraphrase_acc': (19.56, 33.18),\n",
      " 'post_rewrite_acc': (21.04, 34.09),\n",
      " 'post_score': (2.55, nan),\n",
      " 'run_dir': '/share/projects/rewriting-knowledge/OFFICIAL_DATA_MROME/zsre/gpt-j/ROME/run_000',\n",
      " 'time': (46397.495433568954, 0.0)}\n",
      "FT-W & 42.1 & 69.6 ($\\pm$0.6) & 64.8 ($\\pm$0.6) & 24.1 ($\\pm$0.5)\\\\\n",
      "MEND & \\goodmetric{20.0} & \\badmetric{19.4 ($\\pm$0.5)} & \\badmetric{18.6 ($\\pm$0.5)} & \\goodmetric{22.4 ($\\pm$0.5)}\\\\\n",
      "ROME & \\badmetric{2.6} & \\goodmetric{21.0 ($\\pm$0.7)} & \\goodmetric{19.6 ($\\pm$0.7)} & \\badmetric{0.9 ($\\pm$0.1)}\\\\\n"
     ]
    }
   ],
   "source": [
    "gap = \"\\n\\\\midrule\\\\midrule\\n\"\n",
    "\n",
    "dir2j = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA_MROME/zsre/gpt-j\")\n",
    "data2j = {\n",
    "    #     \"GPT-J\": (\"ROME\", False),\n",
    "    \"FT-W\": (\"FT\", False),\n",
    "    \"MEND\": (\"MEND\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2j = 10000\n",
    "\n",
    "print(execute(dir2j, data2j, first2j))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f25701f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gap = \"\\n\\\\midrule\\\\midrule\\n\"\n",
    "\n",
    "dir2medium = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-medium\")\n",
    "data2medium = {\n",
    "    \"GPT-2 M\": (\"ROME\", False),\n",
    "    \"FT+L\": (\"FT_L\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2medium = 10000\n",
    "\n",
    "dir2l = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-large\")\n",
    "data2l = {\n",
    "    \"GPT-2 L\": (\"ROME\", False),\n",
    "    \"FT+L\": (\"FT_L\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2l = 10000\n",
    "\n",
    "dir2xl = Path(\"/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-xl\")\n",
    "data2xl = {\n",
    "    \"GPT-2 XL\": (\"FT\", True),\n",
    "    \"FT\": (\"FT\", True),\n",
    "    \"FT+L\": (\"FT_L\", True),\n",
    "    \"KE\": (\"KE\", False),\n",
    "    \"KE-zsRE\": (\"KE_zsRE\", False),\n",
    "    \"MEND\": (\"MEND\", False),\n",
    "    \"MEND-CF\": (\"MEND_zsRE\", False),\n",
    "    \"ROME\": (\"ROME\", False),\n",
    "}\n",
    "first2xl = 10000\n",
    "\n",
    "print(\n",
    "    execute(dir2medium, data2medium, first2medium)\n",
    "    + gap\n",
    "    + execute(dir2l, data2l, first2l)\n",
    "    + gap\n",
    "    + execute(dir2xl, data2xl, first2xl)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdef5b45",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
