{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d7f7311e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1 GPU(s) on this machine.\n",
      "GPU 0: NVIDIA GeForce RTX 3090\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "def check_gpu_availability():\n",
    "    \"\"\"Check for available GPUs and print their names.\"\"\"\n",
    "    gpu_count = torch.cuda.device_count()\n",
    "    if gpu_count > 0:\n",
    "        print(f\"Found {gpu_count} GPU(s) on this machine.\")\n",
    "        for i in range(gpu_count):\n",
    "            print(f\"GPU {i}: {torch.cuda.get_device_name(i)}\")\n",
    "    else:\n",
    "        print(\"No GPU found on this machine.\")\n",
    "        \n",
    "check_gpu_availability()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "79cc0493",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Currently processing layer 0...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_0\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_0/x4_SprCoef0.1_LR0.0002_Los0.7811_L5.7996_R0.9986_BalancedBest_epoch8580.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_0/x4_SprCoef0.1_LR0.0002_Los0.7811_L5.7996_R0.9986_BalancedBest_epoch8580.pt\n",
      "Layer 0 - Total active atoms: 35407\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_0/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_0/atoms_id.pt\n",
      "Currently processing layer 1...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_1\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_1/x4_SprCoef0.1_LR0.0002_Los0.8810_L6.7453_R0.9984_BalancedBest_epoch8880.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_1/x4_SprCoef0.1_LR0.0002_Los0.8810_L6.7453_R0.9984_BalancedBest_epoch8880.pt\n",
      "Layer 1 - Total active atoms: 36589\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_1/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_1/atoms_id.pt\n",
      "Currently processing layer 2...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_2\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_2/x4_SprCoef0.1_LR0.0002_Los0.8617_L7.2533_R0.9987_BalancedBest_epoch8360.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_2/x4_SprCoef0.1_LR0.0002_Los0.8617_L7.2533_R0.9987_BalancedBest_epoch8360.pt\n",
      "Layer 2 - Total active atoms: 36142\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_2/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_2/atoms_id.pt\n",
      "Currently processing layer 3...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_3\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_3/x4_SprCoef0.1_LR0.0002_Los0.5163_L4.0182_R0.9992_BalancedBest_epoch9720.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_3/x4_SprCoef0.1_LR0.0002_Los0.5163_L4.0182_R0.9992_BalancedBest_epoch9720.pt\n",
      "Layer 3 - Total active atoms: 35799\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_3/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_3/atoms_id.pt\n",
      "Currently processing layer 4...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_4\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_4/x4_SprCoef0.1_LR0.0002_Los0.3889_L3.2714_R0.9994_BalancedBest_epoch8580.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_4/x4_SprCoef0.1_LR0.0002_Los0.3889_L3.2714_R0.9994_BalancedBest_epoch8580.pt\n",
      "Layer 4 - Total active atoms: 33560\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_4/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_4/atoms_id.pt\n",
      "Currently processing layer 5...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_5\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_5/x4_SprCoef0.1_LR0.0002_Los0.5470_L5.1462_R0.9996_BalancedBest_epoch7300.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_5/x4_SprCoef0.1_LR0.0002_Los0.5470_L5.1462_R0.9996_BalancedBest_epoch7300.pt\n",
      "Layer 5 - Total active atoms: 35473\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_5/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_5/atoms_id.pt\n",
      "Currently processing layer 6...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_6\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_6/x4_SprCoef0.1_LR0.0002_Los0.5695_L5.3801_R0.9995_BalancedBest_epoch6320.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_6/x4_SprCoef0.1_LR0.0002_Los0.5695_L5.3801_R0.9995_BalancedBest_epoch6320.pt\n",
      "Layer 6 - Total active atoms: 33332\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_6/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_6/atoms_id.pt\n",
      "Currently processing layer 7...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_7\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_7/x4_SprCoef0.1_LR0.0002_Los0.3791_L3.1689_R0.9993_BalancedBest_epoch8560.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_7/x4_SprCoef0.1_LR0.0002_Los0.3791_L3.1689_R0.9993_BalancedBest_epoch8560.pt\n",
      "Layer 7 - Total active atoms: 32251\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_7/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_7/atoms_id.pt\n",
      "Currently processing layer 8...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_8\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_8/x4_SprCoef0.1_LR0.0002_Los0.6836_L6.4229_R0.9996_BalancedBest_epoch6960.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_8/x4_SprCoef0.1_LR0.0002_Los0.6836_L6.4229_R0.9996_BalancedBest_epoch6960.pt\n",
      "Layer 8 - Total active atoms: 34241\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_8/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_8/atoms_id.pt\n",
      "Currently processing layer 9...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_9\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_9/x4_SprCoef0.1_LR0.0002_Los0.4801_L4.5709_R0.9995_BalancedBest_epoch6060.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_9/x4_SprCoef0.1_LR0.0002_Los0.4801_L4.5709_R0.9995_BalancedBest_epoch6060.pt\n",
      "Layer 9 - Total active atoms: 32726\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_9/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_9/atoms_id.pt\n",
      "Currently processing layer 10...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_10\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_10/x4_SprCoef0.1_LR0.0002_Los0.6357_L6.0707_R0.9992_BalancedBest_epoch5500.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_10/x4_SprCoef0.1_LR0.0002_Los0.6357_L6.0707_R0.9992_BalancedBest_epoch5500.pt\n",
      "Layer 10 - Total active atoms: 31476\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_10/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_10/atoms_id.pt\n",
      "Currently processing layer 11...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_11\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_11/x4_SprCoef0.1_LR0.0002_Los0.6243_L5.9365_R0.9992_BalancedBest_epoch5520.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_11/x4_SprCoef0.1_LR0.0002_Los0.6243_L5.9365_R0.9992_BalancedBest_epoch5520.pt\n",
      "Layer 11 - Total active atoms: 29187\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_11/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_11/atoms_id.pt\n",
      "Currently processing layer 12...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_12\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_12/x4_SprCoef0.1_LR0.0002_Los0.8234_L7.9852_R0.9991_BalancedBest_epoch5340.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_12/x4_SprCoef0.1_LR0.0002_Los0.8234_L7.9852_R0.9991_BalancedBest_epoch5340.pt\n",
      "Layer 12 - Total active atoms: 24989\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_12/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_12/atoms_id.pt\n",
      "Currently processing layer 13...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_13\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_13/x4_SprCoef0.1_LR0.0002_Los0.7069_L6.4095_R0.9973_BalancedBest_epoch5700.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_13/x4_SprCoef0.1_LR0.0002_Los0.7069_L6.4095_R0.9973_BalancedBest_epoch5700.pt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer 13 - Total active atoms: 22445\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_13/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_13/atoms_id.pt\n",
      "Currently processing layer 14...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_14\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_14/x4_SprCoef0.1_LR0.0002_Los0.4161_L3.9074_R0.9989_BalancedBest_epoch7540.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_14/x4_SprCoef0.1_LR0.0002_Los0.4161_L3.9074_R0.9989_BalancedBest_epoch7540.pt\n",
      "Layer 14 - Total active atoms: 25996\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_14/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_14/atoms_id.pt\n",
      "Currently processing layer 15...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_15\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_15/x4_SprCoef0.1_LR0.0002_Los0.4730_L4.4783_R0.9988_BalancedBest_epoch5520.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_15/x4_SprCoef0.1_LR0.0002_Los0.4730_L4.4783_R0.9988_BalancedBest_epoch5520.pt\n",
      "Layer 15 - Total active atoms: 25881\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_15/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_15/atoms_id.pt\n",
      "Currently processing layer 16...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_16\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_16/x4_SprCoef0.1_LR0.0002_Los0.6589_L6.2815_R0.9992_BalancedBest_epoch5600.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_16/x4_SprCoef0.1_LR0.0002_Los0.6589_L6.2815_R0.9992_BalancedBest_epoch5600.pt\n",
      "Layer 16 - Total active atoms: 25924\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_16/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_16/atoms_id.pt\n",
      "Currently processing layer 17...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_17\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_17/x4_SprCoef0.1_LR0.0002_Los0.4576_L4.2865_R0.9994_BalancedBest_epoch5940.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_17/x4_SprCoef0.1_LR0.0002_Los0.4576_L4.2865_R0.9994_BalancedBest_epoch5940.pt\n",
      "Layer 17 - Total active atoms: 25968\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_17/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_17/atoms_id.pt\n",
      "Currently processing layer 18...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_18\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_18/x4_SprCoef0.1_LR0.0002_Los0.9494_L9.0535_R0.9996_BalancedBest_epoch5700.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_18/x4_SprCoef0.1_LR0.0002_Los0.9494_L9.0535_R0.9996_BalancedBest_epoch5700.pt\n",
      "Layer 18 - Total active atoms: 29698\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_18/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_18/atoms_id.pt\n",
      "Currently processing layer 19...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_19\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_19/x4_SprCoef0.1_LR0.0002_Los0.6353_L6.0711_R0.9997_BalancedBest_epoch6160.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_19/x4_SprCoef0.1_LR0.0002_Los0.6353_L6.0711_R0.9997_BalancedBest_epoch6160.pt\n",
      "Layer 19 - Total active atoms: 29727\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_19/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_19/atoms_id.pt\n",
      "Currently processing layer 20...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_20\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_20/x4_SprCoef0.1_LR0.0002_Los0.5327_L4.6972_R0.9989_BalancedBest_epoch6020.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_20/x4_SprCoef0.1_LR0.0002_Los0.5327_L4.6972_R0.9989_BalancedBest_epoch6020.pt\n",
      "Layer 20 - Total active atoms: 29359\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_20/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_20/atoms_id.pt\n",
      "Currently processing layer 21...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_21\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_21/x4_SprCoef0.1_LR0.0002_Los0.8323_L8.1301_R0.9997_BalancedBest_epoch5340.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_21/x4_SprCoef0.1_LR0.0002_Los0.8323_L8.1301_R0.9997_BalancedBest_epoch5340.pt\n",
      "Layer 21 - Total active atoms: 30082\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_21/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_21/atoms_id.pt\n",
      "Currently processing layer 22...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_22\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_22/x4_SprCoef0.1_LR0.0002_Los0.9583_L9.3776_R0.9995_BalancedBest_epoch5120.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_22/x4_SprCoef0.1_LR0.0002_Los0.9583_L9.3776_R0.9995_BalancedBest_epoch5120.pt\n",
      "Layer 22 - Total active atoms: 32424\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_22/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_22/atoms_id.pt\n",
      "Currently processing layer 23...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_23\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_23/x4_SprCoef0.1_LR0.0002_Los0.6076_L5.8653_R0.9994_BalancedBest_epoch5420.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_23/x4_SprCoef0.1_LR0.0002_Los0.6076_L5.8653_R0.9994_BalancedBest_epoch5420.pt\n",
      "Layer 23 - Total active atoms: 30039\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_23/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_23/atoms_id.pt\n",
      "Currently processing layer 24...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_24\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_24/x4_SprCoef0.1_LR0.0002_Los0.6277_L6.0088_R0.9994_BalancedBest_epoch5480.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_24/x4_SprCoef0.1_LR0.0002_Los0.6277_L6.0088_R0.9994_BalancedBest_epoch5480.pt\n",
      "Layer 24 - Total active atoms: 29452\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_24/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_24/atoms_id.pt\n",
      "Currently processing layer 25...\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_25\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_25/x4_SprCoef0.1_LR0.0002_Los0.7987_L7.6514_R0.9993_BalancedBest_epoch5420.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_25/x4_SprCoef0.1_LR0.0002_Los0.7987_L7.6514_R0.9993_BalancedBest_epoch5420.pt\n",
      "Layer 25 - Total active atoms: 27478\n",
      "Saved counts → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_25/sparsity.json\n",
      "Saved active atoms IDs (torch tensor) → ../data/Atomicity/Atoms/google/gemma-2-2b/layer_25/atoms_id.pt\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from pathlib import Path\n",
    "import gc\n",
    "import re\n",
    "\n",
    "# ---------- Configuration ----------\n",
    "model_name = \"google/gemma-2-2b\"\n",
    "d_model = 9216\n",
    "d_sae = d_model * 4\n",
    "base_dir = Path(f\"../saved_models/{model_name}\")\n",
    "repr_dir = Path(f\"../data/Representation/down_proj_input/{model_name}/Counterfact\")\n",
    "sparsity_out_dir = os.path.join(\"../data\", \"Atomicity\", \"Atoms\", model_name)\n",
    "\n",
    "class JumpReLUSAE(nn.Module):\n",
    "    \"\"\"\n",
    "    Sparse AutoEncoder with Jump ReLU nonlinearity.\n",
    "    - d_model: input dimension\n",
    "    - d_sae: hidden dimension (e.g. 4x d_model)\n",
    "    \"\"\"\n",
    "    def __init__(self, d_model, d_sae):\n",
    "        super().__init__()\n",
    "        self.use_pre_enc_bias = True\n",
    "        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))\n",
    "        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))\n",
    "        self.log_threshold = nn.Parameter(torch.zeros(d_sae))  # use log to keep threshold positive\n",
    "        self.b_enc = nn.Parameter(torch.zeros(d_sae))\n",
    "        self.b_dec = nn.Parameter(torch.zeros(d_model))\n",
    "        \n",
    "    @property\n",
    "    def threshold(self):\n",
    "        # dynamic threshold computed from log_threshold\n",
    "        return torch.exp(self.log_threshold)\n",
    "    \n",
    "    def encode(self, input_acts):\n",
    "        if self.use_pre_enc_bias:\n",
    "            input_acts = input_acts - self.b_dec\n",
    "        pre_acts = input_acts @ self.W_enc + self.b_enc\n",
    "        acts = torch.relu(pre_acts) * (pre_acts > self.threshold)\n",
    "        return acts\n",
    "    \n",
    "    def decode(self, acts):\n",
    "        return acts @ self.W_dec + self.b_dec\n",
    "    \n",
    "    def forward(self, x):\n",
    "        acts = self.encode(x)\n",
    "        recon = self.decode(acts)\n",
    "        return recon\n",
    "\n",
    "\n",
    "# ---------- Traverse and process each layer ----------\n",
    "layer_dirs = sorted(base_dir.glob(\"layer_*\"), key=lambda p: int(re.search(r'\\d+', p.name).group()))\n",
    "for layer_dir in layer_dirs:\n",
    "    layer_number = re.search(r'layer_(\\d+)', str(layer_dir)).group(1)\n",
    "    print(f\"Currently processing layer {layer_number}...\")\n",
    "    \n",
    "    pt_files = sorted(layer_dir.glob(\"*.pt\"))\n",
    "    print(f\"Total {len(pt_files)} model found in dir {layer_dir}\")\n",
    "    if not pt_files:\n",
    "        continue\n",
    "    pt_path = pt_files[0]\n",
    "    print(f\"📂 Processing {pt_path}\")\n",
    "\n",
    "    # Load model\n",
    "    sae = JumpReLUSAE(d_model, d_sae)\n",
    "    state_dict = torch.load(pt_path, map_location=\"cpu\")\n",
    "    sae.load_state_dict(state_dict)\n",
    "    sae.eval()\n",
    "    print(f\"SAE model loaded successfully from: {pt_path}\")\n",
    "    \n",
    "    # Load activations\n",
    "    act_path = repr_dir / f\"layer_{layer_number}.pt\"\n",
    "    activation_values = torch.load(act_path)\n",
    "    keys, dataset = zip(*activation_values.items())\n",
    "    \n",
    "    target_act = torch.stack(dataset).to(torch.float32)\n",
    "    \n",
    "    # Statistics container\n",
    "    per_datapoint_counts = []\n",
    "    ever_active = torch.zeros(d_sae, dtype=torch.bool, device=\"cpu\")\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        BATCH_SIZE_INFER = 4096  # Adjustable\n",
    "        for start_idx in range(0, target_act.shape[0], BATCH_SIZE_INFER):\n",
    "            end_idx = min(start_idx + BATCH_SIZE_INFER, target_act.shape[0])\n",
    "            batch = target_act[start_idx:end_idx]\n",
    "            sae_acts_batch = sae.encode(batch)\n",
    "            pos = (sae_acts_batch > 0)\n",
    "            per_datapoint_counts.extend(pos.sum(dim=1).int().cpu().tolist())\n",
    "            ever_active |= pos.any(dim=0).cpu()\n",
    "            del batch, sae_acts_batch, pos\n",
    "            torch.cuda.empty_cache()\n",
    "            \n",
    "    active_ids = torch.nonzero(ever_active, as_tuple=True)[0]\n",
    "    print(f\"Layer {layer_number} - Total active atoms: {active_ids.numel()}\")\n",
    "    \n",
    "    layer_out_dir = f\"{sparsity_out_dir}/layer_{layer_number}\"\n",
    "    os.makedirs(layer_out_dir, exist_ok=True)\n",
    "    \n",
    "    sparsity_json_path = f\"{layer_out_dir}/sparsity.json\"\n",
    "    with open(sparsity_json_path, \"w\") as f:\n",
    "        json.dump(per_datapoint_counts, f)\n",
    "    print(f\"Saved counts → {sparsity_json_path}\")\n",
    "    \n",
    "    atoms_id_path = f\"{layer_out_dir}/atoms_id.pt\"\n",
    "    torch.save(active_ids, atoms_id_path)\n",
    "    print(f\"Saved active atoms IDs (torch tensor) → {atoms_id_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d386f8e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Layer 0 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_0\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_0/x4_SprCoef0.1_LR0.0002_Los0.7811_L5.7996_R0.9986_BalancedBest_epoch8580.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_0/x4_SprCoef0.1_LR0.0002_Los0.7811_L5.7996_R0.9986_BalancedBest_epoch8580.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 35407)\n",
      "  W_normed shape: (9216, 35407)\n",
      "  G shape: (35407, 35407)\n",
      "  off-diag count: 1,253,620,242\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999364 | μ(α)=6.249996e-02 | K(α)=17.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997399 | μ(α)=3.455008e-02 | K(α)=14.972\n",
      "\n",
      "=== Layer 1 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_1\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_1/x4_SprCoef0.1_LR0.0002_Los0.8810_L6.7453_R0.9984_BalancedBest_epoch8880.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_1/x4_SprCoef0.1_LR0.0002_Los0.8810_L6.7453_R0.9984_BalancedBest_epoch8880.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 36589)\n",
      "  W_normed shape: (9216, 36589)\n",
      "  G shape: (36589, 36589)\n",
      "  off-diag count: 1,338,718,332\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999595 | μ(α)=6.666665e-02 | K(α)=16.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997752 | μ(α)=3.448275e-02 | K(α)=15.000\n",
      "\n",
      "=== Layer 2 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_2\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_2/x4_SprCoef0.1_LR0.0002_Los0.8617_L7.2533_R0.9987_BalancedBest_epoch8360.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_2/x4_SprCoef0.1_LR0.0002_Los0.8617_L7.2533_R0.9987_BalancedBest_epoch8360.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 36142)\n",
      "  W_normed shape: (9216, 36142)\n",
      "  G shape: (36142, 36142)\n",
      "  off-diag count: 1,306,208,022\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999557 | μ(α)=6.249993e-02 | K(α)=17.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997803 | μ(α)=3.448274e-02 | K(α)=15.000\n",
      "\n",
      "=== Layer 3 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_3\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_3/x4_SprCoef0.1_LR0.0002_Los0.5163_L4.0182_R0.9992_BalancedBest_epoch9720.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_3/x4_SprCoef0.1_LR0.0002_Los0.5163_L4.0182_R0.9992_BalancedBest_epoch9720.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 35799)\n",
      "  W_normed shape: (9216, 35799)\n",
      "  G shape: (35799, 35799)\n",
      "  off-diag count: 1,281,532,602\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999669 | μ(α)=9.090909e-02 | K(α)=12.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998824 | μ(α)=5.263157e-02 | K(α)=10.000\n",
      "\n",
      "=== Layer 4 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_4\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_4/x4_SprCoef0.1_LR0.0002_Los0.3889_L3.2714_R0.9994_BalancedBest_epoch8580.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_4/x4_SprCoef0.1_LR0.0002_Los0.3889_L3.2714_R0.9994_BalancedBest_epoch8580.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 33560)\n",
      "  W_normed shape: (9216, 33560)\n",
      "  G shape: (33560, 33560)\n",
      "  off-diag count: 1,126,240,040\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999791 | μ(α)=1.111111e-01 | K(α)=10.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.999147 | μ(α)=5.882348e-02 | K(α)=9.000\n",
      "\n",
      "=== Layer 5 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_5\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_5/x4_SprCoef0.1_LR0.0002_Los0.5470_L5.1462_R0.9996_BalancedBest_epoch7300.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_5/x4_SprCoef0.1_LR0.0002_Los0.5470_L5.1462_R0.9996_BalancedBest_epoch7300.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 35473)\n",
      "  W_normed shape: (9216, 35473)\n",
      "  G shape: (35473, 35473)\n",
      "  off-diag count: 1,258,298,256\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999694 | μ(α)=7.692283e-02 | K(α)=14.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998335 | μ(α)=3.999999e-02 | K(α)=13.000\n",
      "\n",
      "=== Layer 6 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_6\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_6/x4_SprCoef0.1_LR0.0002_Los0.5695_L5.3801_R0.9995_BalancedBest_epoch6320.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_6/x4_SprCoef0.1_LR0.0002_Los0.5695_L5.3801_R0.9995_BalancedBest_epoch6320.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 33332)\n",
      "  W_normed shape: (9216, 33332)\n",
      "  G shape: (33332, 33332)\n",
      "  off-diag count: 1,110,988,892\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999688 | μ(α)=6.831554e-02 | K(α)=15.638\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998323 | μ(α)=3.703702e-02 | K(α)=14.000\n",
      "\n",
      "=== Layer 7 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_7\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_7/x4_SprCoef0.1_LR0.0002_Los0.3791_L3.1689_R0.9993_BalancedBest_epoch8560.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_7/x4_SprCoef0.1_LR0.0002_Los0.3791_L3.1689_R0.9993_BalancedBest_epoch8560.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 32251)\n",
      "  W_normed shape: (9216, 32251)\n",
      "  G shape: (32251, 32251)\n",
      "  off-diag count: 1,040,094,750\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999813 | μ(α)=1.111108e-01 | K(α)=10.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.999395 | μ(α)=6.666665e-02 | K(α)=8.000\n",
      "\n",
      "=== Layer 8 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_8\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_8/x4_SprCoef0.1_LR0.0002_Los0.6836_L6.4229_R0.9996_BalancedBest_epoch6960.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_8/x4_SprCoef0.1_LR0.0002_Los0.6836_L6.4229_R0.9996_BalancedBest_epoch6960.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 34241)\n",
      "  W_normed shape: (9216, 34241)\n",
      "  G shape: (34241, 34241)\n",
      "  off-diag count: 1,172,411,840\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999647 | μ(α)=6.666658e-02 | K(α)=16.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997653 | μ(α)=3.448275e-02 | K(α)=15.000\n",
      "\n",
      "=== Layer 9 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_9\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_9/x4_SprCoef0.1_LR0.0002_Los0.4801_L4.5709_R0.9995_BalancedBest_epoch6060.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_9/x4_SprCoef0.1_LR0.0002_Los0.4801_L4.5709_R0.9995_BalancedBest_epoch6060.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 32726)\n",
      "  W_normed shape: (9216, 32726)\n",
      "  G shape: (32726, 32726)\n",
      "  off-diag count: 1,070,958,350\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999687 | μ(α)=7.692282e-02 | K(α)=14.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998665 | μ(α)=4.347825e-02 | K(α)=12.000\n",
      "\n",
      "=== Layer 10 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_10\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_10/x4_SprCoef0.1_LR0.0002_Los0.6357_L6.0707_R0.9992_BalancedBest_epoch5500.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_10/x4_SprCoef0.1_LR0.0002_Los0.6357_L6.0707_R0.9992_BalancedBest_epoch5500.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 31476)\n",
      "  W_normed shape: (9216, 31476)\n",
      "  G shape: (31476, 31476)\n",
      "  off-diag count: 990,707,100\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999650 | μ(α)=6.666649e-02 | K(α)=16.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998276 | μ(α)=3.703704e-02 | K(α)=14.000\n",
      "\n",
      "=== Layer 11 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_11\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_11/x4_SprCoef0.1_LR0.0002_Los0.6243_L5.9365_R0.9992_BalancedBest_epoch5520.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_11/x4_SprCoef0.1_LR0.0002_Los0.6243_L5.9365_R0.9992_BalancedBest_epoch5520.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 29187)\n",
      "  W_normed shape: (9216, 29187)\n",
      "  G shape: (29187, 29187)\n",
      "  off-diag count: 851,851,782\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999591 | μ(α)=6.249995e-02 | K(α)=17.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997925 | μ(α)=3.448275e-02 | K(α)=15.000\n",
      "\n",
      "=== Layer 12 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_12\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_12/x4_SprCoef0.1_LR0.0002_Los0.8234_L7.9852_R0.9991_BalancedBest_epoch5340.pt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_12/x4_SprCoef0.1_LR0.0002_Los0.8234_L7.9852_R0.9991_BalancedBest_epoch5340.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 24989)\n",
      "  W_normed shape: (9216, 24989)\n",
      "  G shape: (24989, 24989)\n",
      "  off-diag count: 624,425,132\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999046 | μ(α)=4.247600e-02 | K(α)=24.543\n",
      "    Max α for  μ < 1/(2K-1):  α=0.991072 | μ(α)=2.439024e-02 | K(α)=21.000\n",
      "\n",
      "=== Layer 13 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_13\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_13/x4_SprCoef0.1_LR0.0002_Los0.7069_L6.4095_R0.9973_BalancedBest_epoch5700.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_13/x4_SprCoef0.1_LR0.0002_Los0.7069_L6.4095_R0.9973_BalancedBest_epoch5700.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 22445)\n",
      "  W_normed shape: (9216, 22445)\n",
      "  G shape: (22445, 22445)\n",
      "  off-diag count: 503,755,580\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999159 | μ(α)=4.796492e-02 | K(α)=21.849\n",
      "    Max α for  μ < 1/(2K-1):  α=0.995977 | μ(α)=2.702702e-02 | K(α)=19.000\n",
      "\n",
      "=== Layer 14 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_14\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_14/x4_SprCoef0.1_LR0.0002_Los0.4161_L3.9074_R0.9989_BalancedBest_epoch7540.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_14/x4_SprCoef0.1_LR0.0002_Los0.4161_L3.9074_R0.9989_BalancedBest_epoch7540.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 25996)\n",
      "  W_normed shape: (9216, 25996)\n",
      "  G shape: (25996, 25996)\n",
      "  off-diag count: 675,766,020\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999833 | μ(α)=9.999973e-02 | K(α)=11.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.999306 | μ(α)=5.263149e-02 | K(α)=10.000\n",
      "\n",
      "=== Layer 15 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_15\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_15/x4_SprCoef0.1_LR0.0002_Los0.4730_L4.4783_R0.9988_BalancedBest_epoch5520.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_15/x4_SprCoef0.1_LR0.0002_Los0.4730_L4.4783_R0.9988_BalancedBest_epoch5520.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 25881)\n",
      "  W_normed shape: (9216, 25881)\n",
      "  G shape: (25881, 25881)\n",
      "  off-diag count: 669,800,280\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999763 | μ(α)=8.333331e-02 | K(α)=13.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998881 | μ(α)=4.276926e-02 | K(α)=12.191\n",
      "\n",
      "=== Layer 16 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_16\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_16/x4_SprCoef0.1_LR0.0002_Los0.6589_L6.2815_R0.9992_BalancedBest_epoch5600.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_16/x4_SprCoef0.1_LR0.0002_Los0.6589_L6.2815_R0.9992_BalancedBest_epoch5600.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 25924)\n",
      "  W_normed shape: (9216, 25924)\n",
      "  G shape: (25924, 25924)\n",
      "  off-diag count: 672,027,852\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999525 | μ(α)=5.882349e-02 | K(α)=18.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997147 | μ(α)=3.225806e-02 | K(α)=16.000\n",
      "\n",
      "=== Layer 17 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_17\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_17/x4_SprCoef0.1_LR0.0002_Los0.4576_L4.2865_R0.9994_BalancedBest_epoch5940.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_17/x4_SprCoef0.1_LR0.0002_Los0.4576_L4.2865_R0.9994_BalancedBest_epoch5940.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 25968)\n",
      "  W_normed shape: (9216, 25968)\n",
      "  G shape: (25968, 25968)\n",
      "  off-diag count: 674,311,056\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999793 | μ(α)=9.090889e-02 | K(α)=12.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.999080 | μ(α)=4.761904e-02 | K(α)=11.000\n",
      "\n",
      "=== Layer 18 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_18\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_18/x4_SprCoef0.1_LR0.0002_Los0.9494_L9.0535_R0.9996_BalancedBest_epoch5700.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_18/x4_SprCoef0.1_LR0.0002_Los0.9494_L9.0535_R0.9996_BalancedBest_epoch5700.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 29698)\n",
      "  W_normed shape: (9216, 29698)\n",
      "  G shape: (29698, 29698)\n",
      "  off-diag count: 881,941,506\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999045 | μ(α)=4.545453e-02 | K(α)=23.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.994439 | μ(α)=2.702703e-02 | K(α)=19.000\n",
      "\n",
      "=== Layer 19 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_19\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_19/x4_SprCoef0.1_LR0.0002_Los0.6353_L6.0711_R0.9997_BalancedBest_epoch6160.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_19/x4_SprCoef0.1_LR0.0002_Los0.6353_L6.0711_R0.9997_BalancedBest_epoch6160.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 29727)\n",
      "  W_normed shape: (9216, 29727)\n",
      "  G shape: (29727, 29727)\n",
      "  off-diag count: 883,664,802\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999556 | μ(α)=6.666663e-02 | K(α)=16.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998011 | μ(α)=3.703704e-02 | K(α)=14.000\n",
      "\n",
      "=== Layer 20 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_20\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_20/x4_SprCoef0.1_LR0.0002_Los0.5327_L4.6972_R0.9989_BalancedBest_epoch6020.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_20/x4_SprCoef0.1_LR0.0002_Los0.5327_L4.6972_R0.9989_BalancedBest_epoch6020.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 29359)\n",
      "  W_normed shape: (9216, 29359)\n",
      "  G shape: (29359, 29359)\n",
      "  off-diag count: 861,921,522\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999766 | μ(α)=8.900110e-02 | K(α)=12.236\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998906 | μ(α)=4.761904e-02 | K(α)=11.000\n",
      "\n",
      "=== Layer 21 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_21\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_21/x4_SprCoef0.1_LR0.0002_Los0.8323_L8.1301_R0.9997_BalancedBest_epoch5340.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_21/x4_SprCoef0.1_LR0.0002_Los0.8323_L8.1301_R0.9997_BalancedBest_epoch5340.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 30082)\n",
      "  W_normed shape: (9216, 30082)\n",
      "  G shape: (30082, 30082)\n",
      "  off-diag count: 904,896,642\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999380 | μ(α)=5.555552e-02 | K(α)=19.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.996438 | μ(α)=3.030303e-02 | K(α)=17.000\n",
      "\n",
      "=== Layer 22 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_22\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_22/x4_SprCoef0.1_LR0.0002_Los0.9583_L9.3776_R0.9995_BalancedBest_epoch5120.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_22/x4_SprCoef0.1_LR0.0002_Los0.9583_L9.3776_R0.9995_BalancedBest_epoch5120.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 32424)\n",
      "  W_normed shape: (9216, 32424)\n",
      "  G shape: (32424, 32424)\n",
      "  off-diag count: 1,051,283,352\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.998830 | μ(α)=4.347825e-02 | K(α)=24.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.992186 | μ(α)=2.564102e-02 | K(α)=20.000\n",
      "\n",
      "=== Layer 23 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_23\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_23/x4_SprCoef0.1_LR0.0002_Los0.6076_L5.8653_R0.9994_BalancedBest_epoch5420.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_23/x4_SprCoef0.1_LR0.0002_Los0.6076_L5.8653_R0.9994_BalancedBest_epoch5420.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 30039)\n",
      "  W_normed shape: (9216, 30039)\n",
      "  G shape: (30039, 30039)\n",
      "  off-diag count: 902,311,482\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999567 | μ(α)=7.052356e-02 | K(α)=15.180\n",
      "    Max α for  μ < 1/(2K-1):  α=0.997894 | μ(α)=3.703704e-02 | K(α)=14.000\n",
      "\n",
      "=== Layer 24 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_24\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_24/x4_SprCoef0.1_LR0.0002_Los0.6277_L6.0088_R0.9994_BalancedBest_epoch5480.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_24/x4_SprCoef0.1_LR0.0002_Los0.6277_L6.0088_R0.9994_BalancedBest_epoch5480.pt\n",
      "  W shape: (9216, 36864)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  W shape: (9216, 29452)\n",
      "  W_normed shape: (9216, 29452)\n",
      "  G shape: (29452, 29452)\n",
      "  off-diag count: 867,390,852\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999626 | μ(α)=7.692305e-02 | K(α)=14.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.998138 | μ(α)=4.000000e-02 | K(α)=13.000\n",
      "\n",
      "=== Layer 25 ===\n",
      "Total 1 model found in dir ../saved_models/google/gemma-2-2b/layer_25\n",
      "📂 Processing ../saved_models/google/gemma-2-2b/layer_25/x4_SprCoef0.1_LR0.0002_Los0.7987_L7.6514_R0.9993_BalancedBest_epoch5420.pt\n",
      "SAE model loaded successfully from: ../saved_models/google/gemma-2-2b/layer_25/x4_SprCoef0.1_LR0.0002_Los0.7987_L7.6514_R0.9993_BalancedBest_epoch5420.pt\n",
      "  W shape: (9216, 36864)\n",
      "  W shape: (9216, 27478)\n",
      "  W_normed shape: (9216, 27478)\n",
      "  G shape: (27478, 27478)\n",
      "  off-diag count: 755,013,006\n",
      "  α search results:\n",
      "    Max α for  μ < 1/(K-1):   α=0.999171 | μ(α)=5.263156e-02 | K(α)=20.000\n",
      "    Max α for  μ < 1/(2K-1):  α=0.994274 | μ(α)=2.857143e-02 | K(α)=18.000\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from pathlib import Path\n",
    "import gc\n",
    "import re\n",
    "\n",
    "MAX_Q_SAMPLES = 5000000\n",
    "# ---------- 配置 ----------\n",
    "model_name = \"google/gemma-2-2b\"\n",
    "d_model = 9216\n",
    "d_sae = d_model * 4\n",
    "base_dir = Path(f\"../saved_models/{model_name}\")\n",
    "\n",
    "class JumpReLUSAE(nn.Module):\n",
    "    \"\"\"\n",
    "    Sparse AutoEncoder with Jump ReLU nonlinearity.\n",
    "    - d_model: input dimension\n",
    "    - d_sae: hidden dimension (e.g. 4x d_model)\n",
    "    \"\"\"\n",
    "    def __init__(self, d_model, d_sae):\n",
    "        super().__init__()\n",
    "        self.use_pre_enc_bias = True\n",
    "        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))\n",
    "        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))\n",
    "        self.log_threshold = nn.Parameter(torch.zeros(d_sae))  # use log to keep threshold positive\n",
    "        self.b_enc = nn.Parameter(torch.zeros(d_sae))\n",
    "        self.b_dec = nn.Parameter(torch.zeros(d_model))\n",
    "        \n",
    "    @property\n",
    "    def threshold(self):\n",
    "        # dynamic threshold computed from log_threshold\n",
    "        return torch.exp(self.log_threshold)\n",
    "    \n",
    "    def encode(self, input_acts):\n",
    "        if self.use_pre_enc_bias:\n",
    "            input_acts = input_acts - self.b_dec\n",
    "        pre_acts = input_acts @ self.W_enc + self.b_enc\n",
    "        acts = torch.relu(pre_acts) * (pre_acts > self.threshold)\n",
    "        return acts\n",
    "    \n",
    "    def decode(self, acts):\n",
    "        return acts @ self.W_dec + self.b_dec\n",
    "    \n",
    "    def forward(self, x):\n",
    "        acts = self.encode(x)\n",
    "        recon = self.decode(acts)\n",
    "        return recon\n",
    "    \n",
    "\n",
    "def largest_alpha(mu_source, K_source, cond_fn, lo=0.0, hi=0.999999, iters=30):\n",
    "    \"\"\"\n",
    "        Perform binary search on [lo, hi] to find the maximum alpha that satisfies cond_fn(mu(alpha), K(alpha)).\n",
    "        mu_source: 1D numpy array or tensor (used for quantile calculation)\n",
    "        K_source:  1D numpy array or tensor (used for quantile calculation)\n",
    "        cond_fn:   callable(mu_alpha, K_alpha) -> bool\n",
    "    \"\"\"\n",
    "    if torch.is_tensor(mu_source):\n",
    "        mu_source = mu_source.detach().float().cpu().numpy().ravel()\n",
    "    else:\n",
    "        mu_source = np.asarray(mu_source, dtype=np.float32).ravel()\n",
    "    if torch.is_tensor(K_source):\n",
    "        K_source = K_source.detach().float().cpu().numpy().ravel()\n",
    "    else:\n",
    "        K_source = np.asarray(K_source, dtype=np.float32).ravel()\n",
    "\n",
    "    # Prevent empty or invalid\n",
    "    if mu_source.size == 0 or K_source.size == 0:\n",
    "        return None\n",
    "\n",
    "    lo_a, hi_a = lo, hi\n",
    "    for _ in range(iters):\n",
    "        mid = (lo_a + hi_a) / 2.0\n",
    "        mu_alpha = np.quantile(mu_source, mid, method=\"linear\")\n",
    "        K_alpha  = np.quantile(K_source,  mid, method=\"linear\")\n",
    "        ok = cond_fn(mu_alpha, K_alpha)\n",
    "        if ok:\n",
    "            lo_a = mid  # Can be bigger\n",
    "        else:\n",
    "            hi_a = mid  # Needs to be smaller\n",
    "    return lo_a\n",
    "\n",
    "\n",
    "# ---------- Traverse and process each layer ----------\n",
    "layer_dirs = sorted(base_dir.glob(\"layer_*\"), key=lambda p: int(re.search(r'\\d+', p.name).group()))\n",
    "for layer_dir in layer_dirs:\n",
    "    layer = re.search(r'layer_(\\d+)', str(layer_dir)).group(1)\n",
    "    print(f\"\\n=== Layer {layer} ===\")\n",
    "    \n",
    "    pt_files = sorted(layer_dir.glob(\"*.pt\"))\n",
    "    print(f\"Total {len(pt_files)} model found in dir {layer_dir}\")\n",
    "    if not pt_files:\n",
    "        continue\n",
    "    pt_path = pt_files[0]\n",
    "    print(f\"📂 Processing {pt_path}\")\n",
    "\n",
    "    # Load model\n",
    "    sae = JumpReLUSAE(d_model, d_sae)\n",
    "    state_dict = torch.load(pt_path, map_location=\"cpu\")\n",
    "    sae.load_state_dict(state_dict)\n",
    "    sae.eval()\n",
    "    print(f\"SAE model loaded successfully from: {pt_path}\")\n",
    "    \n",
    "    W = sae.W_dec.detach().T\n",
    "    print(f\"  W shape: {tuple(W.shape)}\")\n",
    "    layer_dir = f\"../data/Atomicity/Atoms/{model_name}/layer_{layer}\"\n",
    "    atoms_id_path = os.path.join(layer_dir, \"atoms_id.pt\")\n",
    "    if not os.path.exists(atoms_id_path):\n",
    "        print(f\"⚠️ atoms_id not found: {atoms_id_path} (skip layer)\")\n",
    "        continue\n",
    "    atoms_id = torch.load(atoms_id_path, map_location=\"cpu\")\n",
    "    atoms_id = torch.unique(atoms_id)\n",
    "    \n",
    "    valid_mask = (atoms_id >= 0) & (atoms_id < W.shape[1])\n",
    "    if not bool(valid_mask.all()):\n",
    "        invalid = atoms_id[~valid_mask]\n",
    "        print(f\"⚠️ filtered {invalid.numel()} invalid atom ids out of range [0,{W.shape[1]-1}]\")\n",
    "        atoms_id = atoms_id[valid_mask]\n",
    "       \n",
    "    W = W[:, atoms_id]\n",
    "    print(f\"  W shape: {tuple(W.shape)}\")\n",
    "    W_normed = W / W.norm(dim=0, keepdim=True)\n",
    "    print(f\"  W_normed shape: {tuple(W_normed.shape)}\")\n",
    "\n",
    "    M = W_normed @ W_normed.T\n",
    "    try:\n",
    "        M_inv = torch.inverse(M)\n",
    "    except RuntimeError as e:\n",
    "        print(f\"❌ Matrix inversion failed: {e}\")\n",
    "        continue\n",
    "\n",
    "    G = W_normed.T @ M_inv @ W_normed\n",
    "    print(f\"  G shape: {tuple(G.shape)}\")\n",
    "\n",
    "    # Normalization\n",
    "    diag = torch.diagonal(G)\n",
    "    sqrt_diag = torch.sqrt(diag)\n",
    "    denom = torch.outer(sqrt_diag, sqrt_diag)\n",
    "    G_normalized = G / denom\n",
    "    off_mask = ~torch.eye(G_normalized.shape[0], dtype=torch.bool, device=G_normalized.device)\n",
    "    G_norm_wo_diag = G_normalized[off_mask].abs()\n",
    "    print(f\"  off-diag count: {G_norm_wo_diag.numel():,}\")\n",
    "\n",
    "    counts_path = f\"../data/Atomicity/Atoms/{model_name}/layer_{layer}/sparsity.json\"\n",
    "    if not os.path.exists(counts_path):\n",
    "        print(f\"⚠️ activation counts not found: {counts_path} (skip alpha search)\")\n",
    "        continue\n",
    "\n",
    "    with open(counts_path, \"r\") as f:\n",
    "        counts_list = json.load(f)\n",
    "\n",
    "    # RIP condition\n",
    "    def cond1(mu_alpha, K_alpha):\n",
    "        if K_alpha <= 1: \n",
    "            return False\n",
    "        return mu_alpha < 1.0 / (K_alpha - 1.0)\n",
    "\n",
    "    # Uniqueness condition\n",
    "    def cond2(mu_alpha, K_alpha):\n",
    "        if (2.0 * K_alpha - 1.0) <= 0:\n",
    "            return False\n",
    "        return mu_alpha < 1.0 / (2.0 * K_alpha - 1.0)\n",
    "\n",
    "    mu_np_full = G_norm_wo_diag.detach().float().cpu().numpy().ravel()\n",
    "    idx = np.random.choice(mu_np_full.size, size=MAX_Q_SAMPLES, replace=False)\n",
    "    mu_np = mu_np_full[idx]\n",
    "    K_np = np.asarray(counts_list, dtype=np.float32).ravel()\n",
    "    \n",
    "    # Search for the largest alpha\n",
    "    alpha1 = largest_alpha(mu_np, K_np, cond1)\n",
    "    alpha2 = largest_alpha(mu_np, K_np, cond2)\n",
    "\n",
    "    # Calculate μ and K corresponding to this alpha (for printing and verification)\n",
    "    def summarize(alpha):\n",
    "        if alpha is None:\n",
    "            return None, None\n",
    "        mu_a = float(np.quantile(mu_np, alpha, method=\"linear\"))\n",
    "        K_a = float(np.quantile(K_np,  alpha, method=\"linear\"))\n",
    "        return mu_a, K_a\n",
    "\n",
    "    mu1, K1 = summarize(alpha1)\n",
    "    mu2, K2 = summarize(alpha2)\n",
    "\n",
    "    print(\"  α search results:\")\n",
    "    if alpha1 is not None:\n",
    "        print(f\"    Max α for  μ < 1/(K-1):   α={alpha1:.6f} | μ(α)={mu1:.6e} | K(α)={K1:.3f}\")\n",
    "    else:\n",
    "        print(\"    Max α for  μ < 1/(K-1):   not found\")\n",
    "\n",
    "    if alpha2 is not None:\n",
    "        print(f\"    Max α for  μ < 1/(2K-1):  α={alpha2:.6f} | μ(α)={mu2:.6e} | K(α)={K2:.3f}\")\n",
    "    else:\n",
    "        print(\"    Max α for  μ < 1/(2K-1):  not found\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c175466",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
