{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e08b1a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dslabra5/.conda/envs/sae4dlm/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Dream-org/Dream-v0-Base-7B on cuda ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`torch_dtype` is deprecated! Use `dtype` instead!\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 91.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt (first lines):\n",
      "You are a helpful math tutor. Solve the problem step by step and give the final answer as an integer on the last line.\n",
      "Problem: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n",
      "Reasoning: Let's think step by step.\n",
      "Answer:\n",
      "...\n",
      "\n",
      "================================================================================\n",
      "ALG = origin\n",
      "================================================================================\n",
      "[Determinism] same_final_ids = True\n",
      "[Determinism] same_history    = True\n",
      "[Final SHA256] run1 = fd861af4843072620c2680b7e06233e28527f5d505153ab63723948c24564644\n",
      "[Final SHA256] run2 = fd861af4843072620c2680b7e06233e28527f5d505153ab63723948c24564644\n",
      "[Gen tokens] run1 = 230, run2 = 230\n",
      "\n",
      "[Final generation (run1, early-stopped & truncated)]:\n",
      "Let first out how many clips Natalia sold in May sold4 clips in April she as many clips in. So, let's how many she sold May:  2 May clips = 1/ April clips 2 clips = April clips / 2  May clips = 48 / 2 May clips = 24 clips  2 clips clips  clips = * April clips clips * 22 April clips = 48 clips  So, clips sold 48 in she 24 clips, let's find total clips she sold in April May 2 Total clips = April clips + May clips 2 Total clips = clips 24 2 Total clips =2 2 clips =2 clips So, Natal of72 clips in April and May The final answer is 7.\n",
      "\n",
      "[Step diff preview: first 125 steps with changes]:\n",
      "step    2: (pos_rel=130, ' clips')\n",
      "step    3: (pos_rel=198, ' clips')\n",
      "step    4: (pos_rel=95, ' clips'), (pos_rel=181, ' clips'), (pos_rel=216, ' clips')\n",
      "step    5: (pos_rel=136, ' ')\n",
      "step    6: (pos_rel=7, ' clips')\n",
      "step    8: (pos_rel=11, ' in'), (pos_rel=28, ' clips')\n",
      "step   12: (pos_rel=27, ' many')\n",
      "step   13: (pos_rel=137, '4')\n",
      "step   14: (pos_rel=219, ' and')\n",
      "step   16: (pos_rel=189, ' clips')\n",
      "step   17: (pos_rel=207, ' Natal')\n",
      "step   20: (pos_rel=71, ' clips')\n",
      "step   21: (pos_rel=117, ' clips')\n",
      "step   23: (pos_rel=23, ' she'), (pos_rel=65, ' '), (pos_rel=105, ' clips'), (pos_rel=145, ' ')\n",
      "step   24: (pos_rel=195, ' ')\n",
      "step   26: (pos_rel=2, ' first'), (pos_rel=41, ' sold'), (pos_rel=111, ' clips')\n",
      "step   27: (pos_rel=56, ' clips')\n",
      "step   29: (pos_rel=190, ' =')\n",
      "step   30: (pos_rel=188, ' Total')\n",
      "step   33: (pos_rel=214, '7')\n",
      "step   37: (pos_rel=124, ' clips')\n",
      "step   42: (pos_rel=169, ' clips'), (pos_rel=172, ' clips'), (pos_rel=175, ' clips')\n",
      "step   43: (pos_rel=168, ' Total')\n",
      "step   44: (pos_rel=49, ' clips'), (pos_rel=63, ' clips'), (pos_rel=166, ' ')\n",
      "step   45: (pos_rel=173, ' +'), (pos_rel=174, ' May')\n",
      "step   46: (pos_rel=122, '2'), (pos_rel=193, '2')\n",
      "step   47: (pos_rel=87, ' clips'), (pos_rel=120, '2')\n",
      "step   48: (pos_rel=226, ' ')\n",
      "step   50: (pos_rel=51, ' ')\n",
      "step   51: (pos_rel=26, ' as'), (pos_rel=163, ' April'), (pos_rel=227, '7')\n",
      "step   53: (pos_rel=78, ' '), (pos_rel=125, ' =')\n",
      "step   54: (pos_rel=20, ' in'), (pos_rel=203, ' clips')\n",
      "step   55: (pos_rel=12, ' May')\n",
      "step   57: (pos_rel=67, ' '), (pos_rel=133, ','), (pos_rel=148, ' clips')\n",
      "step   58: (pos_rel=57, ' ')\n",
      "step   60: (pos_rel=131, ' ')\n",
      "step   62: (pos_rel=118, ' *')\n",
      "step   64: (pos_rel=74, '4')\n",
      "step   65: (pos_rel=88, ' ')\n",
      "step   66: (pos_rel=153, ',')\n",
      "step   70: (pos_rel=212, ' of')\n",
      "step   72: (pos_rel=146, '2')\n",
      "step   73: (pos_rel=46, ' '), (pos_rel=84, ' '), (pos_rel=128, '8')\n",
      "step   75: (pos_rel=66, '2')\n",
      "step   76: (pos_rel=134, ' clips'), (pos_rel=135, ' sold')\n",
      "step   77: (pos_rel=55, ' April')\n",
      "step   78: (pos_rel=64, ' /'), (pos_rel=159, ' clips')\n",
      "step   79: (pos_rel=52, '1')\n",
      "step   80: (pos_rel=110, ' April')\n",
      "step   81: (pos_rel=5, ' how')\n",
      "step   82: (pos_rel=47, '2')\n",
      "step   83: (pos_rel=202, '2')\n",
      "step   84: (pos_rel=0, 'Let')\n",
      "step   88: (pos_rel=119, ' '), (pos_rel=155, ''s')\n",
      "step   90: (pos_rel=161, ' sold')\n",
      "step   91: (pos_rel=50, ' ='), (pos_rel=162, ' in')\n",
      "step   92: (pos_rel=176, ' ')\n",
      "step   93: (pos_rel=179, ' clips')\n",
      "step   94: (pos_rel=33, ','), (pos_rel=93, '2')\n",
      "step   95: (pos_rel=222, ' The')\n",
      "step   96: (pos_rel=215, '2')\n",
      "step  101: (pos_rel=81, ' May')\n",
      "step  102: (pos_rel=100, ' clips')\n",
      "step  104: (pos_rel=29, ' in')\n",
      "step  106: (pos_rel=17, '4'), (pos_rel=35, ''s'), (pos_rel=184, '4')\n",
      "step  108: (pos_rel=6, ' many')\n",
      "step  109: (pos_rel=62, ' April'), (pos_rel=178, ' Total')\n",
      "step  111: (pos_rel=199, ' =')\n",
      "step  113: (pos_rel=72, ' ='), (pos_rel=85, '2')\n",
      "step  115: (pos_rel=31, '.'), (pos_rel=147, '4')\n",
      "step  119: (pos_rel=73, ' '), (pos_rel=187, '2')\n",
      "step  120: (pos_rel=10, ' sold')\n",
      "step  123: (pos_rel=140, ' in')\n",
      "step  124: (pos_rel=82, ' clips')\n",
      "step  125: (pos_rel=70, ' May')\n",
      "step  127: (pos_rel=218, ' April')\n",
      "step  128: (pos_rel=15, ' sold'), (pos_rel=32, ' So')\n",
      "step  129: (pos_rel=19, ' clips')\n",
      "step  130: (pos_rel=77, ' /'), (pos_rel=123, ' April')\n",
      "step  131: (pos_rel=217, ' in')\n",
      "step  132: (pos_rel=158, ' total')\n",
      "step  133: (pos_rel=180, ' =')\n",
      "step  134: (pos_rel=48, ' May')\n",
      "step  135: (pos_rel=127, '4')\n",
      "step  136: (pos_rel=8, ' Natal'), (pos_rel=58, '2'), (pos_rel=196, '2')\n",
      "step  140: (pos_rel=205, ' So')\n",
      "step  143: (pos_rel=154, ' let'), (pos_rel=156, ' find')\n",
      "step  144: (pos_rel=106, ' =')\n",
      "step  145: (pos_rel=86, '4')\n",
      "step  147: (pos_rel=183, '2')\n",
      "step  150: (pos_rel=40, ' she'), (pos_rel=206, ',')\n",
      "step  152: (pos_rel=223, ' final')\n",
      "step  153: (pos_rel=109, ' *')\n",
      "step  155: (pos_rel=167, '2')\n",
      "step  157: (pos_rel=4, ' out'), (pos_rel=92, ' '), (pos_rel=132, ' So'), (pos_rel=138, '8')\n",
      "step  158: (pos_rel=177, '2'), (pos_rel=229, '.')\n",
      "step  159: (pos_rel=45, ' ')\n",
      "step  160: (pos_rel=61, ' =')\n",
      "step  161: (pos_rel=43, ' May')\n",
      "step  163: (pos_rel=37, ' how')\n",
      "step  164: (pos_rel=60, ' clips')\n",
      "step  165: (pos_rel=38, ' many'), (pos_rel=53, '/')\n",
      "step  166: (pos_rel=21, ' April')\n",
      "step  168: (pos_rel=9, 'ia'), (pos_rel=225, ' is')\n",
      "step  172: (pos_rel=160, ' she'), (pos_rel=220, ' May')\n",
      "step  173: (pos_rel=75, '8')\n",
      "step  174: (pos_rel=34, ' let')\n",
      "step  175: (pos_rel=143, ' she')\n",
      "step  176: (pos_rel=165, ' May')\n",
      "step  177: (pos_rel=186, ' ')\n",
      "step  179: (pos_rel=83, ' ='), (pos_rel=126, ' '), (pos_rel=171, ' April')\n",
      "step  180: (pos_rel=80, '2'), (pos_rel=224, ' answer')\n",
      "step  181: (pos_rel=102, ' '), (pos_rel=182, ' ')\n",
      "step  182: (pos_rel=44, ':'), (pos_rel=170, ' =')\n",
      "\n",
      "================================================================================\n",
      "ALG = entropy\n",
      "================================================================================\n",
      "[Determinism] same_final_ids = True\n",
      "[Determinism] same_history    = True\n",
      "[Final SHA256] run1 = 6719d19aaf878bf9b818be978ee27db2534d3ff5ea906c0a0fa8842dd5161627\n",
      "[Final SHA256] run2 = 6719d19aaf878bf9b818be978ee27db2534d3ff5ea906c0a0fa8842dd5161627\n",
      "[Gen tokens] run1 = 256, run2 = 256\n",
      "\n",
      "[Final generation (run1, early-stopped & truncated)]:\n",
      "Alright, let's tackle this problem step by step. Natalia sold clips to 48 of her friends in April. That's straightforward. Now, in May, she sold half as many clips as she did in April. Hmm, okay, so I need to figure out how many clips she sold in May and then add that to the 48 she sold in April to get the total.\n",
      "\n",
      "First, let's break it down. In April, she sold 48 clips. That's clear. Now, in May, she sold half as many. So, half of 48 is... let me calculate that. 48 divided by 2 is 24. So, in May, she sold 24 clips.\n",
      "\n",
      "Now, to find the total number of clips she sold in April and May, I need to add the number from April to the number from May. That would be 48 plus 24. Let me do that addition. 48 plus 24 is... hmm, 48 plus 20 is 68, and then plus 4 more is 72. So, 48 plus 24 equals 72.\n",
      "\n",
      "Wait, let me double-check that. to\n",
      "\n",
      "[Step diff preview: first 125 steps with changes]:\n",
      "step    1: (pos_rel=0, 'Alright')\n",
      "step    2: (pos_rel=1, ',')\n",
      "step    3: (pos_rel=2, ' let')\n",
      "step    4: (pos_rel=3, ''s')\n",
      "step    5: (pos_rel=4, ' tackle')\n",
      "step    6: (pos_rel=5, ' this')\n",
      "step    7: (pos_rel=6, ' problem')\n",
      "step    8: (pos_rel=8, ' by')\n",
      "step    9: (pos_rel=9, ' step')\n",
      "step   10: (pos_rel=7, ' step')\n",
      "step   11: (pos_rel=10, '.')\n",
      "step   12: (pos_rel=12, 'ia')\n",
      "step   13: (pos_rel=11, ' Natal')\n",
      "step   14: (pos_rel=13, ' sold')\n",
      "step   15: (pos_rel=14, ' clips')\n",
      "step   16: (pos_rel=15, ' to')\n",
      "step   17: (pos_rel=16, ' ')\n",
      "step   18: (pos_rel=17, '4')\n",
      "step   19: (pos_rel=18, '8')\n",
      "step   20: (pos_rel=19, ' of')\n",
      "step   21: (pos_rel=20, ' her')\n",
      "step   22: (pos_rel=21, ' friends')\n",
      "step   23: (pos_rel=22, ' in')\n",
      "step   24: (pos_rel=23, ' April')\n",
      "step   25: (pos_rel=24, '.')\n",
      "step   26: (pos_rel=25, ' That')\n",
      "step   27: (pos_rel=27, ' straightforward')\n",
      "step   28: (pos_rel=26, ''s')\n",
      "step   29: (pos_rel=28, '.')\n",
      "step   30: (pos_rel=30, ',')\n",
      "step   31: (pos_rel=29, ' Now')\n",
      "step   32: (pos_rel=32, ' May')\n",
      "step   33: (pos_rel=33, ',')\n",
      "step   34: (pos_rel=31, ' in')\n",
      "step   35: (pos_rel=34, ' she')\n",
      "step   36: (pos_rel=35, ' sold')\n",
      "step   37: (pos_rel=38, ' many')\n",
      "step   38: (pos_rel=36, ' half')\n",
      "step   39: (pos_rel=37, ' as')\n",
      "step   40: (pos_rel=39, ' clips')\n",
      "step   41: (pos_rel=40, ' as')\n",
      "step   42: (pos_rel=41, ' she')\n",
      "step   43: (pos_rel=43, ' in')\n",
      "step   44: (pos_rel=44, ' April')\n",
      "step   45: (pos_rel=45, '.')\n",
      "step   46: (pos_rel=42, ' did')\n",
      "step   47: (pos_rel=47, ',')\n",
      "step   48: (pos_rel=46, ' Hmm')\n",
      "step   49: (pos_rel=48, ' okay')\n",
      "step   50: (pos_rel=49, ',')\n",
      "step   51: (pos_rel=50, ' so')\n",
      "step   52: (pos_rel=52, ' need')\n",
      "step   53: (pos_rel=53, ' to')\n",
      "step   54: (pos_rel=55, ' out')\n",
      "step   55: (pos_rel=54, ' figure')\n",
      "step   56: (pos_rel=51, ' I')\n",
      "step   57: (pos_rel=57, ' many')\n",
      "step   58: (pos_rel=56, ' how')\n",
      "step   59: (pos_rel=58, ' clips')\n",
      "step   60: (pos_rel=60, ' sold')\n",
      "step   61: (pos_rel=59, ' she')\n",
      "step   62: (pos_rel=61, ' in')\n",
      "step   63: (pos_rel=62, ' May')\n",
      "step   64: (pos_rel=64, ' then')\n",
      "step   65: (pos_rel=63, ' and')\n",
      "step   66: (pos_rel=65, ' add')\n",
      "step   67: (pos_rel=66, ' that')\n",
      "step   68: (pos_rel=67, ' to')\n",
      "step   69: (pos_rel=68, ' the')\n",
      "step   70: (pos_rel=69, ' ')\n",
      "step   71: (pos_rel=70, '4')\n",
      "step   72: (pos_rel=71, '8')\n",
      "step   73: (pos_rel=72, ' she')\n",
      "step   74: (pos_rel=73, ' sold')\n",
      "step   75: (pos_rel=74, ' in')\n",
      "step   76: (pos_rel=75, ' April')\n",
      "step   77: (pos_rel=76, ' to')\n",
      "step   78: (pos_rel=78, ' the')\n",
      "step   79: (pos_rel=79, ' total')\n",
      "step   80: (pos_rel=77, ' get')\n",
      "step   81: (pos_rel=80, '.\\n\\n')\n",
      "step   82: (pos_rel=81, 'First')\n",
      "step   83: (pos_rel=82, ',')\n",
      "step   84: (pos_rel=83, ' let')\n",
      "step   85: (pos_rel=84, ''s')\n",
      "step   86: (pos_rel=85, ' break')\n",
      "step   87: (pos_rel=86, ' it')\n",
      "step   88: (pos_rel=87, ' down')\n",
      "step   89: (pos_rel=88, '.')\n",
      "step   90: (pos_rel=90, ' April')\n",
      "step   91: (pos_rel=89, ' In')\n",
      "step   92: (pos_rel=91, ',')\n",
      "step   93: (pos_rel=93, ' sold')\n",
      "step   94: (pos_rel=92, ' she')\n",
      "step   95: (pos_rel=95, '4')\n",
      "step   96: (pos_rel=96, '8')\n",
      "step   97: (pos_rel=94, ' ')\n",
      "step   98: (pos_rel=97, ' clips')\n",
      "step   99: (pos_rel=98, '.')\n",
      "step  100: (pos_rel=99, ' That')\n",
      "step  101: (pos_rel=100, ''s')\n",
      "step  102: (pos_rel=102, '.')\n",
      "step  103: (pos_rel=104, ',')\n",
      "step  104: (pos_rel=103, ' Now')\n",
      "step  105: (pos_rel=106, ' May')\n",
      "step  106: (pos_rel=107, ',')\n",
      "step  107: (pos_rel=105, ' in')\n",
      "step  108: (pos_rel=109, ' sold')\n",
      "step  109: (pos_rel=108, ' she')\n",
      "step  110: (pos_rel=110, ' half')\n",
      "step  111: (pos_rel=111, ' as')\n",
      "step  112: (pos_rel=112, ' many')\n",
      "step  113: (pos_rel=113, '.')\n",
      "step  114: (pos_rel=114, ' So')\n",
      "step  115: (pos_rel=115, ',')\n",
      "step  116: (pos_rel=101, ' clear')\n",
      "step  117: (pos_rel=116, ' half')\n",
      "step  118: (pos_rel=117, ' of')\n",
      "step  119: (pos_rel=119, '4')\n",
      "step  120: (pos_rel=118, ' ')\n",
      "step  121: (pos_rel=120, '8')\n",
      "step  122: (pos_rel=121, ' is')\n",
      "step  123: (pos_rel=122, '...')\n",
      "step  124: (pos_rel=123, ' let')\n",
      "step  125: (pos_rel=124, ' me')\n"
     ]
    }
   ],
   "source": [
    "#!/usr/bin/env python3\n",
    "# -*- coding: utf-8 -*-\n",
    "\n",
    "import os\n",
    "import hashlib\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Tuple, Optional\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 0) Config\n",
    "# =========================\n",
    "MODEL_ID = \"Dream-org/Dream-v0-Base-7B\"\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n",
    "\n",
    "SEED = 42\n",
    "\n",
    "# For GSM8K in your setup\n",
    "MAX_NEW_TOKENS = 256\n",
    "DIFFUSION_STEPS = 256\n",
    "\n",
    "TEMPERATURE = 0.0\n",
    "TOP_P = 1.0\n",
    "\n",
    "ALG_TEMP = 0.0\n",
    "ALGS_TO_COMPARE = [\"origin\", \"entropy\"]\n",
    "\n",
    "# If True, enforce stricter deterministic behavior (may error if some ops are not supported).\n",
    "STRICT_DETERMINISM = True\n",
    "\n",
    "GSM8K_PROMPT = \"\"\"You are a helpful math tutor. Solve the problem step by step and give the final answer as an integer on the last line.\n",
    "Problem: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n",
    "Reasoning: Let's think step by step.\n",
    "Answer:\"\"\"\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 1) Determinism helpers\n",
    "# =========================\n",
    "def set_global_determinism(seed: int, strict: bool = True) -> None:\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "    os.environ.setdefault(\"CUBLAS_WORKSPACE_CONFIG\", \":16:8\")\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "    if strict:\n",
    "        try:\n",
    "            torch.use_deterministic_algorithms(True)\n",
    "        except Exception as e:\n",
    "            print(f\"[Warning] torch.use_deterministic_algorithms(True) failed: {repr(e)}\")\n",
    "            print(\"[Warning] Continue with strict=False behavior for deterministic algorithms.\")\n",
    "            torch.use_deterministic_algorithms(False)\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 2) Token/history utilities\n",
    "# =========================\n",
    "def sha256_int_list(xs: List[int]) -> str:\n",
    "    m = hashlib.sha256()\n",
    "    m.update((\",\".join(map(str, xs))).encode(\"utf-8\"))\n",
    "    return m.hexdigest()\n",
    "\n",
    "\n",
    "def to_chat_input(tokenizer, text: str, device: str) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "    messages = [{\"role\": \"user\", \"content\": text}]\n",
    "    inputs = tokenizer.apply_chat_template(\n",
    "        messages,\n",
    "        return_tensors=\"pt\",\n",
    "        return_dict=True,\n",
    "        add_generation_prompt=True,\n",
    "    )\n",
    "    return inputs.input_ids.to(device), inputs.attention_mask.to(device)\n",
    "\n",
    "\n",
    "def extract_history(out: Any) -> Optional[Any]:\n",
    "    if hasattr(out, \"history\"):\n",
    "        return out.history\n",
    "    if hasattr(out, \"sequences_history\"):\n",
    "        return out.sequences_history\n",
    "    if isinstance(out, dict) and \"history\" in out:\n",
    "        return out[\"history\"]\n",
    "    return None\n",
    "\n",
    "\n",
    "def normalize_step_seqs(hist: Any) -> List[List[int]]:\n",
    "    step_seqs: List[List[int]] = []\n",
    "    if hist is None:\n",
    "        return step_seqs\n",
    "\n",
    "    if isinstance(hist, (list, tuple)):\n",
    "        for s in hist:\n",
    "            if isinstance(s, torch.Tensor):\n",
    "                arr = s.detach().cpu()\n",
    "            else:\n",
    "                arr = torch.tensor(s)\n",
    "            if arr.dim() == 2:\n",
    "                step_seqs.append(arr[0].tolist())\n",
    "            elif arr.dim() == 1:\n",
    "                step_seqs.append(arr.tolist())\n",
    "            else:\n",
    "                raise ValueError(f\"Unsupported history item shape: {tuple(arr.shape)}\")\n",
    "        return step_seqs\n",
    "\n",
    "    if isinstance(hist, torch.Tensor):\n",
    "        arr = hist.detach().cpu()\n",
    "        if arr.dim() == 3:\n",
    "            for t in range(arr.shape[0]):\n",
    "                step_seqs.append(arr[t, 0].tolist())\n",
    "        elif arr.dim() == 2:\n",
    "            for t in range(arr.shape[0]):\n",
    "                step_seqs.append(arr[t].tolist())\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported history tensor shape: {tuple(arr.shape)}\")\n",
    "        return step_seqs\n",
    "\n",
    "    raise ValueError(f\"Unsupported history type: {type(hist)}\")\n",
    "\n",
    "\n",
    "def build_stop_sequences(tokenizer) -> List[List[int]]:\n",
    "    \"\"\"\n",
    "    Build stop token-id sequences. Prefer single-token stop markers if possible.\n",
    "    We include:\n",
    "      - tokenizer.eos_token_id (if any)\n",
    "      - tokenized \"<|im_end|>\" (common for chat templates)\n",
    "      - tokenized \"</s>\" and \"<|endoftext|>\" as fallbacks\n",
    "    \"\"\"\n",
    "    stop_seqs: List[List[int]] = []\n",
    "\n",
    "    # EOS id\n",
    "    if getattr(tokenizer, \"eos_token_id\", None) is not None:\n",
    "        stop_seqs.append([int(tokenizer.eos_token_id)])\n",
    "\n",
    "    # Common chat end markers\n",
    "    for s in [\"<|im_end|>\", \"</s>\", \"<|endoftext|>\"]:\n",
    "        try:\n",
    "            ids = tokenizer.encode(s, add_special_tokens=False)\n",
    "            if isinstance(ids, list) and len(ids) > 0:\n",
    "                stop_seqs.append([int(x) for x in ids])\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "    # Deduplicate\n",
    "    uniq = []\n",
    "    seen = set()\n",
    "    for seq in stop_seqs:\n",
    "        key = tuple(seq)\n",
    "        if key not in seen:\n",
    "            uniq.append(seq)\n",
    "            seen.add(key)\n",
    "    return uniq\n",
    "\n",
    "\n",
    "def find_first_stop(gen_ids: List[int], stop_seqs: List[List[int]]) -> Optional[Tuple[int, int]]:\n",
    "    \"\"\"\n",
    "    Return (index, length) of earliest stop sequence match in gen_ids.\n",
    "    \"\"\"\n",
    "    best: Optional[Tuple[int, int]] = None\n",
    "    n = len(gen_ids)\n",
    "    for i in range(n):\n",
    "        for seq in stop_seqs:\n",
    "            m = len(seq)\n",
    "            if m == 0 or i + m > n:\n",
    "                continue\n",
    "            if gen_ids[i : i + m] == seq:\n",
    "                if best is None or i < best[0]:\n",
    "                    best = (i, m)\n",
    "        if best is not None and best[0] == i:\n",
    "            # can't get earlier than i once matched at this i\n",
    "            return best\n",
    "    return best\n",
    "\n",
    "\n",
    "def apply_early_stop_from_history(\n",
    "    step_seqs: List[List[int]],\n",
    "    input_len: int,\n",
    "    stop_seqs: List[List[int]],\n",
    ") -> Tuple[List[List[int]], int]:\n",
    "    \"\"\"\n",
    "    Use history to simulate \"stop when stop token appears\":\n",
    "      - Find the earliest step where a stop sequence appears in the generated region.\n",
    "      - Return truncated step_seqs (up to that step) and gen_end_abs (exclusive, stop token removed).\n",
    "    If no stop token ever appears, keep all steps and gen_end_abs = input_len + MAX_NEW_TOKENS (bounded by sequence length).\n",
    "    \"\"\"\n",
    "    if len(step_seqs) == 0:\n",
    "        return step_seqs, input_len\n",
    "\n",
    "    best_step: Optional[int] = None\n",
    "    best_cut_abs: Optional[int] = None\n",
    "\n",
    "    for t, ids in enumerate(step_seqs):\n",
    "        gen_ids = ids[input_len:]\n",
    "        hit = find_first_stop(gen_ids, stop_seqs)\n",
    "        if hit is None:\n",
    "            continue\n",
    "        idx, _m = hit\n",
    "        cut_abs = input_len + idx  # exclude stop token(s)\n",
    "        best_step = t\n",
    "        best_cut_abs = cut_abs\n",
    "        break  # earliest step\n",
    "\n",
    "    if best_step is None:\n",
    "        # No stop token found; keep as-is, cut to available length\n",
    "        last_len = len(step_seqs[-1])\n",
    "        gen_end_abs = min(last_len, input_len + MAX_NEW_TOKENS)\n",
    "        return step_seqs, gen_end_abs\n",
    "\n",
    "    # Keep steps up to best_step, and cut gen_end_abs to best_cut_abs\n",
    "    truncated_steps = step_seqs[: best_step + 1]\n",
    "    return truncated_steps, int(best_cut_abs)\n",
    "\n",
    "\n",
    "def diff_by_step(\n",
    "    tokenizer,\n",
    "    step_seqs: List[List[int]],\n",
    "    gen_start: int,\n",
    "    gen_end_abs: int,\n",
    ") -> List[Dict[str, Any]]:\n",
    "    \"\"\"\n",
    "    Compare step t-1 -> t, record changed positions in [gen_start, gen_end_abs).\n",
    "    \"\"\"\n",
    "    diffs: List[Dict[str, Any]] = []\n",
    "    if len(step_seqs) < 2:\n",
    "        return diffs\n",
    "\n",
    "    prev = step_seqs[0]\n",
    "    for t in range(1, len(step_seqs)):\n",
    "        cur = step_seqs[t]\n",
    "        L = min(len(prev), len(cur), gen_end_abs)\n",
    "        changes = []\n",
    "        for pos in range(gen_start, L):\n",
    "            if prev[pos] != cur[pos]:\n",
    "                tid = cur[pos]\n",
    "                piece = tokenizer.decode([tid], skip_special_tokens=False)\n",
    "                piece_vis = piece.replace(\"\\n\", \"\\\\n\").replace(\"\\t\", \"\\\\t\")\n",
    "                changes.append(\n",
    "                    {\n",
    "                        \"pos_abs\": pos,\n",
    "                        \"pos_rel\": pos - gen_start,\n",
    "                        \"token_id\": int(tid),\n",
    "                        \"token_str\": piece_vis,\n",
    "                    }\n",
    "                )\n",
    "        diffs.append({\"step\": t, \"num_changes\": len(changes), \"changes\": changes})\n",
    "        prev = cur\n",
    "    return diffs\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 3) Dream run wrapper\n",
    "# =========================\n",
    "@dataclass\n",
    "class DreamRunResult:\n",
    "    alg: str\n",
    "    run_id: int\n",
    "    input_len: int\n",
    "    final_ids: List[int]\n",
    "    step_seqs: List[List[int]]\n",
    "    gen_end_abs: int\n",
    "    final_text: str\n",
    "\n",
    "\n",
    "def run_dream_once(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    prompt: str,\n",
    "    alg: str,\n",
    "    run_id: int,\n",
    "    seed: int,\n",
    "    stop_seqs: List[List[int]],\n",
    ") -> DreamRunResult:\n",
    "    set_global_determinism(seed, strict=STRICT_DETERMINISM)\n",
    "\n",
    "    input_ids, attention_mask = to_chat_input(tokenizer, prompt, DEVICE)\n",
    "\n",
    "    out = model.diffusion_generate(\n",
    "        inputs=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        max_new_tokens=MAX_NEW_TOKENS,\n",
    "        steps=DIFFUSION_STEPS,\n",
    "        temperature=TEMPERATURE,\n",
    "        top_p=TOP_P,\n",
    "        alg=alg,\n",
    "        alg_temp=ALG_TEMP,\n",
    "        output_history=True,\n",
    "        return_dict_in_generate=True,\n",
    "    )\n",
    "\n",
    "    hist = extract_history(out)\n",
    "    step_seqs = normalize_step_seqs(hist)\n",
    "    if len(step_seqs) == 0:\n",
    "        raise RuntimeError(\n",
    "            \"No history found. Please ensure your Dream version supports output_history=True \"\n",
    "            \"and the returned object exposes history/sequences_history.\"\n",
    "        )\n",
    "\n",
    "    input_len = int(input_ids.shape[1])\n",
    "\n",
    "    # Early-stop using history: choose earliest step where stop token appears\n",
    "    step_seqs, gen_end_abs = apply_early_stop_from_history(step_seqs, input_len, stop_seqs)\n",
    "\n",
    "    # Define final_ids as the sequence at the stopping step (or last step if no stop)\n",
    "    final_ids_full = step_seqs[-1]\n",
    "    final_ids = final_ids_full[:gen_end_abs]  # exclude stop token and anything after\n",
    "\n",
    "    # Decode only generated region up to gen_end_abs\n",
    "    gen_ids = final_ids[input_len:gen_end_abs]\n",
    "    final_text = tokenizer.decode(gen_ids, skip_special_tokens=True)\n",
    "\n",
    "    return DreamRunResult(\n",
    "        alg=alg,\n",
    "        run_id=run_id,\n",
    "        input_len=input_len,\n",
    "        final_ids=final_ids,\n",
    "        step_seqs=step_seqs,\n",
    "        gen_end_abs=gen_end_abs,\n",
    "        final_text=final_text,\n",
    "    )\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 4) Main\n",
    "# =========================\n",
    "def main() -> None:\n",
    "    print(f\"Loading {MODEL_ID} on {DEVICE} ...\")\n",
    "    model = AutoModel.from_pretrained(\n",
    "        MODEL_ID,\n",
    "        torch_dtype=DTYPE,\n",
    "        trust_remote_code=True,\n",
    "    ).to(DEVICE).eval()\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
    "    stop_seqs = build_stop_sequences(tokenizer)\n",
    "\n",
    "    print(\"Prompt (first lines):\")\n",
    "    for line in GSM8K_PROMPT.splitlines()[:6]:\n",
    "        print(line)\n",
    "    print(\"...\")\n",
    "\n",
    "    for alg in ALGS_TO_COMPARE:\n",
    "        print(\"\\n\" + \"=\" * 80)\n",
    "        print(f\"ALG = {alg}\")\n",
    "        print(\"=\" * 80)\n",
    "\n",
    "        r1 = run_dream_once(model, tokenizer, GSM8K_PROMPT, alg=alg, run_id=1, seed=SEED, stop_seqs=stop_seqs)\n",
    "        r2 = run_dream_once(model, tokenizer, GSM8K_PROMPT, alg=alg, run_id=2, seed=SEED, stop_seqs=stop_seqs)\n",
    "\n",
    "        same_final = (r1.final_ids == r2.final_ids)\n",
    "        same_hist = (len(r1.step_seqs) == len(r2.step_seqs)) and all(\n",
    "            (a == b) for a, b in zip(r1.step_seqs, r2.step_seqs)\n",
    "        )\n",
    "\n",
    "        print(f\"[Determinism] same_final_ids = {same_final}\")\n",
    "        print(f\"[Determinism] same_history    = {same_hist}\")\n",
    "        print(f\"[Final SHA256] run1 = {sha256_int_list(r1.final_ids)}\")\n",
    "        print(f\"[Final SHA256] run2 = {sha256_int_list(r2.final_ids)}\")\n",
    "        print(f\"[Gen tokens] run1 = {len(r1.final_ids) - r1.input_len}, run2 = {len(r2.final_ids) - r2.input_len}\")\n",
    "\n",
    "        # Diff by step using run1\n",
    "        diffs = diff_by_step(tokenizer, r1.step_seqs, gen_start=r1.input_len, gen_end_abs=r1.gen_end_abs)\n",
    "\n",
    "        print(\"\\n[Final generation (run1, early-stopped & truncated)]:\")\n",
    "        print(r1.final_text)\n",
    "\n",
    "        print(\"\\n[Step diff preview: first 125 steps with changes]:\")\n",
    "        shown = 0\n",
    "        for row in diffs:\n",
    "            if row[\"num_changes\"] == 0:\n",
    "                continue\n",
    "            ch = row[\"changes\"]\n",
    "            pieces = \", \".join([f\"(pos_rel={c['pos_rel']}, '{c['token_str']}')\" for c in ch])\n",
    "            print(f\"step {row['step']:>4d}: {pieces}\")\n",
    "            shown += 1\n",
    "            if shown >= 125:\n",
    "                break\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae4dlm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
