{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# HOME_DIR='/home/jovyan/USR'\n",
    "HOME_DIR='/home/USR'\n",
    "\n",
    "import os\n",
    "# os.chdir('/home/jovyan/USR/data/test_time_gd/')\n",
    "os.chdir(f'{HOME_DIR}/data/test_time_gd/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get babi tasks\n",
    "# wget https://github.com/booydar/babilong/raw/refs/heads/main/data/tasks_1-20_v1-2.zip\n",
    "# unzip tasks_1-20_v1-2.zip\n",
    "\n",
    "# and follow instructions in https://github.com/booydar/babilong/tree/main/data\n",
    "# change task to en-valid-10k, to have train/valid/test splits\n",
    "# and run for each split name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "data_path = Path('/home/jovyan/USR/data/babilong/data/generated_tasks')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['qa1', 'qa2', 'qa3', 'qa4', 'qa5']\n",
    "lengths = ['0k']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qa1 0k\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 8999\n",
      "    })\n",
      "    valid: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "})\n",
      "./data/babilong_qa1_0k\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bfcca07bb958423ab3415fc27e2b1956",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/8999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aae8ea9b89ab4a31b8970b0c2b9fc290",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3cb25129b17a48b5ada88b181574b9ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qa2 0k\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 8999\n",
      "    })\n",
      "    valid: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "})\n",
      "./data/babilong_qa2_0k\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2b9db61f26f4010b003457405146849",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/8999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0dd1010cbb2b49738f9d0029f3e4f8c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9709239eb14c480c8e01a70b9d05d365",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qa3 0k\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 8999\n",
      "    })\n",
      "    valid: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "})\n",
      "./data/babilong_qa3_0k\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "75c38559bb744366b7be001b1c8b480f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/8999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d4b34cca60447c8bf2423dbca2157d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4028bb94e4724ddcbe29e54cff11e68f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qa4 0k\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 8999\n",
      "    })\n",
      "    valid: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "})\n",
      "./data/babilong_qa4_0k\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15655c614a144cf7a1fe73ccef2df4de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/8999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d30d9740db1e4e288b06987b2de9dae0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aebf6f86f2ff4870a768faa44e6d500b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qa5 0k\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 8999\n",
      "    })\n",
      "    valid: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['context', 'query', 'target'],\n",
      "        num_rows: 999\n",
      "    })\n",
      "})\n",
      "./data/babilong_qa5_0k\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a18439927a7e448b87b4b7c5b478317d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/8999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "876b48ad4c524301bd0eb3ff2800d3a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e418e7c5b934cc6932078278755cc19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/999 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import json\n",
    "from datasets import Dataset, DatasetDict\n",
    "\n",
    "for length in lengths:\n",
    "    for task in tasks:\n",
    "        print(f'{task} {length}')\n",
    "        path = data_path / task\n",
    "        data = {'train': [], 'valid': [], 'test': []}\n",
    "        for split in data.keys():\n",
    "            with open(path / f'{length}_{split}.json', 'r') as f:\n",
    "                d = json.load(f)\n",
    "                for sample in d:\n",
    "                        # add spaces such that context + query + target is a good looking text sequence\n",
    "                        # user can clean it on its end\n",
    "                        data[split] += [{\n",
    "                            'context': sample['input'].strip() + ' ',\n",
    "                            'query': sample['question'].strip() + ' ',\n",
    "                            'target': sample['target'].strip(),\n",
    "                            }]\n",
    "            data[split] = Dataset.from_list(data[split])\n",
    "        dataset = DatasetDict(data)\n",
    "        print(dataset)\n",
    "        save_path = f'./data/babilong_{task}_{length}'\n",
    "        print(save_path)\n",
    "        dataset.save_to_disk(save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8241 \"Who\"\n",
      "2921 \" gave\"\n",
      "262 \" the\"\n",
      "17180 \" apple\"\n",
      "284 \" to\"\n",
      "5502 \" Jeff\"\n",
      "30 \"?\"\n",
      "220 \" \"\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "\n",
    "for t_id in tokenizer(sample['question']).input_ids:\n",
    "    print(f'{t_id} \"{tokenizer.decode(t_id)}\"')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py311_pt2.6_cu12.4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
