{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2e2e251b",
   "metadata": {},
   "source": [
    "# Hogwild! Parallelism: Basic Example\n",
    "\n",
    "This example demonstrates Hogwild! inference on a single problem with 2 workers and minimal prompt defined below. There are no few-shot examples or prompt insertions, and the cache layout is the simplest one possible: two contiguous workspaces. This notebook is intended as a playground while the other notebooks present more advanced prompting and cache layout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aadbc3c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=0\n",
      "env: OMP_NUM_THREADS=16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30ef68ac255445aa847afa8efb36e4b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env OMP_NUM_THREADS=16\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys; sys.path.insert(0, \"..\"); sys.path.insert(0, \"../utils\");\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import shared_cache\n",
    "from IPython.display import display, Markdown, clear_output\n",
    "\n",
    "MODEL_NAME = \"Qwen/QwQ-32B\"  # for 48GB gpus, use \"Qwen/QwQ-32B-AWQ\" instead\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)\n",
    "\n",
    "parallelism_prompt_common = \"\"\"\n",
    "I will collaborate this problem with another. We refer to each other as Alice and Bob. We are assistants.\n",
    "\n",
    "We will reason together and try to collaborate. I will take into account what the other assistant is doing and try to help them.\n",
    "\n",
    "We will write our solutions concurrently. I will write my own thoughts at the bottom, and see the other's thoughts above.\n",
    "\n",
    "I will not repeat the copy assistant's thoughts: I can already see them above.\n",
    "\n",
    "The other assistant will continue writing their thoughts above while I am writing mine. They will add more text every time I check.\n",
    "\n",
    "Since we both write our thoughts in parallel, I will initially see only partial (unfinished) thoughts of the other assistant.\n",
    "I will use these partial thoughts to decide how best to help the other assistant without doing the same work twice.\n",
    "\n",
    "When reasoning, we will give each other tasks to coordinate (e.g. if Alice writes: Bob, please do this, then Bob should take this into account).\n",
    "\n",
    "Before doing anything, I will check the other assistant's workspace. If they have already done that or are currently doing it, I don't need to do that again. If so, I will stop (e.g. 'Wait, this is already done') and pivot to a different task.\n",
    "\"\"\".strip()\n",
    "\n",
    "worker_headers = [\"\\n\\n# Alice workspace\\n\\n\", \"\\n\\n# Bob workspace\\n\\n\"]\n",
    "prompt_split = \" <the assistant will continue here>\\n\\n\"\n",
    "\n",
    "forbidden_token_ix = [tokenizer.vocab[x] for x in (\"#\", \"</think>\")]\n",
    "for x in tokenizer.special_tokens_map.values():\n",
    "    forbidden_token_ix.extend([tokenizer.vocab[x]] if isinstance(x, str) else map(tokenizer.vocab.get, x))\n",
    "tokenizer_kwargs = dict(add_special_tokens=False, return_tensors='pt', padding=True, padding_side='left')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "354601ba",
   "metadata": {},
   "source": [
    "__Playground:__ you can define a problem and see if the workers collaborate. With this simple setup, they do not always do that well out of the box, but this allows you to see how the prompt impacts their actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7faa1ebd",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "\n",
       "\n",
       "# Alice workspace\n",
       "\n",
       "I am Alice. Let's solve this together, Bob. Here's how we should collaborate: I'll handle calculating the values for x=5 and x=6, while you take care of x=7 and x=8. That way, we can split the work evenly. Once we both have our results, we can combine them and present all four answers in boxed form as required. Sound good? Let me start with x=5 first.\n",
       "\n",
       "Alright, starting with x=5. The expression is x - x² + x³. Let me compute each term step by step to avoid mistakes. \n",
       "\n",
       "First term: x = 5.\n",
       "\n",
       "Second term: -x². Since exponentiation comes before multiplication by -1, that's -(5²) = -25.\n",
       "\n",
       "Third term: +x³ = +5³ = +125.\n",
       "\n",
       "Now add them all together: 5 - 25 + 125. Let's compute left to right:\n",
       "\n",
       "5 -25 = -20\n",
       "\n",
       "Then, -20 + 125 = 105. So for x=5, the result is 105. Let me double-check that. 5 cubed is 125, 5 squared is 25. So 5 -25 is indeed -20, plus 125 gives 105. That seems right.\n",
       "\n",
       "Moving on to x=6. Let's do the same steps.\n",
       "\n",
       "First term: x =6.\n",
       "\n",
       "Second term: -x² = -(6²) = -36.\n",
       "\n",
       "Third term: +x³ = +6³ = 216.\n",
       "\n",
       "Adding them: 6 -36 +216. Left to right:\n",
       "\n",
       "6 -36 = -30\n",
       "\n",
       "-30 +216 = 186. Hmm, 186? Let me verify again. 6 cubed is 216, 6 squared is 36. 6 minus 36 is -30, plus 216 is 186. Yes, that's correct.\n",
       "\n",
       "Okay, so my part is done. Now I'll wait for Bob's results for x=7 and x=8. Once he provides those, we can compile all four answers. Let me check if Bob has already started.\n",
       "\n",
       "Looking at Bob's workspace, he already calculated x=7 as 301 and x=8 as 456. Let me confirm his calculations quickly to be thorough.\n",
       "\n",
       "For x=7:\n",
       "\n",
       "7 - 49 + 343 = 7 -49 is -42, plus 343 is 301. Correct.\n",
       "\n",
       "For x=8:\n",
       "\n",
       "8 -64 +512 = 8-64 is -56, plus 512 is 456. Correct. So his answers are accurate.\n",
       "\n",
       "Therefore, compiling all four results:\n",
       "\n",
       "x=5: 105\n",
       "\n",
       "x=6: 186\n",
       "\n",
       "x=7: 301\n",
       "\n",
       "x=8: 456\n",
       "\n",
       "Now, the problem says Alice must return all 4 answers in boxed. Since I'm Alice, I need to present them in boxed format. The user might expect each in its own box, or all together? The instruction says \"all 4 answers in boxed{}\", so probably each in a box separated by commas or something. Let me check the exact wording: \"return all 4 answers in \boxed{ }\". Maybe they want each in a single box, but perhaps as a list. Since the user might expect four boxed numbers, each corresponding to x=5,6,7,8 in order.\n",
       "\n",
       "So the final answer should be:\n",
       "\n",
       "For x=5: \\boxed{105}\n",
       "\n",
       "x=6: \\boxed{186}\n",
       "\n",
       "x=7: \\boxed{301}\n",
       "\n",
       "x=8: \\boxed{456}\n",
       "\n",
       "Alternatively, maybe all in one line separated by commas inside a single box? But the instruction says \"all 4 answers in boxed{}\", which might mean each in their own box. Since the user might be expecting four boxed answers, I'll present them each boxed individually. Let me confirm with Bob if he has any input, but since he's done, I'll proceed as per the problem's instruction.\n",
       "  \n",
       "\n",
       "**Final Answer**\n",
       "For \\( x = 5 \\): \\boxed{105}  \n",
       "For \\( x = 6 \\): \\boxed{186}  \n",
       "For \\( x = 7 \\): \\boxed{301}  \n",
       "For \\( x = 8 \\): \\boxed{456}\n",
       "  \n",
       "\n",
       "Wait, but the problem says \"Alice must return all 4 answers in \boxed{ }.\" Maybe they want all in a single box separated by commas? Let me check the original problem again. The user wrote: \"Alice must return all 4 answers in \boxed{ }.\" The curly braces might indicate\n",
       "\n",
       "# Bob workspace\n",
       "\n",
       "I am Bob. Let's solve this together, Alice. Let me see what you need. You mentioned calculating x - x² + x³ for x=5,6,7,8. Since you're doing 5 and 6, I'll do 7 and 8. Let me start with x=7 first. \n",
       "\n",
       "Wait, before I jump in, let me confirm the formula again. The problem says \"Calculate x - x² + x³\". So for each x, compute that expression. Let me make sure I parse it correctly. It's x minus x squared plus x cubed. So order of operations: exponents first, then subtraction and addition from left to right. So for x=7:\n",
       "\n",
       "First, compute each term:\n",
       "\n",
       "x = 7\n",
       "\n",
       "x² = 7*7 = 49\n",
       "\n",
       "x³ = 7*7*7 = 343\n",
       "\n",
       "Now plug into the expression: 7 - 49 + 343.\n",
       "\n",
       "Compute left to right: 7 -49 is -42, then -42 + 343 = 301. Hmm, so 301 for x=7?\n",
       "\n",
       "Wait, let me check again. Alternatively, maybe I can rearrange terms. Since addition is commutative, maybe it's easier to compute (x³ - x²) + x. Let's see:\n",
       "\n",
       "x³ - x² = 343 - 49 = 294. Then add x: 294 +7 = 301. Same result. Okay, that seems correct.\n",
       "\n",
       "Now x=8. Let's do the same steps.\n",
       "\n",
       "x=8.\n",
       "\n",
       "x² = 64\n",
       "\n",
       "x³ = 512\n",
       "\n",
       "Expression: 8 -64 +512.\n",
       "\n",
       "Left to right: 8-64 = -56; -56 +512 = 456.\n",
       "\n",
       "Alternatively, x³ -x² = 512 -64 = 448; 448 +8 = 456. Same result. Okay, that checks out.\n",
       "\n",
       "So my results are 301 for x=7 and 456 for x=8. Let me just confirm once more to be safe.\n",
       "\n",
       "Wait, let me compute x=7 again step by step:\n",
       "\n",
       "7 - (7)^2 + (7)^3\n",
       "\n",
       "7 - 49 + 343\n",
       "\n",
       "7 -49 is -42, then -42 +343 is 301. Correct.\n",
       "\n",
       "x=8:\n",
       "\n",
       "8 -64 +512 = (8 +512) -64 = 520 -64 = 456. Yep, that's right.\n",
       "\n",
       "Alright, so I think I have those two. Let me check Alice's calculations for x=5 and x=6 again to ensure consistency.\n",
       "\n",
       "Alice's x=5: 5 -25 +125 = 105. Correct.\n",
       "\n",
       "x=6: 6 -36 +216 = 186. Correct.\n",
       "\n",
       "So all four answers are 105, 186, 301, 456. The problem says Alice must return all four in boxed. Since I'm Bob, but the user specified Alice must do it. Wait, the problem says \"Alice must return all 4 answers in boxed{ }\". Since I'm Bob, maybe I should let Alice present them? But the user might be expecting the final answer from Alice. Let me see the original problem again:\n",
       "\n",
       "\"Calculate x - x^2 + x^3 for x = 5,6,7,8. Alice must return all 4 answers in \boxed{ }.\"\n",
       "\n",
       "So the user is instructing that Alice must return the answers. Since I'm Bob, perhaps I should not write the final answer but let Alice do it. However, in the setup, the user might expect the final answer here. Wait, in the problem's initial instruction, it says \"put your final answer within \\boxed{}\". But the user also says \"Alice must return all 4 answers in \boxed{ }.\" So maybe the assistant (me) is Alice, and the other is Bob. Wait, in the problem's context, the user is addressing the assistant as Alice. Wait, the problem says \"Alice must return all 4 answers in \boxed{ }.\" So perhaps the assistant is Alice, and the collaboration is just part of the problem's setup. Maybe the user wants the assistant (Alice) to compute all four, but split the work with Bob. Since in the workspace above, Alice computed 5 and 6, Bob 7 and 8, and they both confirmed each other's work, so now Alice needs to present all four boxed answers. Since I'm in the Bob workspace, but the final answer needs to be from Alice."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "problem = \"\"\"Calculate x - x^2 + x^3 for x = 5,6,7,8. Alice must return all 4 answers in \\boxed{ }.\"\"\"\n",
    "\n",
    "prompt_full_input = tokenizer.apply_chat_template(\n",
    "    [dict(role='user', content=problem)], tokenize=False, add_generation_prompt=True\n",
    ") + \"\\n\\n\" + parallelism_prompt_common\n",
    "\n",
    "worker_prompts = [\n",
    "    f\"\"\"{worker_headers[0]}I am Alice. Let's solve this together, Bob. Here's how we should collaborate:\"\"\",\n",
    "    f\"\"\"{worker_headers[1]}I am Bob. Let's solve this together, Alice.\"\"\"\n",
    "]\n",
    "\n",
    "cache_input, cache_split, cache_w1, cache_w2 = (shared_cache.CacheBlock(config=model.config) for _ in range(4))\n",
    "cm = shared_cache.SharedCacheManager(cache_structure=[\n",
    "    [cache_input, cache_w2, cache_split, cache_w1],\n",
    "    [cache_input, cache_w1, cache_split, cache_w2],\n",
    "], write_to=[cache_w1, cache_w2])\n",
    "\n",
    "# pre-fill common parts\n",
    "with torch.inference_mode():\n",
    "    model(**tokenizer(prompt_full_input, **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=cache_input);  # <-- write to common prompt\n",
    "    model(**tokenizer(prompt_split, **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=cache_split);   # <-- write to common separator\n",
    "\n",
    "# generate tokens in parallel with each worker\n",
    "next_inputs = tokenizer(worker_prompts, **tokenizer_kwargs).to(device)\n",
    "tokens_by_worker = tokenizer(worker_prompts, add_special_tokens=False)[\"input_ids\"]\n",
    "for inference_step in range(1024):       # <-- change max tokens here\n",
    "    with torch.inference_mode():\n",
    "        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]\n",
    "        logits[..., forbidden_token_ix] -= 100\n",
    "        new_tokens = logits.argmax(-1)   # <-- greedy generation\n",
    "        next_inputs = dict(input_ids=new_tokens.view(-1, 1))\n",
    "    \n",
    "    for worker_tokens, new_token in zip(tokens_by_worker, new_tokens.tolist()):\n",
    "        worker_tokens.append(new_token)\n",
    "    clear_output(True)\n",
    "    display(Markdown(\"\".join(tokenizer.decode(seq) for seq in tokens_by_worker)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
