{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b8e252ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dslabra5/.conda/envs/sae4dlm/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "A new version of the following files was downloaded from https://huggingface.co/Dream-org/Dream-v0-Base-7B:\n",
      "- tokenization_dream.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/Dream-org/Dream-v0-Base-7B:\n",
      "- configuration_dream.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "`torch_dtype` is deprecated! Use `dtype` instead!\n",
      "A new version of the following files was downloaded from https://huggingface.co/Dream-org/Dream-v0-Base-7B:\n",
      "- modeling_dream.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "Fetching 4 files: 100%|██████████| 4/4 [17:36<00:00, 264.03s/it] \n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 79.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer 5 resid_pre:  dtype=torch.float32 shape=(1, 58, 3584) device=cuda:0\n",
      "Layer 5 resid_post: dtype=torch.float32 shape=(1, 58, 3584) device=cuda:0\n",
      "Relative change ||post - pre|| / ||pre|| = 0.918909\n",
      "resid_pre  stats: min=-14.688 max=5.469 mean=0.001957 std=0.390240\n",
      "resid_post stats: min=-67.500 max=31.000 mean=0.000817 std=0.545529\n"
     ]
    }
   ],
   "source": [
    "#!/usr/bin/env python3\n",
    "# Minimal check: verify that a forward hook on model.layers[L] captures the layer's\n",
    "# output residual stream (\"resid_post\"), and compare it to the pre-hook input (\"resid_pre\").\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "\n",
    "MODEL = \"Dream-org/Dream-v0-Base-7B\"   # or \"Dream-org/Dream-v0-Instruct-7B\"\n",
    "DTYPE = torch.bfloat16\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "TARGET_LAYER = 5                        # change to any layer index you want\n",
    "TEXT = (\n",
    "    \"Chandra and VLA observations of the symbiotic star R Aqr in 2004 reveal \"\n",
    "    \"significant changes over the three to four year interval between these \"\n",
    "    \"observations and previous observations taken with the VLA in 1999 and with \"\n",
    "    \"Chandra in 2000.\"\n",
    ")\n",
    "\n",
    "def normalize_output(out):\n",
    "    if isinstance(out, torch.Tensor):\n",
    "        return out\n",
    "    if isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):\n",
    "        return out[0]\n",
    "    if isinstance(out, dict):\n",
    "        for key in (\"last_hidden_state\", \"hidden_states\", \"logits\"):\n",
    "            v = out.get(key, None)\n",
    "            if isinstance(v, torch.Tensor):\n",
    "                return v\n",
    "            if key == \"hidden_states\" and isinstance(v, (tuple, list)) and len(v) > 0 and isinstance(v[-1], torch.Tensor):\n",
    "                return v[-1]\n",
    "        for v in out.values():\n",
    "            if isinstance(v, torch.Tensor):\n",
    "                return v\n",
    "    if hasattr(out, \"last_hidden_state\"):\n",
    "        return out.last_hidden_state\n",
    "    if hasattr(out, \"hidden_states\") and isinstance(out.hidden_states, (tuple, list)) and len(out.hidden_states) > 0:\n",
    "        if isinstance(out.hidden_states[-1], torch.Tensor):\n",
    "            return out.hidden_states[-1]\n",
    "    if hasattr(out, \"logits\"):\n",
    "        return out.logits\n",
    "    raise TypeError(f\"Cannot normalize output of type: {type(out)}\")\n",
    "\n",
    "def main():\n",
    "    tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True, use_fast=True)\n",
    "    if tok.pad_token is None:\n",
    "        tok.pad_token = tok.eos_token\n",
    "\n",
    "    model = AutoModel.from_pretrained(MODEL, trust_remote_code=True, torch_dtype=DTYPE).to(DEVICE).eval()\n",
    "\n",
    "    # Prepare a single unpadded sample (keep it simple)\n",
    "    enc = tok(\n",
    "        TEXT,\n",
    "        return_tensors=\"pt\",\n",
    "        padding=False,\n",
    "        truncation=True,\n",
    "        max_length=256,\n",
    "        return_attention_mask=True,\n",
    "    )\n",
    "    enc = {k: v.to(DEVICE) for k, v in enc.items()}\n",
    "    if \"attention_mask\" in enc:\n",
    "        enc[\"attention_mask\"] = enc[\"attention_mask\"].to(torch.bool)\n",
    "\n",
    "    # Grab the target layer module directly (HF uses ModuleList at model.model.layers)\n",
    "    layer_mod = model.model.layers[TARGET_LAYER]\n",
    "\n",
    "    cache = {}\n",
    "\n",
    "    def pre_hook(_, inputs):\n",
    "        # inputs is a tuple; first arg is hidden_states (residual pre)\n",
    "        x = inputs[0]\n",
    "        cache[\"resid_pre\"] = x.detach()\n",
    "\n",
    "    def post_hook(_, __, output):\n",
    "        # output is the layer's forward result (residual post)\n",
    "        y = normalize_output(output)\n",
    "        cache[\"resid_post\"] = y.detach()\n",
    "\n",
    "    h_pre = layer_mod.register_forward_pre_hook(pre_hook)\n",
    "    h_post = layer_mod.register_forward_hook(post_hook)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        _ = model(**enc)\n",
    "\n",
    "    h_pre.remove()\n",
    "    h_post.remove()\n",
    "\n",
    "    assert \"resid_pre\" in cache and \"resid_post\" in cache, \"Hooks did not fire.\"\n",
    "\n",
    "    pre = cache[\"resid_pre\"].float()\n",
    "    post = cache[\"resid_post\"].float()\n",
    "\n",
    "    print(f\"Layer {TARGET_LAYER} resid_pre:  dtype={pre.dtype} shape={tuple(pre.shape)} device={pre.device}\")\n",
    "    print(f\"Layer {TARGET_LAYER} resid_post: dtype={post.dtype} shape={tuple(post.shape)} device={post.device}\")\n",
    "\n",
    "    # Simple sanity check: resid_post should differ from resid_pre\n",
    "    diff = post - pre\n",
    "    rel = diff.norm() / (pre.norm() + 1e-6)\n",
    "    print(f\"Relative change ||post - pre|| / ||pre|| = {float(rel.cpu()):.6f}\")\n",
    "\n",
    "    # Print a few stats for visual confirmation\n",
    "    def stats(x):\n",
    "        xf = x.float()\n",
    "        return float(xf.min().cpu()), float(xf.max().cpu()), float(xf.mean().cpu()), float(xf.std().cpu())\n",
    "\n",
    "    a = stats(pre)\n",
    "    b = stats(post)\n",
    "    print(f\"resid_pre  stats: min={a[0]:.3f} max={a[1]:.3f} mean={a[2]:.6f} std={a[3]:.6f}\")\n",
    "    print(f\"resid_post stats: min={b[0]:.3f} max={b[1]:.3f} mean={b[2]:.6f} std={b[3]:.6f}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7304d531",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Info] Loading tokenizer/model: Dream-org/Dream-v0-Base-7B\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 89.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== MODULE NAMES (subset: 'mlp' or 'attn') ===\n",
      "model.layers.0.self_attn\n",
      "model.layers.0.self_attn.q_proj\n",
      "model.layers.0.self_attn.k_proj\n",
      "model.layers.0.self_attn.v_proj\n",
      "model.layers.0.self_attn.o_proj\n",
      "model.layers.0.self_attn.rotary_emb\n",
      "model.layers.0.mlp\n",
      "model.layers.0.mlp.gate_proj\n",
      "model.layers.0.mlp.up_proj\n",
      "model.layers.0.mlp.down_proj\n",
      "model.layers.0.mlp.act_fn\n",
      "model.layers.0.post_attention_layernorm\n",
      "model.layers.1.self_attn\n",
      "model.layers.1.self_attn.q_proj\n",
      "model.layers.1.self_attn.k_proj\n",
      "model.layers.1.self_attn.v_proj\n",
      "model.layers.1.self_attn.o_proj\n",
      "model.layers.1.self_attn.rotary_emb\n",
      "model.layers.1.mlp\n",
      "model.layers.1.mlp.gate_proj\n",
      "model.layers.1.mlp.up_proj\n",
      "model.layers.1.mlp.down_proj\n",
      "model.layers.1.mlp.act_fn\n",
      "model.layers.1.post_attention_layernorm\n",
      "model.layers.2.self_attn\n",
      "model.layers.2.self_attn.q_proj\n",
      "model.layers.2.self_attn.k_proj\n",
      "model.layers.2.self_attn.v_proj\n",
      "model.layers.2.self_attn.o_proj\n",
      "model.layers.2.self_attn.rotary_emb\n",
      "model.layers.2.mlp\n",
      "model.layers.2.mlp.gate_proj\n",
      "model.layers.2.mlp.up_proj\n",
      "model.layers.2.mlp.down_proj\n",
      "model.layers.2.mlp.act_fn\n",
      "model.layers.2.post_attention_layernorm\n",
      "model.layers.3.self_attn\n",
      "model.layers.3.self_attn.q_proj\n",
      "model.layers.3.self_attn.k_proj\n",
      "model.layers.3.self_attn.v_proj\n",
      "model.layers.3.self_attn.o_proj\n",
      "model.layers.3.self_attn.rotary_emb\n",
      "model.layers.3.mlp\n",
      "model.layers.3.mlp.gate_proj\n",
      "model.layers.3.mlp.up_proj\n",
      "model.layers.3.mlp.down_proj\n",
      "model.layers.3.mlp.act_fn\n",
      "model.layers.3.post_attention_layernorm\n",
      "model.layers.4.self_attn\n",
      "model.layers.4.self_attn.q_proj\n",
      "model.layers.4.self_attn.k_proj\n",
      "model.layers.4.self_attn.v_proj\n",
      "model.layers.4.self_attn.o_proj\n",
      "model.layers.4.self_attn.rotary_emb\n",
      "model.layers.4.mlp\n",
      "model.layers.4.mlp.gate_proj\n",
      "model.layers.4.mlp.up_proj\n",
      "model.layers.4.mlp.down_proj\n",
      "model.layers.4.mlp.act_fn\n",
      "model.layers.4.post_attention_layernorm\n",
      "model.layers.5.self_attn\n",
      "model.layers.5.self_attn.q_proj\n",
      "model.layers.5.self_attn.k_proj\n",
      "model.layers.5.self_attn.v_proj\n",
      "model.layers.5.self_attn.o_proj\n",
      "model.layers.5.self_attn.rotary_emb\n",
      "model.layers.5.mlp\n",
      "model.layers.5.mlp.gate_proj\n",
      "model.layers.5.mlp.up_proj\n",
      "model.layers.5.mlp.down_proj\n",
      "model.layers.5.mlp.act_fn\n",
      "model.layers.5.post_attention_layernorm\n",
      "model.layers.6.self_attn\n",
      "model.layers.6.self_attn.q_proj\n",
      "model.layers.6.self_attn.k_proj\n",
      "model.layers.6.self_attn.v_proj\n",
      "model.layers.6.self_attn.o_proj\n",
      "model.layers.6.self_attn.rotary_emb\n",
      "model.layers.6.mlp\n",
      "model.layers.6.mlp.gate_proj\n",
      "model.layers.6.mlp.up_proj\n",
      "model.layers.6.mlp.down_proj\n",
      "model.layers.6.mlp.act_fn\n",
      "model.layers.6.post_attention_layernorm\n",
      "model.layers.7.self_attn\n",
      "model.layers.7.self_attn.q_proj\n",
      "model.layers.7.self_attn.k_proj\n",
      "model.layers.7.self_attn.v_proj\n",
      "model.layers.7.self_attn.o_proj\n",
      "model.layers.7.self_attn.rotary_emb\n",
      "model.layers.7.mlp\n",
      "model.layers.7.mlp.gate_proj\n",
      "model.layers.7.mlp.up_proj\n",
      "model.layers.7.mlp.down_proj\n",
      "model.layers.7.mlp.act_fn\n",
      "model.layers.7.post_attention_layernorm\n",
      "model.layers.8.self_attn\n",
      "model.layers.8.self_attn.q_proj\n",
      "model.layers.8.self_attn.k_proj\n",
      "model.layers.8.self_attn.v_proj\n",
      "model.layers.8.self_attn.o_proj\n",
      "model.layers.8.self_attn.rotary_emb\n",
      "model.layers.8.mlp\n",
      "model.layers.8.mlp.gate_proj\n",
      "model.layers.8.mlp.up_proj\n",
      "model.layers.8.mlp.down_proj\n",
      "model.layers.8.mlp.act_fn\n",
      "model.layers.8.post_attention_layernorm\n",
      "model.layers.9.self_attn\n",
      "model.layers.9.self_attn.q_proj\n",
      "model.layers.9.self_attn.k_proj\n",
      "model.layers.9.self_attn.v_proj\n",
      "model.layers.9.self_attn.o_proj\n",
      "model.layers.9.self_attn.rotary_emb\n",
      "model.layers.9.mlp\n",
      "model.layers.9.mlp.gate_proj\n",
      "model.layers.9.mlp.up_proj\n",
      "model.layers.9.mlp.down_proj\n",
      "model.layers.9.mlp.act_fn\n",
      "model.layers.9.post_attention_layernorm\n",
      "model.layers.10.self_attn\n",
      "model.layers.10.self_attn.q_proj\n",
      "model.layers.10.self_attn.k_proj\n",
      "model.layers.10.self_attn.v_proj\n",
      "model.layers.10.self_attn.o_proj\n",
      "model.layers.10.self_attn.rotary_emb\n",
      "model.layers.10.mlp\n",
      "model.layers.10.mlp.gate_proj\n",
      "model.layers.10.mlp.up_proj\n",
      "model.layers.10.mlp.down_proj\n",
      "model.layers.10.mlp.act_fn\n",
      "model.layers.10.post_attention_layernorm\n",
      "model.layers.11.self_attn\n",
      "model.layers.11.self_attn.q_proj\n",
      "model.layers.11.self_attn.k_proj\n",
      "model.layers.11.self_attn.v_proj\n",
      "model.layers.11.self_attn.o_proj\n",
      "model.layers.11.self_attn.rotary_emb\n",
      "model.layers.11.mlp\n",
      "model.layers.11.mlp.gate_proj\n",
      "model.layers.11.mlp.up_proj\n",
      "model.layers.11.mlp.down_proj\n",
      "model.layers.11.mlp.act_fn\n",
      "model.layers.11.post_attention_layernorm\n",
      "model.layers.12.self_attn\n",
      "model.layers.12.self_attn.q_proj\n",
      "model.layers.12.self_attn.k_proj\n",
      "model.layers.12.self_attn.v_proj\n",
      "model.layers.12.self_attn.o_proj\n",
      "model.layers.12.self_attn.rotary_emb\n",
      "model.layers.12.mlp\n",
      "model.layers.12.mlp.gate_proj\n",
      "model.layers.12.mlp.up_proj\n",
      "model.layers.12.mlp.down_proj\n",
      "model.layers.12.mlp.act_fn\n",
      "model.layers.12.post_attention_layernorm\n",
      "model.layers.13.self_attn\n",
      "model.layers.13.self_attn.q_proj\n",
      "model.layers.13.self_attn.k_proj\n",
      "model.layers.13.self_attn.v_proj\n",
      "model.layers.13.self_attn.o_proj\n",
      "model.layers.13.self_attn.rotary_emb\n",
      "model.layers.13.mlp\n",
      "model.layers.13.mlp.gate_proj\n",
      "model.layers.13.mlp.up_proj\n",
      "model.layers.13.mlp.down_proj\n",
      "model.layers.13.mlp.act_fn\n",
      "model.layers.13.post_attention_layernorm\n",
      "model.layers.14.self_attn\n",
      "model.layers.14.self_attn.q_proj\n",
      "model.layers.14.self_attn.k_proj\n",
      "model.layers.14.self_attn.v_proj\n",
      "model.layers.14.self_attn.o_proj\n",
      "model.layers.14.self_attn.rotary_emb\n",
      "model.layers.14.mlp\n",
      "model.layers.14.mlp.gate_proj\n",
      "model.layers.14.mlp.up_proj\n",
      "model.layers.14.mlp.down_proj\n",
      "model.layers.14.mlp.act_fn\n",
      "model.layers.14.post_attention_layernorm\n",
      "model.layers.15.self_attn\n",
      "model.layers.15.self_attn.q_proj\n",
      "model.layers.15.self_attn.k_proj\n",
      "model.layers.15.self_attn.v_proj\n",
      "model.layers.15.self_attn.o_proj\n",
      "model.layers.15.self_attn.rotary_emb\n",
      "model.layers.15.mlp\n",
      "model.layers.15.mlp.gate_proj\n",
      "model.layers.15.mlp.up_proj\n",
      "model.layers.15.mlp.down_proj\n",
      "model.layers.15.mlp.act_fn\n",
      "model.layers.15.post_attention_layernorm\n",
      "model.layers.16.self_attn\n",
      "model.layers.16.self_attn.q_proj\n",
      "model.layers.16.self_attn.k_proj\n",
      "model.layers.16.self_attn.v_proj\n",
      "model.layers.16.self_attn.o_proj\n",
      "model.layers.16.self_attn.rotary_emb\n",
      "model.layers.16.mlp\n",
      "model.layers.16.mlp.gate_proj\n",
      "model.layers.16.mlp.up_proj\n",
      "model.layers.16.mlp.down_proj\n",
      "model.layers.16.mlp.act_fn\n",
      "model.layers.16.post_attention_layernorm\n",
      "model.layers.17.self_attn\n",
      "model.layers.17.self_attn.q_proj\n",
      "model.layers.17.self_attn.k_proj\n",
      "model.layers.17.self_attn.v_proj\n",
      "model.layers.17.self_attn.o_proj\n",
      "model.layers.17.self_attn.rotary_emb\n",
      "model.layers.17.mlp\n",
      "model.layers.17.mlp.gate_proj\n",
      "model.layers.17.mlp.up_proj\n",
      "model.layers.17.mlp.down_proj\n",
      "model.layers.17.mlp.act_fn\n",
      "model.layers.17.post_attention_layernorm\n",
      "model.layers.18.self_attn\n",
      "model.layers.18.self_attn.q_proj\n",
      "model.layers.18.self_attn.k_proj\n",
      "model.layers.18.self_attn.v_proj\n",
      "model.layers.18.self_attn.o_proj\n",
      "model.layers.18.self_attn.rotary_emb\n",
      "model.layers.18.mlp\n",
      "model.layers.18.mlp.gate_proj\n",
      "model.layers.18.mlp.up_proj\n",
      "model.layers.18.mlp.down_proj\n",
      "model.layers.18.mlp.act_fn\n",
      "model.layers.18.post_attention_layernorm\n",
      "model.layers.19.self_attn\n",
      "model.layers.19.self_attn.q_proj\n",
      "model.layers.19.self_attn.k_proj\n",
      "model.layers.19.self_attn.v_proj\n",
      "model.layers.19.self_attn.o_proj\n",
      "model.layers.19.self_attn.rotary_emb\n",
      "model.layers.19.mlp\n",
      "model.layers.19.mlp.gate_proj\n",
      "model.layers.19.mlp.up_proj\n",
      "model.layers.19.mlp.down_proj\n",
      "model.layers.19.mlp.act_fn\n",
      "model.layers.19.post_attention_layernorm\n",
      "model.layers.20.self_attn\n",
      "model.layers.20.self_attn.q_proj\n",
      "model.layers.20.self_attn.k_proj\n",
      "model.layers.20.self_attn.v_proj\n",
      "model.layers.20.self_attn.o_proj\n",
      "model.layers.20.self_attn.rotary_emb\n",
      "model.layers.20.mlp\n",
      "model.layers.20.mlp.gate_proj\n",
      "model.layers.20.mlp.up_proj\n",
      "model.layers.20.mlp.down_proj\n",
      "model.layers.20.mlp.act_fn\n",
      "model.layers.20.post_attention_layernorm\n",
      "model.layers.21.self_attn\n",
      "model.layers.21.self_attn.q_proj\n",
      "model.layers.21.self_attn.k_proj\n",
      "model.layers.21.self_attn.v_proj\n",
      "model.layers.21.self_attn.o_proj\n",
      "model.layers.21.self_attn.rotary_emb\n",
      "model.layers.21.mlp\n",
      "model.layers.21.mlp.gate_proj\n",
      "model.layers.21.mlp.up_proj\n",
      "model.layers.21.mlp.down_proj\n",
      "model.layers.21.mlp.act_fn\n",
      "model.layers.21.post_attention_layernorm\n",
      "model.layers.22.self_attn\n",
      "model.layers.22.self_attn.q_proj\n",
      "model.layers.22.self_attn.k_proj\n",
      "model.layers.22.self_attn.v_proj\n",
      "model.layers.22.self_attn.o_proj\n",
      "model.layers.22.self_attn.rotary_emb\n",
      "model.layers.22.mlp\n",
      "model.layers.22.mlp.gate_proj\n",
      "model.layers.22.mlp.up_proj\n",
      "model.layers.22.mlp.down_proj\n",
      "model.layers.22.mlp.act_fn\n",
      "model.layers.22.post_attention_layernorm\n",
      "model.layers.23.self_attn\n",
      "model.layers.23.self_attn.q_proj\n",
      "model.layers.23.self_attn.k_proj\n",
      "model.layers.23.self_attn.v_proj\n",
      "model.layers.23.self_attn.o_proj\n",
      "model.layers.23.self_attn.rotary_emb\n",
      "model.layers.23.mlp\n",
      "model.layers.23.mlp.gate_proj\n",
      "model.layers.23.mlp.up_proj\n",
      "model.layers.23.mlp.down_proj\n",
      "model.layers.23.mlp.act_fn\n",
      "model.layers.23.post_attention_layernorm\n",
      "model.layers.24.self_attn\n",
      "model.layers.24.self_attn.q_proj\n",
      "model.layers.24.self_attn.k_proj\n",
      "model.layers.24.self_attn.v_proj\n",
      "model.layers.24.self_attn.o_proj\n",
      "model.layers.24.self_attn.rotary_emb\n",
      "model.layers.24.mlp\n",
      "model.layers.24.mlp.gate_proj\n",
      "model.layers.24.mlp.up_proj\n",
      "model.layers.24.mlp.down_proj\n",
      "model.layers.24.mlp.act_fn\n",
      "model.layers.24.post_attention_layernorm\n",
      "model.layers.25.self_attn\n",
      "... (truncated)\n",
      "\n",
      "[Info] Candidate targets at layer 17:\n",
      "  - model.layers.17\n",
      "  - model.layers.17.mlp.down_proj\n",
      "  - model.layers.17.mlp\n",
      "  - model.layers.17.self_attn.o_proj\n",
      "[Info] Using target: model.layers.17\n",
      "\n",
      "=== PROBED LAYER: model.layers.17 ===\n",
      "resid_pre: dtype=torch.bfloat16 shape=(1, 14, 3584) device=cuda:0 min=-596.000000 max=174.000000 mean=-0.025351 std=4.066421\n",
      "resid_post  (this is what SAE 'io=out' would see): dtype=torch.bfloat16 shape=(1, 14, 3584) device=cuda:0 min=-592.000000 max=174.000000 mean=-0.023262 std=4.056655\n",
      "feature-wise |x| mean (first 10 dims): [0.256801, 0.407017, 0.287654, 0.480469, 0.410017, 0.734166, 0.371233, 0.606602, 1.024815, 1.050712]\n",
      "Done.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "from typing import Tuple, Any, Dict, Optional\n",
    "\n",
    "MODEL = \"Dream-org/Dream-v0-Base-7B\"  # or \"Dream-org/Dream-v0-Instruct-7B\"\n",
    "DTYPE = torch.bfloat16\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "TARGET_LAYER_INDEX = 17  # change to the layer you want to probe\n",
    "MAX_LEN = 128\n",
    "\n",
    "\n",
    "def get_module_by_name(root, name: str):\n",
    "    \"\"\"Return a submodule from a dotted path name, or None if not found.\"\"\"\n",
    "    curr = root\n",
    "    for p in name.split(\".\"):\n",
    "        if not hasattr(curr, p):\n",
    "            return None\n",
    "        curr = getattr(curr, p)\n",
    "    return curr\n",
    "\n",
    "\n",
    "def pick_target_names(model) -> Tuple[str, list]:\n",
    "    \"\"\"\n",
    "    Choose a robust target name under model.layers.{L}.\n",
    "    Preference (we want residual post; hooking the entire layer is best):\n",
    "      0) model.layers.{L}               <-- whole layer, forward_hook ~ resid_post\n",
    "      1) model.layers.{L}.mlp.down_proj\n",
    "      2) model.layers.{L}.mlp\n",
    "      3) model.layers.{L}.self_attn.o_proj\n",
    "    Returns (chosen_name, all_candidates_found).\n",
    "    \"\"\"\n",
    "    # Try to get #layers\n",
    "    num_layers = getattr(getattr(model, \"config\", object()), \"num_hidden_layers\", None)\n",
    "    if num_layers is None:\n",
    "        layer_prefixes = [n for n, _ in model.named_modules()\n",
    "                          if n.endswith(\".mlp\") and \"model.layers.\" in n]\n",
    "        if not layer_prefixes:\n",
    "            raise RuntimeError(\"Cannot infer number of layers; no 'model.layers.*.mlp' found.\")\n",
    "        max_idx = max(int(x.split(\".\")[2]) for x in layer_prefixes)\n",
    "        num_layers = max_idx + 1\n",
    "\n",
    "    if TARGET_LAYER_INDEX < 0 or TARGET_LAYER_INDEX >= num_layers:\n",
    "        raise ValueError(f\"TARGET_LAYER_INDEX={TARGET_LAYER_INDEX} out of range [0, {num_layers-1}]\")\n",
    "\n",
    "    base = f\"model.layers.{TARGET_LAYER_INDEX}\"\n",
    "    pref = [\n",
    "        f\"{base}\",                     # whole block (forward_hook -> resid_post)\n",
    "        f\"{base}.mlp.down_proj\",\n",
    "        f\"{base}.mlp\",\n",
    "        f\"{base}.self_attn.o_proj\",\n",
    "    ]\n",
    "    found = [name for name in pref if get_module_by_name(model, name) is not None]\n",
    "    if not found:\n",
    "        raise RuntimeError(f\"No usable submodule found under layer {TARGET_LAYER_INDEX}\")\n",
    "    return found[0], found\n",
    "\n",
    "\n",
    "def normalize_output(out: Any) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Convert various output types to a tensor:\n",
    "      - torch.Tensor: return as is\n",
    "      - tuple: take the first tensor\n",
    "      - dict/ModelOutput: prefer 'last_hidden_state' or 'hidden_states'[-1] or 'logits'\n",
    "    \"\"\"\n",
    "    if isinstance(out, torch.Tensor):\n",
    "        return out\n",
    "    if isinstance(out, tuple) and len(out) > 0 and isinstance(out[0], torch.Tensor):\n",
    "        return out[0]\n",
    "    if isinstance(out, dict):\n",
    "        for k in (\"last_hidden_state\", \"hidden_states\", \"logits\"):\n",
    "            v = out.get(k, None)\n",
    "            if isinstance(v, torch.Tensor):\n",
    "                return v\n",
    "            if k == \"hidden_states\" and isinstance(v, (list, tuple)) and len(v) > 0 and isinstance(v[-1], torch.Tensor):\n",
    "                return v[-1]\n",
    "        for v in out.values():\n",
    "            if isinstance(v, torch.Tensor):\n",
    "                return v\n",
    "    if hasattr(out, \"last_hidden_state\"):\n",
    "        return out.last_hidden_state\n",
    "    if hasattr(out, \"hidden_states\") and isinstance(out.hidden_states, (list, tuple)) and len(out.hidden_states) > 0:\n",
    "        if isinstance(out.hidden_states[-1], torch.Tensor):\n",
    "            return out.hidden_states[-1]\n",
    "    if hasattr(out, \"logits\"):\n",
    "        return out.logits\n",
    "    raise TypeError(f\"Cannot normalize output type: {type(out)}\")\n",
    "\n",
    "\n",
    "def run_forward_three_ways(model, inputs: Dict[str, torch.Tensor]) -> Any:\n",
    "    \"\"\"\n",
    "    Try three safe ways to call Dream forward to avoid SDPA attn_mask dtype issues:\n",
    "      A) Remove attention_mask entirely (works if we don't pad).\n",
    "      B) Keep attention_mask but cast to bool (SDPA accepts bool).\n",
    "      C) Convert attention_mask to additive float mask (0 -> -inf at pads) with shape (B,1,1,T).\n",
    "    Returns the raw model output (whatever Dream returns).\n",
    "    \"\"\"\n",
    "    # A) no attention_mask\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "            inpA = dict(inputs)\n",
    "            inpA.pop(\"attention_mask\", None)\n",
    "            return model(**inpA)\n",
    "    except Exception as e:\n",
    "        print(f\"[Run A] Without attention_mask failed: {type(e).__name__}: {e}\")\n",
    "\n",
    "    # B) bool mask\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "            inpB = dict(inputs)\n",
    "            if \"attention_mask\" in inpB:\n",
    "                inpB[\"attention_mask\"] = inpB[\"attention_mask\"].to(torch.bool)\n",
    "            return model(**inpB)\n",
    "    except Exception as e:\n",
    "        print(f\"[Run B] With bool attention_mask failed: {type(e).__name__}: {e}\")\n",
    "\n",
    "    # C) additive float mask (pad -> -inf)\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "            inpC = dict(inputs)\n",
    "            if \"attention_mask\" in inpC:\n",
    "                am = inpC[\"attention_mask\"]\n",
    "                if am.dtype not in (torch.float32, torch.float16, torch.bfloat16):\n",
    "                    am = am.float()\n",
    "                am4 = am[:, None, None, :]  # (B,1,1,T)\n",
    "                neg_inf = torch.finfo(am4.dtype).min\n",
    "                add_mask = torch.where(am4 > 0, torch.zeros_like(am4), torch.full_like(am4, neg_inf))\n",
    "                inpC[\"attention_mask\"] = add_mask.to(device=am.device)\n",
    "            return model(**inpC)\n",
    "    except Exception as e:\n",
    "        print(f\"[Run C] With additive float attention_mask failed: {type(e).__name__}: {e}\")\n",
    "        raise\n",
    "\n",
    "\n",
    "def main():\n",
    "    print(f\"[Info] Loading tokenizer/model: {MODEL}\")\n",
    "    tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True, use_fast=True)\n",
    "    if not getattr(tok, \"pad_token\", None):\n",
    "        tok.pad_token = tok.eos_token\n",
    "\n",
    "    model = AutoModel.from_pretrained(MODEL, trust_remote_code=True, torch_dtype=DTYPE).to(DEVICE)\n",
    "    model.eval()\n",
    "\n",
    "    print(\"=== MODULE NAMES (subset: 'mlp' or 'attn') ===\")\n",
    "    shown = 0\n",
    "    for n, _ in model.named_modules():\n",
    "        ln = n.lower()\n",
    "        if \"mlp\" in ln or \"attn\" in ln or \"attention\" in ln:\n",
    "            print(n)\n",
    "            shown += 1\n",
    "            if shown > 300:\n",
    "                print(\"... (truncated)\")\n",
    "                break\n",
    "    print()\n",
    "\n",
    "    chosen, all_found = pick_target_names(model)\n",
    "    print(f\"[Info] Candidate targets at layer {TARGET_LAYER_INDEX}:\")\n",
    "    for n in all_found:\n",
    "        print(\"  -\", n)\n",
    "    print(f\"[Info] Using target: {chosen}\")\n",
    "\n",
    "    acts: Dict[str, torch.Tensor] = {}\n",
    "\n",
    "    # We try to capture both \"resid_pre\" (layer input) and \"resid_post\" (layer output).\n",
    "    resid_pre: Dict[str, torch.Tensor] = {}\n",
    "    resid_post: Dict[str, torch.Tensor] = {}\n",
    "\n",
    "    # If chosen == whole layer, this forward hook will be \"resid_post\"\n",
    "    submodule = get_module_by_name(model, chosen)\n",
    "    assert submodule is not None, f\"Submodule not found: {chosen}\"\n",
    "\n",
    "    def pre_hook_fn(_, inputs):\n",
    "        try:\n",
    "            x = inputs[0] if isinstance(inputs, (tuple, list)) and len(inputs) > 0 else inputs\n",
    "            if not isinstance(x, torch.Tensor):\n",
    "                # Sometimes layers accept dicts / complicated inputs; try to locate a tensor\n",
    "                if isinstance(x, dict):\n",
    "                    for v in x.values():\n",
    "                        if isinstance(v, torch.Tensor):\n",
    "                            x = v\n",
    "                            break\n",
    "            if isinstance(x, torch.Tensor):\n",
    "                resid_pre[\"x\"] = x.detach()\n",
    "        except Exception as e:\n",
    "            print(f\"[Pre-Hook] Failed: {e}\")\n",
    "\n",
    "    def fwd_hook_fn(_, __, out):\n",
    "        try:\n",
    "            x = out\n",
    "            if isinstance(x, tuple):\n",
    "                x = x[0]\n",
    "            if not isinstance(x, torch.Tensor):\n",
    "                x = normalize_output(x)\n",
    "            resid_post[\"x\"] = x.detach()\n",
    "        except Exception as e:\n",
    "            print(f\"[Fwd-Hook] Failed: {e}\")\n",
    "\n",
    "    # Register both hooks\n",
    "    h_pre = submodule.register_forward_pre_hook(pre_hook_fn)\n",
    "    h_fwd = submodule.register_forward_hook(fwd_hook_fn)\n",
    "\n",
    "    # Single short input (no padding to simplify mask)\n",
    "    text = \"Hello Dream LM! This is a small probing sentence to verify shapes.\"\n",
    "    inputs = tok([text], return_tensors=\"pt\", padding=False, truncation=True, max_length=MAX_LEN)\n",
    "    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}\n",
    "\n",
    "    _ = run_forward_three_ways(model, inputs)\n",
    "\n",
    "    # Remove hooks\n",
    "    h_pre.remove()\n",
    "    h_fwd.remove()\n",
    "\n",
    "    # Report\n",
    "    def _stat(t: torch.Tensor, name: str):\n",
    "        xx = t.float()\n",
    "        print(f\"{name}: dtype={t.dtype} shape={tuple(t.shape)} device={t.device} \"\n",
    "              f\"min={float(xx.min().cpu()):.6f} max={float(xx.max().cpu()):.6f} \"\n",
    "              f\"mean={float(xx.mean().cpu()):.6f} std={float(xx.std().cpu()):.6f}\")\n",
    "\n",
    "    print(\"\\n=== PROBED LAYER:\", chosen, \"===\")\n",
    "    if \"x\" in resid_pre:\n",
    "        _stat(resid_pre[\"x\"], \"resid_pre\")\n",
    "    else:\n",
    "        print(\"resid_pre: <not captured>\")\n",
    "\n",
    "    if \"x\" in resid_post:\n",
    "        _stat(resid_post[\"x\"], \"resid_post  (this is what SAE 'io=out' would see)\")\n",
    "        acts[\"x\"] = resid_post[\"x\"]\n",
    "    else:\n",
    "        print(\"resid_post: <not captured>\")\n",
    "\n",
    "    # If resid_post was captured, also show a quick flatten-and-summary as SAE would see\n",
    "    if \"x\" in acts:\n",
    "        x = acts[\"x\"]\n",
    "        if x.dim() == 3:\n",
    "            b, t, d = x.shape\n",
    "            xf = x.reshape(b * t, d).float()\n",
    "        elif x.dim() == 2:\n",
    "            xf = x.float()\n",
    "        else:\n",
    "            xf = x.reshape(-1, x.shape[-1]).float()\n",
    "        fmean = xf.abs().mean(dim=0)[:10].cpu().tolist()\n",
    "        print(\"feature-wise |x| mean (first 10 dims):\", [round(v, 6) for v in fmean])\n",
    "\n",
    "    print(\"Done.\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675ebfe1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dslabra5/.conda/envs/sae4dlm/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Dream-org/Dream-v0-Base-7B on cuda ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`torch_dtype` is deprecated! Use `dtype` instead!\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 91.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "################################################################################\n",
      "# Task: GSM8K-CoT\n",
      "# Prompt preview:\n",
      "You are a helpful math tutor. Solve the problem step by step and give the final answer as an integer on the last line.\n",
      "Problem: Tom has 12 apples. He gives 3 to Alice and then buys 7 more. He eats 4 apples on the way home. How many apples does Tom have now?\n",
      "Reasoning: Let's think step by step.\n",
      "Answer:\n",
      "# Gen config base: {'max_new_tokens': 128, 'steps': 64}\n",
      "################################################################################\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.0, top_p=1.0\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +3 -> ' '], [pos +4 -> '1']\n",
      "step   2: [pos +5 -> '2'], [pos +6 -> ' apples']\n",
      "step   3: [pos +1 -> ' starts'], [pos +2 -> ' with']\n",
      "step   4: [pos +7 -> '.'], [pos +9 -> ' gives']\n",
      "step   5: [pos +8 -> ' He'], [pos +10 -> ' ']\n",
      "step   6: [pos +11 -> '3'], [pos +12 -> ' to']\n",
      "step   7: [pos +13 -> ' Alice'], [pos +14 -> ',']\n",
      "step   8: [pos +15 -> ' so'], [pos +16 -> ' he']\n",
      "step   9: [pos +17 -> ' has'], [pos +18 -> ' has']\n",
      "step  10: [pos +19 -> ' '], [pos +20 -> '1']\n",
      "step  11: [pos +21 -> '2'], [pos +24 -> '3']\n",
      "step  12: [pos +22 -> ' -'], [pos +23 -> ' ']\n",
      "step  13: [pos +25 -> ' ='], [pos +26 -> ' ']\n",
      "step  14: [pos +27 -> '9'], [pos +28 -> ' apples']\n",
      "step  15: [pos +29 -> ' left'], [pos +30 -> '.']\n",
      "step  16: [pos +31 -> ' Then'], [pos +33 -> ' buys']\n",
      "step  17: [pos +34 -> ' '], [pos +35 -> '7']\n",
      "step  18: [pos +32 -> ' he'], [pos +36 -> ' more']\n",
      "step  19: [pos +37 -> ' apples'], [pos +38 -> ',']\n",
      "step  20: [pos +39 -> ' so'], [pos +40 -> ' he']\n",
      "step  21: [pos +41 -> ' has'], [pos +47 -> '7']\n",
      "step  22: [pos +45 -> ' +'], [pos +46 -> ' ']\n",
      "step  23: [pos +48 -> ' ='], [pos +49 -> ' ']\n",
      "step  24: [pos +44 -> '9'], [pos +50 -> '1']\n",
      "step  25: [pos +43 -> ' '], [pos +51 -> '6']\n",
      "step  26: [pos +52 -> ' apples'], [pos +53 -> '.']\n",
      "step  27: [pos +42 -> ' has'], [pos +56 -> ' ']\n",
      "step  28: [pos +57 -> '4'], [pos +61 -> ' way']\n",
      "step  29: [pos +59 -> ' on'], [pos +62 -> ' home']\n",
      "step  30: [pos +58 -> ' apples'], [pos +63 -> ',']\n",
      "step  31: [pos +54 -> ' He'], [pos +60 -> ' the']\n",
      "step  32: [pos +55 -> ' eats'], [pos +64 -> ' so']\n",
      "step  33: [pos +65 -> ' he'], [pos +76 -> '1']\n",
      "step  34: [pos +74 -> ' ='], [pos +75 -> ' ']\n",
      "step  35: [pos +72 -> ' '], [pos +73 -> '4']\n",
      "step  36: [pos +69 -> '1'], [pos +70 -> '6']\n",
      "step  37: [pos +68 -> ' '], [pos +71 -> ' -']\n",
      "step  38: [pos +77 -> '2'], [pos +78 -> ' apples']\n",
      "step  39: [pos +66 -> ' has'], [pos +67 -> ' has']\n",
      "step  40: [pos +79 -> ' left'], [pos +80 -> '.\\n']\n",
      "\n",
      "== Final generation ==\n",
      "Tom starts with 12 apples. He gives 3 to Alice, so he has has 12 - 3 = 9 apples left. Then he buys 7 more apples, so he has has 9 + 7 = 16 apples. He eats 4 apples on the way home, so he has has 16 - 4 = 12 apples left.\n",
      "Final answer: 12<|endoftext|><|beginoftext|>In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "In the beginning\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.1, top_p=0.95\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ' starts'], [pos +2 -> ' with']\n",
      "step   2: [pos +3 -> ' '], [pos +4 -> '1']\n",
      "step   3: [pos +5 -> '2'], [pos +6 -> ' apples']\n",
      "step   4: [pos +7 -> '.'], [pos +8 -> ' He']\n",
      "step   5: [pos +9 -> ' gives'], [pos +10 -> ' ']\n",
      "step   6: [pos +11 -> '3'], [pos +12 -> ' to']\n",
      "step   7: [pos +13 -> ' Alice'], [pos +14 -> ',']\n",
      "step   8: [pos +15 -> ' so'], [pos +16 -> ' he']\n",
      "step   9: [pos +17 -> ' has'], [pos +20 -> '2']\n",
      "step  10: [pos +18 -> ' '], [pos +19 -> '1']\n",
      "step  11: [pos +21 -> ' -'], [pos +22 -> ' ']\n",
      "step  12: [pos +23 -> '3'], [pos +24 -> ' =']\n",
      "step  13: [pos +25 -> ' '], [pos +26 -> '9']\n",
      "step  14: [pos +27 -> ' apples'], [pos +28 -> ' left']\n",
      "step  15: [pos +29 -> '.'], [pos +30 -> ' Then']\n",
      "step  16: [pos +31 -> ' he'], [pos +32 -> ' buys']\n",
      "step  17: [pos +33 -> ' '], [pos +34 -> '7']\n",
      "step  18: [pos +35 -> ' more'], [pos +36 -> ',']\n",
      "step  19: [pos +37 -> ' so'], [pos +38 -> ' he']\n",
      "step  20: [pos +39 -> ' has'], [pos +40 -> ' ']\n",
      "step  21: [pos +41 -> '9'], [pos +42 -> ' +']\n",
      "step  22: [pos +43 -> ' '], [pos +44 -> '7']\n",
      "step  23: [pos +45 -> ' ='], [pos +46 -> ' ']\n",
      "step  24: [pos +47 -> '1'], [pos +48 -> '6']\n",
      "step  25: [pos +49 -> ' apples'], [pos +50 -> '.']\n",
      "step  26: [pos +52 -> ','], [pos +54 -> '4']\n",
      "step  27: [pos +51 -> ' Finally'], [pos +53 -> ' ']\n",
      "step  28: [pos +55 -> ' apples'], [pos +56 -> ' eaten']\n",
      "step  29: [pos +57 -> ' on'], [pos +58 -> ' the']\n",
      "step  30: [pos +59 -> ' way'], [pos +60 -> ' home']\n",
      "step  31: [pos +62 -> ' him'], [pos +63 -> ' with']\n",
      "step  32: [pos +61 -> ' leave'], [pos +64 -> ' ']\n",
      "step  33: [pos +65 -> '1'], [pos +66 -> '6']\n",
      "step  34: [pos +67 -> ' -'], [pos +68 -> ' ']\n",
      "step  35: [pos +69 -> '4'], [pos +70 -> ' =']\n",
      "step  36: [pos +71 -> ' '], [pos +72 -> '1']\n",
      "step  37: [pos +73 -> '2'], [pos +74 -> ' apples']\n",
      "step  38: [pos +75 -> '.'], [pos +76 -> ' So']\n",
      "step  39: [pos +77 -> ','], [pos +78 -> ' Tom']\n",
      "step  40: [pos +80 -> ' has'], [pos +81 -> ' ']\n",
      "\n",
      "== Final generation ==\n",
      "Tom starts with 12 apples. He gives 3 to Alice, so he has 12 - 3 = 9 apples left. Then he buys 7 more, so he has 9 + 7 = 16 apples. Finally, 4 apples eaten on the way home leave him with 16 - 4 = 12 apples. So, Tom now has 12 apples.<|endoftext|><|beginoftext|>Skip to main content\n",
      "\n",
      "Chapter 1\n",
      "\n",
      "\n",
      "\n",
      "Chapter 1\n",
      "\n",
      "  1. What is the difference of a and and a or?\n",
      "  2. What is the difference between a and\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.2, top_p=0.95\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ' starts'], [pos +3 -> ' ']\n",
      "step   2: [pos +2 -> ' with'], [pos +4 -> '1']\n",
      "step   3: [pos +5 -> '2'], [pos +6 -> ' apples']\n",
      "step   4: [pos +7 -> '.'], [pos +8 -> ' He']\n",
      "step   5: [pos +9 -> ' gives'], [pos +10 -> ' ']\n",
      "step   6: [pos +11 -> '3'], [pos +12 -> ' to']\n",
      "step   7: [pos +13 -> ' Alice'], [pos +14 -> ',']\n",
      "step   8: [pos +15 -> ' so'], [pos +16 -> ' he']\n",
      "step   9: [pos +24 -> '3'], [pos +27 -> ' apples']\n",
      "step  10: [pos +17 -> ' now'], [pos +18 -> ' has']\n",
      "step  11: [pos +19 -> ' '], [pos +20 -> '1']\n",
      "step  12: [pos +21 -> '2'], [pos +22 -> ' -']\n",
      "step  13: [pos +23 -> ' '], [pos +25 -> ' =']\n",
      "step  14: [pos +28 -> '.'], [pos +30 -> ' Then']\n",
      "step  15: [pos +26 -> '9'], [pos +29 -> ' ']\n",
      "step  16: [pos +31 -> ' he'], [pos +32 -> ' buys']\n",
      "step  17: [pos +33 -> ' '], [pos +34 -> '7']\n",
      "step  18: [pos +35 -> ' more'], [pos +36 -> ',']\n",
      "step  19: [pos +37 -> ' so'], [pos +38 -> ' he']\n",
      "step  20: [pos +39 -> ' now'], [pos +40 -> ' has']\n",
      "step  21: [pos +41 -> ' '], [pos +42 -> '9']\n",
      "step  22: [pos +43 -> ' +'], [pos +44 -> ' ']\n",
      "step  23: [pos +45 -> '7'], [pos +46 -> ' =']\n",
      "step  24: [pos +51 -> '.'], [pos +52 -> ' He']\n",
      "step  25: [pos +47 -> ' '], [pos +48 -> '1']\n",
      "step  26: [pos +49 -> '6'], [pos +50 -> ' apples']\n",
      "step  27: [pos +53 -> ' eats'], [pos +54 -> ' ']\n",
      "step  28: [pos +55 -> '4'], [pos +58 -> ' way']\n",
      "step  29: [pos +56 -> ' on'], [pos +57 -> ' the']\n",
      "step  30: [pos +59 -> ' home'], [pos +60 -> ',']\n",
      "step  31: [pos +61 -> ' so'], [pos +62 -> ' he']\n",
      "step  32: [pos +63 -> ' now'], [pos +64 -> ' has']\n",
      "step  33: [pos +65 -> ' '], [pos +66 -> '1']\n",
      "step  34: [pos +67 -> '6'], [pos +68 -> ' -']\n",
      "step  35: [pos +69 -> ' '], [pos +70 -> '4']\n",
      "step  36: [pos +71 -> ' ='], [pos +72 -> ' ']\n",
      "step  37: [pos +73 -> '1'], [pos +74 -> '2']\n",
      "step  38: [pos +75 -> ' apples'], [pos +78 -> ',']\n",
      "step  39: [pos +76 -> '.'], [pos +79 -> ' Tom']\n",
      "step  40: [pos +77 -> ' Therefore'], [pos +84 -> ' apples']\n",
      "\n",
      "== Final generation ==\n",
      "Tom starts with 12 apples. He gives 3 to Alice, so he now has 12 - 3 =9 apples.  Then he buys 7 more, so he now has 9 + 7 = 16 apples. He eats 4 on the way home, so he now has 16 - 4 = 12 apples. Therefore, Tom has 12 apples now.<|im_end|>\n",
      "<|im_start|>user\n",
      "You are a helpful math tutor. Solve the problem step by step and give the final answer as an integer on the last line.\n",
      "Problem: Tom has 12 apples. He\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.2, top_p=1.0\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ' starts'], [pos +5 -> '2']\n",
      "step   2: [pos +2 -> ' with'], [pos +4 -> '1']\n",
      "step   3: [pos +3 -> ' '], [pos +6 -> ' apples']\n",
      "step   4: [pos +9 -> ' gives'], [pos +10 -> ' ']\n",
      "step   5: [pos +7 -> '.'], [pos +11 -> '3']\n",
      "step   6: [pos +8 -> ' He'], [pos +16 -> ' he']\n",
      "step   7: [pos +14 -> ','], [pos +15 -> ' so']\n",
      "step   8: [pos +12 -> ' to'], [pos +13 -> ' Alice']\n",
      "step   9: [pos +28 -> ' left'], [pos +32 -> ' buys']\n",
      "step  10: [pos +19 -> '1'], [pos +27 -> ' apples']\n",
      "step  11: [pos +17 -> ' has'], [pos +20 -> '2']\n",
      "step  12: [pos +25 -> ' '], [pos +26 -> '9']\n",
      "step  13: [pos +22 -> ' '], [pos +23 -> '3']\n",
      "step  14: [pos +21 -> ' -'], [pos +24 -> ' =']\n",
      "step  15: [pos +18 -> ' '], [pos +34 -> '7']\n",
      "step  16: [pos +29 -> '.'], [pos +33 -> ' ']\n",
      "step  17: [pos +35 -> ' more'], [pos +45 -> '7']\n",
      "step  18: [pos +43 -> ' +'], [pos +44 -> ' ']\n",
      "step  19: [pos +42 -> '9'], [pos +47 -> ' ']\n",
      "step  20: [pos +41 -> ' '], [pos +46 -> ' =']\n",
      "step  21: [pos +48 -> '1'], [pos +49 -> '6']\n",
      "step  22: [pos +40 -> ' has'], [pos +50 -> ' apples']\n",
      "step  23: [pos +39 -> ' he'], [pos +51 -> '.']\n",
      "step  24: [pos +38 -> ' so'], [pos +75 -> ' apples']\n",
      "step  25: [pos +36 -> ' apples'], [pos +37 -> ',']\n",
      "step  26: [pos +69 -> ' '], [pos +74 -> '2']\n",
      "step  27: [pos +67 -> '6'], [pos +70 -> '4']\n",
      "step  28: [pos +66 -> '1'], [pos +71 -> ' =']\n",
      "step  29: [pos +65 -> ' '], [pos +72 -> ' ']\n",
      "step  30: [pos +68 -> ' -'], [pos +73 -> '1']\n",
      "step  31: [pos +52 -> ' He'], [pos +57 -> ' on']\n",
      "step  32: [pos +59 -> ' way'], [pos +60 -> ' home']\n",
      "step  33: [pos +54 -> ' '], [pos +55 -> '4']\n",
      "step  34: [pos +53 -> ' eats'], [pos +56 -> ' apples']\n",
      "step  35: [pos +58 -> ' the'], [pos +61 -> ',']\n",
      "step  36: [pos +62 -> ' so'], [pos +76 -> ' left']\n",
      "step  37: [pos +63 -> ' he'], [pos +64 -> ' has']\n",
      "step  38: [pos +79 -> ','], [pos +80 -> ' Tom']\n",
      "step  39: [pos +78 -> ' Therefore'], [pos +85 -> ' apples']\n",
      "step  40: [pos +81 -> ' has'], [pos +83 -> '1']\n",
      "\n",
      "== Final generation ==\n",
      "Tom starts with 12 apples. He gives 3 to Alice, so he has 12 - 3 = 9 apples left. Then he buys 7 more apples, so he has 9 + 7 = 16 apples. He eats 4 apples on the way home, so he has 16 - 4 = 12 apples left. Therefore, Tom has 12 apples now.<|im_end|>\n",
      "<|im_start|>user\n",
      "You are a helpful math tutor. Solve the problem step by step and give the final answer as an integer on the last line.\n",
      "Problem: Tom has 12 apples.\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.7, top_p=0.95\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ','], [pos +7 -> ' apples']\n",
      "step   2: [pos +2 -> ' Tom'], [pos +4 -> ' ']\n",
      "step   3: [pos +3 -> ' has'], [pos +5 -> '1']\n",
      "step   4: [pos +6 -> '2'], [pos +10 -> ' gives']\n",
      "step   5: [pos +8 -> '.'], [pos +9 -> ' He']\n",
      "step   6: [pos +11 -> ' '], [pos +12 -> '3']\n",
      "step   7: [pos +13 -> ' to'], [pos +16 -> ' so']\n",
      "step   8: [pos +14 -> ' Alice'], [pos +15 -> ',']\n",
      "step   9: [pos +17 -> ' he'], [pos +25 -> '3']\n",
      "step  10: [pos +18 -> ' now'], [pos +19 -> ' has']\n",
      "step  11: [pos +20 -> ' '], [pos +21 -> '1']\n",
      "step  12: [pos +22 -> '2'], [pos +23 -> ' -']\n",
      "step  13: [pos +24 -> ' '], [pos +26 -> ' =']\n",
      "step  14: [pos +27 -> ' '], [pos +28 -> '9']\n",
      "step  15: [pos +29 -> ' apples'], [pos +34 -> ' buys']\n",
      "step  16: [pos +32 -> ','], [pos +33 -> ' he']\n",
      "step  17: [pos +35 -> ' '], [pos +36 -> '7']\n",
      "step  18: [pos +30 -> '.'], [pos +37 -> ' more']\n",
      "step  19: [pos +31 -> ' Then'], [pos +40 -> ' so']\n",
      "step  20: [pos +38 -> ' apples'], [pos +39 -> ',']\n",
      "step  21: [pos +41 -> ' he'], [pos +47 -> ' ']\n",
      "step  22: [pos +42 -> ' now'], [pos +43 -> ' has']\n",
      "step  23: [pos +44 -> ' '], [pos +45 -> '9']\n",
      "step  24: [pos +46 -> ' +'], [pos +48 -> '7']\n",
      "step  25: [pos +49 -> ' ='], [pos +50 -> ' ']\n",
      "step  26: [pos +51 -> '1'], [pos +52 -> '6']\n",
      "step  27: [pos +53 -> ' apples'], [pos +54 -> '.']\n",
      "step  28: [pos +56 -> ','], [pos +59 -> ' ']\n",
      "step  29: [pos +55 -> ' Finally'], [pos +57 -> ' he']\n",
      "step  30: [pos +58 -> ' eats'], [pos +60 -> '4']\n",
      "step  31: [pos +61 -> ' apples'], [pos +62 -> ' on']\n",
      "step  32: [pos +63 -> ' the'], [pos +64 -> ' way']\n",
      "step  33: [pos +65 -> ' home'], [pos +66 -> ',']\n",
      "step  34: [pos +67 -> ' so'], [pos +68 -> ' he']\n",
      "step  35: [pos +69 -> ' now'], [pos +70 -> ' has']\n",
      "step  36: [pos +71 -> ' '], [pos +72 -> '1']\n",
      "step  37: [pos +73 -> '6'], [pos +74 -> ' -']\n",
      "step  38: [pos +75 -> ' '], [pos +76 -> '4']\n",
      "step  39: [pos +77 -> ' ='], [pos +78 -> ' ']\n",
      "step  40: [pos +79 -> '1'], [pos +80 -> '2']\n",
      "\n",
      "== Final generation ==\n",
      "First, Tom has 12 apples. He gives 3 to Alice, so he now has 12 - 3 = 9 apples. Then, he buys 7 more apples, so he now has 9 + 7 = 16 apples. Finally, he eats 4 apples on the way home, so he now has 16 - 4 = 12 apples.\n",
      "Answer: \\boxed{12}\n",
      "3\n",
      "4\n",
      "\n",
      "5\n",
      "6\n",
      "\n",
      "7\n",
      "8\n",
      "9\n",
      "*\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "\n",
      "################################################################################\n",
      "# Task: MMLU\n",
      "# Prompt preview:\n",
      "Question: What is the embryological origin of the hyoid bone?\n",
      "A) The first pharyngeal arch\n",
      "B) The first and second pharyngeal arches\n",
      "C) The second pharyngeal arch\n",
      "D) The second and third pharyngeal arches\n",
      "\n",
      "...\n",
      "# Gen config base: {'max_new_tokens': 128, 'steps': 64}\n",
      "################################################################################\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.0, top_p=1.0\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ')'], [pos +4 -> ' ph']\n",
      "step   2: [pos +5 -> 'ary'], [pos +7 -> 'al']\n",
      "step   3: [pos +6 -> 'nge'], [pos +8 -> ' arch']\n",
      "step   4: [pos +0 -> 'C'], [pos +3 -> ' second']\n",
      "step   5: [pos +9 -> '<|endoftext|>'], [pos +10 -> '<|beginoftext|>']\n",
      "step   6: [pos +11 -> '#'], [pos +12 -> ' the']\n",
      "step   7: [pos +14 -> ' of'], [pos +15 -> ' a']\n",
      "step   8: [pos +17 -> ' is'], [pos +18 -> ' ']\n",
      "step   9: [pos +16 -> ' rectangle'], [pos +19 -> '1']\n",
      "step  10: [pos +13 -> ' perimeter'], [pos +20 -> '2']\n",
      "step  11: [pos +21 -> '0'], [pos +26 -> ' is']\n",
      "step  12: [pos +25 -> ' length'], [pos +27 -> ' ']\n",
      "step  13: [pos +23 -> '.'], [pos +24 -> ' the']\n",
      "step  14: [pos +28 -> '3'], [pos +31 -> ' than']\n",
      "step  15: [pos +30 -> ' more'], [pos +32 -> ' the']\n",
      "step  16: [pos +33 -> ' width'], [pos +34 -> '.']\n",
      "step  17: [pos +35 -> ' find'], [pos +36 -> ' the']\n",
      "step  18: [pos +37 -> ' length'], [pos +38 -> ' and']\n",
      "step  19: [pos +39 -> ' width'], [pos +40 -> ' width']\n",
      "step  20: [pos +22 -> ' feet'], [pos +29 -> ' times']\n",
      "step  21: [pos +41 -> '='], [pos +42 -> ' feet']\n",
      "step  22: [pos +43 -> ' length'], [pos +44 -> '=']\n",
      "step  23: [pos +45 -> ' feet'], [pos +46 -> '\\n\\n']\n",
      "step  24: [pos +47 -> 'the'], [pos +48 -> ' perimeter']\n",
      "step  25: [pos +49 -> ' of'], [pos +50 -> ' a']\n",
      "step  26: [pos +51 -> ' rectangle'], [pos +52 -> ' is']\n",
      "step  27: [pos +53 -> ' '], [pos +56 -> '0']\n",
      "step  28: [pos +54 -> '1'], [pos +55 -> '2']\n",
      "step  29: [pos +57 -> ' feet'], [pos +61 -> ' is']\n",
      "step  30: [pos +60 -> ' length'], [pos +63 -> '3']\n",
      "step  31: [pos +62 -> ' '], [pos +64 -> ' times']\n",
      "step  32: [pos +65 -> ' more'], [pos +67 -> ' the']\n",
      "step  33: [pos +66 -> ' than'], [pos +68 -> ' width']\n",
      "step  34: [pos +58 -> '.'], [pos +59 -> ' the']\n",
      "step  35: [pos +69 -> '.'], [pos +74 -> ' width']\n",
      "step  36: [pos +70 -> ' find'], [pos +73 -> ' and']\n",
      "step  37: [pos +71 -> ' the'], [pos +72 -> ' length']\n",
      "step  38: [pos +75 -> ' width'], [pos +80 -> ' feet']\n",
      "step  39: [pos +77 -> ' feet'], [pos +78 -> ' length']\n",
      "step  40: [pos +76 -> '='], [pos +79 -> '=']\n",
      "\n",
      "== Final generation ==\n",
      "C) The second pharyngeal arch<|endoftext|><|beginoftext|># the perimeter of a rectangle is 120 feet. the length is 3 times more than the width. find the length and width width= feet length= feet\n",
      "\n",
      "the perimeter of a rectangle is 120 feet. the length is 3 times more than the width. find the length and width width= feet length= feet\n",
      "\n",
      "in progress 0\n",
      "5 months 2021-08-08T00:00:07+00:00 1 Answers 0 views 0\n",
      "\n",
      "1. Answer:\n",
      "\n",
      "\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.1, top_p=0.95\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ' hy'], [pos +2 -> 'oid']\n",
      "step   2: [pos +3 -> ' bone'], [pos +4 -> ' is']\n",
      "step   3: [pos +5 -> ' a'], [pos +7 -> ' the']\n",
      "step   4: [pos +6 -> ')'], [pos +8 -> ' first']\n",
      "step   5: [pos +9 -> ' ph'], [pos +10 -> 'ary']\n",
      "step   6: [pos +11 -> 'nge'], [pos +12 -> 'al']\n",
      "step   7: [pos +13 -> ' arch'], [pos +14 -> '.']\n",
      "step   8: [pos +15 -> ' The'], [pos +16 -> '<|beginoftext|>']\n",
      "step   9: [pos +18 -> 'ary'], [pos +19 -> 'ary']\n",
      "step  10: [pos +20 -> 'nge'], [pos +21 -> 'al']\n",
      "step  11: [pos +22 -> ' arch'], [pos +24 -> ' the']\n",
      "step  12: [pos +17 -> ' ph'], [pos +23 -> ' is']\n",
      "step  13: [pos +25 -> ' hy'], [pos +26 -> ' of']\n",
      "step  14: [pos +27 -> ' the'], [pos +28 -> ' hy']\n",
      "step  15: [pos +29 -> 'oid'], [pos +30 -> ' bone']\n",
      "step  16: [pos +31 -> '.'], [pos +32 -> ' The']\n",
      "step  17: [pos +33 -> ' hy'], [pos +34 -> ' ph']\n",
      "step  18: [pos +35 -> 'ary'], [pos +36 -> 'nge']\n",
      "step  19: [pos +37 -> 'al'], [pos +38 -> ' arch']\n",
      "step  20: [pos +39 -> ' is'], [pos +40 -> ' the']\n",
      "step  21: [pos +41 -> ' hy'], [pos +42 -> ' of']\n",
      "step  22: [pos +43 -> ' the'], [pos +44 -> ' hy']\n",
      "step  23: [pos +45 -> 'oid'], [pos +46 -> ' bone']\n",
      "step  24: [pos +47 -> '.'], [pos +48 -> ' The']\n",
      "step  25: [pos +49 -> ' hy'], [pos +50 -> ' ph']\n",
      "step  26: [pos +51 -> 'ary'], [pos +52 -> 'nge']\n",
      "step  27: [pos +53 -> 'al'], [pos +54 -> ' arch']\n",
      "step  28: [pos +55 -> ' is'], [pos +56 -> ' the']\n",
      "step  29: [pos +57 -> ' hy'], [pos +58 -> ' of']\n",
      "step  30: [pos +59 -> ' the'], [pos +60 -> ' hy']\n",
      "step  31: [pos +61 -> 'oid'], [pos +62 -> ' bone']\n",
      "step  32: [pos +63 -> '.'], [pos +64 -> ' The']\n",
      "step  33: [pos +65 -> ' hy'], [pos +66 -> ' ph']\n",
      "step  34: [pos +67 -> 'ary'], [pos +68 -> 'nge']\n",
      "step  35: [pos +69 -> 'al'], [pos +70 -> ' arch']\n",
      "step  36: [pos +71 -> ' is'], [pos +72 -> ' the']\n",
      "step  37: [pos +73 -> ' hy'], [pos +74 -> ' of']\n",
      "step  38: [pos +75 -> ' the'], [pos +76 -> ' hy']\n",
      "step  39: [pos +77 -> 'oid'], [pos +78 -> ' bone']\n",
      "step  40: [pos +79 -> '.'], [pos +80 -> ' The']\n",
      "\n",
      "== Final generation ==\n",
      "The hyoid bone is a) the first pharyngeal arch. The<|beginoftext|> pharyaryngeal arch is the hy of the hyoid bone. The hy pharyngeal arch is the hy of the hyoid bone. The hy pharyngeal arch is the hy of the hyoid bone. The hy pharyngeal arch is the hy of the hyoid bone. The hy pharyngeal arch is the hy of the hyoid bone. The hy pharyngeal arch is the hy of the hyoid bone. The hy pharyngeal arch is the hy of the hyoid bone.\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.2, top_p=0.95\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +0 -> 'Answer'], [pos +2 -> ' The']\n",
      "step   2: [pos +3 -> ' hy'], [pos +4 -> 'oid']\n",
      "step   3: [pos +5 -> ' bone'], [pos +6 -> ' is']\n",
      "step   4: [pos +7 -> ' a'], [pos +9 -> ' the']\n",
      "step   5: [pos +8 -> ')'], [pos +11 -> ' ph']\n",
      "step   6: [pos +10 -> ' first'], [pos +12 -> 'ary']\n",
      "step   7: [pos +13 -> 'nge'], [pos +14 -> 'al']\n",
      "step   8: [pos +15 -> ' arch'], [pos +16 -> '.']\n",
      "step   9: [pos +17 -> ' The'], [pos +18 -> '<|beginoftext|>']\n",
      "step  10: [pos +20 -> 'ary'], [pos +25 -> ' the']\n",
      "step  11: [pos +19 -> ' ph'], [pos +21 -> 'nge']\n",
      "step  12: [pos +22 -> 'al'], [pos +23 -> ' arch']\n",
      "step  13: [pos +26 -> ' hy'], [pos +30 -> ' hy']\n",
      "step  14: [pos +24 -> ' forms'], [pos +27 -> 'oid']\n",
      "step  15: [pos +28 -> ' bone'], [pos +31 -> 'oid']\n",
      "step  16: [pos +32 -> ' bone'], [pos +33 -> ' is']\n",
      "step  17: [pos +29 -> '.The'], [pos +34 -> ' a']\n",
      "step  18: [pos +36 -> '-shaped'], [pos +38 -> '-shaped']\n",
      "step  19: [pos +37 -> ' U'], [pos +39 -> ' bone']\n",
      "step  20: [pos +35 -> ' V'], [pos +42 -> ' the']\n",
      "step  21: [pos +40 -> ' located'], [pos +41 -> ' in']\n",
      "step  22: [pos +43 -> ' neck'], [pos +44 -> ' of']\n",
      "step  23: [pos +51 -> ' the'], [pos +52 -> ' the']\n",
      "step  24: [pos +67 -> ' the'], [pos +68 -> ' the']\n",
      "step  25: [pos +45 -> ' the'], [pos +49 -> ' the']\n",
      "step  26: [pos +46 -> ' the'], [pos +47 -> ' the']\n",
      "step  27: [pos +48 -> ' the'], [pos +50 -> ' the']\n",
      "step  28: [pos +53 -> ' the'], [pos +54 -> ' the']\n",
      "step  29: [pos +55 -> ' the'], [pos +56 -> ' the']\n",
      "step  30: [pos +57 -> ' the'], [pos +58 -> ' the']\n",
      "step  31: [pos +59 -> ' the'], [pos +60 -> ' the']\n",
      "step  32: [pos +61 -> ' the'], [pos +62 -> ' the']\n",
      "step  33: [pos +63 -> ' the'], [pos +64 -> ' the']\n",
      "step  34: [pos +65 -> ' the'], [pos +66 -> ' the']\n",
      "step  35: [pos +69 -> ' the'], [pos +70 -> ' the']\n",
      "step  36: [pos +71 -> ' the'], [pos +72 -> ' the']\n",
      "step  37: [pos +73 -> ' the'], [pos +74 -> ' the']\n",
      "step  38: [pos +76 -> ' the'], [pos +77 -> ' the']\n",
      "step  39: [pos +75 -> ' the'], [pos +78 -> ' the']\n",
      "step  40: [pos +81 -> ' the'], [pos +82 -> ' the']\n",
      "\n",
      "== Final generation ==\n",
      "Answer: The hyoid bone is a) the first pharyngeal arch. The<|beginoftext|> pharyngeal arch forms the hyoid bone.The hyoid bone is a V-shaped U-shaped bone located in the neck of the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the hyoid the the the the the the the the the the hyoid the the the the the hyoid the hyoid the the the the hyoid the hyoid the hyoid the hyoid the hy\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.2, top_p=1.0\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +4 -> ' ph'], [pos +9 -> 'oid']\n",
      "step   2: [pos +5 -> 'ary'], [pos +8 -> ' hy']\n",
      "step   3: [pos +6 -> 'nge'], [pos +7 -> 'al']\n",
      "step   4: [pos +14 -> ' the'], [pos +109 -> ')']\n",
      "step   5: [pos +107 -> '\\n'], [pos +110 -> ' The']\n",
      "step   6: [pos +112 -> ' ph'], [pos +114 -> 'nge']\n",
      "step   7: [pos +113 -> 'ary'], [pos +115 -> 'al']\n",
      "step   8: [pos +104 -> 'al'], [pos +116 -> ' arch']\n",
      "step   9: [pos +102 -> 'ary'], [pos +103 -> 'nge']\n",
      "step  10: [pos +101 -> ' ph'], [pos +106 -> 'es']\n",
      "step  11: [pos +99 -> ' and'], [pos +105 -> ' arch']\n",
      "step  12: [pos +96 -> ')'], [pos +97 -> ' The']\n",
      "step  13: [pos +91 -> 'nge'], [pos +117 -> '\\n']\n",
      "step  14: [pos +92 -> 'al'], [pos +93 -> ' arch']\n",
      "step  15: [pos +89 -> ' ph'], [pos +90 -> 'ary']\n",
      "step  16: [pos +86 -> ')'], [pos +98 -> ' first']\n",
      "step  17: [pos +87 -> ' The'], [pos +119 -> ')']\n",
      "step  18: [pos +88 -> ' first'], [pos +94 -> '\\n']\n",
      "step  19: [pos +111 -> ' second'], [pos +118 -> 'D']\n",
      "step  20: [pos +95 -> 'B'], [pos +108 -> 'C']\n",
      "step  21: [pos +100 -> ' second'], [pos +124 -> ' ph']\n",
      "step  22: [pos +122 -> ' and'], [pos +125 -> 'ary']\n",
      "step  23: [pos +120 -> ' The'], [pos +127 -> 'al']\n",
      "step  24: [pos +121 -> ' second'], [pos +123 -> ' third']\n",
      "step  25: [pos +85 -> 'A'], [pos +126 -> 'nge']\n",
      "step  26: [pos +75 -> ' the'], [pos +83 -> ' bone']\n",
      "step  27: [pos +80 -> ' the'], [pos +81 -> ' hy']\n",
      "step  28: [pos +79 -> ' of'], [pos +82 -> 'oid']\n",
      "step  29: [pos +76 -> ' embry'], [pos +78 -> ' origin']\n",
      "step  30: [pos +74 -> ' is'], [pos +77 -> 'ological']\n",
      "step  31: [pos +70 -> '\\n'], [pos +84 -> '?\\n']\n",
      "step  32: [pos +71 -> 'Question'], [pos +72 -> ':']\n",
      "step  33: [pos +67 -> '\\n'], [pos +73 -> ' What']\n",
      "step  34: [pos +59 -> '\\n'], [pos +69 -> 'user']\n",
      "step  35: [pos +60 -> 'You'], [pos +63 -> ' helpful']\n",
      "step  36: [pos +62 -> ' a'], [pos +64 -> ' assistant']\n",
      "step  37: [pos +61 -> ' are'], [pos +65 -> '.']\n",
      "step  38: [pos +54 -> '.'], [pos +58 -> 'system']\n",
      "step  39: [pos +18 -> 'nge'], [pos +56 -> '\\n']\n",
      "step  40: [pos +17 -> 'ary'], [pos +19 -> 'al']\n",
      "\n",
      "== Final generation ==\n",
      "C) The second pharyngeal hyoid bone is derived from the second pharyngeal arch.<|im_end|>\n",
      "<|im_start|>user\n",
      "Answer:<|im_end|>\n",
      "<|im_start|>assistant\n",
      "C) The pharyngeal hyoid bone is derived from the second pharyngeal arch.<|im_end|>\n",
      "<|im_start|>system\n",
      "You are a helpful assistant.<|im_end|>\n",
      "<|im_start|>user\n",
      "Question: What is the embryological origin of the hyoid bone?\n",
      "A) The first pharyngeal arch\n",
      "B) The first and second pharyngeal arches\n",
      "C) The second pharyngeal arch\n",
      "D) The second and third pharyngeal\n",
      "\n",
      "------------------------------------------------------------\n",
      "[Decode with] temperature=0.7, top_p=0.95\n",
      "== Commit Order (first ~40 events) ==\n",
      "step   1: [pos +1 -> ':'], [pos +4 -> ' ph']\n",
      "step   2: [pos +5 -> 'ary'], [pos +6 -> 'nge']\n",
      "step   3: [pos +7 -> 'al'], [pos +8 -> ' arch']\n",
      "step   4: [pos +0 -> 'C'], [pos +3 -> ' second']\n",
      "step   5: [pos +9 -> '<|endoftext|>'], [pos +10 -> '<|beginoftext|>']\n",
      "step   6: [pos +12 -> ' a'], [pos +13 -> ' I']\n",
      "step   7: [pos +11 -> 'Can'], [pos +78 -> ' the']\n",
      "step   8: [pos +14 -> 'UD'], [pos +79 -> ' I']\n",
      "step   9: [pos +15 -> ' cause'], [pos +80 -> 'UD']\n",
      "step  10: [pos +18 -> '?\\n\\n\\n'], [pos +19 -> 'I']\n",
      "step  11: [pos +20 -> ' have'], [pos +23 -> 'UD']\n",
      "step  12: [pos +21 -> ' an'], [pos +22 -> ' I']\n",
      "step  13: [pos +24 -> ' in'], [pos +25 -> ' but']\n",
      "step  14: [pos +17 -> ' periods'], [pos +26 -> ' I']\n",
      "step  15: [pos +16 -> ' irregular'], [pos +27 -> ' haven']\n",
      "step  16: [pos +28 -> '’t'], [pos +29 -> ' had']\n",
      "step  17: [pos +30 -> ' a'], [pos +31 -> ' period']\n",
      "step  18: [pos +32 -> ' in'], [pos +33 -> ' ']\n",
      "step  19: [pos +35 -> ' months'], [pos +36 -> '.']\n",
      "step  20: [pos +34 -> '4'], [pos +37 -> ' Why']\n",
      "step  21: [pos +38 -> '?\\n\\n\\n'], [pos +39 -> ' this']\n",
      "step  22: [pos +40 -> ' is'], [pos +41 -> ' be']\n",
      "step  23: [pos +42 -> ' a'], [pos +43 -> ' you']\n",
      "step  24: [pos +44 -> ' are'], [pos +81 -> ' and']\n",
      "step  25: [pos +45 -> ' a'], [pos +77 -> ' with']\n",
      "step  26: [pos +46 -> ' little'], [pos +76 -> ' problem']\n",
      "step  27: [pos +74 -> ' have'], [pos +75 -> ' a']\n",
      "step  28: [pos +72 -> ' you'], [pos +73 -> ' you']\n",
      "step  29: [pos +47 -> ' late'], [pos +71 -> ' that']\n",
      "step  30: [pos +70 -> ' be'], [pos +82 -> ' that']\n",
      "step  31: [pos +68 -> ' could'], [pos +69 -> ' could']\n",
      "step  32: [pos +67 -> ' it'], [pos +83 -> ' it']\n",
      "step  33: [pos +84 -> ' is'], [pos +85 -> ' be']\n",
      "step  34: [pos +48 -> ' or'], [pos +66 -> '.']\n",
      "step  35: [pos +49 -> ' you'], [pos +50 -> ' you']\n",
      "step  36: [pos +51 -> ' are'], [pos +52 -> ' pregnant']\n",
      "step  37: [pos +53 -> '.'], [pos +54 -> ' I']\n",
      "step  38: [pos +55 -> ' would'], [pos +57 -> ' you']\n",
      "step  39: [pos +56 -> ' suggest'], [pos +58 -> ' go']\n",
      "step  40: [pos +59 -> ' to'], [pos +61 -> ' g']\n",
      "\n",
      "== Final generation ==\n",
      "C: The second pharyngeal arch<|endoftext|><|beginoftext|>Can a IUD cause irregular periods?\n",
      "\n",
      "\n",
      "I have an IUD in but I haven’t had a period in 4 months. Why?\n",
      "\n",
      "\n",
      " this is be a you are a little late or you you are pregnant. I would suggest you go to your gyn and get tested. it could could be that you you have a problem with the IUD and that it is be causing irregular irregular periods. there<|beginoftext|>are a few of I IUDs that can cause spotting in women. they are be the for women only I IUD and copper copper T I IUD. they also\n"
     ]
    }
   ],
   "source": [
    "# %% [markdown]\n",
    "# # Dream 7B: 温度/Top-p 对输出与“token 固定顺序”的影响（GSM8K & MMLU）\n",
    "# 运行前：请确保已安装并能加载 Dream 7B（需要 GPU 显存）\n",
    "\n",
    "# %%\n",
    "import os\n",
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "\n",
    "# ========== 0) 模型与分词器 ==========\n",
    "MODEL_ID = \"Dream-org/Dream-v0-Base-7B\"  # 也可以换成 Instruct 版本\n",
    "dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "print(f\"Loading {MODEL_ID} on {device} ...\")\n",
    "model = AutoModel.from_pretrained(\n",
    "    MODEL_ID,\n",
    "    torch_dtype=dtype,\n",
    "    trust_remote_code=True\n",
    ").to(device).eval()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
    "\n",
    "# ========== 1) 两个任务的“具体输入例子” ==========\n",
    "# (A) GSM8K-CoT（链式思维风格，短小示例，便于快速演示）\n",
    "gsm8k_prompt = \"\"\"You are a helpful math tutor. Solve the problem step by step and give the final answer as an integer on the last line.\n",
    "Problem: Tom has 12 apples. He gives 3 to Alice and then buys 7 more. He eats 4 apples on the way home. How many apples does Tom have now?\n",
    "Reasoning: Let's think step by step.\n",
    "Answer:\"\"\"\n",
    "\n",
    "# (B) MMLU（解剖学子任务的例题，多选 A-D，最后只需要输出字母）\n",
    "mmlu_prompt = \"\"\"Question: What is the embryological origin of the hyoid bone?\n",
    "A) The first pharyngeal arch\n",
    "B) The first and second pharyngeal arches\n",
    "C) The second pharyngeal arch\n",
    "D) The second and third pharyngeal arches\n",
    "\n",
    "Answer:\"\"\"\n",
    "\n",
    "TASKS = [\n",
    "    (\"GSM8K-CoT\", gsm8k_prompt, dict(max_new_tokens=128, steps=128)),\n",
    "    (\"MMLU\",      mmlu_prompt,  dict(max_new_tokens=128,  steps=128)),\n",
    "]\n",
    "\n",
    "# ========== 2) 超参对比：不同 (temperature, top_p) ==========\n",
    "# 你可以自由增减组合；注意：当 temperature=0 时，top_p 实际不会生效（等价于贪心）。\n",
    "SETTINGS = [\n",
    "    (0.0, 1.0),    # 确定性（贪心）\n",
    "    (0.1, 0.95),\n",
    "    (0.2, 0.95),   # 低温 + nucleus 采样（常见稳妥组合）\n",
    "    (0.2, 1.0),   \n",
    "    (0.7, 0.95),   # 较高温度（更发散，便于观察差异）\n",
    "]\n",
    "\n",
    "ALG = \"entropy\"   # Dream 的位置选择策略\n",
    "ALG_TEMP = 0.0    # 位置选择的温度（通常设 0 即可）\n",
    "\n",
    "# ========== 3) 工具函数：封装调用 & 解析中间“历史” ==========\n",
    "def to_chat_input(text: str):\n",
    "    messages = [{\"role\": \"user\", \"content\": text}]\n",
    "    inputs = tokenizer.apply_chat_template(\n",
    "        messages,\n",
    "        return_tensors=\"pt\",\n",
    "        return_dict=True,\n",
    "        add_generation_prompt=True\n",
    "    )\n",
    "    return inputs.input_ids.to(device), inputs.attention_mask.to(device)\n",
    "\n",
    "def decode(ids):\n",
    "    return tokenizer.decode(ids, skip_special_tokens=False)\n",
    "\n",
    "def run_dream_once(prompt: str, temperature: float, top_p: float, max_new_tokens: int, steps: int):\n",
    "    \"\"\"调用 diffusion_generate，并尽可能解析出“逐步历史序列”\"\"\"\n",
    "    input_ids, attention_mask = to_chat_input(prompt)\n",
    "    out = model.diffusion_generate(\n",
    "        input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        max_new_tokens=max_new_tokens,\n",
    "        steps=steps,\n",
    "        temperature=temperature,\n",
    "        top_p=top_p,\n",
    "        alg=ALG,\n",
    "        alg_temp=ALG_TEMP,\n",
    "        output_history=True,             # 关键：要拿到逐步历史\n",
    "        return_dict_in_generate=True\n",
    "    )\n",
    "    final_ids = out.sequences[0].tolist()\n",
    "\n",
    "    # ---- 解析历史 ----\n",
    "    hist = None\n",
    "    # 常见几种命名 / 结构兼容\n",
    "    if hasattr(out, \"history\"):\n",
    "        hist = out.history\n",
    "    elif hasattr(out, \"sequences_history\"):\n",
    "        hist = out.sequences_history\n",
    "    elif isinstance(out, dict) and \"history\" in out:\n",
    "        hist = out[\"history\"]\n",
    "\n",
    "    # hist 可能是 list[Tensor(step_seq)] 或 Tensor[steps, batch, seq]，做统一整理\n",
    "    step_seqs = []\n",
    "    if hist is not None:\n",
    "        if isinstance(hist, (list, tuple)):\n",
    "            # 取第 0 个样本\n",
    "            for s in hist:\n",
    "                # s 可能是 [B, L] 或 [L]\n",
    "                if isinstance(s, torch.Tensor):\n",
    "                    arr = s\n",
    "                else:\n",
    "                    arr = torch.tensor(s)\n",
    "                if arr.dim() == 2:\n",
    "                    step_seqs.append(arr[0].tolist())\n",
    "                elif arr.dim() == 1:\n",
    "                    step_seqs.append(arr.tolist())\n",
    "        elif isinstance(hist, torch.Tensor):\n",
    "            # [steps, batch, seq] 或 [steps, seq]\n",
    "            if hist.dim() == 3:\n",
    "                step_seqs = [hist[t,0].tolist() for t in range(hist.shape[0])]\n",
    "            elif hist.dim() == 2:\n",
    "                step_seqs = [hist[t].tolist() for t in range(hist.shape[0])]\n",
    "\n",
    "    return {\n",
    "        \"input_len\": input_ids.shape[1],\n",
    "        \"final_ids\": final_ids,\n",
    "        \"step_seqs\": step_seqs,  # 可能为空（某些旧版不返回历史）\n",
    "    }\n",
    "\n",
    "def diff_commit_order(step_seqs, gen_start):\n",
    "    \"\"\"\n",
    "    对逐步序列做 diff，找出每一步新确定的 token 位置与 token id。\n",
    "    返回列表：[(step_idx, [(pos, token_id), ...]) , ...]\n",
    "    \"\"\"\n",
    "    events = []\n",
    "    if not step_seqs or len(step_seqs) < 2:\n",
    "        return events\n",
    "    prev = step_seqs[0]\n",
    "    for t in range(1, len(step_seqs)):\n",
    "        cur = step_seqs[t]\n",
    "        committed = []\n",
    "        # 仅关注 生成区域\n",
    "        L = min(len(prev), len(cur))\n",
    "        for pos in range(gen_start, L):\n",
    "            if prev[pos] != cur[pos]:\n",
    "                committed.append((pos, cur[pos]))\n",
    "        if committed:\n",
    "            events.append((t, committed))\n",
    "        prev = cur\n",
    "    return events\n",
    "\n",
    "def show_commit_events(events, final_ids, start, max_events=30):\n",
    "    \"\"\"\n",
    "    打印前若干个“先后固定事件”，并把 token 解码为可读字符串\n",
    "    \"\"\"\n",
    "    print(\"== Commit Order (first ~{} events) ==\".format(max_events))\n",
    "    cnt = 0\n",
    "    for t, items in events:\n",
    "        # 同一步里可能一次固定多个位置（并行细化）\n",
    "        toks = []\n",
    "        for pos, tid in items:\n",
    "            piece = tokenizer.decode([tid], skip_special_tokens=False)\n",
    "            # 提示：换行、空格等不可见字符标注一下\n",
    "            nice = piece.replace(\"\\n\", \"\\\\n\").replace(\"\\t\", \"\\\\t\")\n",
    "            toks.append(f\"[pos {pos-start:+d} -> '{nice}']\")\n",
    "        print(f\"step {t:>3d}: \" + \", \".join(toks))\n",
    "        cnt += 1\n",
    "        if cnt >= max_events:\n",
    "            break\n",
    "    # 显示最终生成的完整文本\n",
    "    gen_text = tokenizer.decode(final_ids[start:], skip_special_tokens=False)\n",
    "    print(\"\\n== Final generation ==\")\n",
    "    print(gen_text)\n",
    "\n",
    "# ========== 4) 主运行：两任务 × 多组温度/Top-p ==========\n",
    "torch.manual_seed(42)  # 小固定种子，便于复现（对采样有效）\n",
    "\n",
    "for task_name, prompt, gen_cfg in TASKS:\n",
    "    print(\"\\n\" + \"#\"*80)\n",
    "    print(f\"# Task: {task_name}\")\n",
    "    print(\"# Prompt preview:\\n\" + \"\\n\".join(prompt.splitlines()[:6]) + (\"\\n...\" if len(prompt.splitlines())>6 else \"\"))\n",
    "    print(\"# Gen config base:\", gen_cfg)\n",
    "    print(\"#\"*80)\n",
    "\n",
    "    for (T, P) in SETTINGS:\n",
    "        print(\"\\n\" + \"-\"*60)\n",
    "        print(f\"[Decode with] temperature={T}, top_p={P}\")\n",
    "        result = run_dream_once(prompt, T, P, **gen_cfg)\n",
    "        input_len = result[\"input_len\"]\n",
    "        final_ids = result[\"final_ids\"]\n",
    "        step_seqs = result[\"step_seqs\"]\n",
    "\n",
    "        if step_seqs:\n",
    "            events = diff_commit_order(step_seqs, gen_start=input_len)\n",
    "            show_commit_events(events, final_ids, start=input_len, max_events=40)\n",
    "        else:\n",
    "            # 若你的 Dream 版本未返回历史，这里仍打印最终答案，提示如何开启\n",
    "            print(\"⚠️ 当前返回对象中未找到逐步历史（output_history）。仅展示最终输出。\")\n",
    "            print(\"Final:\", tokenizer.decode(final_ids[input_len:], skip_special_tokens=False))\n",
    "            print(\"（如果你在源码中看到 `output_history=True` 已生效，但字段名不同，请在 run_dream_once 中补充解析分支。）\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae4dlm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
