{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Process Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_q(text):\n",
    "    l = len(text)\n",
    "    return [char_to_idx[c] for c in text] + [eos_token] + [empty_token] * (max_q_len - l - 1)\n",
    "\n",
    "def tokenize_a(text):\n",
    "    l = len(text)\n",
    "    token_a = [char_to_idx[c] for c in text]\n",
    "    token_a = [start_token] + token_a + [eos_token] + [empty_token] * (max_a_len - l - 1)\n",
    "    token_target = token_a[:-1]\n",
    "    token_label = token_a[1:]\n",
    "    return token_target, token_label\n",
    "\n",
    "def invert_tokenization(idx):\n",
    "    return [idx_to_char[i] for i in idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_qa(filename):\n",
    "    with open(f'{data_path}/{tr_l}/{task}.txt') as f:\n",
    "        text = f.read().splitlines()\n",
    "        quess = text[::2]\n",
    "        anss = text[1::2]\n",
    "\n",
    "    return quess, anss\n",
    "\n",
    "def tokenize_qa(quess, anss):\n",
    "    tokenized_source = [tokenize_q(text) for text in tqdm(quess)]\n",
    "\n",
    "    tokenized_target, tokenized_label = [], []\n",
    "    for text in tqdm(anss):\n",
    "        tt, tl = tokenize_a(text)\n",
    "        tokenized_target.append(tt)\n",
    "        tokenized_label.append(tl)\n",
    "\n",
    "    tokenized_source = torch.tensor(tokenized_source)\n",
    "    tokenized_target = torch.tensor(tokenized_target)\n",
    "    tokenized_label = torch.tensor(tokenized_label)\n",
    "\n",
    "    return tokenized_source, tokenized_target, tokenized_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_ds(fname):\n",
    "    quess, anss = load_qa(fname)\n",
    "\n",
    "    tokenized_source, tokenized_target, tokenized_label = tokenize_qa(quess, anss)\n",
    "\n",
    "    return torch.utils.data.TensorDataset(tokenized_source, tokenized_target, tokenized_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../../data/math/mathematics_dataset-v1.0'\n",
    "out_path = 'tokenized_data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('text_vectorizer/vocabulary.txt') as f:\n",
    "    vocab = f.read().splitlines()\n",
    "\n",
    "idx_to_char = {i: c for i, c in enumerate(vocab)}\n",
    "char_to_idx = {c: i for i, c in enumerate(vocab)}\n",
    "\n",
    "empty_token = char_to_idx['']\n",
    "eos_token = char_to_idx[';']\n",
    "start_token = char_to_idx['@']\n",
    "\n",
    "max_q_len, max_a_len = 161, 31"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process Task data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = 'polynomials__expand'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_levels = ['train-easy', 'train-medium', 'train-hard']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 666666/666666 [00:04<00:00, 154544.75it/s]\n",
      "100%|██████████| 666666/666666 [00:04<00:00, 160722.16it/s]\n",
      "100%|██████████| 666666/666666 [00:03<00:00, 169745.24it/s]\n",
      "100%|██████████| 666666/666666 [00:03<00:00, 191003.56it/s]\n",
      "100%|██████████| 666666/666666 [00:04<00:00, 159092.56it/s]\n",
      "100%|██████████| 666666/666666 [00:03<00:00, 191631.62it/s]\n",
      "100%|██████████| 3/3 [01:05<00:00, 21.85s/it]\n"
     ]
    }
   ],
   "source": [
    "tss, tts, tls = [], [], []\n",
    "\n",
    "for tr_l in tqdm(train_levels):\n",
    "    fname = f'{data_path}/{tr_l}/{task}.txt'\n",
    "\n",
    "    quess, anss = load_qa(fname)\n",
    "\n",
    "    tokenized_source, tokenized_target, tokenized_label = tokenize_qa(quess, anss)\n",
    "    tss.append(tokenized_source)\n",
    "    tts.append(tokenized_target)\n",
    "    tls.append(tokenized_label)\n",
    "\n",
    "    del quess, anss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts = torch.concat(tss)\n",
    "tt = torch.concat(tts)\n",
    "tl = torch.concat(tls)\n",
    "\n",
    "del tss, tts, tls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = torch.utils.data.TensorDataset(ts, tt, tl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(train_ds, f'{out_path}/{task}_train.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 666666/666666 [00:04<00:00, 141784.27it/s]\n",
      "100%|██████████| 666666/666666 [00:03<00:00, 183483.86it/s]\n",
      "100%|██████████| 666666/666666 [00:04<00:00, 151822.61it/s]\n",
      "100%|██████████| 666666/666666 [00:03<00:00, 190589.63it/s]\n"
     ]
    }
   ],
   "source": [
    "interpolate_ds = create_ds(f'{data_path}/interpolate/{task}.txt')\n",
    "torch.save(interpolate_ds, f'{out_path}/{task}_interpolate.pt')\n",
    "interpolate_ds = create_ds(f'{data_path}/extrapolate/{task}.txt')\n",
    "torch.save(interpolate_ds, f'{out_path}/{task}_extrapolate.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
