{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['question', 'answer'],\n",
      "        num_rows: 7473\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['question', 'answer'],\n",
      "        num_rows: 1319\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "ename": "KeyError",
     "evalue": "'problem'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 24\u001b[0m\n\u001b[1;32m     21\u001b[0m     solution \u001b[38;5;241m=\u001b[39m entry[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124manswer\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m     22\u001b[0m     answer \u001b[38;5;241m=\u001b[39m extract_answer(solution)\n\u001b[1;32m     23\u001b[0m     new_entry \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m---> 24\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mproblem\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mentry\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mproblem\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m,\n\u001b[1;32m     25\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msolution\u001b[39m\u001b[38;5;124m\"\u001b[39m: solution,\n\u001b[1;32m     26\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswer\u001b[39m\u001b[38;5;124m\"\u001b[39m: answer,\n\u001b[1;32m     27\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifficulty\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1.0\u001b[39m,\n\u001b[1;32m     28\u001b[0m     }\n\u001b[1;32m     29\u001b[0m     dataset\u001b[38;5;241m.\u001b[39mappend(new_entry)\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgsm8k.json\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n",
      "\u001b[0;31mKeyError\u001b[0m: 'problem'"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "from rllm.rewards.math_utils import extract_answer\n",
    "\n",
    "ds = load_dataset(\"openai/gsm8k\", \"main\")\n",
    "\n",
    "print(ds)\n",
    "\n",
    "\n",
    "def extract_difficulty(level_str):\n",
    "    # Extract the number from the string and convert it to a float\n",
    "    try:\n",
    "        difficulty = float(level_str.split()[-1])\n",
    "    except (ValueError, IndexError):\n",
    "        print(f\"Error extracting difficulty from {level_str}\")\n",
    "        difficulty = 0.0\n",
    "    return difficulty\n",
    "\n",
    "\n",
    "dataset = []\n",
    "for entry in ds[\"train\"]:\n",
    "    solution = entry[\"answer\"]\n",
    "    answer = extract_answer(solution)\n",
    "    new_entry = {\n",
    "        \"problem\": entry[\"question\"],\n",
    "        \"solution\": solution,\n",
    "        \"answer\": answer,\n",
    "        \"difficulty\": -1.0,\n",
    "    }\n",
    "    dataset.append(new_entry)\n",
    "\n",
    "with open(\"gsm8k.json\", \"w\") as f:\n",
    "    json.dump(dataset, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tsj",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
