{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de369d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8af77d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from experiments.summarize import main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6897bb2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR_NAME = \"/share/projects/rewriting-knowledge/OFFICIAL_DATA/sweeps/ROME\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28a532fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# SWEEP_DATA = {\n",
    "#     \"FT_layers_sweep_2\": \"FT on GPT-2 XL, $\\epsilon=5e-4$\",\n",
    "#     \"FT_layers_sweep_1\": \"FT on GPT-2 XL, $\\epsilon=1e-3$\",\n",
    "#     \"FT_layers_sweep_0\": \"FT on GPT-2 XL, $\\epsilon=5e-3$\",\n",
    "#     \"FT_layers_sweep_3\": \"FT on GPT-2 XL, Unconstrained\",\n",
    "# }\n",
    "\n",
    "# SWEEP_DATA = {\n",
    "#     \"FT_layers_sweep_4\": \"FT on GPT-J, $\\epsilon=1e-5$\",\n",
    "#     \"FT_layers_sweep_5\": \"FT on GPT-J, $\\epsilon=5e-5$\",\n",
    "#     \"FT_layers_sweep_6\": \"FT on GPT-J, $\\epsilon=1e-4$\",\n",
    "#     \"FT_layers_sweep_7\": \"FT on GPT-J, Unconstrained\",\n",
    "# }\n",
    "\n",
    "# SWEEP_DATA = {\n",
    "#     \"FT_norm_constraint_sweep_0\": \"FT+L Attn Norm Sweep\",\n",
    "# }\n",
    "\n",
    "SWEEP_DATA = {\n",
    "    \"ROME_layers_sweep_token_subject_first\": \"First subject token\",\n",
    "    \"ROME_layers_sweep_token_subject_last\": \"Last subject token\",\n",
    "    \"ROME_layers_sweep_token_subject_first_after_last\": \"First token after subject\",\n",
    "    \"ROME_layers_sweep_token_last\": \"Last token\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f81d4b8e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = [main(dir_name=f\"{DIR_NAME}/{k}\", runs=None) for k in SWEEP_DATA.keys()]\n",
    "for i in range(len(data)):\n",
    "    data[i].sort(key=lambda x: x[\"run_dir\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85e2ba98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcParams[\"figure.figsize\"] = ((4, 3))\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 22\n",
    "MEDIUM_SIZE = 23\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c4985be",
   "metadata": {},
   "outputs": [],
   "source": [
    "NAMES = {\n",
    "    \"post_rewrite_success\": \"Rewrite Score\",\n",
    "    \"post_paraphrase_success\": \"Paraphrase Score\",\n",
    "    \"post_neighborhood_success\": \"Neighborh. Score\",\n",
    "    \"post_rewrite_diff\": \"Rewrite Magni.\",\n",
    "    \"post_paraphrase_diff\": \"Paraphrase Magni.\",\n",
    "    \"post_neighborhood_diff\": \"Neighborh. Magni.\",\n",
    "    \"post_reference_score\": \"Consistency\",\n",
    "    \"post_ngram_entropy\": \"Fluency\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ffad9c0",
   "metadata": {},
   "source": [
    "## FT on MLP sweeps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fa3124e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def do_stuff(keys, xlim=None, ylim=None):\n",
    "    n = len(data)\n",
    "    plt.figure(figsize=(len(data) * 4, 3.5))\n",
    "    for i, cur in enumerate(data):\n",
    "        plt.subplot(1, n, i + 1)\n",
    "\n",
    "        #         colors = plt.cm.get_cmap(\"gist_rainbow\")(np.linspace(0, 1, len(keys)))\n",
    "        colors = None\n",
    "\n",
    "        for j, key in enumerate(keys):\n",
    "            vals, err = [\n",
    "                np.array(\n",
    "                    [\n",
    "                        run[key][idx] / 10 if \"entropy\" in key else run[key][idx]\n",
    "                        for run in cur\n",
    "                    ]\n",
    "                )\n",
    "                for idx in [0, 1]\n",
    "            ]\n",
    "\n",
    "            err *= 1.96 / np.sqrt(100)\n",
    "\n",
    "            layers = np.arange(len(cur))\n",
    "            plt.plot(\n",
    "                layers, vals, label=NAMES[key], color=(colors[j] if colors else None)\n",
    "            )\n",
    "\n",
    "            plt.fill_between(layers, vals - err, vals + err, alpha=0.4)\n",
    "\n",
    "            plt.xlabel(\"Layer\")\n",
    "            plt.title(SWEEP_DATA[cur[0][\"run_dir\"].split(\"/\")[-2]])\n",
    "\n",
    "        if i == 3:\n",
    "            plt.legend(loc=4, prop={\"size\": 14}, framealpha=0.5)\n",
    "\n",
    "        plt.xlim(xlim)\n",
    "        plt.ylim(ylim)\n",
    "\n",
    "        plt.grid(True, color=\"#93a1a1\", alpha=0.3)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"tmp_plot.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "do_stuff(\n",
    "    [\"post_rewrite_success\", \"post_paraphrase_success\", \"post_neighborhood_success\"],\n",
    "    ylim=(0, 105),\n",
    ")\n",
    "do_stuff(\n",
    "    [\"post_rewrite_diff\", \"post_paraphrase_diff\", \"post_neighborhood_diff\"],\n",
    "    ylim=(-50, 100),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b00e16af",
   "metadata": {},
   "source": [
    "## ROME Sweeps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bda9dd3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "TMP_NAMES = {\n",
    "    #     \"post_rewrite_success\": \"(a) Rewrite Score\",\n",
    "    #     \"post_paraphrase_success\": \"(b) Paraphrase Score\",\n",
    "    #     \"post_neighborhood_success\": \"(c) Neighborh. Score\",\n",
    "    \"post_rewrite_diff\": \"(a) Efficacy (EM)\",\n",
    "    \"post_paraphrase_diff\": \"(b) Generalization (PM)\",\n",
    "    \"post_neighborhood_diff\": \"(c) Specificity (NM)\",\n",
    "    \"post_reference_score\": \"(d) Consistency (RS)\",\n",
    "    \"post_ngram_entropy\": \"(e) Fluency\",\n",
    "    \"post_score\": \"(d) Score (S)\",\n",
    "}\n",
    "\n",
    "COL_ORDER = [\"orange\", \"tomato\", \"green\", \"cornflowerblue\"]\n",
    "Z_ORDER = [10, 0, 0, 0]\n",
    "\n",
    "\n",
    "def do_stuff(keys, xlim=None, ylim=None):\n",
    "    \"\"\"\n",
    "    Plot one key at a time, over all layers.\n",
    "    Each line is a different entry in SWEEP_DATA\n",
    "    \"\"\"\n",
    "\n",
    "    n = len(keys)\n",
    "    plt.figure(figsize=(n * 5, 2.75))\n",
    "    for j, key in enumerate(keys):\n",
    "        plt.subplot(1, n, j + 1)\n",
    "        colors = None\n",
    "\n",
    "        for i, cur in enumerate(data):\n",
    "            vals, err = [\n",
    "                np.array(\n",
    "                    [\n",
    "                        run[key][idx] / 10 if \"entropy\" in key else run[key][idx]\n",
    "                        for run in cur\n",
    "                    ]\n",
    "                )\n",
    "                for idx in [0, 1]\n",
    "            ]\n",
    "            cur_dict_key = SWEEP_DATA[cur[0][\"run_dir\"].split(\"/\")[-2]]\n",
    "\n",
    "            err *= 1.96 / np.sqrt(100)\n",
    "\n",
    "            layers = np.arange(len(cur))\n",
    "            plt.plot(\n",
    "                layers,\n",
    "                vals,\n",
    "                label=cur_dict_key,\n",
    "                color=COL_ORDER[i],\n",
    "                zorder=Z_ORDER[i],\n",
    "                linewidth=3,\n",
    "            )\n",
    "\n",
    "            #             print(key, [(i, vals[i]) for i in range(len(vals))])\n",
    "\n",
    "            plt.fill_between(\n",
    "                layers,\n",
    "                vals - err,\n",
    "                vals + err,\n",
    "                alpha=0.4,\n",
    "                color=COL_ORDER[i],\n",
    "                zorder=Z_ORDER[i],\n",
    "            )\n",
    "\n",
    "            plt.xlabel(\"single layer edited by ROME\")\n",
    "            #             plt.title()\n",
    "            plt.title(TMP_NAMES[key])\n",
    "\n",
    "        if j == n // 2:\n",
    "            leg = plt.legend(\n",
    "                prop={\"size\": 18}, framealpha=0.5, bbox_to_anchor=(1.5, -0.5), ncol=n\n",
    "            )\n",
    "            for legobj in leg.legendHandles:\n",
    "                legobj.set_linewidth(8.0)\n",
    "\n",
    "        plt.xlim(xlim)\n",
    "        plt.ylim(ylim)\n",
    "\n",
    "        #         plt.xticks(np.arange(0, len(cur), 5))\n",
    "        #         plt.xticks(np.arange(0, len(cur), 5), minor=True)\n",
    "        #         plt.yticks(np.arange(0, 100, 20))\n",
    "        #         plt.yticks(minor_ticks, minor=True)\n",
    "\n",
    "        plt.grid(True, color=\"#93a1a1\", alpha=0.3)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"tmp_plot.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# do_stuff([\"post_rewrite_success\", \"post_paraphrase_success\", \"post_neighborhood_success\", \"post_reference_score\"])#, ylim=(0, 105))\n",
    "# do_stuff([\"post_rewrite_diff\", \"post_paraphrase_diff\", \"post_neighborhood_diff\", \"post_reference_score\"])#, ylim=(-50, 100))\n",
    "\n",
    "do_stuff(\n",
    "    [\n",
    "        \"post_rewrite_diff\",\n",
    "        \"post_paraphrase_diff\",\n",
    "        \"post_neighborhood_diff\",\n",
    "        \"post_score\",\n",
    "    ]\n",
    ")  # , ylim=(-50, 100))\n",
    "# do_stuff([\"post_score\"])\n",
    "\n",
    "# do_stuff([\"post_rewrite_success\", \"post_paraphrase_success\", \"post_neighborhood_success\", \"post_reference_score\", \"post_ngram_entropy\"], ylim=(0, 105))\n",
    "# do_stuff([\"post_rewrite_diff\", \"post_paraphrase_diff\", \"post_neighborhood_diff\", \"post_reference_score\", \"post_ngram_entropy\"], ylim=(-50, 100))\n",
    "# do_stuff([\"post_ngram_entropy\"])\n",
    "# do_stuff([\"post_reference_score\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd5b0e32",
   "metadata": {},
   "source": [
    "## FT on Attn Sweeps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e1a789b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def do_stuff(lkeys):\n",
    "    n = len(lkeys)\n",
    "    plt.figure(figsize=(n * 4, 3.25))\n",
    "    for i, keys in enumerate(lkeys):\n",
    "        plt.subplot(1, n, i + 1)\n",
    "\n",
    "        for key in keys:\n",
    "            plt.plot(\n",
    "                [5e-4, 6e-4, 7e-4, 8e-4, 9e-4, 1e-3, 2e-3, 3e-3, 4e-3, 5e-3],\n",
    "                [run[key][0] for run in data[0]],\n",
    "                label=NAMES[key],\n",
    "            )\n",
    "            plt.xlabel(\"$\\epsilon$\")\n",
    "            plt.xticks(0.001 * np.arange(1, 6))\n",
    "            plt.title(SWEEP_DATA[data[0][0][\"run_dir\"].split(\"/\")[-2]])\n",
    "\n",
    "            #         if i == 0:\n",
    "            plt.legend(loc=4, prop={\"size\": 10}, framealpha=0.5)\n",
    "        plt.grid(True, color=\"#93a1a1\", alpha=0.3)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"tmp_plot.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "do_stuff(\n",
    "    [\n",
    "        [\n",
    "            \"post_rewrite_success\",\n",
    "            \"post_paraphrase_success\",\n",
    "            \"post_neighborhood_success\",\n",
    "        ],\n",
    "        [\"post_rewrite_diff\", \"post_paraphrase_diff\", \"post_neighborhood_diff\"],\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4122682a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
