{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d05f4176",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "\n",
    "#os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "model_path = ''\n",
    "\n",
    "# from custom_llama import LlamaTokenizer, LlamaForCausalLM\n",
    "from transformers import LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, GPTJForCausalLM, GPTNeoXForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "28d6ee2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qwen2.5-instruct\n"
     ]
    }
   ],
   "source": [
    "if 'Llama-3-8B' in model_path:\n",
    "    model_name = 'llama3'\n",
    "elif 'Llama-3.1-8B' in model_path:\n",
    "    model_name = 'llama3.1'\n",
    "elif 'Llama-2' in model_path:\n",
    "    model_name = 'llama2'\n",
    "elif 'vicuna-7b-v1.5' in model_path:\n",
    "    model_name = 'vicuna-7b-v.15'\n",
    "elif 'qwen' in model_path:\n",
    "    model_name = 'qwen2.5-instruct'\n",
    "else:\n",
    "    model_name = 'dolly'\n",
    "    \n",
    "print(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2d45eaf8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [01:32<00:00, 23.25s/it]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_path, device_map=\"auto\", trust_remote_code=True, use_fast=False, low_cpu_mem_usage = True, use_cache = False)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage = True, use_cache = False).to('cuda').eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6993da11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hooks to capture intermediate states\n",
    "def mlp_hook(module, input, output):\n",
    "    mlp_outputs.append(output)\n",
    "\n",
    "def attention_hook(module, input, output):\n",
    "    attention_outputs.append(output)\n",
    "\n",
    "def layer_outputs_hook(module, input, output):\n",
    "    layer_outputs_outputs.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "962a1db5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# attach hooks to each layer outputs\n",
    "print(model.config.num_hidden_layers)\n",
    "for i in range(model.config.num_hidden_layers):\n",
    "    model.model.layers[i].mlp.register_forward_hook(mlp_hook)\n",
    "    model.model.layers[i].self_attn.register_forward_hook(attention_hook)\n",
    "    model.model.layers[i].register_forward_hook(layer_outputs_hook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19212b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "data_path = ''\n",
    "attacked_file = os.path.join(data_path, 'rag_12000_attacked_probing.csv')\n",
    "unattacked_file = os.path.join(data_path, 'rag_12000_safe_probing.csv')\n",
    "attacked_questions = pd.read_csv(attacked_file, nrows=50)\n",
    "unattacked_questions = pd.read_csv(unattacked_file, nrows=50)\n",
    "if 'data' in attacked_questions.columns:\n",
    "    attacked_questions = attacked_questions['data'].tolist()[:]\n",
    "if 'data' in unattacked_questions.columns:\n",
    "    unattacked_questions = unattacked_questions['data'].tolist()[:]\n",
    "\n",
    "probing_layer_idx = 2773\n",
    "attacked_mlp = []\n",
    "attacked_attention = []\n",
    "attacked_layer_outputs = []\n",
    "unattacked_mlp = []\n",
    "unattacked_attention = []\n",
    "unattacked_layer_outputs = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i, question in enumerate(attacked_questions):\n",
    "        attention_outputs = []\n",
    "        mlp_outputs = []\n",
    "        layer_outputs_outputs = []\n",
    "        inputs = tokenizer(question, return_tensors=\"pt\")\n",
    "        outputs = model.generate(inputs.input_ids.cuda(), \n",
    "                                 max_new_tokens = 100,\n",
    "                                 min_new_tokens = 100, \n",
    "                                 num_beams=3,\n",
    "                                 pad_token_id=tokenizer.eos_token_id,\n",
    "                                )\n",
    "        text = tokenizer.decode(outputs[0][0], skip_special_tokens = True)\n",
    "        if probing_layer_idx >= 33 or probing_layer_idx <= 0:\n",
    "            for i in range (probing_layer_idx, probing_layer_idx+28):\n",
    "                attacked_mlp.append(mlp_outputs[i-1][0][-1].clone().cpu())\n",
    "                attacked_attention.append(attention_outputs[i-1][0][0][-1].clone().cpu())\n",
    "                attacked_layer_outputs.append(layer_outputs_outputs[i-1][0][0][-1].clone().cpu())\n",
    "        else:\n",
    "            NUM_TOKENS = 100\n",
    "            layer_output = mlp_outputs[probing_layer_idx - 1][0]  \n",
    "            if layer_output.size(0) < NUM_TOKENS:\n",
    "                raise ValueError(f\"Layer {probing_layer_idx} has fewer than {NUM_TOKENS} tokens.\")\n",
    "            sliced_output = layer_output[-NUM_TOKENS:] \n",
    "            average_representation = sliced_output.mean(dim=0) \n",
    "            cloned_average = average_representation.clone().cpu()\n",
    "            attacked_mlp.append(cloned_average)\n",
    "            \n",
    "            layer_output = attention_outputs[probing_layer_idx - 1][0][0]  \n",
    "            if layer_output.size(0) < NUM_TOKENS:\n",
    "                raise ValueError(f\"Layer {probing_layer_idx} has fewer than {NUM_TOKENS} tokens.\")\n",
    "            sliced_output = layer_output[-NUM_TOKENS:]  \n",
    "            average_representation = sliced_output.mean(dim=0) \n",
    "            cloned_average = average_representation.clone().cpu()\n",
    "            attacked_attention.append(cloned_average)\n",
    "            \n",
    "            layer_output = layer_outputs_outputs[probing_layer_idx - 1][0][0]  \n",
    "            if layer_output.size(0) < NUM_TOKENS:\n",
    "                raise ValueError(f\"Layer {probing_layer_idx} has fewer than {NUM_TOKENS} tokens.\")\n",
    "            sliced_output = layer_output[-NUM_TOKENS:]  \n",
    "            average_representation = sliced_output.mean(dim=0) \n",
    "            cloned_average = average_representation.clone().cpu()\n",
    "            attacked_layer_outputs.append(cloned_average)\n",
    "        \n",
    "        del outputs, attention_outputs, mlp_outputs, layer_outputs_outputs,  \n",
    "        gc.collect()\n",
    "        \n",
    "    for i, question in enumerate(unattacked_questions):\n",
    "        attention_outputs = []\n",
    "        mlp_outputs = []\n",
    "        layer_outputs_outputs = []\n",
    "        \n",
    "        inputs = tokenizer(question, return_tensors=\"pt\")\n",
    "        outputs = model.generate(inputs.input_ids.cuda(), \n",
    "                                 max_new_tokens = 100,\n",
    "                                 min_new_tokens = 100, \n",
    "                                 pad_token_id=tokenizer.eos_token_id,\n",
    "                                 )\n",
    "        text = tokenizer.decode(outputs[0][0], skip_special_tokens = True)\n",
    "        \n",
    "        print(len(mlp_outputs))\n",
    "        if probing_layer_idx >= 33 or probing_layer_idx <= 0:\n",
    "            for i in range (probing_layer_idx, probing_layer_idx+28):\n",
    "                unattacked_mlp.append(mlp_outputs[i-1][0][-1].clone().cpu())\n",
    "                unattacked_attention.append(attention_outputs[i-1][0][0][-1].clone().cpu())\n",
    "                unattacked_layer_outputs.append(layer_outputs_outputs[i-1][0][0][-1].clone().cpu())\n",
    "        else:\n",
    "            NUM_TOKENS = 100\n",
    "            layer_output = mlp_outputs[probing_layer_idx - 1][0]  \n",
    "            if layer_output.size(0) < NUM_TOKENS:\n",
    "                raise ValueError(f\"Layer {probing_layer_idx} has fewer than {NUM_TOKENS} tokens.\")\n",
    "            sliced_output = layer_output[-NUM_TOKENS:]  \n",
    "            average_representation = sliced_output.mean(dim=0) \n",
    "            cloned_average = average_representation.clone().cpu()\n",
    "            unattacked_mlp.append(cloned_average)\n",
    "            \n",
    "            layer_output = attention_outputs[probing_layer_idx - 1][0][0]  \n",
    "            if layer_output.size(0) < NUM_TOKENS:\n",
    "                raise ValueError(f\"Layer {probing_layer_idx} has fewer than {NUM_TOKENS} tokens.\")\n",
    "            sliced_output = layer_output[-NUM_TOKENS:]  \n",
    "            average_representation = sliced_output.mean(dim=0) \n",
    "            cloned_average = average_representation.clone().cpu()\n",
    "            unattacked_attention.append(cloned_average)\n",
    "            \n",
    "            layer_output = layer_outputs_outputs[probing_layer_idx - 1][0][0]  \n",
    "            if layer_output.size(0) < NUM_TOKENS:\n",
    "                raise ValueError(f\"Layer {probing_layer_idx} has fewer than {NUM_TOKENS} tokens.\")\n",
    "            sliced_output = layer_output[-NUM_TOKENS:]  \n",
    "            average_representation = sliced_output.mean(dim=0) \n",
    "            cloned_average = average_representation.clone().cpu()\n",
    "            unattacked_layer_outputs.append(cloned_average)\n",
    "        \n",
    "        del outputs, attention_outputs, mlp_outputs, layer_outputs_outputs,  \n",
    "        gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ce04467",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "\n",
    "def compute_vi_accuracy_f1(all_vectors, labels, probing_layer_name, num_labels=2):\n",
    "   \n",
    "    X_train, X_test, y_train, y_test = train_test_split(\n",
    "        all_vectors, labels, test_size=0.4, random_state=42, stratify=labels\n",
    "    )\n",
    "    \n",
    "    X_train_null = np.zeros_like(X_train)\n",
    "    \n",
    "    clf_null = LogisticRegression(max_iter=1000)\n",
    "    clf_null.fit(X_train_null, y_train)\n",
    "    \n",
    "    y_pred_null_proba = clf_null.predict_proba(np.zeros_like(X_test))\n",
    "    \n",
    "    H_yb = -np.mean([\n",
    "        np.log2(y_pred_null_proba[i][y_test[i]] + 0.01)\n",
    "        for i in range(len(y_test))\n",
    "    ])\n",
    "    \n",
    "    clf = LogisticRegression(max_iter=1000)\n",
    "    clf.fit(X_train, y_train)\n",
    "    \n",
    "    y_pred_proba = clf.predict_proba(X_test)\n",
    "    \n",
    "    H_yx = -np.mean([\n",
    "        np.log2(y_pred_proba[i][y_test[i]] + 0.01)\n",
    "        for i in range(len(y_test))\n",
    "    ])\n",
    "    \n",
    "    Vi = H_yb - H_yx\n",
    "    \n",
    "    \n",
    "    y_pred = clf.predict(X_test)\n",
    "    acc = accuracy_score(y_test, y_pred)\n",
    "    f1 = f1_score(y_test, y_pred, average='weighted')\n",
    "        \n",
    "    print(f\"{probing_layer_name} - Accuracy: {acc:.4f}, F1 Score: {f1:.4f}, V-Usable Information (Vi): {Vi:.4f}\")\n",
    "    \n",
    "    return acc, f1, Vi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b356fe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from sklearn.feature_selection import mutual_info_classif\n",
    "\n",
    "\n",
    "# probing for MLP Outputs\n",
    "Vi_mlps = []\n",
    "Vi_attentions = []\n",
    "Vi_layer_outputs = []\n",
    "group_size = 28\n",
    "for i in range(group_size):\n",
    "    attacked_mlp_chunk = attacked_mlp[i::group_size]\n",
    "    unattacked_mlp_chunk = unattacked_mlp[i::group_size]\n",
    "    \n",
    "    attacked_mlp_cpu = [tensor.cpu().numpy() for tensor in attacked_mlp_chunk]\n",
    "    unattacked_mlp_cpu = [tensor.cpu().numpy() for tensor in unattacked_mlp_chunk]\n",
    "    labels_mlp = [0] * len(unattacked_mlp_cpu) + [1] * len(attacked_mlp_cpu)\n",
    "    all_vectors_mlp = np.concatenate((unattacked_mlp_cpu, attacked_mlp_cpu), axis=0)\n",
    "\n",
    "    acc_mlp, f1_mlp, Vi_mlp = compute_vi_accuracy_f1(\n",
    "        all_vectors_mlp, \n",
    "        labels_mlp, \n",
    "        probing_layer_name=\"MLP Outputs\", \n",
    "        num_labels=2\n",
    "    )\n",
    "    Vi_mlps.append(Vi_mlp)\n",
    "\n",
    "    # probing for Attention Outputs\n",
    "    attacked_attention_chunk = attacked_attention[i::group_size]\n",
    "    unattacked_attention_chunk = unattacked_attention[i::group_size]\n",
    "    attacked_attention_cpu = [tensor.cpu().numpy() for tensor in attacked_attention_chunk]\n",
    "    unattacked_attention_cpu = [tensor.cpu().numpy() for tensor in unattacked_attention_chunk]\n",
    "    labels_attention = [0] * len(unattacked_attention_cpu) + [1] * len(attacked_attention_cpu)\n",
    "    all_vectors_attention = np.concatenate((unattacked_attention_cpu, attacked_attention_cpu), axis=0)\n",
    "\n",
    "    acc_attention, f1_attention, Vi_attention = compute_vi_accuracy_f1(\n",
    "        all_vectors_attention, \n",
    "        labels_attention, \n",
    "        probing_layer_name=\"Attention Outputs\", \n",
    "        num_labels=2\n",
    "    )\n",
    "    Vi_attentions.append(Vi_attention)\n",
    "\n",
    "    # probing for Layer Outputs\n",
    "    attacked_layer_outputs_chunk = attacked_layer_outputs[i::group_size]\n",
    "    unattacked_layer_outputs_chunk = unattacked_layer_outputs[i::group_size]\n",
    "    attacked_layer_outputs_cpu = [tensor.cpu().numpy() for tensor in attacked_layer_outputs_chunk]\n",
    "    unattacked_layer_outputs_cpu = [tensor.cpu().numpy() for tensor in unattacked_layer_outputs_chunk]\n",
    "    labels_layer = [0] * len(unattacked_layer_outputs_cpu) + [1] * len(attacked_layer_outputs_cpu)\n",
    "    all_vectors_layer = np.concatenate((unattacked_layer_outputs_cpu, attacked_layer_outputs_cpu), axis=0)\n",
    "\n",
    "    acc_layer, f1_layer, Vi_layer = compute_vi_accuracy_f1(\n",
    "        all_vectors_layer, \n",
    "        labels_layer, \n",
    "        probing_layer_name=\"Layer Outputs\", \n",
    "        num_labels=2\n",
    "    )\n",
    "    Vi_layer_outputs.append(Vi_layer)\n",
    "\n",
    "print(\"Vi for attentions:\", Vi_attentions)\n",
    "print(\"Vi for layer outputs:\", Vi_layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "784cdd89",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "layers_qwen = np.arange(0, 28)\n",
    "layers_llama3 = np.arange(0, 32)\n",
    "\n",
    "sns.set(style=\"whitegrid\", context=\"talk\")\n",
    "plt.rcParams['axes.facecolor'] = '#f6f6f6'\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "\n",
    "plt.plot(layers_qwen, Vi_layer_outputs_qwen, \n",
    "         label='Qwen2.5-7B-Instruct', \n",
    "         color='red', \n",
    "         marker='o', \n",
    "         linewidth=2, \n",
    "         markersize=5)\n",
    "\n",
    "plt.plot(layers_llama3, Vi_layer_outputs_llama3, \n",
    "         label='LlaMA3-8B-Instruct', \n",
    "         color='blue', \n",
    "         marker='^', \n",
    "         linewidth=2, \n",
    "         markersize=5)\n",
    "\n",
    "plt.plot(layers_llama3, Vi_layer_outputs_vicuna, \n",
    "         label='Vicuna-7B-v1.5', \n",
    "         color='green', \n",
    "         marker='s', \n",
    "         linewidth=2, \n",
    "         markersize=5)\n",
    "\n",
    "plt.plot(layers_qwen, llama3_to_qwen, \n",
    "         label='LlaMA3-transfer-to-Qwen', \n",
    "         color='magenta', \n",
    "         marker='D', \n",
    "         linewidth=2, \n",
    "         markersize=5)\n",
    "\n",
    "plt.xlabel(\"Layer Number\", fontsize=22)\n",
    "plt.ylabel(\"Vi\", fontsize=22)\n",
    "plt.tick_params(axis='both', which='major', labelsize=14)\n",
    "plt.legend(fontsize=18, frameon=False)\n",
    "plt.grid(True, linestyle='--', linewidth=1.0, alpha=1.0)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"vi_comparison_by_layer.png\", bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9318fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "sns.set(style=\"whitegrid\", context=\"talk\")\n",
    "\n",
    "all_vectors = np.concatenate((attacked_mlp_cpu, unattacked_mlp_cpu), axis=0)\n",
    "pca = PCA(n_components=50)\n",
    "pca_result = pca.fit_transform(all_vectors)\n",
    "\n",
    "tsne = TSNE(n_components=2, perplexity=28.0)\n",
    "tsne_result = tsne.fit_transform(pca_result)\n",
    "\n",
    "labels = ['attacked'] * len(attacked_mlp_cpu) + ['unattacked'] * len(unattacked_mlp_cpu)\n",
    "df = pd.DataFrame(data=tsne_result, columns=['Component 1', 'Component 2'])\n",
    "df['Category'] = labels\n",
    "plt.rcParams['axes.facecolor'] = '#f6f6f6'\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.scatterplot(\n",
    "    data=df,\n",
    "    x='Component 1',\n",
    "    y='Component 2',\n",
    "    hue='Category',\n",
    "    palette=['red', 'blue'],\n",
    "    s=90, \n",
    "    alpha=0.8, \n",
    "    legend=None\n",
    ")\n",
    "\n",
    "plt.title(f'MLP, Token {int((probing_layer_idx - 1) / 28) + 1}, Layer {(probing_layer_idx-1) % 28}', fontsize=28, weight='medium')\n",
    "plt.xlabel(\"\")\n",
    "plt.ylabel(\"\")\n",
    "plt.tick_params(axis='both', which='major', labelsize=24, pad=0.5)\n",
    "plt.legend(title=None, fontsize=15, title_fontsize=16, frameon=False)\n",
    "plt.grid(True, linestyle='--', linewidth=1.0, alpha=1.0)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('../results/qwen/%s_token_%s_layer_%s_mlp.png' % (model_name, int((probing_layer_idx-1)/28) + 1, (probing_layer_idx-1) % 28),  bbox_inches='tight', pad_inches=0)\n",
    "\n",
    "\n",
    "\n",
    "attacked_attention_cpu = [tensor.cpu() for tensor in attacked_attention]\n",
    "unattacked_attention_cpu = [tensor.cpu() for tensor in unattacked_attention]\n",
    "\n",
    "all_vectors = np.concatenate((attacked_attention_cpu, unattacked_attention_cpu), axis=0)\n",
    "pca = PCA(n_components=50)\n",
    "pca_result = pca.fit_transform(all_vectors)\n",
    "\n",
    "tsne = TSNE(n_components=2, perplexity=28.0)\n",
    "tsne_result = tsne.fit_transform(pca_result)\n",
    "\n",
    "labels = ['attacked'] * len(attacked_attention_cpu) + ['unattacked'] * len(unattacked_attention_cpu)\n",
    "df = pd.DataFrame(data=tsne_result, columns=['Component 1', 'Component 2'])\n",
    "df['Category'] = labels\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.scatterplot(\n",
    "    data=df,\n",
    "    x='Component 1',\n",
    "    y='Component 2',\n",
    "    hue='Category',\n",
    "    palette=['red', 'blue'],\n",
    "    s=90,\n",
    "    alpha=0.8,\n",
    "    legend=None\n",
    ")\n",
    "\n",
    "plt.title(f'Our, Token {int((probing_layer_idx - 1) / 28) + 1}, Layer {(probing_layer_idx-1) % 28}', fontsize=28, weight='medium')\n",
    "plt.xlabel(\"\")\n",
    "plt.ylabel(\"\")\n",
    "plt.tick_params(axis='both', which='major', labelsize=24, pad=0.5)\n",
    "\n",
    "plt.legend(title=None, fontsize=15, title_fontsize=16, frameon=False)\n",
    "plt.grid(True, linestyle='--', linewidth=1.0, alpha=1.0)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('../results/qwen/%s_token_%s_layer_%s_attention.png' % (model_name, int((probing_layer_idx-1)/28) + 1, (probing_layer_idx-1) % 28), bbox_inches='tight', pad_inches=0)\n",
    "\n",
    "attacked_layer_outputs_cpu = [tensor.cpu() for tensor in attacked_layer_outputs]\n",
    "unattacked_layer_outputs_cpu = [tensor.cpu() for tensor in unattacked_layer_outputs]\n",
    "\n",
    "all_vectors = np.concatenate((attacked_layer_outputs_cpu, unattacked_layer_outputs_cpu), axis=0)\n",
    "pca = PCA(n_components=50)\n",
    "pca_result = pca.fit_transform(all_vectors)\n",
    "\n",
    "tsne = TSNE(n_components=2, perplexity=28.0)\n",
    "tsne_result = tsne.fit_transform(pca_result)\n",
    "\n",
    "labels = ['attacked'] * len(attacked_layer_outputs_cpu) + ['unattacked'] * len(unattacked_layer_outputs_cpu)\n",
    "df = pd.DataFrame(data=tsne_result, columns=['Component 1', 'Component 2'])\n",
    "df['Category'] = labels\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.scatterplot(\n",
    "    data=df,\n",
    "    x='Component 1',\n",
    "    y='Component 2',\n",
    "    hue='Category',\n",
    "    palette=['red', 'blue'],\n",
    "    s=90,\n",
    "    alpha=0.8,\n",
    "    legend=None\n",
    ")\n",
    "\n",
    "plt.title(f'Layer, Token {int((probing_layer_idx - 1) / 28) + 1}, Layer {(probing_layer_idx-1) % 28}', fontsize=28, weight='medium')\n",
    "plt.xlabel(\"\")\n",
    "plt.ylabel(\"\")\n",
    "plt.tick_params(axis='both', which='major', labelsize=24, pad=0.5)\n",
    "plt.legend(title=None, fontsize=15, title_fontsize=16, frameon=False)\n",
    "plt.grid(True, linestyle='--', linewidth=1.0, alpha=1.0)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('../results/qwen/%s_token_%s_layer_%s_layer_outputs.png' % (model_name, int((probing_layer_idx-1)/28) + 1, (probing_layer_idx-1) % 28),  bbox_inches='tight', pad_inches=0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
