{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a6e67e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from matplotlib.ticker import FuncFormatter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "46e45582",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_std(data): \n",
    "    mean = np.mean(data, axis=0)\n",
    "    std = np.std(data, axis=0)\n",
    "\n",
    "    return mean, std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "821a9003",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_list(file_path):\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        results = [json.loads(line) for line in f]\n",
    "    \n",
    "    group_base = [\"prompt_1\", \"prompt_2\", \"prompt_3\"]\n",
    "    group_front = [\"prompt_4\", \"prompt_5\", \"prompt_6\"]\n",
    "    group_back = [\"prompt_7\", \"prompt_8\", \"prompt_9\"]\n",
    "    group_less = [\"prompt_10\", \"prompt_11\", \"prompt_12\"]\n",
    "    group_more = [\"prompt_13\", \"prompt_14\", \"prompt_15\"]\n",
    "\n",
    "    all_prompt_1 = np.stack([r[\"prompt_key0\"] for r in results], axis=0)  \n",
    "    all_prompt_2 = np.stack([r[\"prompt_key1\"] for r in results], axis=0)  \n",
    "\n",
    "    all_grads = np.stack([r[\"grads\"] for r in results], axis=0)          \n",
    "    all_delta_hiddens = np.stack([r[\"delta_hiddens\"] for r in results], axis=0)\n",
    "\n",
    "    front_token_grads = []\n",
    "    back_token_grads = []\n",
    "    less_token_grads = []\n",
    "    more_token_grads = []\n",
    "    \n",
    "    front_token_delta_hiddens = []\n",
    "    back_token_delta_hiddens = []\n",
    "    less_token_delta_hiddens = []\n",
    "    more_token_delta_hiddens = []\n",
    "\n",
    "    for p1, p2, grads, delta_hiddens in zip(all_prompt_1, all_prompt_2, all_grads, all_delta_hiddens):\n",
    "        if p1 == p2:\n",
    "            continue\n",
    "        if p1 in group_base and p2 in group_front:\n",
    "            front_token_grads.append(grads)\n",
    "            front_token_delta_hiddens.append(delta_hiddens)\n",
    "        elif p1 in group_base and p2 in group_back: \n",
    "            back_token_grads.append(grads)\n",
    "            back_token_delta_hiddens.append(delta_hiddens)\n",
    "        elif p1 in group_base and p2 in group_less:\n",
    "            less_token_grads.append(grads)\n",
    "            less_token_delta_hiddens.append(delta_hiddens)\n",
    "        elif p1 in group_base and p2 in group_more:\n",
    "            more_token_grads.append(grads)\n",
    "            more_token_delta_hiddens.append(delta_hiddens)\n",
    "        \n",
    "        \n",
    "    front_token_grads = np.array(front_token_grads)\n",
    "    back_token_grads = np.array(back_token_grads)\n",
    "    less_token_grads = np.array(less_token_grads)\n",
    "    more_token_grads = np.array(more_token_grads)\n",
    "\n",
    "    front_token_delta_hiddens = np.array(front_token_delta_hiddens)\n",
    "    back_token_delta_hiddens = np.array(back_token_delta_hiddens)\n",
    "    less_token_delta_hiddens = np.array(less_token_delta_hiddens)\n",
    "    more_token_delta_hiddens = np.array(more_token_delta_hiddens)\n",
    "\n",
    "    front_grads_x_delta = front_token_grads * front_token_delta_hiddens\n",
    "    back_grads_x_delta = back_token_grads * back_token_delta_hiddens\n",
    "    less_grads_x_delta = less_token_grads * less_token_delta_hiddens\n",
    "    more_grads_x_delta = more_token_grads * more_token_delta_hiddens\n",
    "\n",
    "    front_grads_mean, front_grads_std = get_mean_std(front_token_grads)\n",
    "    back_grads_mean, back_grads_std = get_mean_std(back_token_grads)\n",
    "    less_grads_mean, less_grads_std = get_mean_std(less_token_grads)\n",
    "    more_grads_mean, more_grads_std = get_mean_std(more_token_grads)\n",
    "\n",
    "    front_delta_mean, front_delta_std = get_mean_std(front_token_delta_hiddens)\n",
    "    back_delta_mean, back_delta_std = get_mean_std(back_token_delta_hiddens)\n",
    "    less_delta_mean, less_delta_std = get_mean_std(less_token_delta_hiddens)\n",
    "    more_delta_mean, more_delta_std = get_mean_std(more_token_delta_hiddens)\n",
    "\n",
    "    front_grads_x_delta_mean, front_grads_x_delta_std = get_mean_std(front_grads_x_delta)\n",
    "    back_grads_x_delta_mean, back_grads_x_delta_std = get_mean_std(back_grads_x_delta)\n",
    "    less_grads_x_delta_mean, less_grads_x_delta_std = get_mean_std(less_grads_x_delta)\n",
    "    more_grads_x_delta_mean, more_grads_x_delta_std = get_mean_std(more_grads_x_delta)\n",
    "    \n",
    "    \n",
    "    front_dict = {\n",
    "        \"grads_mean\": front_grads_mean,\n",
    "        \"grads_std\": front_grads_std,\n",
    "        \"delta_mean\": front_delta_mean,\n",
    "        \"delta_std\": front_delta_std,\n",
    "        \"grads_x_delta_mean\": front_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": front_grads_x_delta_std,\n",
    "    }\n",
    "\n",
    "    back_dict = {\n",
    "        \"grads_mean\": back_grads_mean,\n",
    "        \"grads_std\": back_grads_std,\n",
    "        \"delta_mean\": back_delta_mean,\n",
    "        \"delta_std\": back_delta_std,\n",
    "        \"grads_x_delta_mean\": back_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": back_grads_x_delta_std,\n",
    "    }\n",
    "\n",
    "    less_dict = {\n",
    "        \"grads_mean\": less_grads_mean,\n",
    "        \"grads_std\": less_grads_std,\n",
    "        \"delta_mean\": less_delta_mean,\n",
    "        \"delta_std\": less_delta_std,\n",
    "        \"grads_x_delta_mean\": less_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": less_grads_x_delta_std,\n",
    "    }\n",
    "\n",
    "    more_dict = {\n",
    "        \"grads_mean\": more_grads_mean,\n",
    "        \"grads_std\": more_grads_std,\n",
    "        \"delta_mean\": more_delta_mean,\n",
    "        \"delta_std\": more_delta_std,\n",
    "        \"grads_x_delta_mean\": more_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": more_grads_x_delta_std,\n",
    "    }\n",
    "\n",
    "    return front_dict, back_dict, less_dict, more_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84b15b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a873816",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grads_vs_embedding(dataset, model_name_or_path):\n",
    "     file_path = f\"../../results/data_results/real_dataset_misalignment/{model_name_or_path}/{dataset}_result.jsonl\"\n",
    "     front_dict, back_dict, less_dict, more_dict = get_data_list(file_path)\n",
    "     \n",
    "     front_grads_mean = front_dict[\"grads_mean\"]\n",
    "     front_grads_std = front_dict[\"grads_std\"]\n",
    "     front_delta_mean = front_dict[\"delta_mean\"]\n",
    "     front_delta_std = front_dict[\"delta_std\"]\n",
    "     front_grads_x_delta_mean = front_dict[\"grads_x_delta_mean\"]\n",
    "     front_grads_x_delta_std = front_dict[\"grads_x_delta_std\"]\n",
    "\n",
    "     back_grads_mean = back_dict[\"grads_mean\"]\n",
    "     back_grads_std = back_dict[\"grads_std\"]\n",
    "     back_delta_mean = back_dict[\"delta_mean\"]\n",
    "     back_delta_std = back_dict[\"delta_std\"]\n",
    "     back_grads_x_delta_mean = back_dict[\"grads_x_delta_mean\"]\n",
    "     back_grads_x_delta_std = back_dict[\"grads_x_delta_std\"]\n",
    "\n",
    "     less_grads_mean = less_dict[\"grads_mean\"]\n",
    "     less_grads_std = less_dict[\"grads_std\"]\n",
    "     less_delta_mean = less_dict[\"delta_mean\"]\n",
    "     less_delta_std = less_dict[\"delta_std\"]\n",
    "     less_grads_x_delta_mean = less_dict[\"grads_x_delta_mean\"]\n",
    "     less_grads_x_delta_std = less_dict[\"grads_x_delta_std\"]\n",
    "\n",
    "     more_grads_mean = more_dict[\"grads_mean\"]\n",
    "     more_grads_std = more_dict[\"grads_std\"]\n",
    "     more_delta_mean = more_dict[\"delta_mean\"]\n",
    "     more_delta_std = more_dict[\"delta_std\"]\n",
    "     more_grads_x_delta_mean = more_dict[\"grads_x_delta_mean\"]\n",
    "     more_grads_x_delta_std = more_dict[\"grads_x_delta_std\"]\n",
    "     \n",
    "     layer_name = [str(i) for i in range(len(front_delta_mean))]\n",
    "\n",
    "     show_layer_num = 5\n",
    "     xticks = layer_name[-show_layer_num:]\n",
    "     front_delta_mean = front_delta_mean[-show_layer_num:]\n",
    "     back_delta_mean = back_delta_mean[-show_layer_num:]\n",
    "     less_delta_mean = less_delta_mean[-show_layer_num:]\n",
    "     more_delta_mean = more_delta_mean[-show_layer_num:]\n",
    "\n",
    "    \n",
    "     plt.figure(figsize=(2.5, 2.5))\n",
    "\n",
    "     x = np.arange(len(front_delta_mean))\n",
    "\n",
    "     front_color = \"#0D4C6D\"\n",
    "     latter_color = \"#BF1E2E\"\n",
    "\n",
    "     plt.plot(x, \n",
    "          front_delta_mean, \n",
    "          label=r\"First\",\n",
    "          color=front_color\n",
    "          )\n",
    "    \n",
    "     plt.plot(x, \n",
    "          back_delta_mean, \n",
    "          label=r\"Latter\",\n",
    "          color=latter_color \n",
    "          )\n",
    "\n",
    "     plt.xticks(x, xticks)\n",
    "\n",
    "     plt.xlabel(\"Number of layers\")\n",
    "     plt.title(\"First vs. Latter\")\n",
    "     plt.legend()\n",
    "     plt.grid(True)\n",
    "     plt.tight_layout()\n",
    "\n",
    "     save_path = f\"../../results/figure_results/first_vs_latter/{model_name_or_path}/{dataset}_first_vs_latter.pdf\"\n",
    "     os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "     plt.savefig(save_path)\n",
    "     plt.close()\n",
    "\n",
    "     plt.figure(figsize=(2.5, 2.5))\n",
    "\n",
    "     plt.plot(x, \n",
    "               more_delta_mean, \n",
    "               label=r\"More\", \n",
    "               color=front_color\n",
    "               )\n",
    "     plt.plot(x, \n",
    "               less_delta_mean, \n",
    "               label=r\"Fewer\",\n",
    "               color=latter_color\n",
    "               )\n",
    "     \n",
    "     plt.xticks(x, xticks)\n",
    "     plt.xlabel(\"Number of layers\")\n",
    "     plt.title(\"Fewer vs. More\")\n",
    "     plt.legend()\n",
    "     plt.grid(True)\n",
    "\n",
    "     plt.gca().yaxis.set_major_formatter(\n",
    "          FuncFormatter(lambda val, pos: f\"{int(val)}\" if val < 1000 else \"\")\n",
    "     )\n",
    "\n",
    "     plt.tight_layout()\n",
    "     \n",
    "\n",
    "\n",
    "     save_path = f\"../../results/figure_results/fewer_vs_more/{model_name_or_path}/{dataset}_fewer_vs_more.pdf\"\n",
    "     os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "     plt.savefig(save_path)\n",
    "     plt.close()\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2089cbfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list = [\"ARC_Challenge\", \"CommonSenseQA\", \"MMLU\", \"OpenBookQA\"\n",
    "                 ]\n",
    "model_list = [\n",
    "                \"meta-llama/Llama-3.2-1B\",\n",
    "                \"meta-llama/Llama-3.2-3B\",\n",
    "                \"Qwen/Qwen1.5-4B\",\n",
    "                \"Qwen/Qwen1.5-0.5B\",\n",
    "                \"Qwen/Qwen1.5-1.8B\",\n",
    "                ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2205e583",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in dataset_list:\n",
    "    for model_name_or_path in model_list:\n",
    "        plot_grads_vs_embedding(dataset, model_name_or_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
