{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c3e5902",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np, threading, time\n",
    "from async_reasoning.solver import AsyncReasoningSolver, LiveContextQueue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87a79f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# the %env below are for Yandex env, remove or replace it with your own\n",
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env HF_HOME=/home/async_reasoning/.cache/huggingface\n",
    "%env OMP_NUM_THREADS=16\n",
    "\n",
    "import sys; sys.path.insert(0, \"../.\");\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "from IPython.display import display, Markdown, clear_output\n",
    "from typing import Sequence\n",
    "\n",
    "from async_reasoning.prompting import AsyncReasoningPrompting\n",
    "from async_reasoning.cache import State, AsyncReasoningCache\n",
    "\n",
    "import logging\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(filename='demo.log', encoding='utf-8', level=logging.DEBUG)\n",
    "\n",
    "MODEL_NAME = \"Qwen/Qwen3-32B-AWQ\"  # for 48GB gpus, use \"Qwen/Qwen3-32B-AWQ\" instead\n",
    "# MODEL_NAME = \"Qwen/Qwen3-0.6B\"  # for 48GB gpus, use \"Qwen/Qwen3-32B-AWQ\" instead\n",
    "# MODEL_NAME = \"Qwen/Qwen3-1.7B\"  # for 48GB gpus, use \"Qwen/Qwen3-32B-AWQ\" instead\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6f5bf76",
   "metadata": {},
   "source": [
    "Example with input injection with checking of \"end of line\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7844c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input-block injection example: waits for boundary after thinker hits a specific number of tokens\n",
    "user_input = \"USER INPUT: By p I mean p = \\sum_{k=1}^\\infty 1/k^2 and q = \\sum_{k=1}^\\infty 1/k^3; j,k start at 1\"\n",
    "\n",
    "live_input = LiveContextQueue(tokenizer, model.device)\n",
    "_state = {\"triggered\": False}\n",
    "\n",
    "def on_token_input(writer_tokens, thinker_tokens, token_times, eos, state):\n",
    "    if _state[\"triggered\"]:\n",
    "        return\n",
    "    if len(thinker_tokens) >= 100:\n",
    "        live_input.push_text(\n",
    "            user_input,\n",
    "            target=\"input\",\n",
    "            defer_until_boundary=True,  # wait for sentence/paragraph end\n",
    "        )\n",
    "        _state[\"triggered\"] = True\n",
    "\n",
    "solver_input = AsyncReasoningSolver(model, tokenizer, use_fast_kernel=False)\n",
    "\n",
    "writer_input, thinker_input, token_times_input, eos_input = solver_input.solve(\n",
    "    r\"Express S = \\sum_{j}\\sum_{k}\\frac{1}{(j+k)^3} in terms of p and q.\",\n",
    "    budget=2048,\n",
    "    display_generation_in_real_time=True,\n",
    "    live_context_queue=live_input,\n",
    "    on_new_tokens_generated=on_token_input,\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129ff77a",
   "metadata": {},
   "outputs": [],
   "source": [
    "live = LiveContextQueue(tokenizer, model.device)\n",
    "\n",
    "triggered = False\n",
    "\n",
    "def on_token(writer_tokens, thinker_tokens, token_times, eos, state):\n",
    "    global triggered\n",
    "    if triggered:\n",
    "        return \n",
    "    if len(thinker_tokens) >= 100:\n",
    "        live.push_text(\"USER INPUT: By p I mean p = \\sum_{k=1}^\\infty \\frac{1}{k^2} and q = \\sum_{k=1}^\\infty \\frac{1}{k^3}, and the double sum is over j,k = 1,2,3,\\dots\", target=\"thinker\", defer_until_boundary=True)  # or target=\"writer\" if you want\n",
    "        triggered = True\n",
    "\n",
    "solver = AsyncReasoningSolver(model, tokenizer, use_fast_kernel=False)\n",
    "\n",
    "def run_with_injections(problem, injections, delay_s=1.0):\n",
    "    live = LiveContextQueue(tokenizer, model.device)\n",
    "\n",
    "    def feeder():\n",
    "        for text in injections:\n",
    "            time.sleep(delay_s)          # wait for some tokens to stream\n",
    "            live.push_text(\"USER INPUT: ...\" + text, target=\"thinker\", defer_until_boundary=True)  # or target=\"writer\" if you want\n",
    "\n",
    "    threading.Thread(target=feeder, daemon=True).start()\n",
    "\n",
    "    writer, thinker, token_times, eos = solver.solve(\n",
    "        problem,\n",
    "        display_generation_in_real_time=True,\n",
    "        live_context_queue=live,\n",
    "    )\n",
    "    return writer, thinker, token_times, eos\n",
    "\n",
    "# CASE 1\n",
    "case1_problem = r\"Express S = \\sum_{j}\\sum_{k}\\frac{1}{(j+k)^3} in terms of p and q.\"\n",
    "case1_injections = [\n",
    "    r\"By p I mean p = \\sum_{k=1}^\\infty \\frac{1}{k^2} and q = \\sum_{k=1}^\\infty \\frac{1}{k^3}, and the double sum is over j,k = 1,2,3,\\dots\"\n",
    "]\n",
    "\n",
    "writer, thinker, token_times, eos = solver.solve(\n",
    "    case1_problem,\n",
    "    display_generation_in_real_time=True,\n",
    "    live_context_queue=live,\n",
    "    on_new_tokens_generated=on_token\n",
    ")\n",
    "\n",
    "\n",
    "# writer1, thinker1, times1, _ = run_with_injections(case1_problem, case1_injections, delay_s=5.0)\n",
    "\n",
    "\n",
    "\n",
    "# # CASE 2\n",
    "# case2_problem = r\"Let p = \\sum_{k=1}^\\infty \\frac{1}{k^2} and q = \\sum_{k=1}^\\infty \\frac{1}{k^3}. Evaluate \\sum_{j=0}^\\infty \\sum_{k=0}^\\infty \\frac{1}{(j+k)^3} in terms of p and q.\"\n",
    "# case2_injections = [\n",
    "#     \"Correction: j and k should start at 1, not 0.\",\n",
    "#     \"Oops—the exponent is 3 (not 2): it's \\\\sum_{j=1}^\\\\infty \\\\sum_{k=1}^\\\\infty \\\\frac{1}{(j+k)^3}.\"\n",
    "# ]\n",
    "# writer2, thinker2, times2, _ = run_with_injections(case2_problem, case2_injections, delay_s=1.0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9db9892e",
   "metadata": {},
   "source": [
    "Example of injection after 100 token w/o any endof paragraph check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5b4f5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "live = LiveContextQueue(tokenizer, model.device)\n",
    "\n",
    "triggered = False\n",
    "\n",
    "def on_token(writer_tokens, thinker_tokens, token_times, eos, state):\n",
    "    global triggered\n",
    "    if triggered:\n",
    "        return \n",
    "    if len(thinker_tokens) >= 100:\n",
    "        live.push_text(r\"By p I mean p = \\sum_{k=1}^\\infty \\frac{1}{k^2} and q = \\sum_{k=1}^\\infty \\frac{1}{k^3}, and the double sum is over j,k = 1,2,3,\\dots\", target=\"thinker\")\n",
    "        triggered = True\n",
    "\n",
    "solver = AsyncReasoningSolver(model, tokenizer, use_fast_kernel=False)\n",
    "\n",
    "def run_with_injections(problem, injections, delay_s=1.0):\n",
    "    live = LiveContextQueue(tokenizer, model.device)\n",
    "\n",
    "    def feeder():\n",
    "        for text in injections:\n",
    "            time.sleep(delay_s)          # wait for some tokens to stream\n",
    "            live.push_text(text, target=\"thinker\")  # or target=\"writer\" if you want\n",
    "\n",
    "    threading.Thread(target=feeder, daemon=True).start()\n",
    "\n",
    "    writer, thinker, token_times, eos = solver.solve(\n",
    "        problem,\n",
    "        display_generation_in_real_time=True,\n",
    "        live_context_queue=live,\n",
    "    )\n",
    "    return writer, thinker, token_times, eos\n",
    "\n",
    "# CASE 1\n",
    "case1_problem = r\"Express S = \\sum_{j}\\sum_{k}\\frac{1}{(j+k)^3} in terms of p and q.\"\n",
    "case1_injections = [\n",
    "    r\"By p I mean p = \\sum_{k=1}^\\infty \\frac{1}{k^2} and q = \\sum_{k=1}^\\infty \\frac{1}{k^3}, and the double sum is over j,k = 1,2,3,\\dots\"\n",
    "]\n",
    "\n",
    "writer, thinker, token_times, eos = solver.solve(\n",
    "    case1_problem,\n",
    "    display_generation_in_real_time=True,\n",
    "    live_context_queue=live,\n",
    "    on_new_tokens_generated=on_token\n",
    ")\n",
    "\n",
    "\n",
    "# writer1, thinker1, times1, _ = run_with_injections(case1_problem, case1_injections, delay_s=5.0)\n",
    "\n",
    "\n",
    "\n",
    "# # CASE 2\n",
    "# case2_problem = r\"Let p = \\sum_{k=1}^\\infty \\frac{1}{k^2} and q = \\sum_{k=1}^\\infty \\frac{1}{k^3}. Evaluate \\sum_{j=0}^\\infty \\sum_{k=0}^\\infty \\frac{1}{(j+k)^3} in terms of p and q.\"\n",
    "# case2_injections = [\n",
    "#     \"Correction: j and k should start at 1, not 0.\",\n",
    "#     \"Oops—the exponent is 3 (not 2): it's \\\\sum_{j=1}^\\\\infty \\\\sum_{k=1}^\\\\infty \\\\frac{1}{(j+k)^3}.\"\n",
    "# ]\n",
    "# writer2, thinker2, times2, _ = run_with_injections(case2_problem, case2_injections, delay_s=1.0)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
