{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from ast import literal_eval\n",
    "import functools\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "import re\n",
    "\n",
    "\n",
    "# Scienfitic packages\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "torch.cuda.set_device(1)\n",
    "from tqdm import tqdm\n",
    "torch.set_grad_enabled(False)\n",
    "tqdm.pandas()\n",
    "\n",
    "# Visuals\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(context=\"notebook\", \n",
    "        rc={\"font.size\":16,\n",
    "            \"axes.titlesize\":16,\n",
    "            \"axes.labelsize\":16,\n",
    "            \"xtick.labelsize\": 16.0,\n",
    "            \"ytick.labelsize\": 16.0,\n",
    "            \"legend.fontsize\": 16.0})\n",
    "palette_ = sns.color_palette(\"Set1\")\n",
    "palette = palette_[2:5] + palette_[7:]\n",
    "sns.set_theme(style='whitegrid')\n",
    "\n",
    "# Utilities\n",
    "from utils import (\n",
    "    ModelAndTokenizer,\n",
    "    make_inputs,\n",
    "    decode_tokens,\n",
    "    find_token_range,\n",
    "    predict_from_input,\n",
    ")\n",
    "import nethook\n",
    "# List of stopwords from NLTK, needed only for the attributes rate evaluation.\n",
    "# import nltk\n",
    "# nltk.download('stopwords')\n",
    "# from nltk.corpus import stopwords\n",
    "# stopwords0_ = stopwords.words('english')\n",
    "# stopwords0_ = {word: \"\" for word in stopwords0_}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get CounterFact data for GPT2-xl, from the ROME repository.\n",
    "knowns_df = pd.read_json(\"../memit_attn/data/known_1000.json\")\n",
    "knowns_df_size = len(knowns_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPT-J"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "# Load model from local\n",
    "model_path = \"../ptms/\"\n",
    "model_name = \"EleutherAI/gpt-j-6B\"\n",
    "mt = ModelAndTokenizer(\n",
    "    model_path,\n",
    "    model_name,\n",
    "    low_cpu_mem_usage=False,\n",
    "    # torch_dtype=torch.half,\n",
    ")\n",
    "mt.model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cache of hidden representations of GPT-J"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a cache of subject representations\n",
    "def edit_output_fn(cur_out, cur_layer):\n",
    "    return cur_out\n",
    "E = mt.model.get_output_embeddings().weight.detach()\n",
    "ln_f = nethook.get_module(mt.model, \"transformer.ln_f\")\n",
    "\n",
    "mlp_module_tmp = \"transformer.h.{}.mlp\"\n",
    "attn_module_tmp = \"transformer.h.{}.attn\"\n",
    "layers_to_cache = range(0,28)\n",
    "layers_to_trace = []\n",
    "\n",
    "# 使用for循环将layers_to_cache列表中的每个层添加到要追踪的列表中\n",
    "for layer in layers_to_cache:\n",
    "    layers_to_trace.append(mlp_module_tmp.format(layer))\n",
    "    layers_to_trace.append(attn_module_tmp.format(layer))\n",
    "\n",
    "hs_cache = {}\n",
    "for row_i, row in tqdm(knowns_df.iterrows()):\n",
    "    prompt = row.prompt\n",
    "    inp = make_inputs(mt.tokenizer, [prompt])\n",
    "    \n",
    "    with nethook.TraceDict(\n",
    "        module=mt.model,\n",
    "        layers=layers_to_trace,\n",
    "        retain_input=True,\n",
    "        retain_output=True,\n",
    "        detach=True\n",
    "    ) as tr:\n",
    "        logits = mt.model(**inp).logits\n",
    "        predicted_token_index  = torch.argmax(torch.softmax(logits[:, -1, :], dim=-1), dim=-1)[0]\n",
    "        last_input = None\n",
    "        for layer in layers_to_trace:\n",
    "            if (prompt, layer) not in hs_cache:\n",
    "                hs_cache[(layer, prompt)] = {}\n",
    "                hs_cache[(layer,prompt)]['states'] = []\n",
    "                hs_cache[(layer,prompt)]['probs'] = []\n",
    "            \n",
    "            if 'attn' in layer:\n",
    "                hs_cache[(layer,prompt)]['states'].append(last_input)\n",
    "                hs_cache[(layer,prompt)]['states'].append(tr[layer].output[0][0] )\n",
    "            else:\n",
    "                last_input = tr[layer].input[0]\n",
    "                hs_cache[(layer,prompt)]['states'].append(last_input)\n",
    "                hs_cache[(layer,prompt)]['states'].append(tr[layer].output[0])\n",
    "            #compute attribute probs\n",
    "            hs_cache[(layer, prompt)]['probs'] = []\n",
    "            for i in range(2):\n",
    "                if i==1:\n",
    "                    full_repr = hs_cache[(layer,prompt)]['states'][i]+last_input\n",
    "                else:\n",
    "                    full_repr = hs_cache[(layer,prompt)]['states'][i]\n",
    "                log_probs = torch.softmax(ln_f(full_repr) @ E.T, dim = -1)\n",
    "                try:\n",
    "                    attr_log_probs = log_probs[-1, predicted_token_index]\n",
    "                except:\n",
    "                    attr_log_probs = log_probs.squeeze(0)[-1, predicted_token_index]\n",
    "                hs_cache[(layer, prompt)]['probs'].append(attr_log_probs.item())\n",
    "        \n",
    "len(hs_cache)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get sentence representations' similarities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Projection of token representations\n",
    "E = mt.model.get_input_embeddings().weight.detach()\n",
    "k = 50\n",
    "\n",
    "def jaccard_set(list1, list2):\n",
    "    \"\"\"Define Jaccard Similarity function for two sets\"\"\"\n",
    "    intersection = len(list(set(list1).intersection(list2)))\n",
    "    union = (len(list1) + len(list2)) - intersection\n",
    "    return float(intersection) / union\n",
    "records = []\n",
    "all_probs = {}\n",
    "all_cos_sim_attn = [0 for l in layers_to_cache]\n",
    "all_cos_sim_mlp = [0 for l in layers_to_cache]\n",
    "all_jac_sim_attn = [0 for l in layers_to_cache]\n",
    "all_jac_sim_mlp = [0 for l in layers_to_cache]\n",
    "for row_i, row in tqdm(knowns_df.iterrows()):\n",
    "    prompt = row.prompt\n",
    "    subject = row.subject\n",
    "    attribute = row.attribute\n",
    "    \n",
    "    inp = make_inputs(mt.tokenizer, [prompt])\n",
    "    for layer in layers_to_trace:\n",
    "        layer_num = int(re.findall(\"\\d+\", layer)[0])\n",
    "        if layer not in all_probs:\n",
    "            all_probs[layer] = [0, 0]\n",
    "        if (layer,prompt) in hs_cache:\n",
    "            all_probs[layer][0] += hs_cache[(layer,prompt)]['probs'][0]\n",
    "            all_probs[layer][1] += hs_cache[(layer,prompt)]['probs'][1]\n",
    "        position, desc = (len(inp[\"input_ids\"][0])-1, f\"no_subj_last_{layer}\")\n",
    "        in_hs = hs_cache[(layer,prompt)]['states'][0][position]\n",
    "        try:\n",
    "            ot_hs = hs_cache[(layer,prompt)]['states'][1][position]\n",
    "        except:\n",
    "            ot_hs = hs_cache[(layer,prompt)]['states'][1].squeeze(0)[position]\n",
    "        in_projs = in_hs.matmul(E.T).cpu().numpy()\n",
    "        ot_projs = ot_hs.matmul(E.T).cpu().numpy()\n",
    "        in_ind = np.argsort(-in_projs)\n",
    "        ot_ind = np.argsort(-ot_projs)\n",
    "        in_topks = [decode_tokens(mt.tokenizer, [i])[0] for i in in_ind[:k]]\n",
    "        ot_topks = [decode_tokens(mt.tokenizer, [i])[0] for i in ot_ind[:k]]\n",
    "        \n",
    "        if \"mlp\" in layer:\n",
    "            all_cos_sim_mlp[layer_num] += torch.cosine_similarity(in_hs, ot_hs, dim = 0).cpu()\n",
    "            all_jac_sim_mlp[layer_num] += jaccard_set(in_topks, ot_topks)\n",
    "        else:\n",
    "            all_cos_sim_attn[layer_num] += torch.cosine_similarity(in_hs, ot_hs, dim = 0).cpu()\n",
    "            all_jac_sim_attn[layer_num] += jaccard_set(in_topks, ot_topks)\n",
    "        records.append({\n",
    "            \"example_index\": row_i,\n",
    "            \"subject\": subject,\n",
    "            \"layer\": layer,\n",
    "            \"position\": position,\n",
    "            \"desc\": desc,\n",
    "            \"desc_short\": desc.rsplit(\"_\", 1)[0],\n",
    "            \"input_top_k_preds\": in_topks,\n",
    "            \"output_top_k_preds\": ot_topks,\n",
    "        })\n",
    "\n",
    "tmp = pd.DataFrame.from_records(records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_sim(layers_to_cache, ylabel, all_sim_attn, attn_label, all_sim_mlp, mlp_label, save_path=None):\n",
    "    layers_to_cache = [layer+1 for layer in layers_to_cache]\n",
    "    _, ax = plt.subplots(figsize=(6, 4))\n",
    "    \n",
    "    # Choose light colors for the lines\n",
    "    attn_color = 'skyblue'\n",
    "    mlp_color = 'lightcoral'\n",
    "    \n",
    "    # Plot attention similarity\n",
    "    plt.plot(layers_to_cache, np.array(all_sim_attn) / knowns_df_size, label=attn_label, color=attn_color, linestyle='-', marker='o', linewidth=2.5)\n",
    "    \n",
    "    # Plot MLP similarity\n",
    "    plt.plot(layers_to_cache, np.array(all_sim_mlp) / knowns_df_size, label=mlp_label, color=mlp_color, linestyle='--', marker='x', linewidth=2.5)\n",
    "    \n",
    "    plt.tick_params(labelsize=14)\n",
    "    labels = ax.get_xticklabels() + ax.get_yticklabels()\n",
    "    [label.set_fontname('Times New Roman') for label in labels]\n",
    "\n",
    "    plt.xlabel('Layer', font1)\n",
    "    plt.ylabel(ylabel, font1)\n",
    "    plt.legend(loc='best', prop=font1)\n",
    "    plt.grid(True)\n",
    "    plt.tight_layout()  # Ensures all elements fit nicely in the plot area\n",
    "    \n",
    "    if save_path:\n",
    "        plt.savefig(save_path, format='pdf')  # Save the plot as PDF file if the 'save_path' parameter is provided\n",
    "    else:\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "font1 = {'family' : 'Times New Roman',\n",
    "'weight' : 'normal',\n",
    "'size'   : 14,\n",
    "}\n",
    "plot_sim(layers_to_cache, 'Average Cosine Similarity', all_cos_sim_attn,'Avg. Cos. Sim. - MHSA', all_cos_sim_mlp, 'Avg. Cos. Sim. - FFN', \"cos_sim.pdf\")\n",
    "plot_sim(layers_to_cache, 'Average Jaccard Similarity', all_jac_sim_attn,'Avg. Jac. Sim. - MHSA',all_jac_sim_mlp, 'Avg. Jac. Sim. - FFN', \"jac_sim.pdf\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "vscode": {
   "interpreter": {
    "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
