{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "24388ece",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch\n",
    "import pickle\n",
    "import json\n",
    "import bertviz\n",
    "import uuid\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.pyplot import MultipleLocator\n",
    "from collections import Counter\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from IPython.core.display import display, HTML, Javascript\n",
    "from bertviz.util import format_special_chars, format_attention, num_layers, num_heads\n",
    "LAYER_NUM = 28\n",
    "HEAD_NUM = 16\n",
    "HEAD_DIM = 256\n",
    "HIDDEN_DIM = HEAD_NUM * HEAD_DIM\n",
    "torch.set_default_device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eff8bee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "LAYER_NUM = 28\n",
    "HEAD_NUM = 16\n",
    "HEAD_DIM = 256\n",
    "HIDDEN_DIM = HEAD_NUM * HEAD_DIM\n",
    "torch.set_default_device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fff8b3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def transfer_output(model_output):\n",
    "    all_pos_layer_input = []\n",
    "    all_pos_attn_output = []\n",
    "    all_pos_residual_output = []\n",
    "    all_pos_ffn_output = []\n",
    "    all_pos_layer_output = []\n",
    "    all_last_attn_subvalues = []\n",
    "    all_pos_coefficient_scores = []\n",
    "    all_attn_scores = []\n",
    "    for layer_i in range(LAYER_NUM):\n",
    "        cur_layer_input = model_output[layer_i][0]\n",
    "        cur_attn_output = model_output[layer_i][1]\n",
    "        cur_residual_output = model_output[layer_i][2]\n",
    "        cur_ffn_output = model_output[layer_i][3]\n",
    "        cur_layer_output = model_output[layer_i][4]\n",
    "        cur_last_attn_subvalues = model_output[layer_i][5]\n",
    "        cur_coefficient_scores = model_output[layer_i][6]\n",
    "        cur_attn_weights = model_output[layer_i][7]\n",
    "        all_pos_layer_input.append(cur_layer_input[0].tolist())\n",
    "        all_pos_attn_output.append(cur_attn_output[0].tolist())\n",
    "        all_pos_residual_output.append(cur_residual_output[0].tolist())\n",
    "        all_pos_ffn_output.append(cur_ffn_output[0].tolist())\n",
    "        all_pos_layer_output.append(cur_layer_output[0].tolist())\n",
    "        all_last_attn_subvalues.append(cur_last_attn_subvalues[0].tolist())\n",
    "        all_pos_coefficient_scores.append(cur_coefficient_scores[0].tolist())\n",
    "        all_attn_scores.append(cur_attn_weights)\n",
    "    return all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, \\\n",
    "           all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores\n",
    "def get_bsvalues(vector, model, final_var):\n",
    "    E = torch.mean(vector, -1)\n",
    "    vector_ln = (vector-E.unsqueeze(-1))/final_var * model.transformer.ln_f.weight.data\n",
    "    vector_bsvalues = model.lm_head(vector_ln).data\n",
    "    return vector_bsvalues\n",
    "def get_prob(vector):\n",
    "    prob = torch.nn.Softmax(-1)(vector)\n",
    "    return prob\n",
    "def get_fc2_params(model, layer_num):\n",
    "    return model.transformer.h[layer_num].mlp.fc_out.weight.data\n",
    "def transfer_l(l):\n",
    "    new_x, new_y = [], []\n",
    "    for x in l:\n",
    "        new_x.append(x[0])\n",
    "        new_y.append(x[1])\n",
    "    return new_x, new_y\n",
    "def plt_bar(x, y, yname=\"log increase\"):\n",
    "    x_major_locator=MultipleLocator(1)\n",
    "    plt.figure(figsize=(8, 3))\n",
    "    ax=plt.gca()\n",
    "    ax.xaxis.set_major_locator(x_major_locator)\n",
    "    plt_x = [a/2 for a in x]\n",
    "    plt.xlim(-0.5, plt_x[-1]+0.49)\n",
    "    x_attn, y_attn, x_ffn, y_ffn = [], [], [], []\n",
    "    for i in range(len(x)):\n",
    "        if i%2 == 0:\n",
    "            x_attn.append(x[i]/2)\n",
    "            y_attn.append(y[i])\n",
    "        else:\n",
    "            x_ffn.append(x[i]/2)\n",
    "            y_ffn.append(y[i])\n",
    "    plt.bar(x_attn, y_attn, color=\"darksalmon\", label=\"attention layers\")\n",
    "    plt.bar(x_ffn, y_ffn, color=\"lightseagreen\", label=\"FFN layers\")\n",
    "    plt.xlabel(\"layer\")\n",
    "    plt.ylabel(yname)\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "def plt_heatmap(data):\n",
    "    xLabel = range(len(data[0]))\n",
    "    yLabel = range(len(data))\n",
    "    fig = plt.figure(figsize=(10,8))\n",
    "    ax = fig.add_subplot(111)\n",
    "    ax.set_xticks(range(len(xLabel)))\n",
    "    ax.set_xticklabels(xLabel)\n",
    "    ax.set_yticks(range(len(yLabel)))\n",
    "    ax.set_yticklabels(yLabel)\n",
    "    im = ax.imshow(data, cmap=plt.cm.hot_r)\n",
    "    #plt.colorbar(im)\n",
    "    plt.title(\"attn head log increase heatmap\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb7b9069",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelname = \"model_gptj\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(modelname)\n",
    "model = AutoModelForCausalLM.from_pretrained(modelname)\n",
    "model.eval()\n",
    "model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1436682f",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sentence = \"Tim Duncan plays the sport of\"\n",
    "indexed_tokens = tokenizer.encode(test_sentence)\n",
    "tokens = [tokenizer.decode(x) for x in indexed_tokens]\n",
    "tokens_tensor = torch.tensor([indexed_tokens]).to(\"cuda:1\")\n",
    "with torch.no_grad():\n",
    "    outputs = model(tokens_tensor)\n",
    "    predictions = outputs[0]\n",
    "predicted_top10 = torch.argsort(predictions[0][-1], descending=True)[:10]\n",
    "predicted_text = [tokenizer.decode(x) for x in predicted_top10]\n",
    "print(test_sentence, \"=>\", predicted_text)\n",
    "# print(predictions)\n",
    "# print(len(outputs))\n",
    "# print(outputs)\n",
    "all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(outputs[1])\n",
    "final_var = ((torch.var(torch.tensor(all_pos_layer_output[-1][-1]), -1, unbiased=False)+1e-5)**0.5).item()\n",
    "pos_len = len(tokens)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "914a37b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_index = predicted_top10[0].item()\n",
    "print(predict_index, tokenizer.decode(predict_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99a9f05a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#layer-level increase\n",
    "all_attn_log_increase = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    cur_attn_vector = torch.tensor(all_pos_attn_output[layer_i][-1])\n",
    "    cur_layer_input = torch.tensor(all_pos_layer_input[layer_i][-1])\n",
    "    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_layer_input, model, final_var))[predict_index])\n",
    "    cur_attn_vector_plus = cur_attn_vector + cur_layer_input\n",
    "    cur_attn_vector_bsvalues = get_bsvalues(cur_attn_vector_plus, model, final_var)\n",
    "    cur_attn_vector_probs = get_prob(cur_attn_vector_bsvalues)\n",
    "    cur_attn_vector_probs = cur_attn_vector_probs[predict_index]\n",
    "    cur_attn_vector_probs_log = torch.log(cur_attn_vector_probs)\n",
    "    cur_attn_vector_probs_log_increase = cur_attn_vector_probs_log - origin_prob_log\n",
    "    all_attn_log_increase.append(cur_attn_vector_probs_log_increase.item())\n",
    "all_ffn_log_increase = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    cur_ffn_vector = torch.tensor(all_pos_ffn_output[layer_i][-1])\n",
    "    cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])\n",
    "    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])\n",
    "    cur_ffn_vector_plus = cur_ffn_vector + cur_residual\n",
    "    cur_ffn_vector_bsvalues = get_bsvalues(cur_ffn_vector_plus, model, final_var)\n",
    "    cur_ffn_vector_probs = get_prob(cur_ffn_vector_bsvalues)\n",
    "    cur_ffn_vector_probs = cur_ffn_vector_probs[predict_index]\n",
    "    cur_ffn_vector_probs_log = torch.log(cur_ffn_vector_probs)\n",
    "    cur_ffn_vector_probs_log_increase = cur_ffn_vector_probs_log - origin_prob_log\n",
    "    all_ffn_log_increase.append(cur_ffn_vector_probs_log_increase.tolist())\n",
    "attn_list, ffn_list = [], []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    attn_list.append([str(layer_i), all_attn_log_increase[layer_i]])\n",
    "    ffn_list.append([str(layer_i), all_ffn_log_increase[layer_i]])\n",
    "attn_list_sort = sorted(attn_list, key=lambda x: x[-1])[::-1]#[:10]\n",
    "ffn_list_sort = sorted(ffn_list, key=lambda x: x[-1])[::-1]#[:10]\n",
    "attn_increase_compute, ffn_increase_compute = [], []\n",
    "for indx, increase in attn_list_sort:\n",
    "    attn_increase_compute.append((indx, round(increase, 3)))\n",
    "for indx, increase in ffn_list_sort:\n",
    "    ffn_increase_compute.append((indx, round(increase, 3)))\n",
    "print(\"attn sum: \", sum([x[1] for x in attn_increase_compute]), \n",
    "      \"ffn sum: \", sum([x[1] for x in ffn_increase_compute]))\n",
    "print(\"attn: \", attn_increase_compute)\n",
    "print(\"ffn: \", ffn_increase_compute)\n",
    "all_increases_draw = []\n",
    "for i in range(len(attn_list)):\n",
    "    all_increases_draw.append(attn_list[i][1])\n",
    "    all_increases_draw.append(ffn_list[i][1])    \n",
    "plt_bar(range(len(all_increases_draw)), all_increases_draw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b136bcad",
   "metadata": {},
   "outputs": [],
   "source": [
    "#head-level increase\n",
    "all_head_increase = []\n",
    "for test_layer in range(LAYER_NUM):\n",
    "    cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])\n",
    "    cur_v_heads = torch.tensor(all_last_attn_subvalues[test_layer])\n",
    "    cur_attn_o_split = model.transformer.h[test_layer].attn.out_proj.weight.data.view(HEAD_NUM, HEAD_DIM, -1)\n",
    "    cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)\n",
    "    cur_attn_subvalues_head_sum = torch.sum(cur_attn_subvalues_headrecompute, 0)\n",
    "    cur_layer_input_last = cur_layer_input[-1]\n",
    "    origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])\n",
    "    cur_attn_subvalues_head_plus = cur_attn_subvalues_head_sum + cur_layer_input_last\n",
    "    cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(\n",
    "            cur_attn_subvalues_head_plus, model, final_var))[:, predict_index])\n",
    "    cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob\n",
    "    for i in range(len(cur_attn_plus_probs_increase)):\n",
    "        all_head_increase.append([str(test_layer)+\"_\"+str(i), round(cur_attn_plus_probs_increase[i].item(), 4)])\n",
    "\n",
    "all_head_increase_sort = sorted(all_head_increase, key=lambda x:x[-1])[::-1]\n",
    "print(all_head_increase_sort[:20])\n",
    "all_head_increase_list = [x[1] for x in all_head_increase]\n",
    "all_head_increase_list_split = torch.tensor(all_head_increase_list).view((LAYER_NUM, HEAD_NUM)).permute((1,0)).tolist()\n",
    "plt_heatmap(all_head_increase_list_split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d612931",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_ffn_subvalues = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    coefficient_scores = torch.tensor(all_pos_coefficient_scores[layer_i][-1])\n",
    "    fc2_vectors = get_fc2_params(model, layer_i)\n",
    "    ffn_subvalues = (coefficient_scores * fc2_vectors).T\n",
    "    all_ffn_subvalues.append(ffn_subvalues)\n",
    "ffn_subvalue_list = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    cur_ffn_subvalues = all_ffn_subvalues[layer_i]\n",
    "    cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])\n",
    "    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])\n",
    "    cur_ffn_subvalues_plus = cur_ffn_subvalues + cur_residual\n",
    "    cur_ffn_subvalues_bsvalues = get_bsvalues(cur_ffn_subvalues_plus, model, final_var)\n",
    "    cur_ffn_subvalues_probs = get_prob(cur_ffn_subvalues_bsvalues)\n",
    "    cur_ffn_subvalues_probs = cur_ffn_subvalues_probs[:, predict_index]\n",
    "    cur_ffn_subvalues_probs_log = torch.log(cur_ffn_subvalues_probs)\n",
    "    cur_ffn_subvalues_probs_log_increase = cur_ffn_subvalues_probs_log - origin_prob_log\n",
    "    for index, ffn_increase in enumerate(cur_ffn_subvalues_probs_log_increase):\n",
    "        ffn_subvalue_list.append([str(layer_i)+\"_\"+str(index), ffn_increase.item()])\n",
    "ffn_subvalue_list_sort = sorted(ffn_subvalue_list, key=lambda x: x[-1])[::-1]\n",
    "for x in ffn_subvalue_list_sort[:10]:\n",
    "    print(x[0], round(x[1], 4))\n",
    "    layer = int(x[0].split(\"_\")[0])\n",
    "    neuron = int(x[0].split(\"_\")[1])\n",
    "    cur_vector = get_fc2_params(model, layer).T[neuron]\n",
    "    cur_vector_bsvalue = get_bsvalues(cur_vector, model, final_var)\n",
    "    cur_vector_bsvalue_sort = torch.argsort(cur_vector_bsvalue, descending=True)\n",
    "    print(\"top10: \", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[:10]])\n",
    "    print(\"last10: \", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[-10:]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2fe70d83",
   "metadata": {},
   "outputs": [],
   "source": [
    "#attn neuron (pos)\n",
    "cur_file_attn_neuron_list = []\n",
    "for test_layer in range(LAYER_NUM):\n",
    "    cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])\n",
    "    cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[test_layer]).permute(1, 0, 2)\n",
    "    cur_attn_o_split = model.transformer.h[test_layer].attn.out_proj.weight.data.view(HEAD_NUM, HEAD_DIM, -1)\n",
    "    cur_attn_o_recompute = cur_attn_o_split * cur_v_heads_recompute.unsqueeze(-1)\n",
    "    cur_layer_input_last = cur_layer_input[-1]\n",
    "    origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])\n",
    "    cur_attn_o_head_plus = cur_attn_o_recompute + cur_layer_input_last\n",
    "    cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(\n",
    "        cur_attn_o_head_plus, model, final_var))[:, :, :, predict_index])\n",
    "    cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob\n",
    "    for pos_index in range(cur_attn_plus_probs_increase.size(0)):\n",
    "        for head_index in range(cur_attn_plus_probs_increase.size(1)):\n",
    "            for attn_neuron_index in range(cur_attn_plus_probs_increase.size(2)):\n",
    "                cur_file_attn_neuron_list.append((str(test_layer)+\"_\"+str(head_index)+\"_\"+str(\n",
    "                    attn_neuron_index)+\"_\"+str(pos_index), \n",
    "                    cur_attn_plus_probs_increase[pos_index][head_index][attn_neuron_index].item()))\n",
    "cur_file_attn_neuron_list_sort = sorted(cur_file_attn_neuron_list, key=lambda x: x[-1])[::-1]\n",
    "print(list(zip(range(len(tokens)), tokens)))\n",
    "for x in cur_file_attn_neuron_list_sort[:10]:\n",
    "    layer_i, head_i, neuron_i, _ = x[0].split(\"_\")\n",
    "    layer_i, head_i, neuron_i = int(layer_i), int(head_i), int(neuron_i)\n",
    "    cur_neuron = model.transformer.h[layer_i].attn.out_proj.weight.data.view(HEAD_NUM, HEAD_DIM, -1)[head_i][neuron_i]\n",
    "    cur_neuron_bsvalue = get_bsvalues(cur_neuron, model, final_var)\n",
    "    cur_neuron_bsvalue_sort = torch.argsort(cur_neuron_bsvalue, descending=True)\n",
    "    print(x[0], round(x[1], 4), \"top10: \", [tokenizer.decode(a) for a in cur_neuron_bsvalue_sort[:10]])\n",
    "    print(x[0], round(x[1], 4), \"last10: \", [tokenizer.decode(a) for a in cur_neuron_bsvalue_sort[-10:]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e64bcd89",
   "metadata": {},
   "outputs": [],
   "source": [
    "#attn neuron (pos)\n",
    "cur_file_attn_neuron_list = []\n",
    "for test_layer in range(LAYER_NUM):\n",
    "    cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])\n",
    "    cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[test_layer]).permute(1, 0, 2)\n",
    "    cur_attn_o_split = model.transformer.h[test_layer].attn.c_proj.weight.data.view(HEAD_NUM, HEAD_DIM, -1)\n",
    "    cur_attn_o_recompute = cur_attn_o_split * cur_v_heads_recompute.unsqueeze(-1)\n",
    "    cur_layer_input_last = cur_layer_input[-1]\n",
    "    origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])\n",
    "    cur_attn_o_head_plus = cur_attn_o_recompute + cur_layer_input_last\n",
    "    cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(\n",
    "        cur_attn_o_head_plus, model, final_var))[:, :, :, predict_index])\n",
    "    cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob\n",
    "    for pos_index in range(cur_attn_plus_probs_increase.size(0)):\n",
    "        for head_index in range(cur_attn_plus_probs_increase.size(1)):\n",
    "            for attn_neuron_index in range(cur_attn_plus_probs_increase.size(2)):\n",
    "                cur_file_attn_neuron_list.append((str(test_layer)+\"_\"+str(head_index)+\"_\"+str(\n",
    "                    attn_neuron_index)+\"_\"+str(pos_index), \n",
    "                    cur_attn_plus_probs_increase[pos_index][head_index][attn_neuron_index].item()))\n",
    "cur_file_attn_neuron_list_sort = sorted(cur_file_attn_neuron_list, key=lambda x: x[-1])[::-1]\n",
    "print(list(zip(range(len(tokens)), tokens)))\n",
    "for x in cur_file_attn_neuron_list_sort[:10]:\n",
    "    layer_i, head_i, neuron_i, _ = x[0].split(\"_\")\n",
    "    layer_i, head_i, neuron_i = int(layer_i), int(head_i), int(neuron_i)\n",
    "    cur_neuron = model.transformer.h[layer_i].attn.c_proj.weight.data.view(HEAD_NUM, HEAD_DIM, -1)[head_i][neuron_i]\n",
    "    cur_neuron_bsvalue = get_bsvalues(cur_neuron, model, final_var)\n",
    "    cur_neuron_bsvalue_sort = torch.argsort(cur_neuron_bsvalue, descending=True)\n",
    "    print(x[0], round(x[1], 4), \"top10: \", [tokenizer.decode(a) for a in cur_neuron_bsvalue_sort[:10]])\n",
    "    print(x[0], round(x[1], 4), \"last10: \", [tokenizer.decode(a) for a in cur_neuron_bsvalue_sort[-10:]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e2bb9740",
   "metadata": {},
   "outputs": [],
   "source": [
    "#find query FFN neurons activating attn neurons\n",
    "curfile_ffn_score_dict = {}\n",
    "for l_h_n_p, increase_score in cur_file_attn_neuron_list_sort[:30]:\n",
    "    attn_layer, attn_head, attn_neuron, attn_pos = l_h_n_p.split(\"_\")\n",
    "    attn_layer, attn_head, attn_neuron, attn_pos = int(attn_layer), int(attn_head), int(attn_neuron), int(attn_pos)\n",
    "    cur_attn_neuron = attn_head * HEAD_DIM + attn_neuron\n",
    "    \n",
    "    attn_neuron_key = model.transformer.h[attn_layer].attn.k_proj.weight.data[:, cur_attn_neuron]\n",
    "    \n",
    "    attn_neuron_key_new = attn_neuron_key * model.transformer.h[attn_layer].ln_1.weight.data\n",
    "    cur_inner_all = torch.sum(torch.tensor(all_pos_layer_input[attn_layer][attn_pos]) * attn_neuron_key_new, -1)\n",
    "    for layer_i in range(attn_layer):\n",
    "        cur_layer_neurons = (torch.tensor(all_pos_coefficient_scores[layer_i][attn_pos])*get_fc2_params(model, layer_i)).T\n",
    "        cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * attn_neuron_key_new, -1)/cur_inner_all\n",
    "        for neuron_i in range(len(cur_layer_neurons_innerproduct)):\n",
    "            if str(layer_i)+\"_\"+str(neuron_i) not in curfile_ffn_score_dict:\n",
    "                curfile_ffn_score_dict[str(layer_i)+\"_\"+str(neuron_i)] = 0.0\n",
    "            curfile_ffn_score_dict[str(layer_i)+\"_\"+str(neuron_i)] += cur_layer_neurons_innerproduct[neuron_i].item() * increase_score\n",
    "cur_file_neurons_ffn_zip = list(zip(curfile_ffn_score_dict.keys(), curfile_ffn_score_dict.values()))\n",
    "cur_file_neurons_ffn_zip_sort = sorted(cur_file_neurons_ffn_zip, key=lambda x: x[-1])[::-1]\n",
    "for x in cur_file_neurons_ffn_zip_sort[:10]:\n",
    "    print(x[0], round(x[1], 4))\n",
    "    layer = int(x[0].split(\"_\")[0])\n",
    "    neuron = int(x[0].split(\"_\")[1])\n",
    "    cur_vector = get_fc2_params(model, layer).T[neuron]\n",
    "    cur_vector_bsvalue = get_bsvalues(cur_vector, model, final_var)\n",
    "    cur_vector_bsvalue_sort = torch.argsort(cur_vector_bsvalue, descending=True)\n",
    "    print(\"top10: \", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[:10]])\n",
    "    print(\"last10: \", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[-10:]])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
