{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "65fa2b56",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: huggingface_hub==0.36.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (0.36.1)\n",
            "Requirement already satisfied: transformers==4.57.6 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (4.57.6)\n",
            "Requirement already satisfied: transformer_lens==2.17.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.17.0)\n",
            "Requirement already satisfied: filelock in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (3.20.3)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (2025.10.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (1.2.0)\n",
            "Requirement already satisfied: packaging>=20.9 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (6.0.3)\n",
            "Requirement already satisfied: requests in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (2.32.5)\n",
            "Requirement already satisfied: tqdm>=4.42.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (4.67.3)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from huggingface_hub==0.36.1) (4.15.0)\n",
            "Requirement already satisfied: numpy>=1.17 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformers==4.57.6) (1.26.4)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformers==4.57.6) (2026.1.15)\n",
            "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformers==4.57.6) (0.22.2)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformers==4.57.6) (0.7.0)\n",
            "Requirement already satisfied: accelerate>=0.23.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (1.12.0)\n",
            "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.14.1)\n",
            "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.0.3)\n",
            "Requirement already satisfied: datasets>=2.7.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (4.5.0)\n",
            "Requirement already satisfied: einops>=0.6.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.8.2)\n",
            "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.0.3)\n",
            "Requirement already satisfied: jaxtyping>=0.2.11 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.3.7)\n",
            "Requirement already satisfied: pandas>=1.1.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (2.1.4)\n",
            "Requirement already satisfied: protobuf>=3.20.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (6.33.5)\n",
            "Requirement already satisfied: rich>=12.6.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (14.3.2)\n",
            "Requirement already satisfied: sentencepiece in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.2.1)\n",
            "Requirement already satisfied: torch>=2.6 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (2.8.0+cu128)\n",
            "Requirement already satisfied: transformers-stream-generator<0.0.6,>=0.0.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.0.5)\n",
            "Requirement already satisfied: typeguard<5.0,>=4.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (4.4.4)\n",
            "Requirement already satisfied: wandb>=0.13.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from transformer_lens==2.17.0) (0.24.2)\n",
            "Requirement already satisfied: psutil in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from accelerate>=0.23.0->transformer_lens==2.17.0) (7.2.2)\n",
            "Requirement already satisfied: pyarrow>=21.0.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from datasets>=2.7.1->transformer_lens==2.17.0) (23.0.0)\n",
            "Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from datasets>=2.7.1->transformer_lens==2.17.0) (0.4.0)\n",
            "Requirement already satisfied: httpx<1.0.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from datasets>=2.7.1->transformer_lens==2.17.0) (0.28.1)\n",
            "Requirement already satisfied: xxhash in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from datasets>=2.7.1->transformer_lens==2.17.0) (3.6.0)\n",
            "Requirement already satisfied: multiprocess<0.70.19 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from datasets>=2.7.1->transformer_lens==2.17.0) (0.70.18)\n",
            "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (3.13.3)\n",
            "Requirement already satisfied: anyio in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from httpx<1.0.0->datasets>=2.7.1->transformer_lens==2.17.0) (4.12.1)\n",
            "Requirement already satisfied: certifi in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from httpx<1.0.0->datasets>=2.7.1->transformer_lens==2.17.0) (2026.1.4)\n",
            "Requirement already satisfied: httpcore==1.* in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from httpx<1.0.0->datasets>=2.7.1->transformer_lens==2.17.0) (1.0.9)\n",
            "Requirement already satisfied: idna in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from httpx<1.0.0->datasets>=2.7.1->transformer_lens==2.17.0) (3.11)\n",
            "Requirement already satisfied: h11>=0.16 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from httpcore==1.*->httpx<1.0.0->datasets>=2.7.1->transformer_lens==2.17.0) (0.16.0)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (25.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (1.8.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (6.7.1)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (0.4.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets>=2.7.1->transformer_lens==2.17.0) (1.22.0)\n",
            "Requirement already satisfied: wadler-lindig>=0.1.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from jaxtyping>=0.2.11->transformer_lens==2.17.0) (0.1.7)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.1.5->transformer_lens==2.17.0) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.1.5->transformer_lens==2.17.0) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.1.5->transformer_lens==2.17.0) (2025.3)\n",
            "Requirement already satisfied: six>=1.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas>=1.1.5->transformer_lens==2.17.0) (1.17.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests->huggingface_hub==0.36.1) (3.4.4)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests->huggingface_hub==0.36.1) (2.5.0)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from rich>=12.6.0->transformer_lens==2.17.0) (4.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from rich>=12.6.0->transformer_lens==2.17.0) (2.19.2)\n",
            "Requirement already satisfied: mdurl~=0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens==2.17.0) (0.1.2)\n",
            "Requirement already satisfied: setuptools in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (80.10.2)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (1.14.0)\n",
            "Requirement already satisfied: networkx in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (3.6.1)\n",
            "Requirement already satisfied: jinja2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (3.1.6)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.8.93)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.8.90)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.8.90)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.8.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (11.3.3.83)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (10.3.9.90)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (11.7.3.90)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.5.8.93)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (2.27.3)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.8.90)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (12.8.93)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (1.13.1.3)\n",
            "Requirement already satisfied: triton==3.4.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch>=2.6->transformer_lens==2.17.0) (3.4.0)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from sympy>=1.13.3->torch>=2.6->transformer_lens==2.17.0) (1.3.0)\n",
            "Requirement already satisfied: click>=8.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from wandb>=0.13.5->transformer_lens==2.17.0) (8.3.1)\n",
            "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from wandb>=0.13.5->transformer_lens==2.17.0) (3.1.46)\n",
            "Requirement already satisfied: platformdirs in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from wandb>=0.13.5->transformer_lens==2.17.0) (4.5.1)\n",
            "Requirement already satisfied: pydantic<3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from wandb>=0.13.5->transformer_lens==2.17.0) (2.12.5)\n",
            "Requirement already satisfied: sentry-sdk>=2.0.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from wandb>=0.13.5->transformer_lens==2.17.0) (2.52.0)\n",
            "Requirement already satisfied: annotated-types>=0.6.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pydantic<3->wandb>=0.13.5->transformer_lens==2.17.0) (0.7.0)\n",
            "Requirement already satisfied: pydantic-core==2.41.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pydantic<3->wandb>=0.13.5->transformer_lens==2.17.0) (2.41.5)\n",
            "Requirement already satisfied: typing-inspection>=0.4.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pydantic<3->wandb>=0.13.5->transformer_lens==2.17.0) (0.4.2)\n",
            "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens==2.17.0) (4.0.12)\n",
            "Requirement already satisfied: smmap<6,>=3.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens==2.17.0) (5.0.2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from jinja2->torch>=2.6->transformer_lens==2.17.0) (3.0.3)\n",
            "\n",
            "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m26.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
            "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
            "Note: you may need to restart the kernel to use updated packages.\n"
          ]
        }
      ],
      "source": [
        "pip install huggingface_hub==0.36.1 transformers==4.57.6 transformer_lens==2.17.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "f958edf3",
      "metadata": {},
      "outputs": [],
      "source": [
        "import huggingface_hub\n",
        "assert huggingface_hub.__version__ == \"0.36.1\", \\\n",
        "    f\"RESTART KERNEL! Got {huggingface_hub.__version__}, need 0.36.1\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "3759ffbc",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 1: Imports\n",
        "import torch\n",
        "from huggingface_hub import list_repo_refs\n",
        "from transformer_lens import HookedTransformer\n",
        "import time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "e95a5791",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0, 15000, 30000, 60000, 90000, 120000, 140000, 143000]\n"
          ]
        }
      ],
      "source": [
        "from huggingface_hub import list_repo_refs\n",
        "\n",
        "# check if these steps exist for Pythia-160M\n",
        "steps = [0, 15000, 30000, 60000, 90000, 120000, 140000, 143000]\n",
        "available_refs = list_repo_refs(\"EleutherAI/pythia-160m-deduped\")\n",
        "\n",
        "# combine branches and tags into a single list to iterate over\n",
        "all_refs = available_refs.branches + available_refs.tags\n",
        "\n",
        "# filter for step checkpoints\n",
        "step_refs = [r for r in all_refs if \"step\" in str(r)]\n",
        "print([s for s in steps if any(str(s) in str(r) for r in step_refs)])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 2: Check GPU\n",
        "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
        "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}\")\n",
        "print(\n",
        "    f\"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\"\n",
        "    if torch.cuda.is_available()\n",
        "    else \"N/A\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 3: Verify checkpoint availability\n",
        "MODELS = [\n",
        "    \"EleutherAI/pythia-160m-deduped\",\n",
        "    \"EleutherAI/pythia-1b-deduped\",\n",
        "    \"EleutherAI/pythia-2.8b-deduped\",\n",
        "]\n",
        "CHECKPOINT_INDICES = [0, 15, 30, 60, 90, 120, 140, 150, 152, 153]\n",
        "\n",
        "for model in MODELS:\n",
        "    print(f\"\\n=== {model} ===\")\n",
        "    refs = list_repo_refs(model)\n",
        "    available = [r.name for r in refs.branches if \"step\" in r.name]\n",
        "    required = [f\"step{i * 1000}\" for i in CHECKPOINT_INDICES]\n",
        "    missing = set(required) - set(available)\n",
        "    print(f\"Available: {len(available)} step revisions\")\n",
        "    print(f\"Missing required: {missing if missing else 'None'}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 4: Test model loading (smallest first)\n",
        "# Load model directly\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\"EleutherAI/pythia-160m\")\n",
        "print(\"model loaded successfully!\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 5: Test checkpoint loading (corrected — uses revision parameter)\n",
        "try:\n",
        "    from transformers import AutoModelForCausalLM as _AMCLM\n",
        "\n",
        "    hf_model = _AMCLM.from_pretrained(\n",
        "        \"EleutherAI/pythia-160m-deduped\",\n",
        "        revision=\"step30000\",\n",
        "        torch_dtype=torch.float32,\n",
        "    )\n",
        "    model_step = HookedTransformer.from_pretrained(\n",
        "        \"EleutherAI/pythia-160m-deduped\",\n",
        "        hf_model=hf_model,\n",
        "        tokenizer=AutoTokenizer.from_pretrained(\n",
        "            \"EleutherAI/pythia-160m-deduped\", revision=\"step30000\"\n",
        "        ),\n",
        "    )\n",
        "    print(\"✅ Checkpoint loading works (via HF revision parameter)!\")\n",
        "    del hf_model, model_step\n",
        "except Exception as e:\n",
        "    print(f\"❌ Checkpoint loading failed: {e}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 6: Memory test with 2.8B\n",
        "print(\"\\nLoading pythia-2.8b (this is the critical test)...\")\n",
        "start = time.time()\n",
        "model_2_8b = HookedTransformer.from_pretrained(\"pythia-2.8b-deduped\")\n",
        "load_time = time.time() - start\n",
        "print(f\"Loaded in {load_time:.1f}s\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Cell 7: Attention extraction speed test\n",
        "test_input = \"A screen reader is\"\n",
        "tokens = model_2_8b.to_tokens(test_input)\n",
        "\n",
        "start = time.time()\n",
        "logits, cache = model_2_8b.run_with_cache(tokens)\n",
        "cache_time = time.time() - start\n",
        "print(f\"Cache extraction: {cache_time:.2f}s\")\n",
        "print(\n",
        "    f\"Cache memory: {sum(v.element_size() * v.nelement() for v in cache.values()) / 1e6:.1f} MB\"\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.x"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
