{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2e2e251b",
   "metadata": {},
   "source": [
    "# Hogwild! Parallelism: Basic Example\n",
    "\n",
    "This example demonstrates Hogwild! inference on a single problem with 2 workers and minimal prompt defined below. There are no few-shot examples or prompt insertions, and the cache layout is the simplest one possible: two contiguous workspaces. This notebook is intended as a playground while the other notebooks present more advanced prompting and cache layout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac5d102f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env OMP_NUM_THREADS=16\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys; sys.path.insert(0, \"..\"); sys.path.insert(0, \"../utils\");sys.path.insert(0, \"../kernels\");\n",
    "import random\n",
    "from copy import deepcopy\n",
    "from typing import Dict, NamedTuple, Sequence, Optional\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import transformers\n",
    "import torch\n",
    "import transformers\n",
    "import shared_cache\n",
    "from generation import MathFormatting, get_logits_processor\n",
    "\n",
    "from IPython.display import display, Markdown, clear_output\n",
    "\n",
    "\n",
    "import shared_cache\n",
    "from hogwild import model_surgery, HogwildCache, merge_caches\n",
    "from formatting import FormattingBase, MathFormatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69aee56f",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"Qwen/QwQ-32B-AWQ\"  # for 48GB gpus, use \"Qwen/QwQ-32B-AWQ\" instead\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)\n",
    "\n",
    "max_steps=8196\n",
    "print_every_steps=1\n",
    "insert_s1_collab_message_every_tokens=1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "229a95ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem = \"\"\"\n",
    "Calculate x - x^2 + x^3 for x = 5,6,7,8. Alice must return all 4 answers in \\boxed{ }.\"\"\".strip()\n",
    "\n",
    "print_every_steps = 3\n",
    "insert_s1_prompt_every_tokens = 256\n",
    "tokens_since_last_wait = 0\n",
    "\n",
    "workers = [\"Alice\", \"Bob\"]\n",
    "fmt = MathFormatting(\n",
    "    tokenizer, workers,\n",
    ")  # ^-- prompts and optional few-shot examples; has options for different model types - see formatting.py\n",
    "\n",
    "\n",
    "worker_prompts = [\n",
    "    f\"\"\"{fmt.get_step_prefix(workers[0], 1)}Hi, I'm {workers[0]}. Here's how we can collaborate\"\"\",\n",
    "    f\"\"\"{fmt.get_step_prefix(workers[1], 1)}Hi, I'm {workers[1]}.\"\"\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a4fc49",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_processor = get_logits_processor(model)\n",
    "device = next(model.parameters()).device\n",
    "tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10096495",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_since_last_wait = 0\n",
    "cache_common, cache_current_step_header, cache_own_header, cache_w1, cache_w2 = (transformers.DynamicCache() for _ in range(5))\n",
    "cm = HogwildCache(cache_structure=[\n",
    "    [cache_common, cache_current_step_header, cache_w2, cache_own_header, cache_w1],\n",
    "    [cache_common, cache_current_step_header, cache_w1, cache_own_header, cache_w2],\n",
    "], write_to=[cache_w1, cache_w2], model=model)\n",
    "\n",
    "w_prompt_caches = {\n",
    "    0: HogwildCache(cache_structure=[\n",
    "        [cache_common, cache_current_step_header, cache_w2, cache_own_header, cache_w1],\n",
    "    ], write_to=[cache_w1], model=model),\n",
    "    1: HogwildCache(cache_structure=[\n",
    "        [cache_common, cache_current_step_header, cache_w1, cache_own_header, cache_w2],\n",
    "    ], write_to=[cache_w2], model=model)\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d118820",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_surgery(model)\n",
    "model = torch.compile(model)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd629b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fmt = MathFormatting(tokenizer, extract_result=lambda box: int(\"\".join(x for x in box if x.isdigit())))\n",
    "\n",
    "# pre-fill common cache parts\n",
    "with torch.inference_mode():\n",
    "    model(**tokenizer(fmt.apply_chat_template(problem), **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=HogwildCache([[cache_common]], model=model))  # <-- write to common prompt\n",
    "\n",
    "    model(**tokenizer(fmt.current_step_header, **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=HogwildCache([[cache_current_step_header]], model=model))  # <-- write to the separator after history\n",
    "\n",
    "    model(**tokenizer(fmt.current_worker_header, **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=HogwildCache([[cache_own_header]], model=model))  # <-- write to separator between incomplete steps\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a1f8d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate interdependent reasoning chains in parallel\n",
    "current_step_index_by_worker = [1, 1]\n",
    "current_step_tokens_by_worker = tokenizer(list(fmt.worker_prompts), add_special_tokens=False)['input_ids']\n",
    "history = list()\n",
    "next_input_tokens = [None, None]\n",
    "worker_prompt_tokens = [tokenizer.encode(fmt.worker_prompts[i], add_special_tokens=False) for i in range(2)]\n",
    "min_length = min(map(len, worker_prompt_tokens)) - 1\n",
    "# one forward pass to process as many tokens as possible\n",
    "model(\n",
    "    input_ids=torch.tensor([worker_prompt_tokens[0][:min_length], worker_prompt_tokens[1][:min_length]], dtype=torch.int32, device=\"cuda\"),\n",
    "    #**tokenizer([tokenizer.decode(worker_prompt_tokens[0][:min_length]), tokenizer.decode(worker_prompt_tokens[1][:min_length])], **tokenizer_kwargs).to(device),\n",
    "    use_cache=True, past_key_values=cm)\n",
    "\n",
    "for i in range(2):\n",
    "    if min_length == len(worker_prompt_tokens[i]) - 1:\n",
    "        next_input_tokens[i] = worker_prompt_tokens[i][-1:]\n",
    "        continue\n",
    "    # pre-populate worker's cache so that only one new token remains\n",
    "    model(**tokenizer(tokenizer.decode(worker_prompt_tokens[i][min_length:-1]), **tokenizer_kwargs).to(device),\n",
    "          use_cache=True, past_key_values=w_prompt_caches[i])\n",
    "    next_input_tokens[i] = worker_prompt_tokens[i][-1:]\n",
    "next_inputs = tokenizer.pad(\n",
    "        dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d63f56f6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for inference_step in range(max_steps):\n",
    "    # run model with a shared cache (batched inference)\n",
    "    with torch.inference_mode():\n",
    "        logits = model(**next_inputs, use_cache=True, past_key_values=cm,\n",
    "                       position_ids=torch.tensor([[cache_w1.get_seq_length()], [cache_w2.get_seq_length()]],\n",
    "                                                 dtype=torch.int64, device=\"cuda\")\n",
    "                       ).logits[..., -1, :]\n",
    "        logits = logits_processor(next_inputs['input_ids'], logits)\n",
    "        new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(\n",
    "        ) if model.generation_config.do_sample else logits.argmax(-1)\n",
    "        assert len(new_tokens) == len(fmt.workers)\n",
    "\n",
    "    # process generated tokens for printing; handle step change, update next_inputs\n",
    "    next_input_tokens = new_tokens.unsqueeze(-1).tolist()\n",
    "    for worker_index, (worker_name, worker_tokens, new_token) in enumerate(\n",
    "            zip(fmt.workers, current_step_tokens_by_worker, new_tokens.tolist())):\n",
    "        worker_tokens.append(new_token)\n",
    "        if fmt.is_end_of_step(worker_tokens):\n",
    "            # worker just finished their step - add it to common history and start a new step\n",
    "            current_step_index_by_worker[worker_index] += 1\n",
    "            history.extend(worker_tokens)\n",
    "\n",
    "            merge_caches(cache_common, cm.cache_structure[worker_index][-1], model.model)\n",
    "\n",
    "            worker_tokens.clear()\n",
    "            start_msg = fmt.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])\n",
    "            if tokens_since_last_wait > insert_s1_collab_message_every_tokens:\n",
    "                start_msg += fmt.s1_collab_message\n",
    "                tokens_since_last_wait = 0\n",
    "            worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))\n",
    "            cm.cache_structure[worker_index][-1].crop(0)\n",
    "            # pre-populate worker's cache so that only one new token remains\n",
    "            model(input_ids=torch.tensor([[new_token] + worker_tokens[:-1]], dtype=torch.int32, device=\"cuda\"),\n",
    "                      use_cache=True, past_key_values=w_prompt_caches[worker_index])\n",
    "            next_input_tokens[worker_index] = worker_tokens[-1:]\n",
    "        tokens_since_last_wait += len(next_input_tokens[worker_index])\n",
    "    next_inputs = tokenizer.pad(\n",
    "        dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)\n",
    "\n",
    "    if inference_step % print_every_steps == 0:\n",
    "        clear_output(True)  # display current progress\n",
    "        output_parts = [f\"[**Problem:** {problem}]\\n\\n\"]\n",
    "        output_parts.append(fmt.history_header + fmt.sep + tokenizer.decode(history))\n",
    "        output_parts.append(fmt.current_step_header)\n",
    "        for worker_index, worker_tokens in enumerate(current_step_tokens_by_worker):\n",
    "            output_parts.append(tokenizer.decode(worker_tokens) + fmt.incomplete_step + fmt.sep)\n",
    "        display(Markdown(''.join(output_parts)))\n",
    "    if torch.any(new_tokens == tokenizer.eos_token_id).item():\n",
    "        break  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
