{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2e2e251b",
   "metadata": {},
   "source": [
    "# Hogwild! Parallelism: example with interleaved cache and full prompt\n",
    "\n",
    "This is a more advanced version of `basic_example.ipynb` that features a combined layout: interleaved steps with instant (token-level) synchronization. You can find a more script-friendly version of this code in [__`./generation.py`__](./generation.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aadbc3c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=0\n",
      "env: OMP_NUM_THREADS=16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a49a04d8d2444c089ef5292f1899aa6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env OMP_NUM_THREADS=16\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys; sys.path.insert(0, \"..\"); sys.path.insert(0, \"../utils\");\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import shared_cache\n",
    "from generation import MathFormatting, get_logits_processor\n",
    "from IPython.display import clear_output, display, Markdown\n",
    "\n",
    "MODEL_NAME = \"Qwen/QwQ-32B\"  # for 48gb gpu, use \"Qwen/QwQ-32B-AWQ\"\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d8f9684c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "problem = \"\"\"\n",
    "Three vertices of a cube are $P=(7,12,10)$ , $Q=(8,8,1)$ , and $R=(11,3,9)$ . What is the surface area of the cube?\n",
    "\"\"\".strip()\n",
    "\n",
    "print_every_steps = 3\n",
    "insert_s1_prompt_every_tokens = 256\n",
    "tokens_since_last_wait = 0\n",
    "\n",
    "workers = [\"Alice\", \"Bob\"]\n",
    "fmt = MathFormatting(\n",
    "    tokenizer, workers,\n",
    ")  # ^-- prompts and optional few-shot examples; has options for different model types - see formatting.py\n",
    "\n",
    "\n",
    "worker_prompts = [\n",
    "    f\"\"\"{fmt.get_step_prefix(workers[0], 1)}Hi, I'm {workers[0]}. Here's how we can collaborate\"\"\",\n",
    "    f\"\"\"{fmt.get_step_prefix(workers[1], 1)}Hi, I'm {workers[1]}.\"\"\"\n",
    "]\n",
    "\n",
    "# define cache structure for the combined layout\n",
    "cache_common, cache_current_step_header, cache_separator, cache_w1, cache_w2 = (\n",
    "    shared_cache.CacheBlock(config=model.config) for _ in range(5))\n",
    "cm = shared_cache.SharedCacheManager(cache_structure=[\n",
    "    [cache_common, cache_current_step_header, cache_w2, cache_separator, cache_w1],\n",
    "    [cache_common, cache_current_step_header, cache_w1, cache_separator, cache_w2],\n",
    "])\n",
    "\n",
    "logits_processor = get_logits_processor(model)\n",
    "tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)\n",
    "\n",
    "# initialize generation state for printing\n",
    "history = []\n",
    "current_step_index_by_worker = [1, 1]\n",
    "current_step_tokens_by_worker = [tokenizer.encode(p, add_special_tokens=False) for p in worker_prompts]\n",
    "\n",
    "# pre-fill common parts\n",
    "with torch.inference_mode():\n",
    "    model(**tokenizer([fmt.apply_chat_template(problem)], **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=cache_common);  # <-- write to common prompt\n",
    "    model(**tokenizer(fmt.current_step_header, **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=cache_current_step_header);   # <-- write to separator\n",
    "    model(**tokenizer(fmt.current_worker_header, **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=cache_separator);   # <-- write to separator between incomplete steps\n",
    "    \n",
    "next_inputs = tokenizer(worker_prompts, **tokenizer_kwargs).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29aeff2b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "[**Problem:** Three vertices of a cube are $P=(7,12,10)$ , $Q=(8,8,1)$ , and $R=(11,3,9)$ . What is the surface area of the cube?]\n",
       "\n",
       "### Past steps\n",
       "\n",
       "**Alice [1]:** Hi, I'm Alice. Here's how we can collaborate: Let's start by computing the distances between the given points to see which edges or diagonals they might be. Since three vertices are given, they could form an edge, face diagonal, or space diagonal. Let me compute the distance between P and Q first.\n",
       "\n",
       "**Bob [1]:** Hi, I'm Bob. Let me suggest we first compute the distances between the three points to see which edges or diagonals they represent. Since it's a cube, the edges are equal, face diagonals are edge*sqrt(2), and space diagonals are edge*sqrt(3). Let me start with PQ.\n",
       "\n",
       "**Bob [2]:**  I'll compute the distance between Q and R.\n",
       "\n",
       "**Bob [3]:**  I'll do QR next. Let me see, Q is (8,8,1), R is (11,3,9). So Δx=3, Δy=-5, Δz=8. Squared differences: 9 +25 +64=98. Wait, same as PQ?\n",
       "\n",
       "**Alice [2]:**  Agreed. Let me compute the distance between P=(7,12,10) and Q=(8,8,1). The differences in coordinates are Δx=1, Δy=-4, Δz=-9. Squared differences: 1² + (-4)^2 + (-9)^2 =1 +16 +81=98. So PQ distance is sqrt(98)=7*sqrt(2)≈?\n",
       "\n",
       "**Alice [3]:**  The PQ distance is sqrt(98). Now, moving to QR. Wait Bob already said he was doing QR. Let me check.\n",
       "\n",
       "**Bob [4]:** Quick check: am I doing redundant work? (yes/no):  I just realized Alice is doing PQ. Let me switch to PR. Let me compute PR. P=(7,12,10), R=(11,3,9). Δx=4, Δy=-9, Δz=-1. Squared: 16 +81 +1=98. So PR is also sqrt(98). Hmm interesting.\n",
       "\n",
       "**Bob [5]:**  PR distance is sqrt(98). Let me compute QR now. Wait, Alice is doing QR. Let me do RP instead, but that's same as PR. Hmm, perhaps I should do another pair? Wait, the three points are P, Q, R. So the distances are PQ, QR, PR. Since all three distances are equal to sqrt(98), that implies the three points form an equilateral triangle in 3D space. But in a cube, can three vertices form an equilateral triangle with all edges equal? Maybe they are edges of a regular tetrahedron inscribed in the cube? Hmm, perhaps the cube edge length is such that sqrt(edge^2 + edge^2) = sqrt(98). Let me think.\n",
       "\n",
       "**Bob [6]:** Quick check: am I doing redundant work? (yes/no):  No, I was doing PR which is done, so now let me think. Since all three distances between the points are equal, maybe the triangle is part of a regular tetrahedron inside the cube? Alternatively, perhaps each pair is a face diagonal? Let me think about cube geometry. In a cube, three vertices with all pairwise distances equal might form a regular triangle. Let me see.\n",
       "\n",
       "**Alice [4]:**  Bob, I was going to do QR next, but since you're doing PR, maybe I'll do QR after. Let me confirm: PQ is sqrt(98). Let me compute QR now. Q=(8,8,1), R=(11,3,9). Δx=3, Δy=-5, Δz=8. Squared: 9+25+64=98. So QR is sqrt(98). So all three distances between the points are sqrt(98). That means the triangle PQR is equilateral with all sides equal. But in a cube, can three vertices form an equilateral triangle with all edges equal to sqrt(98)? Wait, but in a cube, edges are all the same length, but the distances between vertices can be edges, face diagonals, or space diagonals. Since sqrt(98) is equal to 7*sqrt(2), which would be a face diagonal if edge length is 7, because face diagonal is edge*sqrt(2). Alternatively, if edge length is 7, then sqrt(98) is exactly a face diagonal. But if all three distances between the points are face diagonals, then they must form an equilateral triangle in a face? But in a cube, a face is a square, so all face diagonals are equal but arranged at 90 degrees. So three vertices with all pairwise distances equal would not lie on the same face. Alternatively, maybe they are space diagonals? A space diagonal is edge*sqrt(3). sqrt(98)=7*sqrt(2) is not a multiple of sqrt(3). Hmm.\n",
       "\n",
       "**Alice [5]:** Quick check: am I doing redundant work? (yes/no):  No, let me think. Since all three distances between P, Q, R are equal to 7√2. Let me think of the cube's edge length. Suppose the edge length is a. Then a face diagonal is a√2. If 7√2 = a√2 ⇒ a=7. So edge length is 7? If so, then the surface area would be 6*(7)^2 = 294. But wait, but how do we know that the edge length is indeed 7?\n",
       "\n",
       "### Work in progress (others)\n",
       "\n",
       "**Alice [6]:**  Alternatively, maybe the<...>\n",
       "\n",
       "**Bob [7]:**  Wait, if all three distances are equal to sqrt(98), which is 7√2, then perhaps each edge of the triangle is a face diagonal of the cube. In a cube, two face diagonals on adjacent faces would form an angle of 60 degrees? Wait, actually, in a cube, the angle between two face diagonals on adjacent faces would be 60 degrees? Let me think. Let me imagine a cube with edge length a. The angle between two face diagonals on adjacent faces can be calculated using the dot product. Suppose one face diagonal is along the x-y plane, vector (a,a,0), and another is along y-z plane, vector (0,a,a). The angle between them would be cosθ = ( (a)(0) + a(a) + 0(a) ) / ( |v1||v2|<...>\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for inference_step in range(1024):  # <-- modify the number of generation steps here\n",
    "    # run model with shared cache\n",
    "    with torch.inference_mode():\n",
    "        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]\n",
    "        logits = logits_processor(next_inputs['input_ids'], logits)\n",
    "        new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(\n",
    "        ) if model.generation_config.do_sample else logits.argmax(-1)\n",
    "    \n",
    "    # process generated tokens for printing; handle step change, update next_inputs\n",
    "    assert len(new_tokens) == len(fmt.workers)\n",
    "    next_input_tokens = new_tokens.unsqueeze(-1).tolist()    \n",
    "    for worker_index, (worker_name, worker_tokens, new_token) in enumerate(\n",
    "            zip(fmt.workers, current_step_tokens_by_worker, new_tokens.tolist())):\n",
    "        worker_tokens.append(new_token)\n",
    "        if fmt.is_end_of_step(worker_tokens):\n",
    "            # worker just finished their step - add it to common history and start a new step\n",
    "            current_step_index_by_worker[worker_index] += 1\n",
    "            history.extend(worker_tokens)\n",
    "            worker_tokens.clear()\n",
    "            start_msg = fmt.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])\n",
    "            if tokens_since_last_wait > insert_s1_prompt_every_tokens:\n",
    "                start_msg += fmt.s1_collab_message   # <-- insert \"Wait, am I doing redundant work?\"\n",
    "                tokens_since_last_wait = 0\n",
    "            worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))\n",
    "            cache_common.append_from(cm.cache_structure[worker_index][-1])\n",
    "            cm.cache_structure[worker_index][-1].clear()\n",
    "            next_input_tokens[worker_index] = [new_token] + worker_tokens\n",
    "        tokens_since_last_wait += len(next_input_tokens[worker_index])\n",
    "    next_inputs = tokenizer.pad(dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)\n",
    "\n",
    "    if inference_step % print_every_steps == 0:\n",
    "        clear_output(True)  # display current progress\n",
    "        output_parts = [f\"[**Problem:** {problem}]\\n\\n\"]\n",
    "        output_parts.append(fmt.history_header + fmt.sep + tokenizer.decode(history))\n",
    "        output_parts.append(fmt.current_step_header)\n",
    "        for worker_index, worker_tokens in enumerate(current_step_tokens_by_worker):\n",
    "            output_parts.append(tokenizer.decode(worker_tokens) + fmt.incomplete_step + fmt.sep)\n",
    "        display(Markdown(''.join(output_parts)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f86a1c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
