{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, os\n",
    "import torch\n",
    "import bertviz, uuid\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pyplot import MultipleLocator\n",
    "from collections import Counter\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from IPython.core.display import display, HTML, Javascript\n",
    "from bertviz.util import format_special_chars, format_attention, num_layers, num_heads\n",
    "LAYER_NUM = 36  # Qwen3 has 36 layers\n",
    "HEAD_NUM = 32\n",
    "HEAD_DIM = 128\n",
    "HIDDEN_DIM = HEAD_NUM * HEAD_DIM\n",
    "torch.set_default_device(\"cuda:1\")  # Use GPU 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transfer_output(model_output):\n",
    "    all_pos_layer_input = []\n",
    "    all_pos_attn_output = []\n",
    "    all_pos_residual_output = []\n",
    "    all_pos_ffn_output = []\n",
    "    all_pos_layer_output = []\n",
    "    all_last_attn_subvalues = []\n",
    "    all_pos_coefficient_scores = []\n",
    "    all_attn_scores = []\n",
    "    \n",
    "    for layer_i in range(len(model_output)):\n",
    "        layer_data = model_output[layer_i]\n",
    "        \n",
    "        if len(layer_data) >= 8:\n",
    "            cur_layer_input = layer_data[0]\n",
    "            cur_attn_output = layer_data[1]\n",
    "            cur_residual_output = layer_data[2]\n",
    "            cur_ffn_output = layer_data[3]\n",
    "            cur_layer_output = layer_data[4]\n",
    "            cur_last_attn_subvalues = layer_data[5]\n",
    "            cur_coefficient_scores = layer_data[6]\n",
    "            cur_attn_weights = layer_data[7]\n",
    "            \n",
    "            all_pos_layer_input.append(cur_layer_input[0].tolist())\n",
    "            all_pos_attn_output.append(cur_attn_output[0].tolist())\n",
    "            all_pos_residual_output.append(cur_residual_output[0].tolist())\n",
    "            all_pos_ffn_output.append(cur_ffn_output[0].tolist())\n",
    "            all_pos_layer_output.append(cur_layer_output[0].tolist())\n",
    "            all_last_attn_subvalues.append(cur_last_attn_subvalues[0].tolist())\n",
    "            all_pos_coefficient_scores.append(cur_coefficient_scores[0].tolist())\n",
    "            all_attn_scores.append(cur_attn_weights)\n",
    "        else:\n",
    "            break\n",
    "            \n",
    "    return all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, \\\n",
    "           all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores\n",
    "\n",
    "def get_fc2_params(model, layer_num):\n",
    "    return model.model.layers[layer_num].mlp.down_proj.weight.data\n",
    "def get_bsvalues(vector, model, final_var):\n",
    "    vector = vector * torch.rsqrt(final_var + 1e-6)\n",
    "    vector_rmsn = vector * model.model.norm.weight.data\n",
    "    vector_bsvalues = model.lm_head(vector_rmsn).data\n",
    "    return vector_bsvalues\n",
    "    return vector_bsvalues\n",
    "def get_prob(vector):\n",
    "    prob = torch.nn.Softmax(-1)(vector)\n",
    "    return prob\n",
    "def transfer_l(l):\n",
    "    new_x, new_y = [], []\n",
    "    for x in l:\n",
    "        new_x.append(x[0])\n",
    "        new_y.append(x[1])\n",
    "    return new_x, new_y\n",
    "def plt_bar(x, y, yname=\"log increase\"):\n",
    "    x_major_locator=MultipleLocator(1)\n",
    "    plt.figure(figsize=(8, 3))\n",
    "    ax=plt.gca()\n",
    "    ax.xaxis.set_major_locator(x_major_locator)\n",
    "    plt_x = [a/2 for a in x]\n",
    "    plt.xlim(-0.5, plt_x[-1]+0.49)\n",
    "    x_attn, y_attn, x_ffn, y_ffn = [], [], [], []\n",
    "    for i in range(len(x)):\n",
    "        if i%2 == 0:\n",
    "            x_attn.append(x[i]/2)\n",
    "            y_attn.append(y[i])\n",
    "        else:\n",
    "            x_ffn.append(x[i]/2)\n",
    "            y_ffn.append(y[i])\n",
    "    plt.bar(x_attn, y_attn, color=\"darksalmon\", label=\"attention layers\")\n",
    "    plt.bar(x_ffn, y_ffn, color=\"lightseagreen\", label=\"FFN layers\")\n",
    "    plt.xlabel(\"layer\")\n",
    "    plt.ylabel(yname)\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "def plt_heatmap(data):\n",
    "    xLabel = range(len(data[0]))\n",
    "    yLabel = range(len(data))\n",
    "    fig = plt.figure(figsize=(10,8))\n",
    "    ax = fig.add_subplot(111)\n",
    "    ax.set_xticks(range(len(xLabel)))\n",
    "    ax.set_yticklabels(yLabel)\n",
    "    im = ax.imshow(data, cmap=plt.cm.hot_r)\n",
    "    #plt.colorbar(im)\n",
    "    plt.title(\"attn head log increase heatmap\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelname = \"Qwen\" \n",
    "tokenizer = AutoTokenizer.from_pretrained(modelname)\n",
    "model = AutoModelForCausalLM.from_pretrained(modelname, attn_implementation=\"eager\")\n",
    "model.eval()\n",
    "model.to(\"cuda:1\")  # Use GPU 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sentence = \"Tim Duncan plays the sport of\"\n",
    "indexed_tokens = tokenizer.encode(test_sentence)\n",
    "tokens = [tokenizer.decode(x) for x in indexed_tokens]\n",
    "tokens_tensor = torch.tensor([indexed_tokens])\n",
    "with torch.no_grad():\n",
    "    outputs = model(tokens_tensor)\n",
    "    predictions = outputs[0]\n",
    "predicted_top10 = torch.argsort(predictions[0][-1], descending=True)[:10]\n",
    "predicted_text = [tokenizer.decode(x) for x in predicted_top10]\n",
    "print(test_sentence, \"=>\", predicted_text)\n",
    "# print(transfer_output(outputs[1]))\n",
    "# print(len(outputs[1]))\n",
    "# print(outputs[2])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sentence = \"Tim Duncan plays the sport of\"\n",
    "indexed_tokens = tokenizer.encode(test_sentence)\n",
    "tokens = [tokenizer.decode(x) for x in indexed_tokens]\n",
    "tokens_tensor = torch.tensor([indexed_tokens]).to(\"cuda:1\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    base_outputs = model.model(tokens_tensor, use_cache=True, output_attentions=True)\n",
    "    \n",
    "if base_outputs.past_key_values is not None:\n",
    "    if len(base_outputs.past_key_values) > 0:\n",
    "        print(type(base_outputs.past_key_values[0]))\n",
    "        print(len(base_outputs.past_key_values[0]))\n",
    "        \n",
    "        first_state = base_outputs.past_key_values[0]\n",
    "        try:\n",
    "            all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(base_outputs.past_key_values)\n",
    "        except Exception as e:\n",
    "            import traceback\n",
    "            traceback.print_exc()\n",
    "else:\n",
    "    print(\"Base past_key_values empty\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sentence = \"Tim Duncan plays the sport of\"\n",
    "indexed_tokens = tokenizer.encode(test_sentence)\n",
    "tokens = [tokenizer.decode(x) for x in indexed_tokens]\n",
    "tokens_tensor = torch.tensor([indexed_tokens]).to(\"cuda:1\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(tokens_tensor, use_cache=True, output_attentions=True)\n",
    "    predictions = outputs[0]\n",
    "predicted_top10 = torch.argsort(predictions[0][-1], descending=True)[:10]\n",
    "predicted_text = [tokenizer.decode(x) for x in predicted_top10]\n",
    "print(test_sentence, \"=>\", predicted_text)\n",
    "\n",
    "with torch.no_grad():\n",
    "    base_outputs = model.model(tokens_tensor, use_cache=True, output_attentions=True)\n",
    "\n",
    "all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(base_outputs.past_key_values)\n",
    "final_var = torch.tensor(all_pos_layer_output[-1][-1]).pow(2).mean(-1, keepdim=True)\n",
    "pos_len = len(tokens)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_index = predicted_top10[0].item()\n",
    "print(predict_index, tokenizer.decode(predict_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#layer-level increase (value layers)\n",
    "all_attn_log_increase = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    cur_attn_vector = torch.tensor(all_pos_attn_output[layer_i][-1])\n",
    "    cur_layer_input = torch.tensor(all_pos_layer_input[layer_i][-1])\n",
    "    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_layer_input, model, final_var))[predict_index])\n",
    "    cur_attn_vector_plus = cur_attn_vector + cur_layer_input\n",
    "    cur_attn_vector_bsvalues = get_bsvalues(cur_attn_vector_plus, model, final_var)\n",
    "    cur_attn_vector_probs = get_prob(cur_attn_vector_bsvalues)\n",
    "    cur_attn_vector_probs = cur_attn_vector_probs[predict_index]\n",
    "    cur_attn_vector_probs_log = torch.log(cur_attn_vector_probs)\n",
    "    cur_attn_vector_probs_log_increase = cur_attn_vector_probs_log - origin_prob_log\n",
    "    all_attn_log_increase.append(cur_attn_vector_probs_log_increase.item())\n",
    "all_ffn_log_increase = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    cur_ffn_vector = torch.tensor(all_pos_ffn_output[layer_i][-1])\n",
    "    cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])\n",
    "    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])\n",
    "    cur_ffn_vector_plus = cur_ffn_vector + cur_residual\n",
    "    cur_ffn_vector_bsvalues = get_bsvalues(cur_ffn_vector_plus, model, final_var)\n",
    "    cur_ffn_vector_probs = get_prob(cur_ffn_vector_bsvalues)\n",
    "    cur_ffn_vector_probs = cur_ffn_vector_probs[predict_index]\n",
    "    cur_ffn_vector_probs_log = torch.log(cur_ffn_vector_probs)\n",
    "    cur_ffn_vector_probs_log_increase = cur_ffn_vector_probs_log - origin_prob_log\n",
    "    all_ffn_log_increase.append(cur_ffn_vector_probs_log_increase.tolist())\n",
    "\n",
    "# 打印FFN layers的分数\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    print(f\"Layer {layer_i}: {all_ffn_log_increase[layer_i]:.6f}\")\n",
    "\n",
    "attn_list, ffn_list = [], []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    attn_list.append([str(layer_i), all_attn_log_increase[layer_i]])\n",
    "    ffn_list.append([str(layer_i), all_ffn_log_increase[layer_i]])\n",
    "attn_list_sort = sorted(attn_list, key=lambda x: x[-1])[::-1]#[:10]\n",
    "ffn_list_sort = sorted(ffn_list, key=lambda x: x[-1])[::-1]#[:10]\n",
    "attn_increase_compute, ffn_increase_compute = [], []\n",
    "for indx, increase in attn_list_sort:\n",
    "    attn_increase_compute.append((indx, round(increase, 3)))\n",
    "for indx, increase in ffn_list_sort:\n",
    "    ffn_increase_compute.append((indx, round(increase, 3)))\n",
    "print(\"attn sum: \", sum([x[1] for x in attn_increase_compute]), \n",
    "      \"ffn sum: \", sum([x[1] for x in ffn_increase_compute]))\n",
    "print(\"attn: \", attn_increase_compute)\n",
    "print(\"ffn: \", ffn_increase_compute)\n",
    "all_increases_draw = []\n",
    "for i in range(len(attn_list)):\n",
    "    all_increases_draw.append(attn_list[i][1])\n",
    "    all_increases_draw.append(ffn_list[i][1])    \n",
    "plt_bar(range(len(all_increases_draw)), all_increases_draw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#head-level increase (value heads)\n",
    "all_head_increase = []\n",
    "for test_layer in range(LAYER_NUM):\n",
    "    cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])\n",
    "    cur_v_heads = torch.tensor(all_last_attn_subvalues[test_layer])\n",
    "    cur_attn_o_split = model.model.layers[test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)\n",
    "    cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)\n",
    "    cur_attn_subvalues_head_sum = torch.sum(cur_attn_subvalues_headrecompute, 0)\n",
    "    cur_layer_input_last = cur_layer_input[-1]\n",
    "    origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])\n",
    "    cur_attn_subvalues_head_plus = cur_attn_subvalues_head_sum + cur_layer_input_last\n",
    "    cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(\n",
    "            cur_attn_subvalues_head_plus, model, final_var))[:, predict_index])\n",
    "    cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob\n",
    "    for i in range(len(cur_attn_plus_probs_increase)):\n",
    "        all_head_increase.append([str(test_layer)+\"_\"+str(i), round(cur_attn_plus_probs_increase[i].item(), 4)])\n",
    "\n",
    "all_head_increase_sort = sorted(all_head_increase, key=lambda x:x[-1])[::-1]\n",
    "print(all_head_increase_sort[:30])\n",
    "all_head_increase_list = [x[1] for x in all_head_increase]\n",
    "all_head_increase_list_split = torch.tensor(all_head_increase_list).view((LAYER_NUM, HEAD_NUM)).permute((1,0)).tolist()\n",
    "plt_heatmap(all_head_increase_list_split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pos-level increase in a specified head\n",
    "test_layer, test_head = 15, 15\n",
    "cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])\n",
    "cur_v_heads = torch.tensor(all_last_attn_subvalues[test_layer])\n",
    "cur_attn_o_split = model.model.layers[test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)\n",
    "cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)\n",
    "cur_attn_subvalues_headrecompute_curhead = cur_attn_subvalues_headrecompute[:, test_head, :]\n",
    "cur_layer_input_last = cur_layer_input[-1]\n",
    "origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])\n",
    "cur_attn_subvalues_headrecompute_curhead_plus = cur_attn_subvalues_headrecompute_curhead + cur_layer_input_last\n",
    "cur_attn_subvalues_headrecompute_curhead_plus_probs = torch.log(get_prob(get_bsvalues(\n",
    "    cur_attn_subvalues_headrecompute_curhead_plus, model, final_var))[:, predict_index])\n",
    "cur_attn_subvalues_headrecompute_increase = cur_attn_subvalues_headrecompute_curhead_plus_probs - origin_prob\n",
    "cur_attn_subvalues_headrecompute_increase_zip = list(zip(range(len(cur_attn_subvalues_headrecompute_increase)), \n",
    "    tokens, cur_attn_subvalues_headrecompute_increase.tolist()))\n",
    "cur_attn_subvalues_headrecompute_increase_zip_sort = sorted(cur_attn_subvalues_headrecompute_increase_zip,\n",
    "    key=lambda x:x[-1])[::-1]\n",
    "cur_layer_input_bsvalues = get_bsvalues(cur_layer_input, model, final_var)\n",
    "cur_layer_input_bsvalues_sort = torch.argsort(cur_layer_input_bsvalues, descending=True)\n",
    "cur_attn_subvalues_headrecompute_curhead_bsvalues = get_bsvalues(\n",
    "    cur_attn_subvalues_headrecompute_curhead, model, final_var)\n",
    "cur_attn_subvalues_headrecompute_curhead_bsvalues_sort = torch.argsort(\n",
    "    cur_attn_subvalues_headrecompute_curhead_bsvalues, descending=True)\n",
    "key_input = cur_layer_input.clone()\n",
    "key_input -= torch.tensor(all_pos_layer_input[0])\n",
    "for layer_i in range(test_layer):\n",
    "    key_input -= torch.tensor(all_pos_ffn_output[layer_i])\n",
    "key_input_bsvalues = get_bsvalues(key_input, model, final_var)\n",
    "key_input_bsvalues_sort = torch.argsort(key_input_bsvalues, descending=True)\n",
    "print(list(zip(range(len(tokens)), tokens)))\n",
    "for pos, word, increase in cur_attn_subvalues_headrecompute_increase_zip_sort:\n",
    "    print(\"\\n\", pos, word, \"increase: \", round(increase, 4), \"attn: \", round(\n",
    "        all_attn_scores[test_layer][0][test_head][-1][pos].item(), 4))\n",
    "    print(\"layer input: \", [tokenizer.decode(x) for x in cur_layer_input_bsvalues_sort[pos][:20]])\n",
    "    print(\"key: \", [tokenizer.decode(x) for x in key_input_bsvalues_sort[pos][:20]])\n",
    "    print(\"value: \", [tokenizer.decode(x) for x in cur_attn_subvalues_headrecompute_curhead_bsvalues_sort[pos][:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#FFN neuron increase (value FFN neuron)\n",
    "all_ffn_subvalues = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    coefficient_scores = torch.tensor(all_pos_coefficient_scores[layer_i][-1])\n",
    "    fc2_vectors = get_fc2_params(model, layer_i)\n",
    "    ffn_subvalues = (coefficient_scores * fc2_vectors).T\n",
    "    all_ffn_subvalues.append(ffn_subvalues)\n",
    "ffn_subvalue_list = []\n",
    "for layer_i in range(LAYER_NUM):\n",
    "    cur_ffn_subvalues = all_ffn_subvalues[layer_i]\n",
    "    cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])\n",
    "    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])\n",
    "    cur_ffn_subvalues_plus = cur_ffn_subvalues + cur_residual\n",
    "    cur_ffn_subvalues_bsvalues = get_bsvalues(cur_ffn_subvalues_plus, model, final_var)\n",
    "    cur_ffn_subvalues_probs = get_prob(cur_ffn_subvalues_bsvalues)\n",
    "    cur_ffn_subvalues_probs = cur_ffn_subvalues_probs[:, predict_index]\n",
    "    cur_ffn_subvalues_probs_log = torch.log(cur_ffn_subvalues_probs)\n",
    "    cur_ffn_subvalues_probs_log_increase = cur_ffn_subvalues_probs_log - origin_prob_log\n",
    "    for index, ffn_increase in enumerate(cur_ffn_subvalues_probs_log_increase):\n",
    "        ffn_subvalue_list.append([str(layer_i)+\"_\"+str(index), ffn_increase.item()])\n",
    "ffn_subvalue_list_sort = sorted(ffn_subvalue_list, key=lambda x: x[-1])[::-1]\n",
    "for x in ffn_subvalue_list_sort[:10]:\n",
    "    print(x[0], round(x[1], 4))\n",
    "    layer = int(x[0].split(\"_\")[0])\n",
    "    neuron = int(x[0].split(\"_\")[1])\n",
    "    cur_vector = get_fc2_params(model, layer).T[neuron]\n",
    "    cur_vector_bsvalue = get_bsvalues(cur_vector, model, final_var)\n",
    "    cur_vector_bsvalue_sort = torch.argsort(cur_vector_bsvalue, descending=True)\n",
    "    print(\"top10: \", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[:10]])\n",
    "    print(\"last10: \", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[-10:].tolist()[::-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#visualize the number of value FFN neurons in different layers\n",
    "FFN_value_neurons = [x[0] for x in ffn_subvalue_list_sort[:300]]\n",
    "FFN_layer_count_value = [int(x.split(\"_\")[0]) for x in list(FFN_value_neurons)]\n",
    "FFN_layer_count_value = Counter(FFN_layer_count_value)\n",
    "FFN_layer_count_value = sorted(zip(FFN_layer_count_value.keys(), FFN_layer_count_value.values()))\n",
    "\n",
    "\n",
    "gpt_FFN_value_x, gpt_FFN_value_y = transfer_l(FFN_layer_count_value)\n",
    "\n",
    "plt.figure(figsize=(6,3))\n",
    "plt.xticks(fontsize=10)\n",
    "plt.yticks(fontsize=10)\n",
    "plt.plot(gpt_FFN_value_x, gpt_FFN_value_y, \"bo-\", label=\"qwen3 FFN value neurons\")\n",
    "plt.xlabel(\"layer\", fontsize=10)\n",
    "plt.ylabel(\"count\", fontsize=10)\n",
    "plt.legend(fontsize=10, loc=\"upper right\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#find query layers activating FFN neurons\n",
    "all_residual_scores = [0.0]*(1+2*LAYER_NUM)\n",
    "for l_n, increase_score in ffn_subvalue_list_sort[:30]:\n",
    "    ffn_layer, ffn_neuron = l_n.split(\"_\")\n",
    "    ffn_layer, ffn_neuron = int(ffn_layer), int(ffn_neuron)\n",
    "    ffn_neuron_key = model.model.layers[ffn_layer].mlp.down_proj.weight.data[:, ffn_neuron]\n",
    "    ffn_neuron_key_new = ffn_neuron_key * model.model.layers[ffn_layer].post_attention_layernorm.weight.data\n",
    "    last_layer_residualstream = [torch.tensor(all_pos_layer_input[0][-1]).unsqueeze(0)]\n",
    "    for layer_i in range(ffn_layer):\n",
    "        last_layer_residualstream.append(torch.tensor(all_pos_attn_output[layer_i][-1]).unsqueeze(0))\n",
    "        last_layer_residualstream.append(torch.tensor(all_pos_ffn_output[layer_i][-1]).unsqueeze(0))\n",
    "    last_layer_residualstream.append(torch.tensor(all_pos_attn_output[ffn_layer][-1]).unsqueeze(0))\n",
    "    last_layer_residualstream_cat = torch.cat(last_layer_residualstream, 0)\n",
    "    last_layer_residualstream_innerproduct = torch.sum(last_layer_residualstream_cat*ffn_neuron_key_new, -1)\n",
    "    last_layer_residualstream_innerproduct_zip = list(zip(range(len(last_layer_residualstream_innerproduct)), last_layer_residualstream_innerproduct.tolist()))\n",
    "    sum_inner_product = sum([x[1] for x in last_layer_residualstream_innerproduct_zip])\n",
    "    for l, inner in last_layer_residualstream_innerproduct_zip:\n",
    "        all_residual_scores[l] += inner/sum_inner_product * increase_score\n",
    "all_residual_scores_zip = list(zip(range(len(all_residual_scores)), all_residual_scores))\n",
    "all_residual_scores_zip_sort = sorted(all_residual_scores_zip, key=lambda x: x[-1])[::-1]\n",
    "print([(a[0]/2-0.5, round(a[1],4)) for a in all_residual_scores_zip_sort])\n",
    "plt_bar(range(len(all_residual_scores[1:])), all_residual_scores[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#find query layers activating attn neurons\n",
    "all_residual_scores = [0.0]*(1+2*LAYER_NUM)\n",
    "avg_attn_layer_curdir = []\n",
    "for l_h_n_p, increase_score in cur_file_attn_neuron_list_sort[:30]:\n",
    "    attn_layer, attn_head, attn_neuron, attn_pos = l_h_n_p.split(\"_\")\n",
    "    attn_layer, attn_head, attn_neuron, attn_pos = int(attn_layer), int(attn_head), int(attn_neuron), int(attn_pos)\n",
    "    avg_attn_layer_curdir.append(attn_layer)\n",
    "    cur_attn_neuron = attn_head*HEAD_DIM+attn_neuron\n",
    "    attn_neuron_key = model.model.layers[attn_layer].self_attn.v_proj.weight.data[cur_attn_neuron]\n",
    "    attn_neuron_key_new = attn_neuron_key * model.model.layers[attn_layer].input_layernorm.weight.data\n",
    "    pos_layer_residualstream = [torch.tensor(all_pos_layer_input[0][attn_pos]).unsqueeze(0)]\n",
    "    for layer_i in range(attn_layer):\n",
    "        pos_layer_residualstream.append(torch.tensor(all_pos_attn_output[layer_i][attn_pos]).unsqueeze(0))\n",
    "        pos_layer_residualstream.append(torch.tensor(all_pos_ffn_output[layer_i][attn_pos]).unsqueeze(0))\n",
    "    pos_layer_residualstream_cat = torch.cat(pos_layer_residualstream, 0)\n",
    "    pos_layer_residualstream_innerproduct = torch.sum(pos_layer_residualstream_cat*attn_neuron_key_new, -1)\n",
    "    pos_layer_residualstream_innerproduct_zip = list(zip(range(len(pos_layer_residualstream_innerproduct)), pos_layer_residualstream_innerproduct.tolist()))\n",
    "    sum_inner_product = sum([x[1] for x in pos_layer_residualstream_innerproduct_zip])\n",
    "    for l, inner in pos_layer_residualstream_innerproduct_zip:\n",
    "        all_residual_scores[l] += inner/sum_inner_product * increase_score\n",
    "all_residual_scores_zip = list(zip(range(len(all_residual_scores)), all_residual_scores))\n",
    "all_residual_scores_zip_sort = sorted(all_residual_scores_zip, key=lambda x: x[-1])[::-1]\n",
    "print(\"avg attn layer: \", sum(avg_attn_layer_curdir)/len(avg_attn_layer_curdir))\n",
    "print([(a[0]/2-0.5, a[1]) for a in all_residual_scores_zip_sort[:10]])\n",
    "plt_bar(range(len(all_residual_scores[1:])), all_residual_scores[1:])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
