{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4a9b9c24",
   "metadata": {},
   "source": [
    "# Checkpoint Verification v2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d531f9e",
   "metadata": {},
   "source": [
    "### Checkpoint Verification Summary\n",
    "\n",
    "This configuration reflects corrected checkpoint availability based on findings from  \n",
    "`verify_setup.ipynb`.\n",
    "\n",
    "**What’s been verified:**\n",
    "- Uses the HuggingFace `revision` parameter  \n",
    "  *(instead of the deprecated `checkpoint_value`)*\n",
    "- Tests **8 verified checkpoints** rather than the original 10\n",
    "- Confirms successful loading through the  \n",
    "  **HuggingFace → HookedTransformer** pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40f1c130",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install transformers==4.57.6 transformer_lens==2.17.0 accelerate==1.12.0 huggingface_hub==0.36.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f76a8fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers, huggingface_hub, accelerate\n",
    "assert transformers.__version__ == \"4.57.6\", f\"RESTART KERNEL! Got transformers {transformers.__version__}\"\n",
    "assert huggingface_hub.__version__ == \"0.36.1\", f\"RESTART KERNEL! Got huggingface_hub {huggingface_hub.__version__}\"\n",
    "print(f\"✅ transformers={transformers.__version__}, huggingface_hub={huggingface_hub.__version__}, accelerate={accelerate.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f5aa5cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import time\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c5a060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Corrected checkpoint map (8 verified checkpoints)\n",
    "CHECKPOINT_STEPS = [\n",
    "    \"step0\",\n",
    "    \"step15000\",\n",
    "    \"step30000\",\n",
    "    \"step60000\",\n",
    "    \"step90000\",\n",
    "    \"step120000\",\n",
    "    \"step140000\",\n",
    "    \"step143000\",  # Final checkpoint (proxy for \"late training\")\n",
    "]\n",
    "\n",
    "MODEL_SIZES = [\"160m\", \"1b\", \"2.8b\"]\n",
    "\n",
    "# Quick availability check: test 3 representative checkpoints per model\n",
    "SPOT_CHECK_STEPS = [\"step0\", \"step30000\", \"step143000\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d4d85c5",
   "metadata": {},
   "source": [
    "## Step 1: Spot-Check Checkpoint Availability (HF Download Only)\n",
    "\n",
    "Test **early**, **mid**, and **late** checkpoints across all **three model sizes** using  \n",
    "`AutoModelForCausalLM` with the HuggingFace `revision` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e7fba6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "\n",
    "for size in MODEL_SIZES:\n",
    "    for step in SPOT_CHECK_STEPS:\n",
    "        model_name = f\"EleutherAI/pythia-{size}-deduped\"\n",
    "        key = f\"{size}/{step}\"\n",
    "        try:\n",
    "            start = time.time()\n",
    "            model = AutoModelForCausalLM.from_pretrained(\n",
    "                model_name,\n",
    "                revision=step,\n",
    "                dtype=torch.float32,        # was: torch_dtype (deprecated in 5.1.0)\n",
    "                device_map=\"auto\",\n",
    "                use_safetensors=False,       # suppress auto-conversion thread errors\n",
    "            )\n",
    "            elapsed = time.time() - start\n",
    "            param_count = sum(p.numel() for p in model.parameters()) / 1e6\n",
    "            results[key] = {\"status\": \"OK\", \"time_s\": elapsed, \"params_M\": param_count}\n",
    "            print(f\"✅ {key}: {elapsed:.1f}s, {param_count:.0f}M params\")\n",
    "            del model\n",
    "            if torch.cuda.is_available():\n",
    "                torch.cuda.empty_cache()\n",
    "        except Exception as e:\n",
    "            results[key] = {\"status\": \"FAIL\", \"error\": str(e)}\n",
    "            print(f\"❌ {key}: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d5c913b",
   "metadata": {},
   "source": [
    "## Step 2: Verify All 8 Checkpoints (160M Model)\n",
    "\n",
    "Run a full sweep across all **8 corrected checkpoints** using the **smallest (160M)** model  \n",
    "to confirm availability before scaling up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df64ae1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_160m = \"EleutherAI/pythia-160m-deduped\"\n",
    "all_ok = True\n",
    "\n",
    "for step in CHECKPOINT_STEPS:\n",
    "    try:\n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name_160m,\n",
    "            revision=step,\n",
    "            dtype=torch.float32,        # was: torch_dtype\n",
    "            use_safetensors=False,       # suppress auto-conversion thread errors\n",
    "        )\n",
    "        print(f\"✅ 160m {step}: OK\")\n",
    "        del model\n",
    "    except Exception as e:\n",
    "        all_ok = False\n",
    "        print(f\"❌ 160m {step}: {e}\")\n",
    "\n",
    "print(f\"\\nAll 8 checkpoints available: {'YES' if all_ok else 'NO'}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f907553b",
   "metadata": {},
   "source": [
    "## Step 3: End-to-End HookedTransformer Loading Test\n",
    "\n",
    "Verify the full pipeline from **HuggingFace model loading** to  \n",
    "**HookedTransformer initialization** using the `revision` parameter,  \n",
    "implemented via `src/utils_model.py`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4f8d96e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.insert(0, \"..\")\n",
    "from src.utils_model import load_pythia_with_checkpoint\n",
    "\n",
    "# Test the full pipeline on 160m at step30000\n",
    "start = time.time()\n",
    "hooked_model = load_pythia_with_checkpoint(\"160m\", \"step30000\")\n",
    "elapsed = time.time() - start\n",
    "print(f\"✅ HookedTransformer loaded in {elapsed:.1f}s\")\n",
    "print(f\"   Model: {hooked_model.cfg.model_name}\")\n",
    "print(f\"   Layers: {hooked_model.cfg.n_layers}\")\n",
    "print(f\"   Device: {hooked_model.cfg.device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f7a2a79",
   "metadata": {},
   "source": [
    "## Step 4: Attention Extraction Sanity Check\n",
    "\n",
    "Confirm that **attention cache extraction** works correctly with models  \n",
    "loaded via the `revision` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2701c615",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input = \"A screen reader is\"\n",
    "tokens = hooked_model.to_tokens(test_input)\n",
    "\n",
    "start = time.time()\n",
    "logits, cache = hooked_model.run_with_cache(tokens)\n",
    "cache_time = time.time() - start\n",
    "\n",
    "cache_bytes = sum(v.element_size() * v.nelement() for v in cache.values())\n",
    "print(f\"✅ Attention cache extracted in {cache_time:.2f}s\")\n",
    "print(f\"   Cache memory: {cache_bytes / 1e6:.1f} MB\")\n",
    "print(f\"   Token count: {tokens.shape[1]}\")\n",
    "print(f\"   Logits shape: {logits.shape}\")\n",
    "\n",
    "# Verify attention pattern shape\n",
    "attn_key = [k for k in cache.keys() if \"attn\" in k and \"pattern\" in k][0]\n",
    "print(f\"   Attention pattern shape: {cache[attn_key].shape}\")\n",
    "\n",
    "del hooked_model, logits, cache\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f17f1b",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Run all cells above. Expected outcomes are listed below.\n",
    "\n",
    "| Check | Expected Result |\n",
    "|------:|-----------------|\n",
    "| Spot-check (3 sizes × 3 steps) | All 9 combinations pass ✅ |\n",
    "| Full 8-checkpoint sweep (160M) | All 8 checkpoints pass ✅ |\n",
    "| HookedTransformer pipeline | Loads successfully via `utils_model.py` ✅ |\n",
    "| Attention extraction | Cache extracted; attention pattern shapes correct ✅ |\n",
    "\n",
    "**If all checks pass → proceed to _Stage 1A (Behavioral Probes)_:**  \n",
    "**8 checkpoints × 3 models × 3 terms = 72 total combinations.**"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
