{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, load_from_disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ds = load_dataset(\"reasoning-machines/gsm-hard\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd383dd928af43938b9f61d8109c7235",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/1319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9a74d0390fe44d1b07862185582241c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1188 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2dbbd21329684ecfa35f453a3979db00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1188 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60433a0e81dd4366a9085d79dfcc6973",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1188 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ds[\"train\"] = ds[\"train\"].add_column(\"index\", list(range(1,len(ds[\"train\"])+1))).filter(lambda x: x[\"index\"] not in indices)\n",
    "ds[\"test\"] = ds[\"train\"]\n",
    "ds[\"validation\"] = ds[\"train\"]\n",
    "ds.save_to_disk(\"var/gsmhard\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'code', 'target', 'index'],\n",
       "        num_rows: 1188\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'code', 'target', 'index'],\n",
       "        num_rows: 1188\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['input', 'code', 'target', 'index'],\n",
       "        num_rows: 1188\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = load_from_disk(\"var/gsmhard\")\n",
    "d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cce2c4b1316d468b960659e4248f4136",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/6449 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d8a6442d1444ecd9e993b6b24143023",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/594 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ba8b6c2b9494aa7b062759f5a8e48e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/594 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['question', 'answer', 'reasoning', 'raw_answer', 'answer_part', 'traj_keys', 'traj_values', 'rewoo_traj_keys', 'rewoo_traj_values'],\n",
       "        num_rows: 6449\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'code', 'target', 'index'],\n",
       "        num_rows: 594\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['input', 'code', 'target', 'index'],\n",
       "        num_rows: 594\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split = d[\"train\"].train_test_split(test_size=0.5, shuffle=False)\n",
    "split[\"validation\"] = split[\"train\"]\n",
    "ds = load_from_disk(\"var/gsm8k_proc_json\")\n",
    "split[\"train\"] = ds[\"train\"]\n",
    "split.save_to_disk(\"var/gsmhard_split\")\n",
    "split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95a3b9143c524d378b13daf031737a60",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/594 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87a0f38a086d4ac88ceb3495ec9eece1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/594 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = d[\"train\"].train_test_split(test_size=0.5, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1818757f9714c6a813aa683e9a04b38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/1319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dss = dss.filter(lambda x: x[\"index\"] not in indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c0dd6ceed6141488902d1570ac9ef18",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1188 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dss.save_to_disk(\"var/gsmhard\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 15, 17, 22, 38, 39, 52, 54, 58, 73, 94, 95, 109, 110, 115, 143, 145, 153, 167, 187, 192, 200, 207, 224, 245, 248, 267, 271, 272, 291, 306, 317, 318, 320, 322, 322, 336, 342, 360, 363, 367, 380, 392, 402, 403, 419, 423, 436, 445, 457, 472, 487, 503, 505, 515, 537, 557, 585, 587, 591, 594, 596, 609, 610, 611, 614, 621, 622, 645, 648, 651, 675, 677, 710, 718, 720, 728, 759, 763, 765, 773, 776, 777, 779, 792, 824, 842, 846, 847, 856, 862, 876, 877, 878, 894, 923, 934, 935, 936, 945, 963, 967, 972, 974, 982, 990, 1016, 1023, 1040, 1045, 1055, 1062, 1067, 1069, 1075, 1091, 1097, 1099, 1108, 1114, 1120, 1143, 1181, 1186, 1211, 1212, 1214, 1218, 1236, 1248, 1300, 1312]\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import re\n",
    "\n",
    "# Load the file\n",
    "file_path = Path(\"gsmhard-incorrect.txt\")\n",
    "content = file_path.read_text()\n",
    "\n",
    "# Find all indices\n",
    "indices = [int(match.group(1)) for match in re.finditer(r\"Index:\\s*(\\d+)\", content)]\n",
    "\n",
    "print(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "132"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "notebook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
