{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import random\n",
    "from statistics import mean, median\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dir = ''"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Alphanumeric filtering of train and test splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_split = [x.strip() for x in open(f'{dataset_dir}/subword-qac/data/aol/full/train.query.txt').readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_query_to_freq = {}\n",
    "for query in train_split:\n",
    "    if query not in orig_query_to_freq:\n",
    "        orig_query_to_freq[query] = 1\n",
    "    else:\n",
    "        orig_query_to_freq[query] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def AlphanumericFilter(input):\n",
    "    input = re.sub(\"[^a-zA-Z0-9]\", \" \", input)\n",
    "    input = re.sub(\"\\s+\", \" \", input)\n",
    "    return input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abc def 123 '"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "AlphanumericFilter(\"abc              def  ______123___\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8862181/8862181 [00:34<00:00, 254808.73it/s]\n"
     ]
    }
   ],
   "source": [
    "processed_query_to_freq = {}\n",
    "for query, freq in tqdm(orig_query_to_freq.items()):\n",
    "    query = AlphanumericFilter(query)\n",
    "    if query.strip() != '':\n",
    "        if query not in processed_query_to_freq:\n",
    "            processed_query_to_freq[query] = freq\n",
    "        else:\n",
    "            processed_query_to_freq[query] += freq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open('./dataset/raw/train.query.alphanumeric_filtered.python.txt', 'w')\n",
    "for query, freq in processed_query_to_freq.items():\n",
    "    f.write(query+'\\t'+str(freq)+'\\n')\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_split = [x.strip() for x in open(f'{dataset_dir}/subword-qac/data/aol/full/test.query.txt').readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_query_to_freq = {}\n",
    "for query in test_split:\n",
    "    if query not in orig_query_to_freq:\n",
    "        orig_query_to_freq[query] = 1\n",
    "    else:\n",
    "        orig_query_to_freq[query] += 1\n",
    "        \n",
    "processed_query_to_freq = {}\n",
    "for query, freq in tqdm(orig_query_to_freq.items()):\n",
    "    query = AlphanumericFilter(query)\n",
    "    if query.strip() != '':\n",
    "        if query not in processed_query_to_freq:\n",
    "            processed_query_to_freq[query] = freq\n",
    "        else:\n",
    "            processed_query_to_freq[query] += freq\n",
    "        \n",
    "f = open('./dataset/raw/test.query.alphanumeric_filtered.python.txt', 'w')\n",
    "for query, freq in processed_query_to_freq.items():\n",
    "    f.write(query+'\\t'+str(freq)+'\\n')\n",
    "f.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Processing alphanumeric filtered files to generate dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SUFFIXES = 10000000\n",
    "max_chars_to_add = 10\n",
    "min_chars_to_add = 1\n",
    "max_suffix_length = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_queries_path = './dataset/raw/train.query.alphanumeric_filtered.python.txt'\n",
    "train_queries_lines = [x.strip() for x in open(train_queries_path).readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_query_to_freq = {}\n",
    "for i in tqdm(range(len(train_queries_lines))):\n",
    "    line = train_queries_lines[i].split('\\t')\n",
    "    try:\n",
    "        query, freq = line[0], int(line[1])\n",
    "    except:\n",
    "        print(line)\n",
    "    if query not in train_query_to_freq:\n",
    "        train_query_to_freq[query] = freq\n",
    "    else:\n",
    "        train_query_to_freq[query] = max(freq, train_query_to_freq[query])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GenerateSuffixesFromQuery(query):\n",
    "    words = query.split()\n",
    "    suffixes = []\n",
    "    for i in range(len(words)):\n",
    "        suffix = ' '.join(words[i:])\n",
    "        if len(suffix)<=max_suffix_length:\n",
    "            suffixes.append(suffix)\n",
    "    return suffixes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_suffixes_to_freq = {}\n",
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes = GenerateSuffixesFromQuery(query)\n",
    "    for suffix in suffixes:\n",
    "        if suffix in all_suffixes_to_freq:\n",
    "            all_suffixes_to_freq[suffix] += freq\n",
    "        else:\n",
    "            all_suffixes_to_freq[suffix] = freq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_suffixes_to_freq_sorted = sorted(all_suffixes_to_freq.items(), key=lambda x:-x[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_suffixes = []\n",
    "for i in range(NUM_SUFFIXES):\n",
    "    final_suffixes.append(all_suffixes_to_freq_sorted[i][0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_suffixes_set = set(final_suffixes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SAMPLE ONE PREFIX PER QUERY, UNTIL YOU GET AT LEAST ONE SUFFIX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for query in train_query_to_freq.keys():\n",
    "    if len(query.split(' ')) != (query.count(' ')+1):\n",
    "        cnt += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = 'abc def ghi '\n",
    "suffix = 'def ghi '\n",
    "max(0, len(query)-len(suffix)-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GenerateShortlistSuffixes(query, suffix_set):\n",
    "    req_suffixes = []\n",
    "    words = query.split(' ')\n",
    "    min_index = len(query)\n",
    "    for i in range(len(words)):\n",
    "        suffix = ' '.join(words[i:])\n",
    "        if suffix in suffix_set:\n",
    "            start_index = max(0, len(query)-len(suffix)-2)\n",
    "            req_suffixes.append([suffix, start_index])\n",
    "            min_index = min(min_index, start_index)\n",
    "    return req_suffixes, min_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_prefix_to_suffixes = {}\n",
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes, min_index = GenerateShortlistSuffixes(query, final_suffixes_set)\n",
    "    if min_index == len(query):\n",
    "        continue\n",
    "    prefix_end_index = random.randint(min_index, max(len(query)-2, 0))\n",
    "    prefix = query[:prefix_end_index+1]\n",
    "    if prefix not in train_prefix_to_suffixes:\n",
    "        train_prefix_to_suffixes[prefix] = {}\n",
    "    for j in range(len(suffixes)):\n",
    "        suffix, index = suffixes[j][0], suffixes[j][1]\n",
    "        if prefix_end_index>=index:\n",
    "            train_prefix_to_suffixes[prefix][suffix] = freq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_prefix_to_suffixes.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_suffixes = set([])\n",
    "for suffixes in train_prefix_to_suffixes.values():\n",
    "    for suffix in suffixes.keys():\n",
    "        unique_suffixes.add(suffix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(unique_suffixes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate all ground truth pairs to have full gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes, min_index = GenerateShortlistSuffixes(query, final_suffixes_set)\n",
    "    if min_index == len(query):\n",
    "        continue\n",
    "    \n",
    "    for i in range(len(suffixes)):\n",
    "        suffix, index = suffixes[i][0], suffixes[i][1]\n",
    "        for j in range(index, len(query)-1):\n",
    "            prefix = query[:index+1]\n",
    "            if prefix in train_prefix_to_suffixes:\n",
    "                if suffix not in train_prefix_to_suffixes[prefix]:\n",
    "                    train_prefix_to_suffixes[prefix][suffix] = freq\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_prefix_to_suffixes.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_suffixes = set([])\n",
    "for suffixes in train_prefix_to_suffixes.values():\n",
    "    for suffix in suffixes.keys():\n",
    "        unique_suffixes.add(suffix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(unique_suffixes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate test prefix, suffix pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_queries_path = './dataset/raw/test.query.alphanumeric_filtered.python.txt'\n",
    "test_queries_lines = [x.strip() for x in open(test_queries_path).readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_query_to_freq = {}\n",
    "for i in tqdm(range(len(test_queries_lines))):\n",
    "    line = test_queries_lines[i].split('\\t')\n",
    "    query, freq = line[0], int(line[1])\n",
    "    if query not in test_query_to_freq:\n",
    "        test_query_to_freq[query] = freq\n",
    "    else:\n",
    "        test_query_to_freq[query] = max(freq, test_query_to_freq[query])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GenerateShortlistSuffixes(query, suffix_set):\n",
    "    req_suffixes = []\n",
    "    words = query.split(' ')\n",
    "    min_index = len(query)\n",
    "    for i in range(len(words)):\n",
    "        suffix = ' '.join(words[i:])\n",
    "        if suffix in suffix_set:\n",
    "            start_index = max(0, len(query)-len(suffix)-2)\n",
    "            req_suffixes.append([suffix, start_index])\n",
    "            min_index = min(min_index, start_index)\n",
    "    return req_suffixes, min_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_prefix_to_suffixes = {}\n",
    "for query, freq in tqdm(test_query_to_freq.items()):\n",
    "    suffixes, min_index = GenerateShortlistSuffixes(query, unique_suffixes)\n",
    "    if min_index == len(query):\n",
    "        continue\n",
    "    prefix_end_index = random.randint(min_index, max(len(query)-2, 0))\n",
    "    prefix = query[:prefix_end_index+1]\n",
    "    gt_suffixes = {}\n",
    "    for j in range(len(suffixes)):\n",
    "        suffix, index = suffixes[j][0], suffixes[j][1]\n",
    "        if prefix_end_index>=index:\n",
    "            gt_suffixes[suffix] = freq\n",
    "    test_prefix_to_suffixes[prefix] = gt_suffixes\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_unique_suffixes = set([])\n",
    "for suffixes in test_prefix_to_suffixes.values():\n",
    "    for suffix in suffixes.keys():\n",
    "        test_unique_suffixes.add(suffix)\n",
    "print(len(test_unique_suffixes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for query, freq in tqdm(test_query_to_freq.items()):\n",
    "    suffixes, min_index = GenerateShortlistSuffixes(query, unique_suffixes)\n",
    "    if min_index == len(query):\n",
    "        continue\n",
    "    \n",
    "    for i in range(len(suffixes)):\n",
    "        suffix, index = suffixes[i][0], suffixes[i][1]\n",
    "        for j in range(index, len(query)-1):\n",
    "            prefix = query[:index+1]\n",
    "            if prefix in test_prefix_to_suffixes:\n",
    "                if suffix not in test_prefix_to_suffixes[prefix]:\n",
    "                    test_prefix_to_suffixes[prefix][suffix] = freq\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_unique_suffixes = set([])\n",
    "for suffixes in test_prefix_to_suffixes.values():\n",
    "    for suffix in suffixes.keys():\n",
    "        test_unique_suffixes.add(suffix)\n",
    "print(len(test_unique_suffixes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_prefix_to_suffixes.keys())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Removing seen test points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_prefix_to_suffixes_v2 = {}\n",
    "for prefix, suffixes in test_prefix_to_suffixes.items():\n",
    "    if prefix not in train_prefix_to_suffixes:\n",
    "        test_prefix_to_suffixes_v2[prefix] = suffixes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_prefix_to_suffixes = test_prefix_to_suffixes_v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_prefix_to_suffixes.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_unique_suffixes = set([])\n",
    "for suffixes in test_prefix_to_suffixes.values():\n",
    "    for suffix in suffixes.keys():\n",
    "        test_unique_suffixes.add(suffix)\n",
    "print(len(test_unique_suffixes))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Removing labels with no test point from train as well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(unique_suffixes & test_unique_suffixes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_train_prefix_to_suffixes = {}\n",
    "for prefix, suffixes in tqdm(train_prefix_to_suffixes.items()):\n",
    "    updated_lbl_dict = {}\n",
    "    for suffix, freq in train_prefix_to_suffixes[prefix].items():\n",
    "        if suffix in test_unique_suffixes:\n",
    "            updated_lbl_dict[suffix] = freq\n",
    "    if len(updated_lbl_dict.keys())>0:\n",
    "        updated_train_prefix_to_suffixes[prefix] = updated_lbl_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(updated_train_prefix_to_suffixes.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3922479/3922479 [00:03<00:00, 981918.29it/s] \n"
     ]
    }
   ],
   "source": [
    "updated_train_suffixes = set([])\n",
    "for prefix, suffixes in tqdm(updated_train_prefix_to_suffixes.items()):\n",
    "    for suffix, value in suffixes.items():\n",
    "        updated_train_suffixes.add(suffix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(updated_train_suffixes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Stats for final dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "# labels per data point\n",
    "\n",
    "lbls_per_point = []\n",
    "for p, s in updated_train_prefix_to_suffixes.items():\n",
    "    lbls_per_point.append(len(s.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min(lbls_per_point), max(lbls_per_point), mean(lbls_per_point), median(lbls_per_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data points per label\n",
    "\n",
    "points_per_suffix = {}\n",
    "for p, ss in updated_train_prefix_to_suffixes.items():\n",
    "    for s in ss.keys():\n",
    "        if s in points_per_suffix:\n",
    "            points_per_suffix[s] += 1\n",
    "        else:\n",
    "            points_per_suffix[s] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "points_per_suffix = list(points_per_suffix.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min(points_per_suffix), max(points_per_suffix), mean(points_per_suffix), median(points_per_suffix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# labels per data point\n",
    "\n",
    "tst_lbls_per_point = []\n",
    "for p, s in test_prefix_to_suffixes.items():\n",
    "    tst_lbls_per_point.append(len(s.keys()))\n",
    "print(min(tst_lbls_per_point), max(tst_lbls_per_point), mean(tst_lbls_per_point), median(tst_lbls_per_point))\n",
    "\n",
    "# data points per label\n",
    "\n",
    "tst_points_per_suffix = {}\n",
    "for p, ss in test_prefix_to_suffixes.items():\n",
    "    for s in ss.keys():\n",
    "        if s in tst_points_per_suffix:\n",
    "            tst_points_per_suffix[s] += 1\n",
    "        else:\n",
    "            tst_points_per_suffix[s] = 1\n",
    "tst_points_per_suffix = list(tst_points_per_suffix.values())\n",
    "\n",
    "min(tst_points_per_suffix), max(tst_points_per_suffix), mean(tst_points_per_suffix), median(tst_points_per_suffix)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert to XC format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_prefix_to_prefix_id = {}\n",
    "trn_prefix_id_to_prefix = {}\n",
    "trn_suffix_id_to_suffix = {}\n",
    "trn_suffix_to_suffix_id = {}\n",
    "\n",
    "prefix_id = 0\n",
    "suffix_id = 0\n",
    "for prefix, suffixes in updated_train_prefix_to_suffixes.items():\n",
    "    if prefix not in trn_prefix_to_prefix_id:\n",
    "        trn_prefix_to_prefix_id[prefix] = prefix_id\n",
    "        trn_prefix_id_to_prefix[prefix_id] = prefix\n",
    "        prefix_id += 1\n",
    "    for suffix, freq in suffixes.items():\n",
    "        if suffix not in trn_suffix_to_suffix_id:\n",
    "            trn_suffix_to_suffix_id[suffix] = suffix_id\n",
    "            trn_suffix_id_to_suffix[suffix_id] = suffix\n",
    "            suffix_id += 1\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(trn_prefix_id_to_prefix.keys()), len(trn_suffix_id_to_suffix.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "## writing trn_X and Y\n",
    "data_dir = f'{dataset_dir}/final-dataset'\n",
    "f = open(f'{data_dir}/raw/trn_X.txt', 'w')\n",
    "for i in range(len(trn_prefix_id_to_prefix.keys())):\n",
    "    f.write(trn_prefix_id_to_prefix[i]+'\\n')\n",
    "f.close()\n",
    "\n",
    "f = open(f'{data_dir}/raw/Y.txt', 'w')\n",
    "for i in range(len(trn_suffix_id_to_suffix.keys())):\n",
    "    f.write(trn_suffix_id_to_suffix[i]+'\\n')\n",
    "f.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_X_Y = sp.dok_matrix((3922479,272825), dtype=np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for prefix, suffixes in tqdm(updated_train_prefix_to_suffixes.items()):\n",
    "    prefix_id = trn_prefix_to_prefix_id[prefix]\n",
    "    flag = 0\n",
    "    for suffix, freq in updated_train_prefix_to_suffixes[prefix].items():\n",
    "        suffix_id = trn_suffix_to_suffix_id[suffix]\n",
    "        trn_X_Y[prefix_id, suffix_id] = freq\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_X_Y = trn_X_Y.tocsr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_X_Y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_lpp = np.array((trn_X_Y > 0).astype(int).sum(axis=1)).squeeze()\n",
    "round(np.max(trn_lpp),2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "sp.save_npz(f'{data_dir}/raw/trn_X_Y.npz', trn_X_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_prefix_to_prefix_id = {}\n",
    "tst_prefix_id_to_prefix = {}\n",
    "prefix_id = 0\n",
    "for prefix, suffixes in test_prefix_to_suffixes.items():\n",
    "    if prefix not in tst_prefix_to_prefix_id:\n",
    "        tst_prefix_to_prefix_id[prefix] = prefix_id\n",
    "        tst_prefix_id_to_prefix[prefix_id] = prefix\n",
    "        prefix_id += 1\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tst_prefix_id_to_prefix.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open(f'{data_dir}/raw/tst_X.txt', 'w')\n",
    "for i in range(len(tst_prefix_id_to_prefix.keys())):\n",
    "    f.write(tst_prefix_id_to_prefix[i]+'\\n')\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_X_Y = sp.dok_matrix((519352,272825), dtype=np.int32)\n",
    "for prefix, suffixes in tqdm(test_prefix_to_suffixes.items()):\n",
    "    prefix_id = tst_prefix_to_prefix_id[prefix]\n",
    "    for suffix, freq in test_prefix_to_suffixes[prefix].items():\n",
    "        suffix_id = trn_suffix_to_suffix_id[suffix]\n",
    "        tst_X_Y[prefix_id, suffix_id] = freq\n",
    "tst_X_Y = tst_X_Y.tocsr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "sp.save_npz(f'{data_dir}/raw/tst_X_Y.npz', tst_X_Y)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Coverage Calculation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries_covered_freq_list = []\n",
    "total_freq_list = []\n",
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes, min_index = GenerateShortlistSuffixes(query, final_suffixes_set)\n",
    "    if len(suffixes) > 0:\n",
    "        queries_covered_freq_list.append(freq)\n",
    "    total_freq_list.append(freq)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(queries_covered_freq_list)/sum(total_freq_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(queries_covered_freq_list)/len(total_freq_list)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verify older suffix list coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = [x.strip() for x in open(f'{dataset_dir}/SuffixesDatasets/CharsToAddPrefixes/V2Normalization/1M/AllPrefixesSample/raw/Y.txt')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_set = set(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8700222/8700222 [00:24<00:00, 355001.33it/s]\n"
     ]
    }
   ],
   "source": [
    "queries_covered_freq_list = []\n",
    "total_freq_list = []\n",
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes, min_index = GenerateShortlistSuffixes(query, Y_set)\n",
    "    if min_index!=len(query):\n",
    "        queries_covered_freq_list.append(freq)\n",
    "    total_freq_list.append(freq)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(queries_covered_freq_list)/sum(total_freq_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(queries_covered_freq_list)/len(total_freq_list)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SAMPLE ONE PREFIX PER QUERY, SUFFIX AND THEN COMPLETE GT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [],
   "source": [
    "def SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):\n",
    "    max_char = max_chars_to_add\n",
    "    min_char = min_chars_to_add\n",
    "    max_char = min(len(suffix), max_char)\n",
    "    start_index = max(len(query) - max_char - 1, 0)\n",
    "    end_index = max(len(query) - min_char, 0)\n",
    "    chosen_index = random.randint(start_index, end_index)\n",
    "    prefix = query[:chosen_index]\n",
    "    return prefix, suffix\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix_to_suffixes = {}\n",
    "query_suffix_pairs = []\n",
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes = GenerateSuffixesFromQuery(query)\n",
    "    for suffix in suffixes:\n",
    "        if suffix in final_suffixes_set:\n",
    "            prefix, suffix = SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add) #sample 1 prefix between 1 to 10 chars to add\n",
    "            if prefix not in prefix_to_suffixes:\n",
    "                prefix_to_suffixes[prefix] = {}\n",
    "            prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]\n",
    "            query_suffix_pairs.append([query, suffix])\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.sample(prefix_to_suffixes.items(), k = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GenerateAllPrefixesFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):\n",
    "    max_char = max_chars_to_add\n",
    "    min_char = min_chars_to_add\n",
    "    max_char = min(len(suffix), max_char)\n",
    "    start_index = max(len(query) - max_char - 1, 0)\n",
    "    end_index = max(len(query) - min_char, 0)\n",
    "    prefixes = []\n",
    "    for i in range(start_index, end_index+1):\n",
    "        prefixes.append(query[:i])\n",
    "    return prefixes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in tqdm(range(len(query_suffix_pairs))):\n",
    "    query, suffix = query_suffix_pairs[i][0], query_suffix_pairs[i][1]\n",
    "    prefixes = GenerateAllPrefixesFromQuerySuffix(query, suffix, 100, min_chars_to_add)\n",
    "    for prefix in prefixes:\n",
    "        if prefix in prefix_to_suffixes:\n",
    "            if suffix not in prefix_to_suffixes[prefix]:\n",
    "                prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(prefix_to_suffixes.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(final_suffixes_set)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## REDUCE QUERY, SUFFIX PAIRS TO START WITH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "suffix_to_queries = {}\n",
    "for query, freq in tqdm(train_query_to_freq.items()):\n",
    "    suffixes = GenerateSuffixesFromQuery(query)\n",
    "    for suffix in suffixes:\n",
    "        if suffix in final_suffixes_set:\n",
    "            if suffix not in suffix_to_queries:\n",
    "                suffix_to_queries[suffix] = {}\n",
    "            suffix_to_queries[suffix][query] = train_query_to_freq[query]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_query_suffix_pairs = 0\n",
    "max_queries_per_suffix = 0\n",
    "for suffix, queries in suffix_to_queries.items():\n",
    "    total_query_suffix_pairs += len(suffix_to_queries[suffix].keys())\n",
    "    max_queries_per_suffix = max(max_queries_per_suffix, len(suffix_to_queries[suffix].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_query_suffix_pairs, max_queries_per_suffix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_QUERIES_PER_SUFFIX = 100\n",
    "query_suffix_pairs = []\n",
    "for suffix, queries in tqdm(suffix_to_queries.items()):\n",
    "    sorted_queries = sorted(suffix_to_queries[suffix].items(), key=lambda x:-x[1])\n",
    "    for i in range(min(len(sorted_queries), MAX_QUERIES_PER_SUFFIX)):\n",
    "        query_suffix_pairs.append([sorted_queries[i][0], suffix])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(query_suffix_pairs), len(suffix_to_queries.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [],
   "source": [
    "def SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):\n",
    "    max_char = max_chars_to_add\n",
    "    min_char = min_chars_to_add\n",
    "    max_char = min(len(suffix), max_char)\n",
    "    start_index = max(len(query) - max_char - 1, 0)\n",
    "    end_index = max(len(query) - min_char, 0)\n",
    "    chosen_index = random.randint(start_index, end_index)\n",
    "    prefix = query[:chosen_index]\n",
    "    return prefix, suffix\n",
    "\n",
    "prefix_to_suffixes = {}\n",
    "for i in range(len(query_suffix_pairs)):\n",
    "    query, suffix = query_suffix_pairs[i][0], query_suffix_pairs[i][1]\n",
    "    prefix, suffix = SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add) #sample 1 prefix between 1 to 10 chars to add\n",
    "    if prefix not in prefix_to_suffixes:\n",
    "        prefix_to_suffixes[prefix] = {}\n",
    "    prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(prefix_to_suffixes.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GenerateAllPrefixesFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):\n",
    "    max_char = max_chars_to_add\n",
    "    min_char = min_chars_to_add\n",
    "    max_char = min(len(suffix), max_char)\n",
    "    start_index = max(len(query) - max_char - 1, 0)\n",
    "    end_index = max(len(query) - min_char, 0)\n",
    "    prefixes = []\n",
    "    for i in range(start_index, end_index+1):\n",
    "        prefixes.append(query[:i])\n",
    "    return prefixes\n",
    "\n",
    "for i in tqdm(range(len(query_suffix_pairs))):\n",
    "    query, suffix = query_suffix_pairs[i][0], query_suffix_pairs[i][1]\n",
    "    prefixes = GenerateAllPrefixesFromQuerySuffix(query, suffix, 100, min_chars_to_add)\n",
    "    for prefix in prefixes:\n",
    "        if prefix in prefix_to_suffixes:\n",
    "            if suffix not in prefix_to_suffixes[prefix]:\n",
    "                prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(prefix_to_suffixes.keys())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "my_python",
   "language": "python",
   "name": "my_python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
