{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "# os.chdir(\"/data/your_name/funclm\")\n",
    "\n",
    "from funchub.math import _add_, _subtract_, _multiply_, _divide_, _power_, _sqrt_, _log_, _ln_, \\\n",
    "    _sin_, _cos_, _tan_, _asin_, _acos_, _atan_, _factorial_, _floor_, _ceil_, _round_, _radians_, _degrees_, \\\n",
    "    _exp_, _choose_, _permutate_, _gcd_, _lcm_, _root_, _remainder_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import llama\n",
    "import re\n",
    "\n",
    "# tokenizer = llama.Tokenizer()\n",
    "\n",
    "def show_tokens(text):\n",
    "    tokens = tokenizer.encode(text, bos=True, eos=True)\n",
    "    print(tokens)\n",
    "    print([tokenizer.decode(tok) for tok in tokens])\n",
    "\n",
    "show_tokens(\"The number of permutations of 6 things taken 3 at a time is <permutate>(6,3)=120<eoe>.\")\n",
    "\n",
    "def transform(number):\n",
    "    # add comma to every 3 digits\n",
    "    number_comma = re.sub(r\"(\\d)(?=(\\d\\d\\d)+(?!\\d))\", r\"\\1,\", str(number))\n",
    "    return_list = [number, number_comma]\n",
    "    if number.startswith(\"0.\"):\n",
    "        return_list.append(number[1:])\n",
    "        # still a bug\n",
    "    return return_list\n",
    "\n",
    "print(transform(\"-66666\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "# text = data['train'][1]['answer']\n",
    "\n",
    "def construct_training_data(text):\n",
    "    cur_token_ind = 0\n",
    "    inds = []\n",
    "    tar_number = []\n",
    "    tar_eq = []\n",
    "\n",
    "    while re.search(r\"(<.*?>)(.*?)=(.*?)<eoe>\", text):\n",
    "        res = re.search(r\"(<.*?>)(.*?)=(.*?)<eoe>\", text).group(0)\n",
    "        ind = text.find(res)\n",
    "        inds.append(ind)\n",
    "        tar_eq.append(res)\n",
    "        tar_number.append(re.search(r\"(<.*?>)(.*?)=(.*?)<eoe>\", text).group(3))\n",
    "        text = text.replace(res, \"\")\n",
    "    \n",
    "        \n",
    "\n",
    "    encoding = tokenizer.encode(text, bos=True, eos=True)\n",
    "\n",
    "    start_token_idx = []\n",
    "    end_token_idx = []\n",
    "\n",
    "    for ind, ind_char in enumerate(inds):\n",
    "        for i in range(cur_token_ind, len(encoding)):\n",
    "            cur_decode = tokenizer.decode(encoding[:i])\n",
    "            if cur_decode + \" \" == text[:ind_char] or cur_decode == text[:ind_char]:\n",
    "                if tokenizer.decode(encoding[i]) == \"\":\n",
    "                    continue\n",
    "                \n",
    "                start_token_idx.append(i)\n",
    "                cur_token_ind = i\n",
    "\n",
    "                for j in range(i, len(encoding)):\n",
    "                    if any([tokenizer.decode(encoding[: j]).endswith(number) for number in transform(tar_number[ind])]):\n",
    "                        end_token_idx.append(j)\n",
    "                        cur_token_ind = j\n",
    "                        break\n",
    "    \n",
    "    if len(start_token_idx) == len(end_token_idx) and len(start_token_idx) == len(tar_eq):\n",
    "        return {\n",
    "            \"text\": text,\n",
    "            \"start_token_idx\": start_token_idx,\n",
    "            \"end_token_idx\": end_token_idx,\n",
    "            \"tar_eq\": tar_eq,\n",
    "            \"tar_number\": tar_number\n",
    "        }\n",
    "    else:\n",
    "        print(\"error\")\n",
    "        print(start_token_idx, end_token_idx)\n",
    "        print(tar_eq)\n",
    "        print([tokenizer.decode(encoding[i:j]) for i,j in zip(start_token_idx, end_token_idx)])\n",
    "        print(text)\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': 'so the pH is -log(0.029046244381608112)=- -1.',\n",
       " 'start_token_idx': [31],\n",
       " 'end_token_idx': [33],\n",
       " 'tar_eq': ['<log>(0.029046244381608112)=-1<eoe>'],\n",
       " 'tar_number': ['-1']}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "construct_training_data(\"so the pH is -log(0.029046244381608112)=- <log>(0.029046244381608112)=-1<eoe>-1.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "random.seed(42)\n",
    "\n",
    "\n",
    "def read_jsonl(file_path):\n",
    "    with open(file_path, \"r\") as f:\n",
    "        return [json.loads(line) for line in f.readlines()]\n",
    "\n",
    "# read all the jsonl file under data/ohqa/training_data\n",
    "train_data, dev_data = [], []\n",
    "for file_path in glob.glob(\"data/ohqa/training_data/*.jsonl\"):\n",
    "    data = read_jsonl(file_path)\n",
    "    # shuffle the data\n",
    "    random.shuffle(data)\n",
    "    train_data.extend(data[:47])\n",
    "    dev_data.extend(data[47:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(611, 39)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(dev_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 611/611 [00:00<00:00, 662.38it/s]\n",
      "100%|██████████| 39/39 [00:00<00:00, 635.26it/s]\n"
     ]
    }
   ],
   "source": [
    "final_data = []\n",
    "for data in tqdm(train_data):\n",
    "    prompt = f\"Q: {data['question']}\\nA: {data['answer']}\"\n",
    "    res = construct_training_data(prompt)\n",
    "    if res:\n",
    "        final_data.append(res)\n",
    "\n",
    "for data in tqdm(dev_data):\n",
    "    prompt = f\"Q: {data['question']}\\nA: {data['answer']}\"\n",
    "    res = construct_training_data(prompt)\n",
    "    if res:\n",
    "        final_data.append(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write the final data\n",
    "with open(\"data/ohqa/train_list_fix_log.json\", \"w\") as f:\n",
    "    json.dump(final_data, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "table2text",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
