{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../../\")\n",
    "import copy\n",
    "import os\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src import functional\n",
    "from src.models import ModelandTokenizer\n",
    "# from src.data import load_relation\n",
    "import json\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from collections import OrderedDict\n",
    "\n",
    "@dataclass\n",
    "class LayerRewriteResult:\n",
    "    layer_idx: int\n",
    "    score: float\n",
    "    efficacy: float\n",
    "    generalization: float\n",
    "    specificity: float"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 18\n",
    "MEDIUM_SIZE = 22\n",
    "BIGGER_SIZE = 26\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=MEDIUM_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+2)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE+2)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=BIGGER_SIZE)  # legend fontsize\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "from typing import Optional\n",
    "\n",
    "def layername(i, jump=5):\n",
    "    if i % jump == 0:\n",
    "        return str(i)\n",
    "    else:\n",
    "        return \"\"\n",
    "\n",
    "def visulize_sweep_results(\n",
    "    module_wise_results: dict[int, LayerRewriteResult],\n",
    "    savepdf: Optional[str] = None,\n",
    "):\n",
    "    metric_title = {\n",
    "        \"efficacy\": {\"title\": \"Efficacy (ES)\"},\n",
    "        \"generalization\": {\"title\": \"Generalization (PS)\"},\n",
    "        \"specificity\": {\"title\": \"Specificity (NS)\"},\n",
    "        \"score\": {\"title\": \"Score (S)\"},\n",
    "    }\n",
    "\n",
    "    module_color = {\n",
    "        \"in_proj/ssm\": \"firebrick\",\n",
    "        \"in_proj/non_ssm\": \"green\",\n",
    "        \"out_proj\": \"blue\",\n",
    "        \"dense_4h_to_h\": \"#004c4c\"\n",
    "    }\n",
    "\n",
    "    fig, ax = plt.subplots(1, 4, figsize=(20, 4))\n",
    "\n",
    "    for module in module_wise_results:\n",
    "        layer_results = module_wise_results[module]\n",
    "        layers = sorted(list(layer_results.keys()))\n",
    "        data = {\n",
    "            \"efficacy\": [layer_results[layer].efficacy for layer in layers],\n",
    "            \"generalization\": [layer_results[layer].generalization for layer in layers],\n",
    "            \"specificity\": [layer_results[layer].specificity for layer in layers],\n",
    "            \"score\": [layer_results[layer].score for layer in layers],\n",
    "        }\n",
    "        \n",
    "        for canvas, (k, v) in zip(ax, data.items()):\n",
    "            canvas.plot(\n",
    "                layers, v,\n",
    "                marker=\"o\",\n",
    "                markersize=3,\n",
    "                linewidth=2.5,\n",
    "                alpha=0.7,\n",
    "                color=module_color[module],\n",
    "                label=module if k == \"efficacy\" else None,\n",
    "            )\n",
    "            canvas.set_title(k)\n",
    "            canvas.set_xlabel(\"Edit Layer\")\n",
    "            # canvas.set_ylabel(\"Score\")\n",
    "            canvas.set_title(metric_title[k][\"title\"])\n",
    "            canvas.set_ylim(0, 108)\n",
    "            layer_names = [\n",
    "                str(l) if i%3 == 0 else \"\" for i, l in enumerate(layers)\n",
    "            ]\n",
    "            canvas.set_xticks(layers, layer_names)\n",
    "            canvas.set_yticks(np.arange(0, 110, 25))\n",
    "\n",
    "    fig.legend(ncol = 3, bbox_to_anchor=(0.5, -.13), loc='lower center', frameon=False)\n",
    "    fig.tight_layout()\n",
    "\n",
    "    if savepdf:\n",
    "        os.makedirs(os.path.dirname(savepdf), exist_ok=True)\n",
    "        fig.savefig(savepdf, bbox_inches=\"tight\", dpi=300)\n",
    "\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.summarize import main as summarize\n",
    "from pathlib import Path\n",
    "from typing import Optional, Literal\n",
    "\n",
    "MODEL_INFO = {\n",
    "    \"pythia-2.8b-deduped\": {\n",
    "        \"n_layer\": 32,\n",
    "        \"dirs\": [\"dense_4h_to_h\"],\n",
    "        \"hparams\": \"pythia-3b.json\",\n",
    "        \"known_data\": \"data/known/pythia-2.8b-deduped.json\"\n",
    "    },\n",
    "    \"mamba-2.8b\": {\n",
    "        \"n_layer\": 64,\n",
    "        # \"rewrite_module\": \"layers.{}.mixer.out_proj\"\n",
    "        # \"rewrite_module\": \"layers.{}.mixer.in_proj\",\n",
    "        \"dirs\": [\"in_proj/ssm\", \"in_proj/non_ssm\", \"out_proj\"],\n",
    "        \"hparams\": \"mamba-3b.json\",\n",
    "        \"known_data\": \"data/known/mamba-2.8b.json\"\n",
    "    }\n",
    "}\n",
    "\n",
    "#############################################################\n",
    "ROME_DIR = \"../../results/ROME/\"\n",
    "# MODEL = \"pythia-2.8b-deduped\"\n",
    "MODEL = \"mamba-2.8b\"\n",
    "#############################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewrite_matrix = MODEL_INFO[MODEL][\"dirs\"][0]\n",
    "sweep_dir = os.path.join(ROME_DIR, MODEL, rewrite_matrix)\n",
    "# os.listdir(sweep_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_wise_results = {module: OrderedDict() for module in MODEL_INFO[MODEL][\"dirs\"]}\n",
    "\n",
    "for module in MODEL_INFO[MODEL][\"dirs\"]:\n",
    "    sweep_dir = os.path.join(ROME_DIR, MODEL, module)\n",
    "\n",
    "    for layer_folder in os.listdir(sweep_dir):\n",
    "        summary = summarize(\n",
    "            dir_name=Path(sweep_dir) / layer_folder,\n",
    "            # runs=[\"run_000\"],\n",
    "            abs_path=True\n",
    "        )[0]\n",
    "        layer_idx = int(layer_folder.split(\"_\")[-1])\n",
    "        module_wise_results[module][layer_idx] = LayerRewriteResult(\n",
    "            layer_idx=layer_idx,\n",
    "            score=summary[\"post_score\"][0],\n",
    "            efficacy=summary[\"post_rewrite_success\"][0],\n",
    "            generalization=summary[\"post_paraphrase_success\"][0],\n",
    "            specificity=summary[\"post_neighborhood_success\"][0]\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_wise_results.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# module_wise_results[\"in_proj/ssm\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visulize_sweep_results(\n",
    "    module_wise_results,\n",
    "    savepdf=f\"../../Figures/ROME/{MODEL}_sweep.pdf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
