{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk\n",
    "from datasets import load_dataset\n",
    "from tqdm.autonotebook import tqdm\n",
    "from pprint import pprint\n",
    "from src.pdl.optimize.parse_number import parse_number\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 183715.29 examples/s]\n",
      "Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 777771.26 examples/s]\n"
     ]
    }
   ],
   "source": [
    "gsm8k = load_dataset(\"openai/gsm8k\", \"main\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 7473/7473 [00:00<00:00, 32790.56 examples/s]\n",
      "Map: 100%|██████████| 1319/1319 [00:00<00:00, 46052.88 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def parse_answers(row):\n",
    "    question = row[\"question\"].strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    parts = row[\"answer\"].split(\"####\")\n",
    "    answer = parse_number(parts[-1])\n",
    "    reasoning = \"####\".join(parts[:-1]).strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    return {\n",
    "        \"question\": question,\n",
    "        \"answer\": answer,\n",
    "        \"reasoning\": reasoning,\n",
    "        \"raw_answer\": row[\"answer\"],\n",
    "        \"answer_part\": parts[-1],\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k = gsm8k.map(parse_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 7473/7473 [00:00<00:00, 14437.73 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def react_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip()}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\n",
    "                    \"thought\": f\"{thought.strip().replace('  ', ' ')}. I need to calculate {exp}\"\n",
    "                },\n",
    "                {\"action\": f\"Calculator[{exp}]\"},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "    if next(iter(trajectory[-1].keys())) == \"observation\":\n",
    "        trajectory.append({\"thought\": f\"The answer is {answer}\"})\n",
    "\n",
    "    trajectory.append({\"action\": f\"Finish[{answer}]\"})\n",
    "\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\n",
    "        \"traj_keys\": traj_keys,\n",
    "        \"traj_values\": traj_values,\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(react_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 7473/7473 [00:00<00:00, 11128.45 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def rewoo_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip().replace(\"  \", \" \")}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\"thought\": f\"{thought.strip().replace('  ', ' ')}. Calculate {exp}\"},\n",
    "                {\"action\": f\"Calculator[{exp}]\"},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "\n",
    "    evidence_counter = 0\n",
    "    for i in range(len(trajectory)):\n",
    "        outer = trajectory[i]\n",
    "        type_event = next(iter(outer.keys()))\n",
    "        value = next(iter(outer.values()))\n",
    "\n",
    "        if type_event == \"action\":\n",
    "            evidence_counter += 1\n",
    "        if type_event == \"observation\":\n",
    "            for j in range(i + 1, len(trajectory)):\n",
    "                inner = trajectory[j]\n",
    "                inner_type_event = next(iter(inner.keys()))\n",
    "                if inner_type_event == \"action\":\n",
    "                    trajectory[j][\"action\"] = trajectory[j][\"action\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "                elif inner_type_event == \"thought\":\n",
    "                    trajectory[j][\"thought\"] = trajectory[j][\"thought\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\"rewoo_traj_keys\": traj_keys, \"rewoo_traj_values\": traj_values}\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(rewoo_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 7473/7473 [00:00<00:00, 341762.17 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 1319/1319 [00:00<00:00, 329518.55 examples/s]\n"
     ]
    }
   ],
   "source": [
    "gsm8k.save_to_disk(\"var/gsm8k_proc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf var/gsm8k_proc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['question', 'answer', 'reasoning', 'raw_answer', 'answer_part', 'traj_keys', 'traj_values', 'rewoo_traj_keys', 'rewoo_traj_values'],\n",
       "        num_rows: 7473\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['question', 'answer', 'reasoning', 'raw_answer', 'answer_part'],\n",
       "        num_rows: 1319\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "load_from_disk(\"var/gsm8k_proc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sympy tool\n",
    "\n",
    "Example:\n",
    "Let x be the cost of the pencil.\n",
    "If the pen costs 2 times the cost of the pencil, then it costs 2x.\n",
    "Adding the cost of the pen and pencil we get 2x + x = 3x\n",
    "Since the total cost is $6 then 3x = $6 therefore x = $6 / 3 = $2\n",
    "One pen is equal to 2 * x which is 2 * $2 = $4\n",
    "\n",
    "Use symbolic calculator?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?\n",
    "Tho: To make a robe, you need 2 bolts of blue fiber. I need to write 2\n",
    "Act: Write[2]\n",
    "Obs: Invalid action. Valid actions are Calculator[<expression>] and Finish[<answer>].\n",
    "Tho: You also need half as many bolts of white fiber. I need to calculate 2 / 4\n",
    "Act: Calculator[2 / 4]\n",
    "Obs: 0.5\n",
    "Tho: Thus, you need 0.5 bolts of white fiber. I need to write 0.5\n",
    "Act: Write[0.5]\n",
    "Obs: Invalid action. Valid actions are Calculator[<expression>] and Finish[<answer>].\n",
    "Tho: The answer is 2.5\n",
    "Act: Finish[2.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'answer': 250,\n",
      " 'answer_part': ' 250',\n",
      " 'question': 'The moon is made of 50% iron, 20% carbon, and the remainder is '\n",
      "             'other elements. Mars weighs twice as much as the moon, but has '\n",
      "             'the exact same composition. If Mars is 150 tons of other '\n",
      "             'elements, how many tons does the moon weigh?',\n",
      " 'raw_answer': '30% of Mars is made up of other elements because 100 - 50 - 20 '\n",
      "               '= <<100-50-20=30>>30\\n'\n",
      "               'Mars weighs 500 tons because 150 / .3 = <<150/.3=500>>500\\n'\n",
      "               'The moon weighs 250 tons because 500 / 2 = <<500/2=250>>250\\n'\n",
      "               '#### 250',\n",
      " 'reasoning': '30% of Mars is made up of other elements because 100 - 50 - 20 '\n",
      "              '= <<100-50-20=30>>30\\n'\n",
      "              'Mars weighs 500 tons because 150 / .3 = <<150/.3=500>>500\\n'\n",
      "              'The moon weighs 250 tons because 500 / 2 = <<500/2=250>>250',\n",
      " 'traj_keys': ['question',\n",
      "               'thought',\n",
      "               'action',\n",
      "               'observation',\n",
      "               'thought',\n",
      "               'action',\n",
      "               'observation',\n",
      "               'thought',\n",
      "               'action',\n",
      "               'observation',\n",
      "               'thought',\n",
      "               'action'],\n",
      " 'traj_values': ['The moon is made of 50% iron, 20% carbon, and the remainder '\n",
      "                 'is other elements. Mars weighs twice as much as the moon, '\n",
      "                 'but has the exact same composition. If Mars is 150 tons of '\n",
      "                 'other elements, how many tons does the moon weigh?',\n",
      "                 '30% of Mars is made up of other elements because 100 - 50 - '\n",
      "                 '20. I need to calculate 100-50-20',\n",
      "                 'Calculator[100-50-20]',\n",
      "                 '30',\n",
      "                 'Mars weighs 500 tons because 150 / .3. I need to calculate '\n",
      "                 '150/.3',\n",
      "                 'Calculator[150/.3]',\n",
      "                 '500',\n",
      "                 'The moon weighs 250 tons because 500 / 2. I need to '\n",
      "                 'calculate 500/2',\n",
      "                 'Calculator[500/2]',\n",
      "                 '250',\n",
      "                 'The answer is 250',\n",
      "                 'Finish[250]'],\n",
      " 'trajectory': [{'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': 'The moon is made of 50% iron, 20% carbon, and '\n",
      "                             'the remainder is other elements. Mars weighs '\n",
      "                             'twice as much as the moon, but has the exact '\n",
      "                             'same composition. If Mars is 150 tons of other '\n",
      "                             'elements, how many tons does the moon weigh?',\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': '30% of Mars is made up of other elements because '\n",
      "                            '100 - 50 - 20. I need to calculate 100-50-20'},\n",
      "                {'action': 'Calculator[100-50-20]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': '30',\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': 'Mars weighs 500 tons because 150 / .3. I need to '\n",
      "                            'calculate 150/.3'},\n",
      "                {'action': 'Calculator[150/.3]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': '500',\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': 'The moon weighs 250 tons because 500 / 2. I need '\n",
      "                            'to calculate 500/2'},\n",
      "                {'action': 'Calculator[500/2]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': '250',\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': 'The answer is 250'},\n",
      "                {'action': 'Finish[250]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None}]}\n",
      "{'answer': 96,\n",
      " 'answer_part': ' 96',\n",
      " 'question': 'The moon has a surface area that is 1/5 that of Earth. The '\n",
      "             'surface area of the Earth is 200 square acres. The land on the '\n",
      "             'moon is worth 6 times that of the land on the Earth. If the '\n",
      "             'total value of all the land on the earth is 80 billion dollars, '\n",
      "             'what is the total value in billions of all the land on the moon?',\n",
      " 'raw_answer': 'If the moon land had the same value as earth land it would be '\n",
      "               'worth 16 billion because 80 / 5 = <<80/5=16>>16\\n'\n",
      "               \"The moon's total land value is 96 billion because 16 x 6 = \"\n",
      "               '<<16*6=96>>96\\n'\n",
      "               '#### 96',\n",
      " 'reasoning': 'If the moon land had the same value as earth land it would be '\n",
      "              'worth 16 billion because 80 / 5 = <<80/5=16>>16\\n'\n",
      "              \"The moon's total land value is 96 billion because 16 x 6 = \"\n",
      "              '<<16*6=96>>96',\n",
      " 'traj_keys': ['question',\n",
      "               'thought',\n",
      "               'action',\n",
      "               'observation',\n",
      "               'thought',\n",
      "               'action',\n",
      "               'observation',\n",
      "               'thought',\n",
      "               'action'],\n",
      " 'traj_values': ['The moon has a surface area that is 1/5 that of Earth. The '\n",
      "                 'surface area of the Earth is 200 square acres. The land on '\n",
      "                 'the moon is worth 6 times that of the land on the Earth. If '\n",
      "                 'the total value of all the land on the earth is 80 billion '\n",
      "                 'dollars, what is the total value in billions of all the land '\n",
      "                 'on the moon?',\n",
      "                 'If the moon land had the same value as earth land it would '\n",
      "                 'be worth 16 billion because 80 / 5. I need to calculate 80/5',\n",
      "                 'Calculator[80/5]',\n",
      "                 '16',\n",
      "                 \"The moon's total land value is 96 billion because 16 x 6. I \"\n",
      "                 'need to calculate 16*6',\n",
      "                 'Calculator[16*6]',\n",
      "                 '96',\n",
      "                 'The answer is 96',\n",
      "                 'Finish[96]'],\n",
      " 'trajectory': [{'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': 'The moon has a surface area that is 1/5 that of '\n",
      "                             'Earth. The surface area of the Earth is 200 '\n",
      "                             'square acres. The land on the moon is worth 6 '\n",
      "                             'times that of the land on the Earth. If the '\n",
      "                             'total value of all the land on the earth is 80 '\n",
      "                             'billion dollars, what is the total value in '\n",
      "                             'billions of all the land on the moon?',\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': 'If the moon land had the same value as earth land '\n",
      "                            'it would be worth 16 billion because 80 / 5. I '\n",
      "                            'need to calculate 80/5'},\n",
      "                {'action': 'Calculator[80/5]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': '16',\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': \"The moon's total land value is 96 billion because \"\n",
      "                            '16 x 6. I need to calculate 16*6'},\n",
      "                {'action': 'Calculator[16*6]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': '96',\n",
      "                 'question': None,\n",
      "                 'thought': None},\n",
      "                {'action': None,\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': 'The answer is 96'},\n",
      "                {'action': 'Finish[96]',\n",
      "                 'observation': None,\n",
      "                 'question': None,\n",
      "                 'thought': None}]}\n"
     ]
    }
   ],
   "source": [
    "for x in gsm8kk[\"train\"]:\n",
    "    if \"The moon\" in x[\"question\"]:\n",
    "        pprint(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# trajectory bootstrapping\n",
    "\n",
    "manual conversion of ~5 examples\n",
    "\n",
    "Cot question/reasoning/answer\n",
    "ReAct: question/thoughts/observations/answer\n",
    "\n",
    "Cot question/reasoning/answer query\n",
    "get React trajectory, if answer matches groundtruth, use, otherwise resample\n",
    "- can add more examples to prompt to improve results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb01e912c2de4ff9a089d21a858821e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/6449 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5993a5de0754602b23cf61a8a63d83f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "800af8ae51ef43819e2ec08200d779e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1024 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gsm8k_orig = load_dataset(\"openai/gsm8k\", \"main\")\n",
    "new_split = gsm8k_orig[\"train\"].train_test_split(test_size=1024)\n",
    "gsm8k_orig[\"validation\"] = new_split[\"test\"]\n",
    "gsm8k_orig[\"train\"] = new_split[\"train\"]\n",
    "gsm8k_orig.save_to_disk(\"var/gsm8k_split\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsm8k = load_from_disk(\"var/gsm8k_split\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ed5204e69614c37a994232992e72566",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6449 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b9c110adc2ec489d8e13624bc33a78d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f5370331ff04096acc52d987fa63100",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1024 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def parse_answers(row):\n",
    "    question = row[\"question\"].strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    parts = row[\"answer\"].split(\"####\")\n",
    "    answer = parse_number(parts[-1])\n",
    "    reasoning = \"####\".join(parts[:-1]).strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    return {\n",
    "        \"question\": question,\n",
    "        \"answer\": answer,\n",
    "        \"reasoning\": reasoning,\n",
    "        \"raw_answer\": row[\"answer\"],\n",
    "        \"answer_part\": parts[-1],\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k = gsm8k.map(parse_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28b6b6d20f1d474c90d5375403b1673c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6449 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def react_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip()}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\n",
    "                    \"thought\": f\"{thought.strip().replace('  ', ' ')}. I need to calculate {exp}\"\n",
    "                },\n",
    "                {\"action\": '{\"name\": \"Calculator\", \"arguments\": {\"expr\": \"' + f\"{exp}\" +'\"}}'}, #Calculator[{exp}]\"},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "    if next(iter(trajectory[-1].keys())) == \"observation\":\n",
    "        trajectory.append({\"thought\": f\"The answer is {answer}\"})\n",
    "\n",
    "    trajectory.append({\"action\":\n",
    "                       '{\"name\": \"Finish\", \"arguments\": {\"topic\": \"' + f\"{answer}\" + '\"}}'\n",
    "                       })\n",
    "                       #f\"Finish[{answer}]\"\n",
    "\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\n",
    "        \"traj_keys\": traj_keys,\n",
    "        \"traj_values\": traj_values,\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(react_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4f71291b5df74749873f54830f085eb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6449 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def rewoo_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip().replace(\"  \", \" \")}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\"thought\": f\"{thought.strip().replace('  ', ' ')}. Calculate {exp}\"},\n",
    "                {\"action\": '{\"name\": \"Calculator\", \"arguments\": {\"expr\": \"' + f\"{exp}\" +'\"}}'},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "\n",
    "    evidence_counter = 0\n",
    "    for i in range(len(trajectory)):\n",
    "        outer = trajectory[i]\n",
    "        type_event = next(iter(outer.keys()))\n",
    "        value = next(iter(outer.values()))\n",
    "\n",
    "        if type_event == \"action\":\n",
    "            evidence_counter += 1\n",
    "        if type_event == \"observation\":\n",
    "            for j in range(i + 1, len(trajectory)):\n",
    "                inner = trajectory[j]\n",
    "                inner_type_event = next(iter(inner.keys()))\n",
    "                if inner_type_event == \"action\":\n",
    "                    trajectory[j][\"action\"] = trajectory[j][\"action\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "                elif inner_type_event == \"thought\":\n",
    "                    trajectory[j][\"thought\"] = trajectory[j][\"thought\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\"rewoo_traj_keys\": traj_keys, \"rewoo_traj_values\": traj_values}\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(rewoo_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0407259bc3d54b1f980d5b973e962cc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/6449 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "776484fa190743e8ae87783dcbc92c63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d501e914bc54131bc0f5b6da88e7f51",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1024 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gsm8k.save_to_disk(\"var/gsm8k_proc_json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['question', 'answer', 'reasoning', 'raw_answer', 'answer_part', 'traj_keys', 'traj_values', 'rewoo_traj_keys', 'rewoo_traj_values'],\n",
       "        num_rows: 6449\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['question', 'answer', 'reasoning', 'raw_answer', 'answer_part'],\n",
       "        num_rows: 1319\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['question', 'answer', 'reasoning', 'raw_answer', 'answer_part'],\n",
       "        num_rows: 1024\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_from_disk\n",
    "ds = load_from_disk(\"var/gsm8k_proc_json\")\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "notebook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
