{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cd7b45f3-22e4-4235-b9cc-eb911e6211c5",
   "metadata": {},
   "source": [
    "# PRM Statistics"
   ]
  },
  {
   "cell_type": "code",
   "id": "77e44eba-3231-4bfa-87f4-39ffd90e2f92",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "from vllm import LLM\n",
    "import os\n",
    "import pickle\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from collections import Counter\n",
    "import random\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "import torch.nn.functional as F\n",
    "from datasets import load_dataset\n",
    "\n",
    "save_path = \"NuminaMath-CoT\"\n",
    "dataset = load_dataset(save_path)\n",
    "filtered_train = dataset[\"train\"]\n",
    "file_path = 'bridge_results.pkl'\n",
    "with open(file_path, 'rb') as file:\n",
    "    loaded_results = pickle.load(file)\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b6cbd1c0-8c68-48b5-8145-0e5003f2f698",
   "metadata": {},
   "source": [
    "def extract_missing_steps(text):\n",
    "    pattern = re.compile(\n",
    "        r'Missing Step (\\d+)：\\s*'\n",
    "        r'The missing step should be placed between Step (\\d+) and Step (\\d+)\\.\\s*'\n",
    "        r'The missing step is:\\s*(.*?)'\n",
    "        r'(?=\\s*Missing Step \\d+：|\\Z)',\n",
    "        re.DOTALL\n",
    "    )\n",
    "    matches = pattern.findall(text)\n",
    "    results = []\n",
    "    for match in matches:\n",
    "        a, x, y, z = match\n",
    "        results.append({\n",
    "            'a': a.strip(),\n",
    "            'x': x.strip(),\n",
    "            'y': y.strip(),\n",
    "            'z': z.strip()\n",
    "        })\n",
    "    return results\n",
    "\n",
    "\n",
    "def filter_results(results):\n",
    "    filtered_results = []\n",
    "    for result in results:\n",
    "        x = int(result['x']) \n",
    "        y = int(result['y']) \n",
    "        z = result['z']\n",
    "        if x + 1 == y and not z.startswith('####'):\n",
    "            filtered_results.append(result)\n",
    "    return filtered_results\n",
    "\n",
    "\n",
    "def sort_results_by_x(results):\n",
    "    sorted_results = sorted(results, key=lambda result: int(result['x']))\n",
    "    return sorted_results\n",
    "\n",
    "\n",
    "def process_text(text):\n",
    "    results = extract_missing_steps(text)\n",
    "    filtered_results = filter_results(results)\n",
    "    sorted_results = sort_results_by_x(filtered_results)\n",
    "    return sorted_results\n",
    "\n",
    "def process(i):\n",
    "    query = filtered_train[i]['problem']\n",
    "    response = filtered_train[i]['solution']\n",
    "    data = response.split(\"\\n\\n\")\n",
    "    result = \"\\n\".join([f\"step{i + 1}:\\n{item}\" for i, item in enumerate(data)])\n",
    "    temp = process_text(loaded_results[i])\n",
    "    insert_pos = []\n",
    "    for i in range(len(temp)):\n",
    "        x = temp[i]['x']\n",
    "        missing_step = temp[i]['z']\n",
    "        data.insert(int(x) + i, missing_step)\n",
    "        insert_pos.append(int(x) + i)\n",
    "    insert_result = \"\\n\".join([f\"step{i + 1}:\\n{item}\" for i, item in enumerate(data)])\n",
    "    output = \"\\n\".join([item for item in data])\n",
    "    sy = \"You are a math problem solver. You should think step by step.\"\n",
    "    return query, data, insert_pos"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4f7066ee-8024-449b-ac93-579aa893106a",
   "metadata": {},
   "source": [
    "query, m, insert_pos = [], [], []\n",
    "for i in tqdm(range(len(filtered_train))):\n",
    "    t1, t2, t3 = process(i)\n",
    "    query.append(t1)\n",
    "    m.append(t2)\n",
    "    insert_pos.append(t3)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b3cb1c2b-4a65-4763-9ef6-ded907bceaa8",
   "metadata": {},
   "source": [
    "with open(\"prm_result.pkl\", 'rb') as file:\n",
    "    results = pickle.load(file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "faba0498-e6f3-4f9b-89ef-1da179d6f49a",
   "metadata": {},
   "source": [
    "temp = 0\n",
    "def get_values_by_indices(values, indices):\n",
    "    global temp\n",
    "    try:\n",
    "        return [values[i] for i in indices]\n",
    "    except:\n",
    "        temp += len(indices)\n",
    "        return \"error data\"\n",
    "insert_score = []\n",
    "for i in range(len(results)):\n",
    "    insert_score.append(get_values_by_indices(results[i], insert_pos[i]))\n",
    "print(temp)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e00fa557-0a1c-4613-866d-bf5de0f7efcc",
   "metadata": {},
   "source": [
    "distribution = [0] * 10\n",
    "insert_context = [[] for _ in range(10)]\n",
    "total_sum = 0.0 \n",
    "count = 0   \n",
    "min_val = 1.0 \n",
    "max_val = 0.0 \n",
    "\n",
    "for i in range(len(insert_score)):\n",
    "    if insert_score[i] == \"error data\":\n",
    "        continue\n",
    "    else:\n",
    "        for j, score in enumerate(insert_score[i]):\n",
    "            idx = int(score * 10)\n",
    "            idx = min(idx, 9)  \n",
    "            distribution[idx] += 1\n",
    "            insert_context[idx].append(m[i][insert_pos[i][j]]) \n",
    "            total_sum += score\n",
    "            count += 1\n",
    "            if score < min_val:\n",
    "                min_val = score\n",
    "            if score > max_val:\n",
    "                max_val = score\n",
    "distribution[0] += temp\n",
    "mean = total_sum / count if count > 0 else 0.0\n",
    "print(f\"error: {temp}\")\n",
    "print(f\"distribution: {distribution}\")\n",
    "print(f\"avg: {mean:.4f}\")\n",
    "print(f\"min: {min_val:.4f}, max: {max_val:.4f}\")\n",
    "for i in range(len(distribution)):\n",
    "    print(distribution[i]/sum(distribution))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "1d8db17d-8217-4a2f-9fcb-c00e60c3201c",
   "metadata": {},
   "source": [
    "# construct denoising data (remove bridged steps of low prm score)"
   ]
  },
  {
   "cell_type": "code",
   "id": "7c714641-033a-45c7-930d-cadd07778bb2",
   "metadata": {},
   "source": [
    "def process_prm(i):\n",
    "    query = filtered_train[i]['problem']\n",
    "    response = filtered_train[i]['solution']\n",
    "    sy = \"You are a math problem solver. You should think step by step.\"\n",
    "    if insert_score[i] == \"error data\":\n",
    "        mess = {\n",
    "            \"messages\": [\n",
    "                {\n",
    "                    \"role\": \"system\",\n",
    "                    \"content\": sy\n",
    "                },\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": query\n",
    "                },\n",
    "                {\n",
    "                    \"role\": \"assistant\",\n",
    "                    \"content\": response\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "        return mess\n",
    "    else:\n",
    "        delete_pos = []\n",
    "        for j, score in enumerate(insert_score[i]):\n",
    "            if score < 0.1:\n",
    "                delete_pos.append(insert_pos[i][j])\n",
    "        temp = []\n",
    "        for j in range(len(m[i])):\n",
    "            if j not in delete_pos:\n",
    "                temp.append(m[i][j])\n",
    "        output = \"\\n\".join([item for item in temp])\n",
    "        mess = {\n",
    "            \"messages\": [\n",
    "                {\n",
    "                    \"role\": \"system\",\n",
    "                    \"content\": sy\n",
    "                },\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": query\n",
    "                },\n",
    "                {\n",
    "                    \"role\": \"assistant\",\n",
    "                    \"content\": output\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "        return mess"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e7be8a11-2203-43c3-b929-a8cf56482b57",
   "metadata": {},
   "source": [
    "process_data = []\n",
    "for i in tqdm(range(len(insert_score))):\n",
    "    mess = process_prm(i)\n",
    "    process_data.append(mess)\n",
    "with open('numina-math-multi-fill-prm-0.1.json', 'w') as f:\n",
    "    json.dump(process_data, f, ensure_ascii=False, indent=4)\n",
    "\n",
    "print(len(process_data))\n",
    "print(\"process_data and process_idx have been saved to JSON files.\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4f1dbe79-72fa-4fb9-8a6f-cdf1323cca92",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
