{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dictionary_learning.dictionary import (\n",
    "    AutoEncoder,\n",
    "    GatedAutoEncoder,\n",
    "    AutoEncoderNew,\n",
    "    JumpReluAutoEncoder,\n",
    ")\n",
    "from dictionary_learning.trainers.top_k import AutoEncoderTopK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dslabra5/.conda/envs/sae4dlm/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Fetching 4 files: 100%|██████████| 4/4 [18:43<00:00, 280.91s/it] \n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.71s/it]\n",
      "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What is your name?assed\n",
      "What is your name?assed\n",
      "What is your name?assed\n",
      "What is your name?assed\n",
      "What is your name?assed\n",
      "What is your name?\n"
     ]
    }
   ],
   "source": [
    "# Load model directly\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-7B\")\n",
    "messages = [\n",
    "    {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
    "]\n",
    "inputs = tokenizer.apply_chat_template(\n",
    "\tmessages,\n",
    "\tadd_generation_prompt=True,\n",
    "\ttokenize=True,\n",
    "\treturn_dict=True,\n",
    "\treturn_tensors=\"pt\",\n",
    ").to(model.device)\n",
    "\n",
    "outputs = model.generate(**inputs, max_new_tokens=40)\n",
    "print(tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[-1]:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[load] /home/dslabra5/sae4dlm/dictionary_learning_demo/saes__Qwen_Qwen2.5-7B_top_k/resid_post_layer_1/trainer_0/ae.pt\n",
      "[top] ckpt: dict(len=6) keys(sample)=['b_dec', 'k', 'threshold', 'decoder.weight', 'encoder.weight', 'encoder.bias']\n",
      "\n",
      "[top keys]\n",
      " - b_dec                    : Tensor: tensor([-0.0149,  0.0329, -0.0190,  ...,  0.0298, -0.0419,  0.0128])\n",
      " - k                        : Tensor: tensor(50, dtype=torch.int32)\n",
      " - threshold                : Tensor: tensor(0.1899)\n",
      " - decoder.weight           : Tensor: tensor([[-2.9374e-03,  2.0523e-03,  1.4573e-02,  ...,  8.6945e-03,\n",
      "         -2.5500e-03,  1.2732e-02],\n",
      "        [ 1.2998e-02, -1.2335e-03, -8.6354e-04,  ...,  6.0299e-03,\n",
      "          1.0892e-02, -3.1802e\n",
      " - encoder.weight           : Tensor: tensor([[-0.0316,  0.0187,  0.0128,  ...,  0.0049,  0.0048,  0.0091],\n",
      "        [-0.0194, -0.0083,  0.0408,  ...,  0.0377, -0.0509,  0.0156],\n",
      "        [ 0.0073, -0.0077,  0.0291,  ...,  0.0166,  0.0363, \n",
      " - encoder.bias             : Tensor: tensor([-1.0703, -0.6782, -1.0128,  ..., -1.4562, -1.3037, -1.2682])\n",
      "\n",
      "[state_dict] located at: ckpt(itself)\n",
      "[state_dict] num keys: 6\n",
      "\n",
      "[state_dict] top-level prefix frequency (top 20):\n",
      " - encoder              2\n",
      " - b_dec                1\n",
      " - k                    1\n",
      " - threshold            1\n",
      " - decoder              1\n",
      "\n",
      "[state_dict] key samples (first 60):\n",
      " - b_dec                                                         shape=(3584,)\n",
      " - k                                                             shape=()\n",
      " - threshold                                                     shape=()\n",
      " - decoder.weight                                                shape=(3584, 16384)\n",
      " - encoder.weight                                                shape=(16384, 3584)\n",
      " - encoder.bias                                                  shape=(16384,)\n",
      "\n",
      "[check] contains 'encoder.weight'? -> True\n",
      "\n",
      "[candidates] likely encoder weight keys:\n",
      " - score=11  encoder.weight                                                shape=(16384, 3584)  dtype=torch.float32\n",
      " - score= 4  decoder.weight                                                shape=(3584, 16384)  dtype=torch.float32\n",
      "\n",
      "[encoder.weight] shape=(16384, 3584) dtype=torch.float32\n",
      "[inferred] dict_size=16384, activation_dim=3584\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import re\n",
    "import torch\n",
    "from collections import Counter\n",
    "\n",
    "def human(x, max_items=20):\n",
    "    if isinstance(x, dict):\n",
    "        return f\"dict(len={len(x)}) keys(sample)={list(x.keys())[:max_items]}\"\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return f\"{type(x).__name__}(len={len(x)}) sample={x[:3]}\"\n",
    "    return f\"{type(x).__name__}: {str(x)[:200]}\"\n",
    "\n",
    "def strip_common_prefixes(sd):\n",
    "    # 常见前缀：module. / model. / net. / sae. 等（只做轻量处理）\n",
    "    prefixes = [\"module.\", \"model.\", \"net.\", \"sae.\", \"ae.\"]\n",
    "    # 找到出现频率最高的前缀（如果存在）\n",
    "    hits = {p: sum(k.startswith(p) for k in sd.keys()) for p in prefixes}\n",
    "    best = max(hits, key=hits.get)\n",
    "    if hits[best] > 0:\n",
    "        print(f\"[info] Detected prefix '{best}' in {hits[best]} keys. Stripping it.\")\n",
    "        sd = { (k[len(best):] if k.startswith(best) else k): v for k, v in sd.items() }\n",
    "    return sd\n",
    "\n",
    "def find_state_dict(obj):\n",
    "    \"\"\"\n",
    "    返回：(state_dict, path_description)\n",
    "    兼容：\n",
    "      - 直接就是 state_dict\n",
    "      - {'state_dict': ...}\n",
    "      - {'model_state_dict': ...}\n",
    "      - Lightning: {'state_dict': ...}\n",
    "      - 其他：尝试在顶层 dict 里找 “像 state_dict 的 dict”\n",
    "    \"\"\"\n",
    "    if isinstance(obj, dict):\n",
    "        for k in [\"state_dict\", \"model_state_dict\", \"model\", \"weights\", \"params\"]:\n",
    "            if k in obj and isinstance(obj[k], dict):\n",
    "                # 判断是否“像权重字典”：key 大多是字符串、value 大多是 tensor\n",
    "                v = obj[k]\n",
    "                tensor_like = sum(isinstance(x, torch.Tensor) for x in v.values())\n",
    "                if tensor_like >= max(1, len(v) // 4):\n",
    "                    return v, f\"ckpt['{k}']\"\n",
    "        # 否则：找最像的那个子 dict\n",
    "        best_key, best_score = None, -1\n",
    "        for k, v in obj.items():\n",
    "            if isinstance(v, dict):\n",
    "                tensor_like = sum(isinstance(x, torch.Tensor) for x in v.values())\n",
    "                score = tensor_like\n",
    "                if score > best_score:\n",
    "                    best_key, best_score = k, score\n",
    "        if best_key is not None and best_score > 0:\n",
    "            return obj[best_key], f\"ckpt['{best_key}'](heuristic)\"\n",
    "    # 如果本身就像 state_dict\n",
    "    if isinstance(obj, dict):\n",
    "        tensor_like = sum(isinstance(x, torch.Tensor) for x in obj.values())\n",
    "        if tensor_like >= max(1, len(obj) // 4):\n",
    "            return obj, \"ckpt(itself)\"\n",
    "    return None, None\n",
    "\n",
    "def rank_encoder_candidates(sd, topk=30):\n",
    "    \"\"\"\n",
    "    从 state_dict 里找“可能是 encoder 权重”的键：\n",
    "      - 包含 'encoder' 且以 '.weight' 结尾\n",
    "      - 或者包含 'enc' 且 '.weight'\n",
    "      - 或者包含 'W_enc'\n",
    "    并按“像线性层权重(2D tensor)”优先排序\n",
    "    \"\"\"\n",
    "    candidates = []\n",
    "    for k, v in sd.items():\n",
    "        if not isinstance(v, torch.Tensor):\n",
    "            continue\n",
    "        lk = k.lower()\n",
    "        score = 0\n",
    "        if \"encoder\" in lk and lk.endswith(\"weight\"):\n",
    "            score += 5\n",
    "        if re.search(r\"\\benc\\b\", lk) and \"weight\" in lk:\n",
    "            score += 3\n",
    "        if \"w_enc\" in lk or \"wenc\" in lk:\n",
    "            score += 4\n",
    "        if \"encoder\" in lk and \".weight\" in lk:\n",
    "            score += 2\n",
    "        if v.ndim == 2:\n",
    "            score += 3\n",
    "        if v.ndim == 2 and (v.shape[0] > 1 and v.shape[1] > 1):\n",
    "            score += 1\n",
    "        if score > 0:\n",
    "            candidates.append((score, k, tuple(v.shape), v.dtype))\n",
    "    candidates.sort(reverse=True, key=lambda x: x[0])\n",
    "    return candidates[:topk]\n",
    "\n",
    "def main(ae_path, device=\"cpu\"):\n",
    "    assert os.path.exists(ae_path), f\"File not found: {ae_path}\"\n",
    "    print(f\"[load] {ae_path}\")\n",
    "    ckpt = torch.load(ae_path, map_location=device)\n",
    "    print(\"[top] ckpt:\", human(ckpt))\n",
    "\n",
    "    if isinstance(ckpt, dict):\n",
    "        print(\"\\n[top keys]\")\n",
    "        for k in list(ckpt.keys())[:50]:\n",
    "            v = ckpt[k]\n",
    "            print(f\" - {k:24s} : {human(v, max_items=10)}\")\n",
    "\n",
    "    sd, sd_path = find_state_dict(ckpt)\n",
    "    if sd is None:\n",
    "        print(\"\\n[error] Could not locate a tensor state_dict inside this checkpoint.\")\n",
    "        return\n",
    "\n",
    "    print(f\"\\n[state_dict] located at: {sd_path}\")\n",
    "    print(f\"[state_dict] num keys: {len(sd)}\")\n",
    "\n",
    "    # 统计前缀情况\n",
    "    prefix_counts = Counter(k.split(\".\")[0] for k in sd.keys() if isinstance(k, str))\n",
    "    print(\"\\n[state_dict] top-level prefix frequency (top 20):\")\n",
    "    for p, c in prefix_counts.most_common(20):\n",
    "        print(f\" - {p:20s} {c}\")\n",
    "\n",
    "    # 打印 key 样例\n",
    "    print(\"\\n[state_dict] key samples (first 60):\")\n",
    "    for k in list(sd.keys())[:60]:\n",
    "        v = sd[k]\n",
    "        shape = tuple(v.shape) if isinstance(v, torch.Tensor) else None\n",
    "        print(f\" - {k:60s}  shape={shape}\")\n",
    "\n",
    "    # 处理常见前缀\n",
    "    sd2 = strip_common_prefixes(sd)\n",
    "\n",
    "    # 直接检查目标键\n",
    "    print(\"\\n[check] contains 'encoder.weight'? ->\", \"encoder.weight\" in sd2)\n",
    "\n",
    "    # 候选 encoder 权重\n",
    "    print(\"\\n[candidates] likely encoder weight keys:\")\n",
    "    for score, k, shape, dtype in rank_encoder_candidates(sd2, topk=30):\n",
    "        print(f\" - score={score:2d}  {k:60s}  shape={shape}  dtype={dtype}\")\n",
    "\n",
    "    # 如果找到了 encoder.weight，顺便给出它的 shape\n",
    "    if \"encoder.weight\" in sd2:\n",
    "        w = sd2[\"encoder.weight\"]\n",
    "        print(f\"\\n[encoder.weight] shape={tuple(w.shape)} dtype={w.dtype}\")\n",
    "        if w.ndim == 2:\n",
    "            dict_size, activation_dim = w.shape\n",
    "            print(f\"[inferred] dict_size={dict_size}, activation_dim={activation_dim}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # 把这里换成你的路径\n",
    "    ae_path = \"/home/dslabra5/sae4dlm/dictionary_learning_demo/saes__Qwen_Qwen2.5-7B_top_k/resid_post_layer_1/trainer_0/ae.pt\"\n",
    "    main(ae_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = \"cpu\"\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "d_model = 100\n",
    "\n",
    "torch.manual_seed(1)\n",
    "\n",
    "scale = 4\n",
    "\n",
    "x = torch.randn(1000, d_model, device=device)\n",
    "\n",
    "x_scaled = x / scale\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jumprelu_ae = JumpReluAutoEncoder(activation_dim=d_model, dict_size=d_model * 8, device=device)\n",
    "\n",
    "jumprelu_ae.b_enc.data = torch.randn_like(jumprelu_ae.b_enc.data)\n",
    "jumprelu_ae.b_dec.data = torch.randn_like(jumprelu_ae.b_dec.data)\n",
    "jumprelu_ae.threshold.data = abs(torch.randn_like(jumprelu_ae.threshold.data))\n",
    "\n",
    "reconstruction_1 = jumprelu_ae(x_scaled)\n",
    "\n",
    "def scale_jumprelu(ae: JumpReluAutoEncoder, scale: float):\n",
    "    ae.b_dec.data *= scale\n",
    "    ae.b_enc.data *= scale\n",
    "    ae.threshold.data *= scale\n",
    "\n",
    "print(jumprelu_ae.threshold.mean())\n",
    "scale_jumprelu(jumprelu_ae, (scale))\n",
    "print(jumprelu_ae.threshold.mean())\n",
    "\n",
    "reconstruction_2 = jumprelu_ae(x)\n",
    "\n",
    "reconstruction_1 = reconstruction_1 * scale\n",
    "\n",
    "diff = torch.abs(reconstruction_1 - reconstruction_2)\n",
    "print(f\"max diff: {diff.max()}, mean diff: {diff.mean()}\")\n",
    "\n",
    "assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gated_ae = GatedAutoEncoder(activation_dim=d_model, dict_size=d_model * 8, device=device)\n",
    "\n",
    "gated_ae.r_mag.data = torch.randn_like(gated_ae.r_mag.data)\n",
    "gated_ae.decoder_bias.data = torch.randn_like(gated_ae.decoder_bias.data)\n",
    "gated_ae.mag_bias.data = torch.randn_like(gated_ae.mag_bias.data)\n",
    "gated_ae.gate_bias.data = torch.randn_like(gated_ae.gate_bias.data)\n",
    "\n",
    "reconstruction_1 = gated_ae(x_scaled)\n",
    "\n",
    "def scale_gated(ae: GatedAutoEncoder, scale: float):\n",
    "    ae.decoder_bias.data *= scale\n",
    "    ae.mag_bias.data *= scale\n",
    "    ae.gate_bias.data *= scale\n",
    "\n",
    "print(gated_ae.r_mag.mean(), gated_ae.decoder_bias.mean(), gated_ae.mag_bias.mean(), gated_ae.gate_bias.mean())\n",
    "scale_gated(gated_ae, (scale))\n",
    "scale_gated(gated_ae, (1 / scale))\n",
    "scale_gated(gated_ae, (scale))\n",
    "\n",
    "\n",
    "print(gated_ae.r_mag.mean(), gated_ae.decoder_bias.mean(), gated_ae.mag_bias.mean(), gated_ae.gate_bias.mean())\n",
    "\n",
    "reconstruction_2 = gated_ae(x)\n",
    "\n",
    "reconstruction_1 = reconstruction_1 * scale\n",
    "\n",
    "diff = torch.abs(reconstruction_1 - reconstruction_2)\n",
    "\n",
    "print(f\"max diff: {diff.max()}, mean diff: {diff.mean()}\")\n",
    "assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relu_ae = AutoEncoder(activation_dim=d_model, dict_size=d_model * 8)\n",
    "relu_ae = relu_ae.to(device)\n",
    "\n",
    "# relu_ae.encoder.bias.data = torch.randn_like(relu_ae.decoder.bias.data)\n",
    "relu_ae.bias.data = torch.randn_like(relu_ae.bias.data)\n",
    "\n",
    "reconstruction_1 = relu_ae(x_scaled)\n",
    "\n",
    "def scale_relu(ae: AutoEncoder, scale: float):\n",
    "    ae.encoder.bias.data *= scale\n",
    "    ae.bias.data *= scale\n",
    "\n",
    "\n",
    "print(relu_ae.bias.mean())\n",
    "scale_relu(relu_ae, (scale))\n",
    "print(relu_ae.bias.mean())\n",
    "\n",
    "reconstruction_2 = relu_ae(x)\n",
    "\n",
    "reconstruction_1 = reconstruction_1 * scale\n",
    "\n",
    "diff = torch.abs(reconstruction_1 - reconstruction_2)\n",
    "\n",
    "print(f\"max diff: {diff.max()}, mean diff: {diff.mean()}\")\n",
    "assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_ae = AutoEncoderTopK(activation_dim=d_model, dict_size=d_model * 8, k=20)\n",
    "\n",
    "topk_ae.encoder.bias.data = torch.randn_like(topk_ae.encoder.bias.data)\n",
    "topk_ae.b_dec.data = torch.randn_like(topk_ae.b_dec.data)\n",
    "print(topk_ae.threshold)\n",
    "topk_ae.threshold = abs(torch.randn_like(topk_ae.threshold))\n",
    "print(topk_ae.threshold)\n",
    "\n",
    "reconstruction_1 = topk_ae(x_scaled)\n",
    "\n",
    "def scale_topk(ae: AutoEncoderTopK, scale: float):\n",
    "    ae.encoder.bias.data *= scale\n",
    "    ae.b_dec.data *= scale\n",
    "    if ae.threshold >= 0:\n",
    "        ae.threshold *= scale\n",
    "\n",
    "print(topk_ae.encoder.bias.mean(), topk_ae.b_dec.mean())\n",
    "scale_topk(topk_ae, (scale))\n",
    "print(topk_ae.encoder.bias.mean(), topk_ae.b_dec.mean())\n",
    "\n",
    "reconstruction_2 = topk_ae(x)\n",
    "\n",
    "reconstruction_1 = reconstruction_1 * scale\n",
    "\n",
    "diff = torch.abs(reconstruction_1 - reconstruction_2)\n",
    "\n",
    "print(f\"max diff: {diff.max()}, mean diff: {diff.mean()}\")\n",
    "assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae4dlm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
