{
 "cells": [
  {
   "cell_type": "code",
   "id": "db23b611-c239-406e-8173-5c299b5b96eb",
   "metadata": {},
   "source": [
    "import os\n",
    "import re\n",
    "import pickle\n",
    "import pandas as pd\n",
    "from vllm import LLM, SamplingParams\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "\n",
    "def find_max_strict_step(text):\n",
    "    pattern = r'step\\s*(\\d+):\\n'\n",
    "    matches = re.findall(pattern, text, flags=re.IGNORECASE)\n",
    "    if not matches:\n",
    "        return None\n",
    "    return max(int(num) for num in matches)\n",
    "system = \"\"\"You are a mathematics teacher reviewing a solution that appears to be missing one step. Given the position of the missing step, your task is to fill in the missing step.\n",
    "The steps in the solution are labeled from Step 0 (problem statement) to Step N.\n",
    "Please format your response as:\n",
    "The missing step is:\n",
    "[Write the complete missing step here with necessary explanations and equations]\n",
    "\"\"\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3\"\n",
    "df = pd.read_json(\"ScaleQM+_test.json\")\n",
    "llm = LLM(model=\"CoT-Bridge-Random\", tensor_parallel_size=4, gpu_memory_utilization=0.85)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"CoT-Bridge-Random\")\n",
    "texts = []\n",
    "for i in range(len(df)):\n",
    "    max_step = find_max_strict_step(df.iloc[i][\"messages\"][1][\"content\"])\n",
    "    for j in range(0, max_step):\n",
    "        prompt1 = f\"\"\"There is a missing step between Step {j} and Step {j+1}.\n",
    "{df.iloc[i][\"messages\"][1][\"content\"]}\n",
    "\"\"\"\n",
    "        messages = [\n",
    "            {\"role\": \"system\",\n",
    "             \"content\": system},\n",
    "            {\"role\": \"user\",\n",
    "             \"content\": prompt1},\n",
    "        ]\n",
    "        \n",
    "        text = tokenizer.apply_chat_template(\n",
    "            messages,\n",
    "            tokenize=False,\n",
    "            add_generation_prompt=True\n",
    "        )\n",
    "        texts.append(text)\n",
    "\n",
    "outputs = llm.generate(\n",
    "    texts,\n",
    "    SamplingParams(\n",
    "    temperature=0,\n",
    "    max_tokens=1024,\n",
    "    skip_special_tokens=True\n",
    ")\n",
    ")\n",
    "\n",
    "results = []\n",
    "for i in range(len(outputs)):\n",
    "    results.append(outputs[i].outputs[0].text)\n",
    "\n",
    "with open('results-sim.pkl', 'wb') as f:\n",
    "    pickle.dump(results, f)"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
