{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pdl.optimize.mbpp_dataset import MBPPDataset\n",
    "\n",
    "ds = MBPPDataset()\n",
    "for k,v in ds.items():\n",
    "    print(k, len(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk\n",
    "from datasets import load_dataset, concatenate_datasets\n",
    "from tqdm.autonotebook import tqdm\n",
    "from pprint import pprint\n",
    "import re\n",
    "from evalplus.data import get_mbpp_plus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_plus = get_mbpp_plus()\n",
    "\n",
    "mbpp = load_dataset(\"google-research-datasets/mbpp\", name=\"full\")\n",
    "\n",
    "# mbpp_out = mbpp.filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" not in mbpp_plus,\n",
    "# )\n",
    "\n",
    "# mbpp_in = mbpp.filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" in mbpp_plus,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 374\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 500\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 90\n",
       "    })\n",
       "    prompt: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 10\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mbpp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "378"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(mbpp_plus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list', 'react_prompt', 'traj_keys', 'traj_values'],\n",
       "        num_rows: 374\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list', 'react_prompt', 'traj_keys', 'traj_values'],\n",
       "        num_rows: 224\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list', 'react_prompt', 'traj_keys', 'traj_values'],\n",
       "        num_rows: 39\n",
       "    })\n",
       "    prompt: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list', 'react_prompt', 'traj_keys', 'traj_values'],\n",
       "        num_rows: 10\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "load_from_disk(\"var/mbpp_trajectified\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 266\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 276\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 51\n",
       "    })\n",
       "    prompt: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 3\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mbpp_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 108\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 224\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 39\n",
       "    })\n",
       "    prompt: Dataset({\n",
       "        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],\n",
       "        num_rows: 7\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mbpp_in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'task_id': 602,\n",
       " 'text': 'Write a python function to find the first repeated character in a given string.',\n",
       " 'code': 'def first_repeated_char(str1):\\r\\n  for index,c in enumerate(str1):\\r\\n    if str1[:index+1].count(c) > 1:\\r\\n      return c \\r\\n  return \"None\"',\n",
       " 'test_list': ['assert first_repeated_char(\"abcabc\") == \"a\"',\n",
       "  'assert first_repeated_char(\"abc\") == \"None\"',\n",
       "  'assert first_repeated_char(\"123123\") == \"1\"'],\n",
       " 'test_setup_code': '',\n",
       " 'challenge_test_list': []}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mbpp_in[\"train\"][0]#[\"task_id\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'task_id': 'Mbpp/602',\n",
       " 'prompt': '\"\"\"\\nWrite a python function to find the first repeated character in a given string.\\nassert first_repeated_char(\"abcabc\") == \"a\"\\n\"\"\"\\n',\n",
       " 'entry_point': 'first_repeated_char',\n",
       " 'canonical_solution': '\\ndef first_repeated_char(str1):\\n  for index, c in enumerate(str1):\\n    if str1[:index + 1].count(c) > 1:\\n      return c\\n  return None\\n',\n",
       " 'base_input': [['abcabc'], ['abc'], ['123123']],\n",
       " 'atol': 0,\n",
       " 'plus_input': [[''],\n",
       "  ['abcdefghijklmnopqrstuvwxyz'],\n",
       "  ['abcabcxyz'],\n",
       "  ['ABCaBC'],\n",
       "  ['a'],\n",
       "  ['aaa'],\n",
       "  ['aaaabcabcxyz'],\n",
       "  ['aaaaaaabcabcxyzz'],\n",
       "  ['abcdefghinjklmnopqrstuvwxyz'],\n",
       "  ['aaaa'],\n",
       "  ['aaaaaaabcabcxyzzaaaa'],\n",
       "  ['aaaaa'],\n",
       "  ['aa'],\n",
       "  ['aaaaaaaabcabcxyzzaaaa'],\n",
       "  ['aaaaaaaabcabcxyzzaaaaABCaaBC'],\n",
       "  ['aaaaaaaabcabaaaaaaabcabcxyzzcxyzzaaaa'],\n",
       "  ['aaaabcabcabcdefghinjklmnopqrstuvwxyzxyz'],\n",
       "  ['aaaaaaabcabcxyzzaaaaa'],\n",
       "  ['aaaabcabcabcdABCaBCeafghinjklmnopqrstuvwxyzxyz'],\n",
       "  ['abcdefgxhinjklmnopqrstuvwxyz'],\n",
       "  ['aaaaaaabcabaaacxyzzaaaa'],\n",
       "  ['aaaaaaabcabacxyzzaaaa'],\n",
       "  ['aaaaaaaaa'],\n",
       "  ['aaaabcabcabcdABCaBaCeafghinjklmnopqrstuvwxyzxyz'],\n",
       "  ['aaaaaaabaaaaaaaa'],\n",
       "  ['abccdefghinjzklabcdefghijklmnopqrstuvwxyzmnopxyz'],\n",
       "  ['aaaaaaabcabcABCaBCxyzzaaaa'],\n",
       "  ['aaaaaaabcabcABCaBCzaaaa'],\n",
       "  ['aaaaazcxyzzaaaa'],\n",
       "  ['abcdefgxaahinjklmnopqrsaaaaaaabcabaaacxyzzaaaatuvwxyz'],\n",
       "  ['ababcdefghinjklmnopqrstuvwxyzcdefghinjklmnopqrstuvwxyz'],\n",
       "  ['abcdefghijklmnopqrstuvwvz'],\n",
       "  ['abcdefgxhzinjklmnopqrstuvwxyz'],\n",
       "  ['aaaaaabcabcxyz'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyz'],\n",
       "  ['aaaaaaaabaaaaaaaa'],\n",
       "  ['aaaabcabcABCaBCabcdABCaBCeafghinjklmnopqrstuvwxyzxyz'],\n",
       "  ['abcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyz'],\n",
       "  ['aaaaaaaaaa'],\n",
       "  ['aaaaaaabcabcxyza'],\n",
       "  ['MJ'],\n",
       "  ['abcdefghijkvlmnopqrstuvwvz'],\n",
       "  ['aaaaaaabcabcdefghijklmnopqrstuvwxyzabcABCaBCzaaaa'],\n",
       "  ['abcxyza'],\n",
       "  ['aabcdefghijklmnopqrstuvwvz'],\n",
       "  ['aaaaaa'],\n",
       "  ['aaaaaaaaaaaaaaaabcabcxyzza'],\n",
       "  ['aaaabcxyzzaaaa'],\n",
       "  ['aabcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzaaaaa'],\n",
       "  ['JF'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyznqrsaaaaaaabcabaaacxyzzaaaatuvwxyz'],\n",
       "  ['aabcdefgxaahinjklmcnopaaaaaaabcdefgxhzinjklmnopqrstuvwxyzbcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzaaaaa'],\n",
       "  ['aaaaaaaaaaaaabcabcxyz'],\n",
       "  ['aaaaaaaaabcxyzabcabcxyzzaaaaABCaaBC'],\n",
       "  ['aaaabaaabcabcxyzz'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaaaaaabcabacxaaaatuvwxyz'],\n",
       "  ['abcdefghiuvwvz'],\n",
       "  ['aaaaaaabcabcdefghijklmnaaaaaaabcabcABCaBCzaaaaopqrstuvwxyzabcABCaBCzaaaa'],\n",
       "  ['abcdefgxhlmnopqrstuvwxyz'],\n",
       "  ['aaaaaaabcabcxaaaaaaabcabcxyzzaaaaayzzaaaabmcabcabcdefghinjklmnopqrstuvwxyzxyzaaaaaaaaaaaaaaaabcabcxyzza'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzABCaBCzaaaaaaaabcazbacxaaaatuvwxyz'],\n",
       "  ['aaaaaaabcaaaaaaabcabcABCaBCzaaaazzaaaaa'],\n",
       "  ['aaaaaaaaabcabcxyzzaaaa'],\n",
       "  ['JJF'],\n",
       "  ['abcaaaaaaabcabcxyzzaaaaaabcxyz'],\n",
       "  ['aabcdefgxaahinjklmcnopaaaaaaabcdefgxhzinjklmnopqrstuvwxyzbcabcxyzqrsaaaaaaabcabaaacxyzzaaabcdefgxaahinjklmnopqrsaaaaaaabcabaaacxyzzaaaatuvwxyzxyzaaaaa'],\n",
       "  ['aabcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzaaeaaa'],\n",
       "  ['aaaaaaaaaaaaaabcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaaaaaabcabacxaaaatuvwxyzbcabcxyz'],\n",
       "  ['aaaaaaa'],\n",
       "  ['abcdefgxaahabcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzinjklmnopaaaaaabcabcxyznqrsaaaaaaabcabaaacxyzzaaaatuvwxyz'],\n",
       "  ['aaaaaaaaJJFaa'],\n",
       "  ['aaaaaaabcabcxyzaaa'],\n",
       "  ['aaabcdefghijklmnopqrstuvwxyzaabcabcxyz'],\n",
       "  ['aacaaaaaaabcxyzabcabcxyzzaaaaABCaaBC'],\n",
       "  ['aaaaaaabcabcxyzaaaaaaa'],\n",
       "  ['abcxayza'],\n",
       "  ['aabcdefgxaahinjklmmcnopaaaaaaabcdefgxhzinjklmnopqrstuvwxyzbcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzaaaaa'],\n",
       "  ['Mlszaaabcdefghijklmnopqrstuvwxyzaabcabcxyzvopvu'],\n",
       "  ['MMJ'],\n",
       "  ['aaaaaaabcabcdefghiyzabcABCaBCzaaaa'],\n",
       "  ['aabcaaaaaaabacabcxyzzaaaaaabcxyz'],\n",
       "  ['aaaabcabcabcdefghjklmnopqrstuvwxyzxyz'],\n",
       "  ['aabcabcxyz'],\n",
       "  ['aaaaklmnopqrstuvwxyzxyz'],\n",
       "  ['aaaaaaaabcabcxyzz'],\n",
       "  ['aabcdefgxaahinjklmcnopaaaaaaabcdefgxhzinjklmnopqraaaaaaabcabcxaaaaaaabcabcxyzzaaaaayzzaaaabmcabcabcdefghinaaaaaaabcabcxyzajklmnopqrstuvwxyzxyzaaaaaaaabcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzaaaaa'],\n",
       "  ['MMJaaaaaaabcabcxyzaaa'],\n",
       "  ['abcaaaaaaabcabcxyzzaaaaaabcxbyz'],\n",
       "  ['aaaabcabcablmnopqrstuvwxyzxyz'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaabcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyzxyzABCaBCzaaaaaavwxyz'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaaaaaabcabacxaaaatuvwzxyz'],\n",
       "  ['aaaaaaabaaaaxyzaaa'],\n",
       "  ['aabcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaababcdefghinjklmnopqrstuvwxyzcdefghinjklmnopqrstuvwxyzaaacxyzzaaaatuvwxyzaaaaa'],\n",
       "  ['JJJFaaaabcabcabcdefghjklmnopqrstuvwxyzxyz'],\n",
       "  ['BCaaaabcabcABCaBCabcdABCaBCeafghinjklmnopqrstuvwxyzxyzaBC'],\n",
       "  ['abcdefgxaahabcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatuvwxyabcdefghinjklmnopqrstuvwxyzzinjklmnopaaaaaabcabcxyznqrsaaaaaaabcabaaacxyzzaaaatuvwxyz'],\n",
       "  ['aacaaaaaabcdefghijklmnopqrstuvwvzaaaaABCaaBC'],\n",
       "  ['aaaaaaabcaMMJaaaaaaabcabcxyzaaabcABCaBCxyzzaaa'],\n",
       "  ['aaaabcabcabcdABCaBCeafghinjklmnopqrstuvwxyzaaaaaaaaJJFaaxyz'],\n",
       "  ['aaaabcabcabcdABCaBaCeafghinjklmnaaaaaaabcabcdefghiyzabcABCaBCzaaaa'],\n",
       "  ['aaaacaaaaaaabcxyzabcabcxyzzaaaaABCaaBCaaaaabcabacxyzzaaaa'],\n",
       "  ['abcdefgxaahinjklmnopaaaaaabcabcxyzqrsaaabcdefgxaahinjklmcnopaaaaaabcabcxyzqrsaaaaaaabcabaaacxyzzaaaatugvwxyzxyzABCaBCzaaaaaavwxyz'],\n",
       "  ['MaaaaaaabaaaaxyzaaaMJ'],\n",
       "  ['aaaaaaabcabcdefghijklamnaaaaaaabcabcABCaBCzaaaaopqrstuvwxyzabcABCaBCzaaaa'],\n",
       "  ['abcdefgxhlmnopqrstutvwxyz']],\n",
       " 'contract': '\\n  assert isinstance(str1, str), \"invalid inputs\" # $_CONTRACT_$I\\n',\n",
       " 'assertion': '\\nassert first_repeated_char(\"abcabc\") == \"a\"\\nassert first_repeated_char(\"abc\") == None\\nassert first_repeated_char(\"123123\") == \"1\"\\n'}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mbpp_plus[\"Mbpp/602\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# mbpp[\"train\"] = mbpp[\"train\"].filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" not in mbpp_plus,\n",
    "# )\n",
    "# # .rename_columns(\n",
    "# #             {\"code\": \"canonical_solution\"}\n",
    "# #         )\n",
    "\n",
    "# mbpp[\"test\"] = mbpp[\"test\"].filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" in mbpp_plus,\n",
    "# )\n",
    "\n",
    "# mbpp[\"validation\"] = mbpp[\"validation\"].filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" not in mbpp_plus,\n",
    "# )\n",
    "\n",
    "\n",
    "def trajectify(row):\n",
    "    # - action: |-\n",
    "    #         def similar_elements(test_tup1, test_tup2):\n",
    "    #           res = tuple(set(test_tup1) & set(test_tup2))\n",
    "    #           return res\n",
    "    #         res = similar_elements((3, 4, 5, 6), (5, 7, 4, 10))\n",
    "    #         assert res == (4, 5), \"Expected (4, 5) but got {}\".format(res)\n",
    "    #     - observation: \"[Executed Successfully with No Output]\"\n",
    "    #     - thought: There is no more AssertionError. I can now submit the solution.\n",
    "    #     - solution: |-\n",
    "    #         def similar_elements(test_tup1, test_tup2):\n",
    "    #           res = tuple(set(test_tup1) & set(test_tup2))\n",
    "    #           return res\n",
    "    # Regex pattern to match the assert statement and capture the function call and expected result\n",
    "    task_id = f\"Mbpp/{row['task_id']}\"\n",
    "    code = row[\"code\"].replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\").strip()\n",
    "    first_test = row[\"test_list\"][0].strip().lstrip()\n",
    "    pattern = r\"assert\\s+(\\w+\\(.*?\\))\\s*==\\s*(.+)\"\n",
    "\n",
    "    # Replacement format\n",
    "    replacement = r\"res = \\1\\nassert res == \\2, \\\"Expected \\2 but got {}\\\".format(res)\"\n",
    "\n",
    "    # Perform the substitution\n",
    "    converted_string = (\n",
    "        re.sub(pattern, replacement, first_test)\n",
    "        .replace('\\\\\"Expected ', '\"Expected ')\n",
    "        .replace('{}\\\\\"', '{}\"')\n",
    "    )\n",
    "    code_w_assert = code + \"\\n\" + converted_string.strip()\n",
    "    prompt = row[\"text\"].strip() + \"\\n\" + first_test\n",
    "    # (\n",
    "    #     mbpp_plus[task_id][\"prompt\"]\n",
    "    #     .strip('\"\"\"')\n",
    "    #     .strip()\n",
    "    #     .strip('\"\"\"')\n",
    "    #     .strip()\n",
    "    #     .replace(\"\\n\\nassert\", \"\\nassert\")\n",
    "    # )  # row[\"text\"].strip() + \"\\n\" + first_test\n",
    "    # print(code_w_assert)\n",
    "    # print(\"-----\")\n",
    "    trajectory = [\n",
    "        {\"task\": prompt},\n",
    "        {\n",
    "            \"thought\": \"I should run a solution on the test case before proposing a solution.\"\n",
    "        },\n",
    "        {\"action\": code_w_assert},\n",
    "        {\"observation\": \"[Executed Successfully with No Output]\"},\n",
    "        {\"thought\": \"There is no AssertionError. I can now submit the solution.\"},\n",
    "        {\"solution\": code},\n",
    "    ]\n",
    "\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\n",
    "        # \"prompt\": row[].strip('\"\"\"').strip().strip('\"\"\"').strip()\n",
    "        \"react_prompt\": prompt,\n",
    "        \"code\": code,\n",
    "        \"traj_keys\": traj_keys,\n",
    "        \"traj_values\": traj_values,\n",
    "    }\n",
    "\n",
    "\n",
    "mbpp_trajectified = mbpp.map(trajectify)\n",
    "# mbpp_trajectified.save_to_disk(\"var/mbpp_trajectified\")\n",
    "\n",
    "train_concat = concatenate_datasets(\n",
    "    mbpp.filter(\n",
    "        lambda x: f\"Mbpp/{x['task_id']}\" not in mbpp_plus,\n",
    "    )\n",
    "    # .rename_columns({\"code\": \"canonical_solution\", \"text\": \"prompt\"})\n",
    "    .values()\n",
    ")\n",
    "\n",
    "test_concat = concatenate_datasets(\n",
    "    mbpp.filter(\n",
    "        lambda x: f\"Mbpp/{x['task_id']}\" in mbpp_plus,\n",
    "    )\n",
    "    # .rename_columns({\"code\": \"canonical_solution\", \"text\": \"prompt\"})\n",
    "    .values()\n",
    ")\n",
    "\n",
    "# filt = mbpp.filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" in mbpp_plus,\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk\n",
    "from datasets import load_dataset, concatenate_datasets\n",
    "from tqdm.autonotebook import tqdm\n",
    "from pprint import pprint\n",
    "import re\n",
    "from evalplus.data import get_mbpp_plus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_plus = get_mbpp_plus()\n",
    "mbpp_plus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\n\".join(mbpp_plus.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['task_id', 'prompt', 'entry_point', 'canonical_solution', 'base_input', 'atol', 'plus_input', 'contract', 'assertion'])\n"
     ]
    }
   ],
   "source": [
    "print(mbpp_plus[\"Mbpp/603\"].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Write a function to get all lucid numbers smaller than or equal to a given integer.\n",
      "assert get_ludic(10) == [1, 2, 3, 5, 7]\n"
     ]
    }
   ],
   "source": [
    "print(mbpp_plus[\"Mbpp/603\"][\"prompt\"].strip('\"\"\"')\n",
    "        .strip()\n",
    "        .strip('\"\"\"')\n",
    "        .strip()\n",
    "        .replace(\"\\n\\nassert\", \"\\nassert\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import string\n",
      "\n",
      "def is_pangram(sentence):\n",
      "    alphabet = string.ascii_lowercase\n",
      "    sentence_lower = sentence.lower()\n",
      "    for char in alphabet:\n",
      "        if char not in sentence_lower:\n",
      "            return False\n",
      "    return True\n"
     ]
    }
   ],
   "source": [
    "answer = \"\"\"<solution>\n",
    "\n",
    "```\n",
    "import string\n",
    "\n",
    "def is_pangram(sentence):\n",
    "    alphabet = string.ascii_lowercase\n",
    "    sentence_lower = sentence.lower()\n",
    "    for char in alphabet:\n",
    "        if char not in sentence_lower:\n",
    "            return False\n",
    "    return True\n",
    "```\"\"\"\n",
    "pattern = r\"```(?:python)?\\n(.*?)\\n```\"\n",
    "match = re.search(pattern, answer, re.DOTALL)\n",
    "if match:\n",
    "    print(match.group(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'task_id': 603,\n",
       " 'text': 'Write a function to get a lucid number smaller than or equal to n.',\n",
       " 'code': 'def get_ludic(n):\\r\\n\\tludics = []\\r\\n\\tfor i in range(1, n + 1):\\r\\n\\t\\tludics.append(i)\\r\\n\\tindex = 1\\r\\n\\twhile(index != len(ludics)):\\r\\n\\t\\tfirst_ludic = ludics[index]\\r\\n\\t\\tremove_index = index + first_ludic\\r\\n\\t\\twhile(remove_index < len(ludics)):\\r\\n\\t\\t\\tludics.remove(ludics[remove_index])\\r\\n\\t\\t\\tremove_index = remove_index + first_ludic - 1\\r\\n\\t\\tindex += 1\\r\\n\\treturn ludics',\n",
       " 'test_list': ['assert get_ludic(10) == [1, 2, 3, 5, 7]',\n",
       "  'assert get_ludic(25) == [1, 2, 3, 5, 7, 11, 13, 17, 23, 25]',\n",
       "  'assert get_ludic(45) == [1, 2, 3, 5, 7, 11, 13, 17, 23, 25, 29, 37, 41, 43]'],\n",
       " 'test_setup_code': '',\n",
       " 'challenge_test_list': []}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mbpp[\"train\"][2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp = load_dataset(\"google-research-datasets/mbpp\", name=\"full\")\n",
    "# .filter(\n",
    "#     lambda x: f\"Mbpp/{x['task_id']}\" in mbpp_plus,\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def trajectify(row):\n",
    "    # - action: |-\n",
    "    #         def similar_elements(test_tup1, test_tup2):\n",
    "    #           res = tuple(set(test_tup1) & set(test_tup2))\n",
    "    #           return res\n",
    "    #         res = similar_elements((3, 4, 5, 6), (5, 7, 4, 10))\n",
    "    #         assert res == (4, 5), \"Expected (4, 5) but got {}\".format(res)\n",
    "    #     - observation: \"[Executed Successfully with No Output]\"\n",
    "    #     - thought: There is no more AssertionError. I can now submit the solution.\n",
    "    #     - solution: |-\n",
    "    #         def similar_elements(test_tup1, test_tup2):\n",
    "    #           res = tuple(set(test_tup1) & set(test_tup2))\n",
    "    #           return res\n",
    "    # Regex pattern to match the assert statement and capture the function call and expected result\n",
    "    task_id = f\"Mbpp/{row['task_id']}\"\n",
    "    code = row[\"code\"].replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\").strip()\n",
    "    first_test = row[\"test_list\"][0].strip().lstrip()\n",
    "    pattern = r\"assert\\s+(\\w+\\(.*?\\))\\s*==\\s*(.+)\"\n",
    "\n",
    "    # Replacement format\n",
    "    replacement = r\"res = \\1\\nassert res == \\2, \\\"Expected \\2 but got {}\\\".format(res)\"\n",
    "\n",
    "    # Perform the substitution\n",
    "    converted_string = (\n",
    "        re.sub(pattern, replacement, first_test)\n",
    "        .replace('\\\\\"Expected ', '\"Expected ')\n",
    "        .replace('{}\\\\\"', '{}\"')\n",
    "    )\n",
    "    code_w_assert = code + \"\\n\" + converted_string.strip()\n",
    "    prompt = (\n",
    "        mbpp_plus[task_id][\"prompt\"].strip('\"\"\"').strip().strip('\"\"\"').strip().replace(\"\\n\\nassert\", \"\\nassert\")\n",
    "    )  # row[\"text\"].strip() + \"\\n\" + first_test\n",
    "    # print(code_w_assert)\n",
    "    # print(\"-----\")\n",
    "    trajectory = [\n",
    "        {\"task\": prompt},\n",
    "        {\n",
    "            \"thought\": \"I should run a solution on the test case before proposing a solution.\"\n",
    "        },\n",
    "        {\"action\": code_w_assert},\n",
    "        {\"observation\": \"[Executed Successfully with No Output]\"},\n",
    "        {\"thought\": \"There is no AssertionError. I can now submit the solution.\"},\n",
    "        {\"solution\": code},\n",
    "    ]\n",
    "\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\n",
    "        # \"prompt\": row[].strip('\"\"\"').strip().strip('\"\"\"').strip()\n",
    "        \"react_prompt\": prompt,\n",
    "        \"code\": code,\n",
    "        \"traj_keys\": traj_keys,\n",
    "        \"traj_values\": traj_values,\n",
    "    }\n",
    "\n",
    "\n",
    "mbpp_trajectified = mbpp.map(trajectify)\n",
    "mbpp_trajectified.save_to_disk(\"var/mbpp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_trajectified[\"validation\"][\"task_id\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(mbpp[\"validation\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp[\"validation\"][41]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp[\"validation\"].select([41])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(range(0,len(mbpp[\"validation\"])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_original = load_dataset(\"google-research-datasets/mbpp\", name=\"full\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_plus[\"\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_plus = get_mbpp_plus()\n",
    "train = concatenate_datasets(\n",
    "                    mbpp_original.filter(\n",
    "                        lambda x: f\"Mbpp/{x['task_id']}\" not in mbpp_plus,\n",
    "                    )\n",
    "                    .rename_columns({\"code\": \"canonical_solution\", \"text\": \"prompt\"})\n",
    "                    .values(),\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filt = mbpp_original.filter(\n",
    "                        lambda x: f\"Mbpp/{x['task_id']}\" in mbpp_plus,\n",
    "                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp_original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mbpp.filter(lambda x: \"similar_elements\" in x[\"code\"])[\"prompt\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_answers(row):\n",
    "    question = row[\"question\"].strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    parts = row[\"answer\"].split(\"####\")\n",
    "    answer = parse_number(parts[-1])\n",
    "    reasoning = \"####\".join(parts[:-1]).strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    return {\n",
    "        \"question\": question,\n",
    "        \"answer\": answer,\n",
    "        \"reasoning\": reasoning,\n",
    "        \"raw_answer\": row[\"answer\"],\n",
    "        \"answer_part\": parts[-1],\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k = gsm8k.map(parse_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def react_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip()}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\n",
    "                    \"thought\": f\"{thought.strip().replace('  ', ' ')}. I need to calculate {exp}\"\n",
    "                },\n",
    "                {\"action\": f\"Calculator[{exp}]\"},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "    if next(iter(trajectory[-1].keys())) == \"observation\":\n",
    "        trajectory.append({\"thought\": f\"The answer is {answer}\"})\n",
    "\n",
    "    trajectory.append({\"action\": f\"Finish[{answer}]\"})\n",
    "\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\n",
    "        \"traj_keys\": traj_keys,\n",
    "        \"traj_values\": traj_values,\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(react_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rewoo_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip().replace(\"  \", \" \")}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\"thought\": f\"{thought.strip().replace('  ', ' ')}. Calculate {exp}\"},\n",
    "                {\"action\": f\"Calculator[{exp}]\"},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "\n",
    "    evidence_counter = 0\n",
    "    for i in range(len(trajectory)):\n",
    "        outer = trajectory[i]\n",
    "        type_event = next(iter(outer.keys()))\n",
    "        value = next(iter(outer.values()))\n",
    "\n",
    "        if type_event == \"action\":\n",
    "            evidence_counter += 1\n",
    "        if type_event == \"observation\":\n",
    "            for j in range(i + 1, len(trajectory)):\n",
    "                inner = trajectory[j]\n",
    "                inner_type_event = next(iter(inner.keys()))\n",
    "                if inner_type_event == \"action\":\n",
    "                    trajectory[j][\"action\"] = trajectory[j][\"action\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "                elif inner_type_event == \"thought\":\n",
    "                    trajectory[j][\"thought\"] = trajectory[j][\"thought\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\"rewoo_traj_keys\": traj_keys, \"rewoo_traj_values\": traj_values}\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(rewoo_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsm8k.save_to_disk(\"var/gsm8k_proc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf var/gsm8k_proc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_from_disk(\"var/gsm8k_proc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sympy tool\n",
    "\n",
    "Example:\n",
    "Let x be the cost of the pencil.\n",
    "If the pen costs 2 times the cost of the pencil, then it costs 2x.\n",
    "Adding the cost of the pen and pencil we get 2x + x = 3x\n",
    "Since the total cost is $6 then 3x = $6 therefore x = $6 / 3 = $2\n",
    "One pen is equal to 2 * x which is 2 * $2 = $4\n",
    "\n",
    "Use symbolic calculator?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?\n",
    "Tho: To make a robe, you need 2 bolts of blue fiber. I need to write 2\n",
    "Act: Write[2]\n",
    "Obs: Invalid action. Valid actions are Calculator[<expression>] and Finish[<answer>].\n",
    "Tho: You also need half as many bolts of white fiber. I need to calculate 2 / 4\n",
    "Act: Calculator[2 / 4]\n",
    "Obs: 0.5\n",
    "Tho: Thus, you need 0.5 bolts of white fiber. I need to write 0.5\n",
    "Act: Write[0.5]\n",
    "Obs: Invalid action. Valid actions are Calculator[<expression>] and Finish[<answer>].\n",
    "Tho: The answer is 2.5\n",
    "Act: Finish[2.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in gsm8kk[\"train\"]:\n",
    "    if \"The moon\" in x[\"question\"]:\n",
    "        pprint(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# trajectory bootstrapping\n",
    "\n",
    "manual conversion of ~5 examples\n",
    "\n",
    "Cot question/reasoning/answer\n",
    "ReAct: question/thoughts/observations/answer\n",
    "\n",
    "Cot question/reasoning/answer query\n",
    "get React trajectory, if answer matches groundtruth, use, otherwise resample\n",
    "- can add more examples to prompt to improve results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsm8k = load_dataset(\"openai/gsm8k\", \"main\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_answers(row):\n",
    "    question = row[\"question\"].strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    parts = row[\"answer\"].split(\"####\")\n",
    "    answer = parse_number(parts[-1])\n",
    "    reasoning = \"####\".join(parts[:-1]).strip().replace(\"’\", \"'\").replace(\"  \", \" \")\n",
    "    return {\n",
    "        \"question\": question,\n",
    "        \"answer\": answer,\n",
    "        \"reasoning\": reasoning,\n",
    "        \"raw_answer\": row[\"answer\"],\n",
    "        \"answer_part\": parts[-1],\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k = gsm8k.map(parse_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def react_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip()}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\n",
    "                    \"thought\": f\"{thought.strip().replace('  ', ' ')}. I need to calculate {exp}\"\n",
    "                },\n",
    "                {\"action\": '{\"name\": \"Calc\", \"arguments\": {\"expr\": \"' + f\"{exp}\" +'\"}}'}, #Calculator[{exp}]\"},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "    if next(iter(trajectory[-1].keys())) == \"observation\":\n",
    "        trajectory.append({\"thought\": f\"The answer is {answer}\"})\n",
    "\n",
    "    trajectory.append({\"action\":\n",
    "                       '{\"name\": \"Finish\", \"arguments\": {\"topic\": \"' + f\"{answer}\" + '\"}}'\n",
    "                       })\n",
    "                       #f\"Finish[{answer}]\"\n",
    "\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\n",
    "        \"traj_keys\": traj_keys,\n",
    "        \"traj_values\": traj_values,\n",
    "    }\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(react_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rewoo_trajectory(row):\n",
    "    question = row[\"question\"]\n",
    "    answer = row[\"answer\"]\n",
    "    reasoning = row[\"reasoning\"].splitlines()\n",
    "    trajectory = [{\"question\": question.strip().replace(\"  \", \" \")}]\n",
    "    res = answer\n",
    "\n",
    "    for line in reasoning:\n",
    "        pattern = (\n",
    "            r\"(?P<pre>(=(\\ )?|equals(\\ )?)?(\\$)?)<<(?P<exp>.*?)=(?P<res>.*?)>>([^\\s]*)\"\n",
    "        )\n",
    "        expressions = re.search(pattern, line)\n",
    "\n",
    "        if expressions is None:\n",
    "            trajectory += [\n",
    "                {\"thought\": line.strip().replace(\"  \", \" \")},\n",
    "            ]\n",
    "        else:\n",
    "            thought = re.sub(pattern, \"\", line)\n",
    "            thought = thought.rstrip(\".\").rstrip(\",\")\n",
    "            exp = expressions.group(\"exp\").strip()\n",
    "            res = expressions.group(\"res\").strip()\n",
    "\n",
    "            trajectory += [\n",
    "                {\"thought\": f\"{thought.strip().replace('  ', ' ')}. Calculate {exp}\"},\n",
    "                {\"action\": '{\"name\": \"Calc\", \"arguments\": {\"expr\": \"' + f\"{exp}\" +'\"}}'},\n",
    "                {\"observation\": res},\n",
    "            ]\n",
    "\n",
    "    evidence_counter = 0\n",
    "    for i in range(len(trajectory)):\n",
    "        outer = trajectory[i]\n",
    "        type_event = next(iter(outer.keys()))\n",
    "        value = next(iter(outer.values()))\n",
    "\n",
    "        if type_event == \"action\":\n",
    "            evidence_counter += 1\n",
    "        if type_event == \"observation\":\n",
    "            for j in range(i + 1, len(trajectory)):\n",
    "                inner = trajectory[j]\n",
    "                inner_type_event = next(iter(inner.keys()))\n",
    "                if inner_type_event == \"action\":\n",
    "                    trajectory[j][\"action\"] = trajectory[j][\"action\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "                elif inner_type_event == \"thought\":\n",
    "                    trajectory[j][\"thought\"] = trajectory[j][\"thought\"].replace(\n",
    "                        value, f\"#E{evidence_counter}\"\n",
    "                    )\n",
    "    traj_keys = [next(iter(t.keys())) for t in trajectory]\n",
    "    traj_values = [next(iter(t.values())) for t in trajectory]\n",
    "\n",
    "    return {\"rewoo_traj_keys\": traj_keys, \"rewoo_traj_values\": traj_values}\n",
    "\n",
    "\n",
    "gsm8k[\"train\"] = gsm8k[\"train\"].map(rewoo_trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsm8k.save_to_disk(\"var/gsm8k_proc_json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pdlnew",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
