{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a6e67e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "fa34cdeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_std(data): \n",
    "    mean = np.mean(data, axis=0)\n",
    "    std = np.std(data, axis=0)\n",
    "    return mean, std\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "821a9003",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_list(file_path):\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        results = [json.loads(line) for line in f]\n",
    "    \n",
    "\n",
    "    all_len0 = np.stack([r[\"len0\"] for r in results], axis=0)  \n",
    "    all_len1 = np.stack([r[\"len1\"] for r in results], axis=0)  \n",
    "\n",
    "    all_grads = np.stack([r[\"grads\"] for r in results], axis=0)        \n",
    "    all_delta_hiddens = np.stack([r[\"delta_hiddens\"] for r in results], axis=0)\n",
    "\n",
    "    pad_grads = []\n",
    "    trim_grads = []\n",
    "    equal_grads = []\n",
    "    pad_delta_hiddens = []\n",
    "    trim_delta_hiddens = []\n",
    "    equal_delta_hiddens = []\n",
    "    for l0, l1, grads, delta_hiddens in zip(all_len0, all_len1, all_grads, all_delta_hiddens):\n",
    "        if l0 > l1 :\n",
    "            pad_grads.append(grads)\n",
    "            pad_delta_hiddens.append(delta_hiddens)\n",
    "        elif l0 < l1: \n",
    "            trim_grads.append(grads)\n",
    "            trim_delta_hiddens.append(delta_hiddens)\n",
    "        else:\n",
    "            equal_grads.append(grads)\n",
    "            equal_delta_hiddens.append(delta_hiddens)\n",
    "    \n",
    "    pad_grads = np.array(pad_grads)\n",
    "    trim_grads = np.array(trim_grads)\n",
    "    equal_grads = np.array(equal_grads)\n",
    "\n",
    "    min_len = min(len(pad_grads), len(trim_grads), len(equal_grads))\n",
    "    pad_grads = pad_grads[:min_len]\n",
    "    trim_grads = trim_grads[:min_len]\n",
    "    equal_grads = equal_grads[:min_len]\n",
    "\n",
    "    pad_delta_hiddens = np.array(pad_delta_hiddens)\n",
    "    trim_delta_hiddens = np.array(trim_delta_hiddens)\n",
    "    equal_delta_hiddens = np.array(equal_delta_hiddens)\n",
    "    \n",
    "    pad_delta_hiddens = pad_delta_hiddens[:min_len]\n",
    "    trim_delta_hiddens = trim_delta_hiddens[:min_len]\n",
    "    equal_delta_hiddens = equal_delta_hiddens[:min_len]\n",
    "    \n",
    "    pad_grads_x_delta = pad_grads * pad_delta_hiddens\n",
    "    trim_grads_x_delta = trim_grads * trim_delta_hiddens\n",
    "    equal_grads_x_delta = equal_grads * equal_delta_hiddens\n",
    "    \n",
    "    \n",
    "\n",
    "    pad_grads_mean, pad_grads_std = get_mean_std(pad_grads)\n",
    "    trim_grads_mean, trim_grads_std = get_mean_std(trim_grads)\n",
    "    equal_grads_mean, equal_grads_std = get_mean_std(equal_grads)\n",
    "\n",
    "    pad_delta_mean, pad_delta_std = get_mean_std(pad_delta_hiddens)\n",
    "    trim_delta_mean, trim_delta_std = get_mean_std(trim_delta_hiddens)\n",
    "    equal_delta_mean, equal_delta_std = get_mean_std(equal_delta_hiddens)\n",
    "\n",
    "    pad_grads_x_delta_mean, pad_grads_x_delta_std = get_mean_std(pad_grads_x_delta)\n",
    "    trim_grads_x_delta_mean, trim_grads_x_delta_std = get_mean_std(trim_grads_x_delta)\n",
    "    equal_grads_x_delta_mean, equal_grads_x_delta_std = get_mean_std(equal_grads_x_delta)\n",
    "\n",
    "    pad_dict = {\n",
    "        \"grads_mean\": pad_grads_mean,\n",
    "        \"grads_std\": pad_grads_std,\n",
    "        \"delta_mean\": pad_delta_mean,\n",
    "        \"delta_std\": pad_delta_std,\n",
    "        \"grads_x_delta_mean\": pad_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": pad_grads_x_delta_std,\n",
    "    }\n",
    "\n",
    "    trim_dict = {\n",
    "        \"grads_mean\": trim_grads_mean,\n",
    "        \"grads_std\": trim_grads_std,\n",
    "        \"delta_mean\": trim_delta_mean,\n",
    "        \"delta_std\": trim_delta_std,\n",
    "        \"grads_x_delta_mean\": trim_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": trim_grads_x_delta_std,\n",
    "    }\n",
    "\n",
    "    equal_dict = {\n",
    "        \"grads_mean\": equal_grads_mean,\n",
    "        \"grads_std\": equal_grads_std,\n",
    "        \"delta_mean\": equal_delta_mean,\n",
    "        \"delta_std\": equal_delta_std,\n",
    "        \"grads_x_delta_mean\": equal_grads_x_delta_mean,\n",
    "        \"grads_x_delta_std\": equal_grads_x_delta_std,\n",
    "    }\n",
    "    \n",
    "    return pad_dict, trim_dict, equal_dict\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a873816",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grads_vs_embedding(dataset, model_name_or_path):\n",
    "     file_path = f\"../../results/data_results/real_dataset/{model_name_or_path}/{dataset}_result.jsonl\"\n",
    "     pad_dict, trim_dict, equal_dict = get_data_list(file_path)\n",
    "     \n",
    "     \n",
    "     pad_grads_mean = pad_dict[\"grads_mean\"]\n",
    "     pad_grads_std = pad_dict[\"grads_std\"]\n",
    "     pad_delta_mean = pad_dict[\"delta_mean\"]\n",
    "     pad_delta_std = pad_dict[\"delta_std\"]\n",
    "\n",
    "     trim_grads_mean = trim_dict[\"grads_mean\"]\n",
    "     trim_grads_std = trim_dict[\"grads_std\"]\n",
    "     trim_delta_mean = trim_dict[\"delta_mean\"]\n",
    "     trim_delta_std = trim_dict[\"delta_std\"]\n",
    "\n",
    "     equal_grads_mean = equal_dict[\"grads_mean\"]\n",
    "     equal_grads_std = equal_dict[\"grads_std\"]\n",
    "     equal_delta_mean = equal_dict[\"delta_mean\"]\n",
    "     equal_delta_std = equal_dict[\"delta_std\"]\n",
    "\n",
    "     layer_name = [str(i) for i in range(len(pad_grads_mean))]\n",
    "\n",
    "     show_layer_num = 5\n",
    "     xticks = layer_name[-show_layer_num:]\n",
    "\n",
    "     pad_grads_mean = pad_grads_mean[-show_layer_num:]\n",
    "     pad_delta_mean = pad_delta_mean[-show_layer_num:]\n",
    "     trim_delta_mean = trim_delta_mean[-show_layer_num:]\n",
    "     equal_delta_mean = equal_delta_mean[-show_layer_num:]\n",
    "\n",
    "     plt.figure(figsize=(2.5, 2.5))\n",
    "\n",
    "     x = np.arange(len(pad_grads_mean))\n",
    "\n",
    "     pad_color = \"#0D4C6D\"\n",
    "     trim_color = \"#FEB705\"\n",
    "     equal_color = \"#BF1E2E\"\n",
    "     plt.plot(x, \n",
    "          pad_delta_mean, \n",
    "          label=r\"Pad\",\n",
    "          color = pad_color\n",
    "          )\n",
    "     \n",
    "     plt.plot(x, \n",
    "          trim_delta_mean, \n",
    "          label=r\"Trim\", \n",
    "          color = trim_color\n",
    "     \n",
    "          )\n",
    "     \n",
    "     plt.plot(x, \n",
    "          equal_delta_mean, \n",
    "          label=r\"Equal\",\n",
    "          color=equal_color\n",
    "          )\n",
    "\n",
    "     plt.xticks(x, xticks)\n",
    "\n",
    "     plt.xlabel(\"Number of layers\")\n",
    "     plt.title(\"Pad vs. Trim vs. Equal\")\n",
    "     plt.legend()\n",
    "     plt.grid(True)\n",
    "     plt.tight_layout()\n",
    "     \n",
    "     save_path = f\"../../results/figure_results/pad_vs_trim_vs_equal/{model_name_or_path}/{dataset}_pad_vs_trim_vs_equal.pdf\"\n",
    "     os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "     plt.savefig(save_path)\n",
    "     plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2089cbfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list = [\"ARC_Challenge\", \"CommonSenseQA\", \"MMLU\", \"OpenBookQA\"\n",
    "                ]\n",
    "model_list = [\n",
    "                \"meta-llama/Llama-3.2-1B\",\n",
    "                \"meta-llama/Llama-3.2-3B\",\n",
    "                \"Qwen/Qwen1.5-4B\",\n",
    "                \"Qwen/Qwen1.5-0.5B\",\n",
    "                \"Qwen/Qwen1.5-1.8B\",\n",
    "                ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "2205e583",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in dataset_list:\n",
    "    for model_name_or_path in model_list:\n",
    "        plot_grads_vs_embedding(dataset, model_name_or_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
