{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51716b0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "os.chdir('./test_time_gd/')\n",
    "from kv_dataset_utils import generate_sequence, get_extra_chars, BASE_KV_ALPHABET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88959dc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_kv_pairs = 256\n",
    "k_length = 2\n",
    "v_length = 2\n",
    "n_segments = 1\n",
    "min_segment_len = 0\n",
    "max_segment_len = 0\n",
    "\n",
    "kv_vocab_size = 62\n",
    "\n",
    "kv_alphabet = BASE_KV_ALPHABET + get_extra_chars(kv_vocab_size)\n",
    "\n",
    "sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                           min_segment_len, max_segment_len, kv_alphabet)\n",
    "\n",
    "sample\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a71323",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('KV pairs length:', (k_length + v_length + 3)*num_kv_pairs)\n",
    "print('~Total length:', n_segments * (min_segment_len + max_segment_len)/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a85c80f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                           min_segment_len, max_segment_len, kv_alphabet)\n",
    "sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "756e14c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset, DatasetDict\n",
    "from tqdm import tqdm\n",
    "\n",
    "num_samples = 1_000_000\n",
    "\n",
    "data = []\n",
    "for _ in tqdm(range(num_samples), total=num_samples):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                               min_segment_len, max_segment_len, kv_alphabet)\n",
    "    data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['target'],\n",
    "    }]\n",
    "data = Dataset.from_list(data)\n",
    "\n",
    "num_samples = 5_000\n",
    "\n",
    "valid_data = []\n",
    "for _ in tqdm(range(num_samples), total=num_samples):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                               min_segment_len, max_segment_len, kv_alphabet)\n",
    "    valid_data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['target'],\n",
    "    }]\n",
    "\n",
    "valid_data = Dataset.from_list(valid_data)\n",
    "test_data = []\n",
    "for _ in tqdm(range(num_samples * 2), total=num_samples * 2):\n",
    "    sample = generate_sequence(num_kv_pairs, k_length, v_length, n_segments,\n",
    "                               min_segment_len, max_segment_len, kv_alphabet)\n",
    "    test_data += [{\n",
    "        'context': sample['context'],\n",
    "        'query': sample['query'],\n",
    "        'target': sample['target'],\n",
    "    }]\n",
    "\n",
    "test_data = Dataset.from_list(test_data)\n",
    "\n",
    "dataset = DatasetDict({'train': data, 'valid': valid_data, 'test': test_data})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f435eb96",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.push_to_hub('USR/kv_retrieval', config_name=f\"N{num_kv_pairs}-K{k_length}V{v_length}-V{kv_vocab_size}\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
