{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3962e067-b25d-4b88-ae4a-c55e124040f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd \n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM\n",
    "import re\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "from huggingface_hub import login\n",
    "# Log in with your access token\n",
    "login(token=\"\") # you need to request access to the models on HuggingFace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "faabb613-964d-4532-b542-29e38cc01b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Compare sign differences ----\n",
    "def compare_signs(model_a, model_b):\n",
    "    stats = {}\n",
    "    total_params = 0\n",
    "    flipped = 0\n",
    "    \n",
    "    for (name_a, param_a), (name_b, param_b) in zip(model_a.state_dict().items(), model_b.state_dict().items()):\n",
    "        if param_a.shape != param_b.shape:\n",
    "            print(f\"Skipping {name_a} due to shape mismatch\")\n",
    "            continue\n",
    "        # Compute sign comparison\n",
    "        sign_a = torch.sign(param_a).cpu()\n",
    "        sign_b = torch.sign(param_b).cpu()\n",
    "        diff = (sign_a != sign_b).sum().item()\n",
    "       # print(len(sign_a.flatten()))\n",
    "        stats[name_a] = {\n",
    "            \"total\": param_a.numel(),\n",
    "            \"flipped\": diff,\n",
    "            \"flipped_pct\": 100 * diff / param_a.numel()\n",
    "        }\n",
    "\n",
    "        total_params += param_a.numel()\n",
    "        flipped += diff\n",
    "\n",
    "    overall = {\n",
    "        \"total_params\": total_params,\n",
    "        \"total_flipped\": flipped,\n",
    "        \"flipped_pct\": 100 * flipped / total_params\n",
    "    }\n",
    "    return stats, overall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "26709d4d-8386-4513-b496-c975669e6110",
   "metadata": {},
   "outputs": [],
   "source": [
    "def safe_norm(tensor, p=1, clip_value=None):\n",
    "    \"\"\"\n",
    "    Compute Lp norm safely: avoids NaN/inf and optionally clips extreme values.\n",
    "    \"\"\"\n",
    "    t = tensor.detach().cpu().float()  # accumulate in float32\n",
    "    t[torch.isnan(t)] = 0.0\n",
    "    t[torch.isinf(t)] = 0.0\n",
    "    if clip_value is not None:\n",
    "        t = torch.clamp(t, -clip_value, clip_value)\n",
    "    return torch.norm(t, p=p).item()\n",
    "\n",
    "def compare_qk_norm_diff_safe(model_pre, model_ft, clip_value=None):\n",
    "    \"\"\"\n",
    "    Computes ||Q||_p - ||K||_p for both L1 and L2 norms, for pretrained and fine-tuned models.\n",
    "    Returns the absolute change (delta) between pre and ft.\n",
    "    \n",
    "    Returns:\n",
    "        stats: dict per base layer with L1 and L2 diffs\n",
    "        overall: summary dict\n",
    "    \"\"\"\n",
    "    stats = {}\n",
    "    total_layers = 0\n",
    "    total_l1_delta = 0.0\n",
    "    total_l2_delta = 0.0\n",
    "\n",
    "    total_pos_l1 = 0.0\n",
    "    total_pos_l2 = 0.0\n",
    "    \n",
    "    state_pre = model_pre.state_dict()\n",
    "    state_ft  = model_ft.state_dict()\n",
    "\n",
    "    # Find Q/K pairs by matching base name before \".q\" or \".k\"\n",
    "    qk_layers = {}\n",
    "    for name in state_pre:\n",
    "        match_q = re.match(r\"(.+)\\.q.*\", name)\n",
    "        if match_q:\n",
    "            base = match_q.group(1)\n",
    "            # Look for a matching K tensor with the same base\n",
    "            k_name_candidates = [n for n in state_pre if n.startswith(base) and \".k\" in n]\n",
    "            if k_name_candidates:\n",
    "                qk_layers[base] = (name, k_name_candidates[0])\n",
    "\n",
    "    # Compute L1 and L2 norm differences safely\n",
    "    for base, (q_name, k_name) in qk_layers.items():\n",
    "        Q_pre = state_pre[q_name]\n",
    "        K_pre = state_pre[k_name]\n",
    "        Q_ft  = state_ft[q_name]\n",
    "        K_ft  = state_ft[k_name]\n",
    "\n",
    "        # L1 differences\n",
    "        diff_pre_l1 = abs(safe_norm(Q_pre, p=1, clip_value=clip_value) - safe_norm(K_pre, p=1, clip_value=clip_value))\n",
    "        diff_ft_l1  = abs(safe_norm(Q_ft, p=1, clip_value=clip_value)  - safe_norm(K_ft, p=1, clip_value=clip_value))\n",
    "        delta_l1    = diff_ft_l1 - diff_pre_l1\n",
    "        pos_l1 = (delta_l1 <= 0)\n",
    "       # print(diff_ft_l1 - diff_pre_l1)\n",
    "        # L2 differences\n",
    "        diff_pre_l2 = abs(safe_norm(Q_pre, p=2, clip_value=clip_value)**2 - safe_norm(K_pre, p=2, clip_value=clip_value)**2)\n",
    "        diff_ft_l2  = abs(safe_norm(Q_ft, p=2, clip_value=clip_value)**2  - safe_norm(K_ft, p=2, clip_value=clip_value)**2)\n",
    "        delta_l2    = diff_ft_l2 - diff_pre_l2\n",
    "        pos_l2 = (delta_l2 <= 0)\n",
    "        \n",
    "        #signdff_pre =  ((torch.sign(Q_pre.detach().cpu().float()) != torch.sign(K_pre.detach().cpu().float())).sum()) / len(Q_pre.flatten())\n",
    "        #print(signdff_pre)\n",
    "        stats[base] = {\n",
    "            \"diff_pre_l1\": diff_pre_l1,\n",
    "            \"diff_ft_l1\": diff_ft_l1,\n",
    "            \"delta_l1\": delta_l1,\n",
    "            \"diff_pre_l2\": diff_pre_l2,\n",
    "            \"diff_ft_l2\": diff_ft_l2,\n",
    "            \"delta_l2\": delta_l2,\n",
    "          #  \"signdff_pre\": signdff_pre,\n",
    "        }\n",
    "\n",
    "        total_layers += 1\n",
    "        total_l1_delta += delta_l1\n",
    "        total_l2_delta += delta_l2\n",
    "\n",
    "        total_pos_l1 += pos_l1\n",
    "        total_pos_l2 += pos_l2\n",
    "\n",
    "    overall = {\n",
    "        \"total_qk_layers\": total_layers,\n",
    "        \"sum_delta_l1\": total_l1_delta,\n",
    "        \"avg_delta_l1\": total_l1_delta / total_layers if total_layers > 0 else 0.0,\n",
    "        \"sum_delta_l2\": total_l2_delta,\n",
    "        \"avg_delta_l2\": total_l2_delta / total_layers if total_layers > 0 else 0.0,\n",
    "        \"percentageposl1\": total_pos_l1 / total_layers if total_layers > 0 else 0.0,\n",
    "        \"percentageposl2\": total_pos_l2 / total_layers if total_layers > 0 else 0.0\n",
    "    }\n",
    "\n",
    "    return stats, overall\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5578df40-54e7-4e6b-a936-45036f382817",
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_qkvo(layer_stats):\n",
    "    groups = defaultdict(dict)\n",
    "    \n",
    "    for name, stats in layer_stats.items():\n",
    "        if \".q\" in name or \"q_proj\" in name:\n",
    "            groups[\"q\"][name] = stats\n",
    "        elif \".k\" in name or \"k_proj\" in name:\n",
    "            groups[\"k\"][name] = stats\n",
    "        elif \".v\" in name or \"v_proj\" in name:\n",
    "            groups[\"v\"][name] = stats\n",
    "        elif \".o\" in name or \"o_proj\" in name:\n",
    "            groups[\"o\"][name] = stats\n",
    "        else:\n",
    "            groups[\"other\"][name] = stats\n",
    "    \n",
    "    return groups\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b494b5bc-5855-4ba6-b5a6-133493f6b8a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_qkvo(layer_stats):\n",
    "    groups = {\"q\": [], \"k\": [], \"v\": [], \"o\": [], \"other\": []}\n",
    "    \n",
    "    # bucket stats into groups\n",
    "    for name, stats in layer_stats.items():\n",
    "        if \".q\" in name or \"q_proj\" in name:\n",
    "            groups[\"q\"].append(stats)\n",
    "        elif \".k\" in name or \"k_proj\" in name:\n",
    "            groups[\"k\"].append(stats)\n",
    "        elif \".v\" in name or \"v_proj\" in name:\n",
    "            groups[\"v\"].append(stats)\n",
    "        elif \".o\" in name or \"o_proj\" in name:\n",
    "            groups[\"o\"].append(stats)\n",
    "        else:\n",
    "            groups[\"other\"].append(stats)\n",
    "    \n",
    "    # aggregate\n",
    "    agg = {}\n",
    "    for g, items in groups.items():\n",
    "        total_sum = sum(x[\"total\"] for x in items)\n",
    "        flipped_sum = sum(x[\"flipped\"] for x in items)\n",
    "        flipped_pct = (flipped_sum / total_sum * 100) if total_sum > 0 else 0.0\n",
    "        agg[g] = {\n",
    "            \"total\": total_sum,\n",
    "            \"flipped\": flipped_sum,\n",
    "            \"flipped_pct\": flipped_pct\n",
    "        }\n",
    "    \n",
    "    return agg\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e89888d1-e3f5-4564-9831-5406d413278b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`torch_dtype` is deprecated! Use `dtype` instead!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b104c8b1cd044a3c82bfc53234554fd0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading fine-tuned model...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8ec1b431be746c3b3646c9bc7d89d09",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# ---- Settings ----\n",
    "# Replace with actual model paths or Hugging Face model IDs\n",
    "pretrained_model_id8b = \"meta-llama/Llama-3.1-8b\"   # pretrained\n",
    "finetuned_model_id8b = \"meta-llama/Llama-3.1-8b-Instruct\"  # finetuned\n",
    "\n",
    "# ---- Load models (only weights, not tokenizer needed here) ----\n",
    "print(\"Loading pretrained model...\")\n",
    "model_pre8b = AutoModelForCausalLM.from_pretrained(\n",
    "    pretrained_model_id8b,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cpu\",         # or \"auto\" for GPU\n",
    "    low_cpu_mem_usage=False \n",
    ")\n",
    "\n",
    "print(\"Loading fine-tuned model...\")\n",
    "model_ft8b = AutoModelForCausalLM.from_pretrained(\n",
    "    finetuned_model_id8b,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cpu\",         # or \"auto\" for GPU\n",
    "    low_cpu_mem_usage=False \n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "015ea174-c0ba-4c0a-8103-5ba6dea79449",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparing signs...\n",
      "\n",
      "=== Overall Sign Flip Stats ===\n",
      "{'total_params': 8030261248, 'total_flipped': 126182338, 'flipped_pct': 1.5713354037071547}\n",
      "\n",
      "=== Sample per-layer stats ===\n",
      "model.embed_tokens.weight {'total': 525336576, 'flipped': 6370738, 'flipped_pct': 1.212696448533597}\n",
      "model.layers.0.self_attn.q_proj.weight {'total': 16777216, 'flipped': 233328, 'flipped_pct': 1.3907432556152344}\n",
      "model.layers.0.self_attn.k_proj.weight {'total': 4194304, 'flipped': 49577, 'flipped_pct': 1.1820077896118164}\n",
      "model.layers.0.self_attn.v_proj.weight {'total': 4194304, 'flipped': 107970, 'flipped_pct': 2.5742053985595703}\n",
      "model.layers.0.self_attn.o_proj.weight {'total': 16777216, 'flipped': 334499, 'flipped_pct': 1.9937694072723389}\n",
      "model.layers.0.mlp.gate_proj.weight {'total': 58720256, 'flipped': 736372, 'flipped_pct': 1.2540340423583984}\n",
      "model.layers.0.mlp.up_proj.weight {'total': 58720256, 'flipped': 794299, 'flipped_pct': 1.3526831354413713}\n",
      "model.layers.0.mlp.down_proj.weight {'total': 58720256, 'flipped': 822219, 'flipped_pct': 1.4002306120736259}\n",
      "model.layers.0.input_layernorm.weight {'total': 4096, 'flipped': 185, 'flipped_pct': 4.5166015625}\n",
      "model.layers.0.post_attention_layernorm.weight {'total': 4096, 'flipped': 9, 'flipped_pct': 0.2197265625}\n"
     ]
    }
   ],
   "source": [
    "print(\"Comparing signs...\")\n",
    "layer_stats8b, overall_stats8b = compare_signs(model_pre8b, model_ft8b)\n",
    "\n",
    "# ---- Display summary ----\n",
    "print(\"\\n=== Overall Sign Flip Stats ===\")\n",
    "print(overall_stats8b)\n",
    "\n",
    "print(\"\\n=== Sample per-layer stats ===\")\n",
    "for layer, stat in list(layer_stats8b.items())[:10]:  # show first 10 layers\n",
    "    print(layer, stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "42ad3a59-6602-4253-a9cd-07255d0a400f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "q keys: ['model.layers.0.self_attn.q_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.4.self_attn.q_proj.weight']\n",
      "k keys: ['model.layers.0.self_attn.k_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.4.self_attn.k_proj.weight']\n",
      "v keys: ['model.layers.0.self_attn.v_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.self_attn.v_proj.weight']\n",
      "o keys: ['model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.o_proj.weight']\n",
      "q {'total': 536870912, 'flipped': 6698672, 'flipped_pct': 1.2477248907089233}\n",
      "k {'total': 134217728, 'flipped': 1167319, 'flipped_pct': 0.8697204291820526}\n",
      "v {'total': 134217728, 'flipped': 2354355, 'flipped_pct': 1.75413116812706}\n",
      "o {'total': 536870912, 'flipped': 9013124, 'flipped_pct': 1.678825169801712}\n",
      "other {'total': 6688083968, 'flipped': 106948868, 'flipped_pct': 1.599095772596616}\n"
     ]
    }
   ],
   "source": [
    "grouped_stats8b = group_qkvo(layer_stats8b)\n",
    "\n",
    "print(\"q keys:\", list(grouped_stats8b[\"q\"].keys())[:5])\n",
    "print(\"k keys:\", list(grouped_stats8b[\"k\"].keys())[:5])\n",
    "print(\"v keys:\", list(grouped_stats8b[\"v\"].keys())[:5])\n",
    "print(\"o keys:\", list(grouped_stats8b[\"o\"].keys())[:5])\n",
    "\n",
    "agg_stats8b = aggregate_qkvo(layer_stats8b)\n",
    "for g, stats in agg_stats8b.items():\n",
    "    print(g, stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "db6ed11c-416d-4290-a018-1ac778e641c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall Q/K L1 difference delta: {'total_qk_layers': 32, 'sum_delta_l1': -19972.0859375, 'avg_delta_l1': -624.127685546875, 'sum_delta_l2': -648.9048654412909, 'avg_delta_l2': -20.27827704504034, 'percentageposl1': 1.0, 'percentageposl2': 1.0}\n"
     ]
    }
   ],
   "source": [
    "stats8b, overall8b = compare_qk_norm_diff_safe(model_pre8b, model_ft8b)\n",
    "print(\"Overall Q/K L1 difference delta:\", overall8b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c2e66317-712d-4ff6-9ca3-be52bda4b42d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained model...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff34241ad47d47c296c202c08a912bc6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading fine-tuned model...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d314e38c1f1464bae57346ebb5725a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pretrained_model_id3b = \"meta-llama/Llama-3.2-3B\"   # pretrained\n",
    "finetuned_model_id3b = \"meta-llama/Llama-3.2-3B-Instruct\"  # finetuned\n",
    "\n",
    "print(\"Loading pretrained model...\")\n",
    "model_pre3b = AutoModelForCausalLM.from_pretrained(\n",
    "    pretrained_model_id3b,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cpu\",         # or \"auto\" for GPU\n",
    "    low_cpu_mem_usage=False \n",
    ")\n",
    "\n",
    "print(\"Loading fine-tuned model...\")\n",
    "model_ft3b = AutoModelForCausalLM.from_pretrained(\n",
    "    finetuned_model_id3b,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cpu\",         # or \"auto\" for GPU\n",
    "    low_cpu_mem_usage=False \n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "bc2c6277-327f-42d8-8811-500beae48251",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparing signs...\n",
      "\n",
      "=== Overall Sign Flip Stats ===\n",
      "{'total_params': 3606752256, 'total_flipped': 181914991, 'flipped_pct': 5.043734032393712}\n",
      "\n",
      "=== Sample per-layer stats ===\n",
      "model.embed_tokens.weight {'total': 394002432, 'flipped': 21561224, 'flipped_pct': 5.4723580995560965}\n",
      "model.layers.0.self_attn.q_proj.weight {'total': 9437184, 'flipped': 289456, 'flipped_pct': 3.0671861436631946}\n",
      "model.layers.0.self_attn.k_proj.weight {'total': 3145728, 'flipped': 98506, 'flipped_pct': 3.1314214070638022}\n",
      "model.layers.0.self_attn.v_proj.weight {'total': 3145728, 'flipped': 223260, 'flipped_pct': 7.0972442626953125}\n",
      "model.layers.0.self_attn.o_proj.weight {'total': 9437184, 'flipped': 586865, 'flipped_pct': 6.218645307752821}\n",
      "model.layers.0.mlp.gate_proj.weight {'total': 25165824, 'flipped': 986269, 'flipped_pct': 3.919080893198649}\n",
      "model.layers.0.mlp.up_proj.weight {'total': 25165824, 'flipped': 1033136, 'flipped_pct': 4.105313618977864}\n",
      "model.layers.0.mlp.down_proj.weight {'total': 25165824, 'flipped': 1050806, 'flipped_pct': 4.175527890523274}\n",
      "model.layers.0.input_layernorm.weight {'total': 3072, 'flipped': 20, 'flipped_pct': 0.6510416666666666}\n",
      "model.layers.0.post_attention_layernorm.weight {'total': 3072, 'flipped': 5, 'flipped_pct': 0.16276041666666666}\n"
     ]
    }
   ],
   "source": [
    "print(\"Comparing signs...\")\n",
    "layer_stats3b, overall_stats3b = compare_signs(model_pre3b, model_ft3b)\n",
    "\n",
    "print(\"\\n=== Overall Sign Flip Stats ===\")\n",
    "print(overall_stats3b)\n",
    "\n",
    "print(\"\\n=== Sample per-layer stats ===\")\n",
    "for layer, stat in list(layer_stats3b.items())[:10]:  # show first 10 layers\n",
    "    print(layer, stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3ac7d196-a9df-4027-a555-b18ec83988f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "q keys: ['model.layers.0.self_attn.q_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.4.self_attn.q_proj.weight']\n",
      "k keys: ['model.layers.0.self_attn.k_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.4.self_attn.k_proj.weight']\n",
      "v keys: ['model.layers.0.self_attn.v_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.self_attn.v_proj.weight']\n",
      "o keys: ['model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.o_proj.weight']\n",
      "q {'total': 264241152, 'flipped': 11535917, 'flipped_pct': 4.365677682180253}\n",
      "k {'total': 88080384, 'flipped': 3086174, 'flipped_pct': 3.5038153330485025}\n",
      "v {'total': 88080384, 'flipped': 4772396, 'flipped_pct': 5.418227967761811}\n",
      "o {'total': 264241152, 'flipped': 13757920, 'flipped_pct': 5.206577361576141}\n",
      "other {'total': 2902109184, 'flipped': 148762584, 'flipped_pct': 5.126016099606541}\n"
     ]
    }
   ],
   "source": [
    "grouped_stats3b = group_qkvo(layer_stats3b)\n",
    "\n",
    "print(\"q keys:\", list(grouped_stats3b[\"q\"].keys())[:5])\n",
    "print(\"k keys:\", list(grouped_stats3b[\"k\"].keys())[:5])\n",
    "print(\"v keys:\", list(grouped_stats3b[\"v\"].keys())[:5])\n",
    "print(\"o keys:\", list(grouped_stats3b[\"o\"].keys())[:5])\n",
    "\n",
    "agg_stats3b = aggregate_qkvo(layer_stats3b)\n",
    "for g, stats in agg_stats3b.items():\n",
    "    print(g, stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f3af81a5-3c38-4df5-ad5f-5d04c49ee1ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall Q/K L1 difference delta: {'total_qk_layers': 28, 'sum_delta_l1': -49197.1640625, 'avg_delta_l1': -1757.0415736607142, 'sum_delta_l2': -2847.8831539862877, 'avg_delta_l2': -101.71011264236742, 'percentageposl1': 1.0, 'percentageposl2': 1.0}\n"
     ]
    }
   ],
   "source": [
    "stats3b, overall3b = compare_qk_norm_diff_safe(model_pre3b, model_ft3b)\n",
    "print(\"Overall Q/K L1 difference delta:\", overall3b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6e3ceb59-7466-4b07-aebb-5de3939877b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained model...\n",
      "Loading fine-tuned model...\n"
     ]
    }
   ],
   "source": [
    "pretrained_model_id1b = \"meta-llama/Llama-3.2-1B\"   # pretrained\n",
    "finetuned_model_id1b = \"meta-llama/Llama-3.2-1B-Instruct\"  # finetuned\n",
    "\n",
    "print(\"Loading pretrained model...\")\n",
    "model_pre1b = AutoModelForCausalLM.from_pretrained(\n",
    "    pretrained_model_id1b,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cpu\",        \n",
    "    low_cpu_mem_usage=False \n",
    ")\n",
    "\n",
    "print(\"Loading fine-tuned model...\")\n",
    "model_ft1b = AutoModelForCausalLM.from_pretrained(\n",
    "    finetuned_model_id1b,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cpu\",        \n",
    "    low_cpu_mem_usage=False \n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9333b407-a20e-43ef-b805-c556048c1a2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparing signs...\n",
      "\n",
      "=== Overall Sign Flip Stats ===\n",
      "{'total_params': 1498482688, 'total_flipped': 85833939, 'flipped_pct': 5.728056766178669}\n",
      "\n",
      "=== Sample per-layer stats ===\n",
      "model.embed_tokens.weight {'total': 262668288, 'flipped': 16070981, 'flipped_pct': 6.118356015629873}\n",
      "model.layers.0.self_attn.q_proj.weight {'total': 4194304, 'flipped': 152442, 'flipped_pct': 3.634500503540039}\n",
      "model.layers.0.self_attn.k_proj.weight {'total': 1048576, 'flipped': 31946, 'flipped_pct': 3.0466079711914062}\n",
      "model.layers.0.self_attn.v_proj.weight {'total': 1048576, 'flipped': 79733, 'flipped_pct': 7.603931427001953}\n",
      "model.layers.0.self_attn.o_proj.weight {'total': 4194304, 'flipped': 275668, 'flipped_pct': 6.572437286376953}\n",
      "model.layers.0.mlp.gate_proj.weight {'total': 16777216, 'flipped': 682260, 'flipped_pct': 4.066586494445801}\n",
      "model.layers.0.mlp.up_proj.weight {'total': 16777216, 'flipped': 737083, 'flipped_pct': 4.393357038497925}\n",
      "model.layers.0.mlp.down_proj.weight {'total': 16777216, 'flipped': 751979, 'flipped_pct': 4.482144117355347}\n",
      "model.layers.0.input_layernorm.weight {'total': 2048, 'flipped': 1, 'flipped_pct': 0.048828125}\n",
      "model.layers.0.post_attention_layernorm.weight {'total': 2048, 'flipped': 3, 'flipped_pct': 0.146484375}\n"
     ]
    }
   ],
   "source": [
    "print(\"Comparing signs...\")\n",
    "layer_stats1b, overall_stats1b = compare_signs(model_pre1b, model_ft1b)\n",
    "\n",
    "print(\"\\n=== Overall Sign Flip Stats ===\")\n",
    "print(overall_stats1b)\n",
    "\n",
    "print(\"\\n=== Sample per-layer stats ===\")\n",
    "for layer, stat in list(layer_stats1b.items())[:10]:  # show first 10 layers\n",
    "    print(layer, stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a08560d5-8ca0-4872-95f7-9131fe313f20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "q keys: ['model.layers.0.self_attn.q_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.4.self_attn.q_proj.weight']\n",
      "k keys: ['model.layers.0.self_attn.k_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.4.self_attn.k_proj.weight']\n",
      "v keys: ['model.layers.0.self_attn.v_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.self_attn.v_proj.weight']\n",
      "o keys: ['model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.o_proj.weight']\n",
      "q {'total': 67108864, 'flipped': 3174628, 'flipped_pct': 4.73056435585022}\n",
      "k {'total': 16777216, 'flipped': 544163, 'flipped_pct': 3.243464231491089}\n",
      "v {'total': 16777216, 'flipped': 949962, 'flipped_pct': 5.662214756011963}\n",
      "o {'total': 67108864, 'flipped': 3859626, 'flipped_pct': 5.751290917396545}\n",
      "other {'total': 1330710528, 'flipped': 77305560, 'flipped_pct': 5.809344584970474}\n"
     ]
    }
   ],
   "source": [
    "grouped_stats1b = group_qkvo(layer_stats1b)\n",
    "\n",
    "print(\"q keys:\", list(grouped_stats1b[\"q\"].keys())[:5])\n",
    "print(\"k keys:\", list(grouped_stats1b[\"k\"].keys())[:5])\n",
    "print(\"v keys:\", list(grouped_stats1b[\"v\"].keys())[:5])\n",
    "print(\"o keys:\", list(grouped_stats1b[\"o\"].keys())[:5])\n",
    "\n",
    "agg_stats1b = aggregate_qkvo(layer_stats1b)\n",
    "for g, stats in agg_stats1b.items():\n",
    "    print(g, stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7e8b68c2-93e4-4a6f-9355-00e7291f237d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall Q/K L1 difference delta: {'total_qk_layers': 16, 'sum_delta_l1': -14268.134765625, 'avg_delta_l1': -891.7584228515625, 'sum_delta_l2': -1058.476924921375, 'avg_delta_l2': -66.15480780758594, 'percentageposl1': 1.0, 'percentageposl2': 1.0}\n"
     ]
    }
   ],
   "source": [
    "stats1b, overall1b = compare_qk_norm_diff_safe(model_pre1b, model_ft1b)\n",
    "print(\"Overall Q/K L1 difference delta:\", overall1b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "276940f9-da97-4962-8ec4-c50a8b840cea",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
