{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env OMP_NUM_THREADS=16\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "from async_reasoning.solver import AsyncReasoningSolver as Solver\n",
    "from evals.tts_evaluator import TTSEvaluator\n",
    "from utils.answer_processing import find_last_valid_expression, check_equality_judge, check_equality_local_model\n",
    "\n",
    "MODEL_NAME = \"Qwen/Qwen3-32B\"  # for 48GB gpus, use \"Qwen/Qwen3-32B-AWQ\" instead\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)\n",
    "\n",
    "system_tokens = [key for key in tokenizer.vocab.keys() if key.endswith(\"SYSTEM\") or key.endswith(\"SYSTEM:\")]\n",
    "writer_forbidden_token_ix = [tokenizer.vocab[x] for x in [\"</think>\", \"<|im_start|>\", \"<|endoftext|>\"] + system_tokens]\n",
    "thinker_forbidden_token_ix = [tokenizer.vocab[x] for x in [\"</think>\", \"<|im_start|>\", \"<|im_end|>\", \"<|endoftext|>\"] + system_tokens]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "problem = \"\"\"Calculate x - x^2 + x^3 for x = 5,6,7,8. Return all 4 answers in \\\\boxed{ }.\"\"\"\n",
    "answer = \"105, 186, 201, 456\"\n",
    "solver = Solver(model, \n",
    "                tokenizer, \n",
    "                writer_forbidden_token_ix=writer_forbidden_token_ix, \n",
    "                thinker_forbidden_token_ix=thinker_forbidden_token_ix,\n",
    "                use_fast_kernel=True)  # Note: set to True if you skipped compiling the kernels otherwise set to False!\n",
    "                                       # Note: also you cannot use local judge to evaluate results while fast kernels enabled!\n",
    "writer_output_str, thinker_output_str, token_times, eos_generated = solver.solve(problem, budget=1024, display_generation_in_real_time=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = find_last_valid_expression(writer_output_str, extract_result=lambda x: x[7:-1])\n",
    "use_api_not_local = True # set to True to use the canonical openai judge\n",
    "                         # do not forget to populate utils/api_config.json before using api judge\n",
    "                         # reminder: you cannot use local judge with fast kernels\n",
    "if use_api_not_local:\n",
    "    is_equal = check_equality_judge(response, answer)\n",
    "else:\n",
    "    is_equal = check_equality_local_model(model, tokenizer, response, answer)\n",
    "print(f\"Answer is correct: {is_equal}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, Audio\n",
    "evaluator = TTSEvaluator()\n",
    "chunks, audio = evaluator.get_chunks_with_tts(token_times, k_chunks=5, return_audio=True)\n",
    "metrics = evaluator(**chunks, add_tts_in_parrallel=True, return_delays=False)\n",
    "\n",
    "indent = 2\n",
    "for k, v in metrics.items():\n",
    "    pad = \">\" * indent\n",
    "    if isinstance(v, dict):\n",
    "        print(f\"{pad} {k}:\")\n",
    "        pretty_dict(v, indent + 4)\n",
    "    else:\n",
    "        print(f\"{pad} {k}: {v}\")\n",
    "display(Audio(data=audio['frame'], rate=audio['frame_rate']))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "AsyncReasoning-public",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
