{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "\n",
    "random.seed(0)\n",
    "\n",
    "def generate_sample(max_digits,\n",
    "                    num_summands_max,\n",
    "                    num_summands_min,\n",
    "                    with_sol=True,\n",
    "                    with_rationale=True,\n",
    "                    tool_use=False,\n",
    "                    exclusion_set=None,\n",
    "                    operators=['+']):\n",
    "    while True:\n",
    "        list_of_digits = [random.randint(0, 10**max_digits - 1) for _ in range(random.randint(num_summands_min, num_summands_max))]\n",
    "        str_query_core = f'{list_of_digits[0]}'\n",
    "        for digit in list_of_digits[1:]:\n",
    "            str_query_core += f' {random.choice(operators)} {digit}'\n",
    "        int_sol = eval(str_query_core)\n",
    "        str_query = f'Question: {str_query_core} = ? Answer:'\n",
    "        if str_query not in exclusion_set:\n",
    "            break\n",
    "    list_query_core = str_query_core.split()\n",
    "    stmts = []\n",
    "    while len(list_query_core) > 1:\n",
    "        partial_result = eval(' '.join(list_query_core[:3]))\n",
    "        if tool_use:\n",
    "            stmts.append(' '.join(list_query_core[:3] + ['=']))\n",
    "        else:\n",
    "            stmts.append(' '.join(list_query_core[:3] + ['=', str(partial_result)]))\n",
    "        list_query_core = [str(partial_result)] + list_query_core[3:]\n",
    "    assert int(list_query_core[0]) == int_sol\n",
    "    str_rationale = ', '.join(stmts)\n",
    "    str_sol = f'. The answer is {int_sol}.'\n",
    "    res = {'str_query': str_query, 'num_sol': int_sol}\n",
    "    if with_rationale:\n",
    "        res['str_rationale'] = str_rationale\n",
    "    if with_sol:\n",
    "        res['str_sol'] = str_sol\n",
    "    return res\n",
    "\n",
    "num_samples_train = 50\n",
    "num_samples_hw = 250\n",
    "num_samples_test = 500\n",
    "\n",
    "operators = ['+', '-']\n",
    "max_digit_len = 1\n",
    "\n",
    "for num_summands in [3]:\n",
    "    exclusion_set = set()\n",
    "\n",
    "    data_json = []\n",
    "    for _ in range(num_samples_train):\n",
    "        data_json.append(generate_sample(max_digit_len,\n",
    "                                         num_summands_max=4,\n",
    "                                         num_summands_min=3,\n",
    "                                         with_rationale=True,\n",
    "                                         with_sol=True,\n",
    "                                         tool_use=True,\n",
    "                                         exclusion_set=exclusion_set,\n",
    "                                         operators=operators))\n",
    "        exclusion_set.add(data_json[-1]['str_query'])\n",
    "    with open(f'data/arithmetic_with_tool/{max_digit_len}digit_34_op{\"\".join(operators)}_train_{num_samples_train}.json', 'w') as f:\n",
    "        json.dump(data_json, f, indent=4)\n",
    "\n",
    "    data_json = []\n",
    "    for _ in range(num_samples_hw):\n",
    "        data_json.append(generate_sample(max_digit_len,\n",
    "                                         num_summands_max=4,\n",
    "                                         num_summands_min=3,\n",
    "                                         with_rationale=False,\n",
    "                                         with_sol=True,\n",
    "                                         exclusion_set=exclusion_set,\n",
    "                                         operators=['+', '-']))\n",
    "        exclusion_set.add(data_json[-1]['str_query'])\n",
    "    with open(f'data/arithmetic_with_tool/{max_digit_len}digit_34_op{\"\".join(operators)}_hw.json', 'w') as f:\n",
    "        json.dump(data_json, f, indent=4)\n",
    "\n",
    "    data_json = []\n",
    "    for _ in range(num_samples_test):\n",
    "        data_json.append(generate_sample(max_digit_len,\n",
    "                                         num_summands_max=5,\n",
    "                                         num_summands_min=5,\n",
    "                                         with_rationale=False,\n",
    "                                         with_sol=False,\n",
    "                                         exclusion_set=exclusion_set,\n",
    "                                         operators=['+', '-']))\n",
    "    with open(f'data/arithmetic_with_tool/{max_digit_len}digit_5_op{\"\".join(operators)}_test.json', 'w') as f:\n",
    "        json.dump(data_json, f, indent=4)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lmnew",
   "language": "python",
   "name": "lmnew"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
