{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "os.chdir('/home/jovyan/USR/data/test_time_gd/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kv_dataset_utils import generate_sequence, get_extra_chars, BASE_KV_ALPHABET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'kv_pairs': ['!Ы😶:ü⑤!', '!ÌÂ:ü9!', '!Ⓑ⑦:😘⑮!', '!⑲°:⓴Й!'],\n",
       " 'segment_ids_to_kv_ids': {0: [0, 1, 2, 3]},\n",
       " 'context': '!⑲°:⓴Й!!Ⓑ⑦:😘⑮!!ÌÂ:ü9!!Ы😶:ü⑤!|',\n",
       " 'query': '?!⑲°:',\n",
       " 'input_sequence': '!⑲°:⓴Й!!Ⓑ⑦:😘⑮!!ÌÂ:ü9!!Ы😶:ü⑤!|?!⑲°:',\n",
       " 'target': '⓴Й!|'}"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_kv_pairs = 4\n",
    "k_length = 2\n",
    "v_length = 2\n",
    "n_segments = 1\n",
    "min_segment_len = 0\n",
    "max_segment_len = 0\n",
    "\n",
    "kv_vocab_size = 512\n",
    "\n",
    "kv_alphabet = BASE_KV_ALPHABET + get_extra_chars(kv_vocab_size)\n",
    "\n",
    "sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                           min_segment_len, max_segment_len, kv_alphabet)\n",
    "\n",
    "sample\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KV pairs length: 28\n",
      "~Total length: 0.0\n"
     ]
    }
   ],
   "source": [
    "print('KV pairs length:', (k_length + v_length + 3)*num_kv_pairs)\n",
    "print('~Total length:', n_segments * (min_segment_len + max_segment_len)/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'kv_pairs': ['!ЦU:Β▦!', '!😠😗:Ⓔ◅!', '!ì▾:ØP!', '!Д¦:😣ⓘ!'],\n",
       " 'segment_ids_to_kv_ids': {0: [0, 1, 2, 3]},\n",
       " 'context': '!Д¦:😣ⓘ!!ì▾:ØP!!😠😗:Ⓔ◅!!ЦU:Β▦!|',\n",
       " 'query': '?!ì▾:',\n",
       " 'input_sequence': '!Д¦:😣ⓘ!!ì▾:ØP!!😠😗:Ⓔ◅!!ЦU:Β▦!|?!ì▾:',\n",
       " 'target': 'ØP!|'}"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                           min_segment_len, max_segment_len, kv_alphabet)\n",
    "sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1000000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000000/1000000 [00:29<00:00, 34165.69it/s]\n",
      "100%|██████████| 5000/5000 [00:00<00:00, 34647.98it/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import Dataset, DatasetDict\n",
    "from tqdm import tqdm\n",
    "\n",
    "num_samples = 1_000_000\n",
    "\n",
    "data = []\n",
    "for _ in tqdm(range(num_samples), total=num_samples):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                               min_segment_len, max_segment_len, kv_alphabet)\n",
    "    data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['target'],\n",
    "    }]\n",
    "data = Dataset.from_list(data)\n",
    "\n",
    "num_samples = 5_000\n",
    "\n",
    "valid_data = []\n",
    "for _ in tqdm(range(num_samples), total=num_samples):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                               min_segment_len, max_segment_len, kv_alphabet)\n",
    "    valid_data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['target'],\n",
    "    }]\n",
    "valid_data = Dataset.from_list(valid_data)\n",
    "\n",
    "dataset = DatasetDict({'train': data, 'valid': valid_data})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['context', 'query', 'target'],\n",
       "        num_rows: 1000000\n",
       "    })\n",
       "    valid: Dataset({\n",
       "        features: ['context', 'query', 'target'],\n",
       "        num_rows: 5000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N4-K2V2-V512_1M\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6fdb6f0c89c464ca789d12297126769",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1000000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c84acd8d73b5469696baa259f382c5db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if n_segments == 1 and min_segment_len == 0 and max_segment_len == 0:\n",
    "    # no noise dataset\n",
    "    dataset_name = f'N{num_kv_pairs}-K{k_length}V{v_length}-V{kv_vocab_size}_1M'\n",
    "else:\n",
    "    dataset_name = f'N{num_kv_pairs}-K{k_length}V{v_length}-S{n_segments}({min_segment_len}-{max_segment_len})_1M'\n",
    "print(dataset_name)\n",
    "dataset.save_to_disk(f'./data/{dataset_name}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# copy task dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'kv_pairs': [],\n",
       " 'segment_ids_to_kv_ids': {0: []},\n",
       " 'context': 'v42v|',\n",
       " 'query': '?!:',\n",
       " 'input_sequence': 'v42v|?!:',\n",
       " 'target': ''}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from kv_dataset_utils import generate_sequence\n",
    "\n",
    "num_kv_pairs = 0\n",
    "k_length = 4\n",
    "v_length = 4\n",
    "n_segments = 1\n",
    "min_segment_len = 4\n",
    "max_segment_len = 4\n",
    "\n",
    "sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                           min_segment_len, max_segment_len)\n",
    "sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000000/1000000 [00:05<00:00, 168668.54it/s]\n",
      "100%|██████████| 5000/5000 [00:00<00:00, 170832.10it/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import Dataset, DatasetDict\n",
    "from tqdm import tqdm\n",
    "\n",
    "num_samples = 1_000_000\n",
    "\n",
    "data = []\n",
    "for _ in tqdm(range(num_samples), total=num_samples):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments, min_segment_len, max_segment_len)\n",
    "    data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['context'],\n",
    "    }]\n",
    "data = Dataset.from_list(data)\n",
    "\n",
    "num_samples = 5_000\n",
    "\n",
    "valid_data = []\n",
    "for _ in tqdm(range(num_samples), total=num_samples):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments, min_segment_len, max_segment_len)\n",
    "    valid_data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['context'],\n",
    "    }]\n",
    "valid_data = Dataset.from_list(valid_data)\n",
    "\n",
    "dataset = DatasetDict({'train': data, 'valid': valid_data})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb4782e0a5e348aea150178f74fdbdd6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1000000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e43536bce00946c8b7534ff8d472caf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset_name = f'N{num_kv_pairs}-S{n_segments}({min_segment_len}-{max_segment_len})_1M'\n",
    "dataset.save_to_disk(f'./data/{dataset_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.load_from_disk(f'./data/{dataset_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
