{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "56af47e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import regex\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "batch_size = 500\n",
    "\n",
    "def batch_iterator():\n",
    "    for i in range(0, len(dataset), batch_size):\n",
    "        yield dataset[i : i + batch_size]\n",
    "        \n",
    "def batch_iterator_split():\n",
    "    pat = regex.compile(r\"\"\"'s|'t|'re|'ve|'m|'ll|'d| ?[\\p{L}]+| ?[\\p{N}]+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\")\n",
    "    for i in range(0, len(dataset), batch_size):\n",
    "        yield [regex.findall(pat, text) \n",
    "               for text in dataset[i : i + batch_size]]\n",
    "    \n",
    "dataset = [open(f\"../project5/data/un/TXT/{f}\").read() for f in os.listdir(\"../project5/data/un/TXT/\")[:10000]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b143773",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "Word counts size: 103190\n",
      "Token set size: 0\n",
      "Empty token set size selected -> all possible substrings with...\n",
      "Max token size: 100\n",
      "Min. word count: 1\n",
      "len:  100\n",
      "Final candidate token set size: 678311\n",
      "Initial setup phase: 1644 ms\n",
      "0. |<pad> [3c 70 61 64 3e ] | 0\n",
      "1. |<unk> [3c 75 6e 6b 3e ] | 0\n",
      "2. |<eos> [3c 65 6f 73 3e ] | 0\n",
      "4. | the [20 74 68 65 ] | 7814940 | 60 ms | 77 ms | shortlist: 2947\n",
      "5. |tion [74 69 6f 6e ] | 3820722 | 36 ms | 55 ms | shortlist: 68487\n",
      "6. | of [20 6f 66 ] | 3114622 | 29 ms | 30 ms | shortlist: 895\n",
      "7. | and [20 61 6e 64 ] | 3075315 | 23 ms | 24 ms | shortlist: 612\n",
      "8. |in [69 6e ] | 2624836 | 19 ms | 75 ms | shortlist: 177066\n",
      "9. |re [72 65 ] | 2250933 | 34 ms | 64 ms | shortlist: 98801\n",
      "10. | t [20 74 ] | 2166471 | 31 ms | 39 ms | shortlist: 23658\n",
      "11. | a [20 61 ] | 2047195 | 30 ms | 42 ms | shortlist: 39319\n",
      "12. |er [65 72 ] | 1800671 | 28 ms | 67 ms | shortlist: 117395\n",
      "13. |en [65 6e ] | 1720810 | 31 ms | 61 ms | shortlist: 95674\n",
      "14. | co [20 63 6f ] | 1681810 | 31 ms | 42 ms | shortlist: 32292\n",
      "15. |it [69 74 ] | 1530652 | 31 ms | 54 ms | shortlist: 61995\n",
      "16. | w [20 77 ] | 1321129 | 30 ms | 35 ms | shortlist: 10508\n",
      "17. |es [65 73 ] | 1287657 | 32 ms | 67 ms | shortlist: 122831\n",
      "18. | s [20 73 ] | 1285953 | 30 ms | 44 ms | shortlist: 52822\n",
      "19. |or [6f 72 ] | 1246898 | 28 ms | 47 ms | shortlist: 54858\n",
      "20. |at [61 74 ] | 1218292 | 42 ms | 68 ms | shortlist: 73386\n",
      "21. |al [61 6c ] | 1208649 | 31 ms | 60 ms | shortlist: 86300\n",
      "22. |is [69 73 ] | 1207139 | 30 ms | 59 ms | shortlist: 84113\n",
      "23. | p [20 70 ] | 1181255 | 30 ms | 44 ms | shortlist: 46314\n",
      "24. |on [6f 6e ] | 1156606 | 28 ms | 52 ms | shortlist: 68982\n",
      "25. |an [61 6e ] | 1120391 | 30 ms | 60 ms | shortlist: 96467\n",
      "26. | in [20 69 6e ] | 1094400 | 30 ms | 40 ms | shortlist: 34118\n",
      "27. |ed [65 64 ] | 1072871 | 28 ms | 58 ms | shortlist: 93979\n",
      "28. | to [20 74 6f ] | 1048026 | 31 ms | 35 ms | shortlist: 2337\n",
      "29. | f [20 66 ] | 924247 | 30 ms | 37 ms | shortlist: 21773\n",
      "30. | be [20 62 65 ] | 901370 | 24 ms | 25 ms | shortlist: 3629\n",
      "31. |ation [61 74 69 6f 6e ] | 897505 | 23 ms | 39 ms | shortlist: 52613\n",
      "32. |ic [69 63 ] | 895399 | 27 ms | 46 ms | shortlist: 61636\n",
      "33. |ou [6f 75 ] | 873642 | 30 ms | 43 ms | shortlist: 42465\n",
      "34. |ar [61 72 ] | 827638 | 27 ms | 49 ms | shortlist: 69537\n",
      "35. |ment [6d 65 6e 74 ] | 824632 | 43 ms | 51 ms | shortlist: 22247\n",
      "36. | that [20 74 68 61 74 ] | 808214 | 29 ms | 30 ms | shortlist: 406\n",
      "37. |ing [69 6e 67 ] | 767679 | 23 ms | 52 ms | shortlist: 111558\n",
      "38. | develop [20 64 65 76 65 6c 6f 70 ] | 737485 | 35 ms | 37 ms | shortlist: 185\n",
      "39. | m [20 6d ] | 722906 | 26 ms | 34 ms | shortlist: 31637\n",
      "40. |le [6c 65 ] | 719256 | 24 ms | 36 ms | shortlist: 47560\n",
      "41. | h [20 68 ] | 689344 | 26 ms | 31 ms | shortlist: 17710\n",
      "42. | re [20 72 65 ] | 654055 | 23 ms | 31 ms | shortlist: 33192\n",
      "43. | United [20 55 6e 69 74 65 64 ] | 638492 | 25 ms | 26 ms | shortlist: 179\n",
      "44. | d [20 64 ] | 626820 | 19 ms | 30 ms | shortlist: 43501\n",
      "45. | countr [20 63 6f 75 6e 74 72 ] | 611196 | 26 ms | 27 ms | shortlist: 396\n",
      "46. | international [20 69 6e 74 65 72 6e 61 74 69 6f 6e 61 6c ] | 605355 | 18 ms | 18 ms | shortlist: 715\n",
      "47. |st [73 74 ] | 603387 | 22 ms | 35 ms | shortlist: 37575\n",
      "48. |ro [72 6f ] | 550903 | 26 ms | 36 ms | shortlist: 35278\n",
      "49. |ce [63 65 ] | 537524 | 27 ms | 34 ms | shortlist: 23499\n",
      "50. |ve [76 65 ] | 531965 | 26 ms | 33 ms | shortlist: 23927\n",
      "51. | n [20 6e ] | 523212 | 26 ms | 29 ms | shortlist: 14181\n",
      "52. | which [20 77 68 69 63 68 ] | 509985 | 24 ms | 25 ms | shortlist: 70\n",
      "53. |ec [65 63 ] | 503635 | 19 ms | 26 ms | shortlist: 22863\n",
      "54. |il [69 6c ] | 478851 | 24 ms | 36 ms | shortlist: 39866\n",
      "55. | c [20 63 ] | 453889 | 26 ms | 32 ms | shortlist: 25115\n",
      "56. | b [20 62 ] | 441286 | 25 ms | 30 ms | shortlist: 21189\n",
      "57. | Assembly [20 41 73 73 65 6d 62 6c 79 ] | 439816 | 24 ms | 25 ms | shortlist: 74\n",
      "58. |th [74 68 ] | 432669 | 21 ms | 27 ms | shortlist: 20972\n",
      "59. |as [61 73 ] | 431958 | 23 ms | 31 ms | shortlist: 28314\n",
      "60. | e [20 65 ] | 428436 | 25 ms | 30 ms | shortlist: 20877\n",
      "61. | The [20 54 68 65 ] | 423648 | 24 ms | 25 ms | shortlist: 703\n",
      "62. | with [20 77 69 74 68 ] | 413206 | 23 ms | 23 ms | shortlist: 428\n",
      "63. | Nations [20 4e 61 74 69 6f 6e 73 ] | 403002 | 24 ms | 25 ms | shortlist: 169\n",
      "64. | con [20 63 6f 6e ] | 385095 | 22 ms | 27 ms | shortlist: 15533\n",
      "65. |ly [6c 79 ] | 384700 | 22 ms | 37 ms | shortlist: 66966\n",
      "66. | for [20 66 6f 72 ] | 366319 | 29 ms | 31 ms | shortlist: 2766\n",
      "67. | is [20 69 73 ] | 365296 | 20 ms | 20 ms | shortlist: 600\n",
      "68. | our [20 6f 75 72 ] | 360590 | 20 ms | 21 ms | shortlist: 105\n",
      "69. | peace [20 70 65 61 63 65 ] | 358947 | 21 ms | 21 ms | shortlist: 560\n",
      "70. |op [6f 70 ] | 347146 | 20 ms | 25 ms | shortlist: 14546\n",
      "71. | th [20 74 68 ] | 342386 | 24 ms | 25 ms | shortlist: 2938\n",
      "72. |im [69 6d ] | 337865 | 20 ms | 27 ms | shortlist: 25043\n",
      "73. | Government [20 47 6f 76 65 72 6e 6d 65 6e 74 ] | 330654 | 24 ms | 24 ms | shortlist: 144\n",
      "74. | world [20 77 6f 72 6c 64 ] | 329490 | 17 ms | 18 ms | shortlist: 189\n",
      "75. |ent [65 6e 74 ] | 327542 | 19 ms | 27 ms | shortlist: 22146\n",
      "76. |si [73 69 ] | 326758 | 22 ms | 29 ms | shortlist: 21004\n",
      "77. |om [6f 6d ] | 322756 | 23 ms | 30 ms | shortlist: 24507\n",
      "78. |ol [6f 6c ] | 320844 | 24 ms | 32 ms | shortlist: 28006\n",
      "79. | States [20 53 74 61 74 65 73 ] | 316016 | 33 ms | 33 ms | shortlist: 310\n",
      "80. | its [20 69 74 73 ] | 312756 | 27 ms | 27 ms | shortlist: 266\n",
      "81. | have [20 68 61 76 65 ] | 311246 | 27 ms | 27 ms | shortlist: 136\n",
      "82. |ight [69 67 68 74 ] | 306870 | 25 ms | 27 ms | shortlist: 4095\n",
      "83. |ity [69 74 79 ] | 304932 | 26 ms | 32 ms | shortlist: 22111\n",
      "84. | on [20 6f 6e ] | 302598 | 23 ms | 24 ms | shortlist: 989\n",
      "85. |un [75 6e ] | 302079 | 20 ms | 33 ms | shortlist: 52481\n",
      "86. |ur [75 72 ] | 297800 | 30 ms | 37 ms | shortlist: 23314\n",
      "87. | Organization [20 4f 72 67 61 6e 69 7a 61 74 69 6f 6e ] | 281491 | 35 ms | 36 ms | shortlist: 212\n",
      "88. | we [20 77 65 ] | 275931 | 30 ms | 30 ms | shortlist: 1634\n",
      "89. | I [20 49 ] | 274896 | 29 ms | 33 ms | shortlist: 20489\n",
      "90. |ul [75 6c ] | 274442 | 23 ms | 31 ms | shortlist: 26761\n",
      "91. | all [20 61 6c 6c ] | 273450 | 23 ms | 23 ms | shortlist: 773\n",
      "92. | g [20 67 ] | 272529 | 19 ms | 22 ms | shortlist: 13484\n",
      "93. | l [20 6c ] | 271187 | 20 ms | 23 ms | shortlist: 10742\n",
      "94. |ra [72 61 ] | 271022 | 21 ms | 30 ms | shortlist: 31012\n",
      "95. |se [73 65 ] | 268368 | 24 ms | 28 ms | shortlist: 13221\n",
      "96. |ir [69 72 ] | 263563 | 22 ms | 27 ms | shortlist: 15652\n",
      "97. | should [20 73 68 6f 75 6c 64 ] | 260564 | 25 ms | 26 ms | shortlist: 200\n",
      "98. |ies [69 65 73 ] | 258393 | 24 ms | 30 ms | shortlist: 24933\n",
      "99. | pro [20 70 72 6f ] | 257290 | 23 ms | 25 ms | shortlist: 9018\n",
      "100. | Afric [20 41 66 72 69 63 ] | 246904 | 22 ms | 23 ms | shortlist: 384\n",
      "Total time taken: 3 seconds\n",
      "Trie constructed\n",
      "eos_token <eos> 2\n",
      "unk_token <unk> 1\n",
      "pad_token <pad> 0\n"
     ]
    }
   ],
   "source": [
    "from pcatt.hf.greedtok import GreedTok\n",
    "\n",
    "# text iterator yield batches of lists of str\n",
    "\n",
    "GT_Train = GreedTok().train_new_from_iterator(\n",
    "    batch_iterator(), \n",
    "    vocab_size = 100,\n",
    "    special_tokens_map={\n",
    "        \"pad_token\":\"<pad>\",\n",
    "        \"unk_token\":\"<unk>\", \n",
    "        \"eos_token\":\"<eos>\"\n",
    "    },\n",
    "    min_word_count=1,\n",
    "    max_token_length=1000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9cbd27f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "Word counts size: 103190\n",
      "Token set size: 0\n",
      "Empty token set size selected -> all possible substrings with...\n",
      "Max token size: 100\n",
      "len:  100\n",
      "Min. word count: 1\n",
      "Final candidate token set size: 678311\n",
      "Initial setup phase: 1972 ms\n",
      "0. |<pad> [3c 70 61 64 3e ] | 0\n",
      "1. |<unk> [3c 75 6e 6b 3e ] | 0\n",
      "2. |<eos> [3c 65 6f 73 3e ] | 0\n",
      "4. | the [20 74 68 65 ] | 7814940 | 59 ms | 76 ms | shortlist: 2947\n",
      "5. |tion [74 69 6f 6e ] | 3820722 | 29 ms | 49 ms | shortlist: 68487\n",
      "6. | of [20 6f 66 ] | 3114622 | 29 ms | 31 ms | shortlist: 895\n",
      "7. | and [20 61 6e 64 ] | 3075315 | 25 ms | 26 ms | shortlist: 612\n",
      "8. |in [69 6e ] | 2624836 | 21 ms | 78 ms | shortlist: 177066\n",
      "9. |re [72 65 ] | 2250933 | 32 ms | 62 ms | shortlist: 98801\n",
      "10. | t [20 74 ] | 2166471 | 31 ms | 39 ms | shortlist: 23658\n",
      "11. | a [20 61 ] | 2047195 | 29 ms | 40 ms | shortlist: 39319\n",
      "12. |er [65 72 ] | 1800671 | 28 ms | 68 ms | shortlist: 117395\n",
      "13. |en [65 6e ] | 1720810 | 31 ms | 61 ms | shortlist: 95674\n",
      "14. | co [20 63 6f ] | 1681810 | 32 ms | 43 ms | shortlist: 32292\n",
      "15. |it [69 74 ] | 1530652 | 28 ms | 51 ms | shortlist: 61995\n",
      "16. | w [20 77 ] | 1321129 | 31 ms | 35 ms | shortlist: 10508\n",
      "17. |es [65 73 ] | 1287657 | 31 ms | 68 ms | shortlist: 122831\n",
      "18. | s [20 73 ] | 1285953 | 34 ms | 50 ms | shortlist: 52822\n",
      "19. |or [6f 72 ] | 1246898 | 31 ms | 50 ms | shortlist: 54858\n",
      "20. |at [61 74 ] | 1218292 | 33 ms | 61 ms | shortlist: 73386\n",
      "21. |al [61 6c ] | 1208649 | 31 ms | 63 ms | shortlist: 86300\n",
      "22. |is [69 73 ] | 1207139 | 32 ms | 61 ms | shortlist: 84113\n",
      "23. | p [20 70 ] | 1181255 | 32 ms | 47 ms | shortlist: 46314\n",
      "24. |on [6f 6e ] | 1156606 | 30 ms | 54 ms | shortlist: 68982\n",
      "25. |an [61 6e ] | 1120391 | 31 ms | 62 ms | shortlist: 96467\n",
      "26. | in [20 69 6e ] | 1094400 | 33 ms | 45 ms | shortlist: 34118\n",
      "27. |ed [65 64 ] | 1072871 | 28 ms | 57 ms | shortlist: 93979\n",
      "28. | to [20 74 6f ] | 1048026 | 33 ms | 37 ms | shortlist: 2337\n",
      "29. | f [20 66 ] | 924247 | 23 ms | 29 ms | shortlist: 21773\n",
      "30. | be [20 62 65 ] | 901370 | 25 ms | 27 ms | shortlist: 3629\n",
      "31. |ation [61 74 69 6f 6e ] | 897505 | 27 ms | 46 ms | shortlist: 52613\n",
      "32. |ic [69 63 ] | 895399 | 28 ms | 47 ms | shortlist: 61636\n",
      "33. |ou [6f 75 ] | 873642 | 31 ms | 45 ms | shortlist: 42465\n",
      "34. |ar [61 72 ] | 827638 | 38 ms | 60 ms | shortlist: 69537\n",
      "35. |ment [6d 65 6e 74 ] | 824632 | 29 ms | 37 ms | shortlist: 22247\n",
      "36. | that [20 74 68 61 74 ] | 808214 | 24 ms | 25 ms | shortlist: 406\n",
      "37. |ing [69 6e 67 ] | 767679 | 19 ms | 51 ms | shortlist: 111558\n",
      "38. | develop [20 64 65 76 65 6c 6f 70 ] | 737485 | 30 ms | 31 ms | shortlist: 185\n",
      "39. | m [20 6d ] | 722906 | 22 ms | 30 ms | shortlist: 31637\n",
      "40. |le [6c 65 ] | 719256 | 24 ms | 36 ms | shortlist: 47560\n",
      "41. | h [20 68 ] | 689344 | 37 ms | 42 ms | shortlist: 17710\n",
      "42. | re [20 72 65 ] | 654055 | 25 ms | 34 ms | shortlist: 33192\n",
      "43. | United [20 55 6e 69 74 65 64 ] | 638492 | 24 ms | 25 ms | shortlist: 179\n",
      "44. | d [20 64 ] | 626820 | 18 ms | 29 ms | shortlist: 43501\n",
      "45. | countr [20 63 6f 75 6e 74 72 ] | 611196 | 31 ms | 32 ms | shortlist: 396\n",
      "46. | international [20 69 6e 74 65 72 6e 61 74 69 6f 6e 61 6c ] | 605355 | 27 ms | 27 ms | shortlist: 715\n",
      "47. |st [73 74 ] | 603387 | 26 ms | 39 ms | shortlist: 37575\n",
      "48. |ro [72 6f ] | 550903 | 25 ms | 35 ms | shortlist: 35278\n",
      "49. |ce [63 65 ] | 537524 | 25 ms | 32 ms | shortlist: 23499\n",
      "50. |ve [76 65 ] | 531965 | 24 ms | 30 ms | shortlist: 23927\n",
      "51. | n [20 6e ] | 523212 | 33 ms | 37 ms | shortlist: 14181\n",
      "52. | which [20 77 68 69 63 68 ] | 509985 | 23 ms | 23 ms | shortlist: 70\n",
      "53. |ec [65 63 ] | 503635 | 17 ms | 23 ms | shortlist: 22863\n",
      "54. |il [69 6c ] | 478851 | 22 ms | 33 ms | shortlist: 39866\n",
      "55. | c [20 63 ] | 453889 | 27 ms | 34 ms | shortlist: 25115\n",
      "56. | b [20 62 ] | 441286 | 24 ms | 30 ms | shortlist: 21189\n",
      "57. | Assembly [20 41 73 73 65 6d 62 6c 79 ] | 439816 | 23 ms | 24 ms | shortlist: 74\n",
      "58. |th [74 68 ] | 432669 | 18 ms | 23 ms | shortlist: 20972\n",
      "59. |as [61 73 ] | 431958 | 22 ms | 31 ms | shortlist: 28314\n",
      "60. | e [20 65 ] | 428436 | 25 ms | 30 ms | shortlist: 20877\n",
      "61. | The [20 54 68 65 ] | 423648 | 24 ms | 25 ms | shortlist: 703\n",
      "62. | with [20 77 69 74 68 ] | 413206 | 20 ms | 20 ms | shortlist: 428\n",
      "63. | Nations [20 4e 61 74 69 6f 6e 73 ] | 403002 | 19 ms | 19 ms | shortlist: 169\n",
      "64. | con [20 63 6f 6e ] | 385095 | 21 ms | 25 ms | shortlist: 15533\n",
      "65. |ly [6c 79 ] | 384700 | 21 ms | 36 ms | shortlist: 66966\n",
      "66. | for [20 66 6f 72 ] | 366319 | 28 ms | 30 ms | shortlist: 2766\n",
      "67. | is [20 69 73 ] | 365296 | 21 ms | 22 ms | shortlist: 600\n",
      "68. | our [20 6f 75 72 ] | 360590 | 21 ms | 21 ms | shortlist: 105\n",
      "69. | peace [20 70 65 61 63 65 ] | 358947 | 21 ms | 22 ms | shortlist: 560\n",
      "70. |op [6f 70 ] | 347146 | 20 ms | 25 ms | shortlist: 14546\n",
      "71. | th [20 74 68 ] | 342386 | 20 ms | 21 ms | shortlist: 2938\n",
      "72. |im [69 6d ] | 337865 | 19 ms | 27 ms | shortlist: 25043\n",
      "73. | Government [20 47 6f 76 65 72 6e 6d 65 6e 74 ] | 330654 | 23 ms | 23 ms | shortlist: 144\n",
      "74. | world [20 77 6f 72 6c 64 ] | 329490 | 17 ms | 17 ms | shortlist: 189\n",
      "75. |ent [65 6e 74 ] | 327542 | 18 ms | 25 ms | shortlist: 22146\n",
      "76. |si [73 69 ] | 326758 | 23 ms | 30 ms | shortlist: 21004\n",
      "77. |om [6f 6d ] | 322756 | 23 ms | 30 ms | shortlist: 24507\n",
      "78. |ol [6f 6c ] | 320844 | 23 ms | 32 ms | shortlist: 28006\n",
      "79. | States [20 53 74 61 74 65 73 ] | 316016 | 32 ms | 33 ms | shortlist: 310\n",
      "80. | its [20 69 74 73 ] | 312756 | 27 ms | 27 ms | shortlist: 266\n",
      "81. | have [20 68 61 76 65 ] | 311246 | 24 ms | 24 ms | shortlist: 136\n",
      "82. |ight [69 67 68 74 ] | 306870 | 22 ms | 24 ms | shortlist: 4095\n",
      "83. |ity [69 74 79 ] | 304932 | 19 ms | 24 ms | shortlist: 22111\n",
      "84. | on [20 6f 6e ] | 302598 | 21 ms | 22 ms | shortlist: 989\n",
      "85. |un [75 6e ] | 302079 | 18 ms | 31 ms | shortlist: 52481\n",
      "86. |ur [75 72 ] | 297800 | 29 ms | 37 ms | shortlist: 23314\n",
      "87. | Organization [20 4f 72 67 61 6e 69 7a 61 74 69 6f 6e ] | 281491 | 25 ms | 26 ms | shortlist: 212\n",
      "88. | we [20 77 65 ] | 275931 | 18 ms | 19 ms | shortlist: 1634\n",
      "89. | I [20 49 ] | 274896 | 19 ms | 23 ms | shortlist: 20489\n",
      "90. |ul [75 6c ] | 274442 | 22 ms | 30 ms | shortlist: 26761\n",
      "91. | all [20 61 6c 6c ] | 273450 | 32 ms | 33 ms | shortlist: 773\n",
      "92. | g [20 67 ] | 272529 | 27 ms | 31 ms | shortlist: 13484\n",
      "93. | l [20 6c ] | 271187 | 21 ms | 23 ms | shortlist: 10742\n",
      "94. |ra [72 61 ] | 271022 | 21 ms | 31 ms | shortlist: 31012\n",
      "95. |se [73 65 ] | 268368 | 25 ms | 29 ms | shortlist: 13221\n",
      "96. |ir [69 72 ] | 263563 | 25 ms | 30 ms | shortlist: 15652\n",
      "97. | should [20 73 68 6f 75 6c 64 ] | 260564 | 23 ms | 23 ms | shortlist: 200\n",
      "98. |ies [69 65 73 ] | 258393 | 19 ms | 24 ms | shortlist: 24933\n",
      "99. | pro [20 70 72 6f ] | 257290 | 23 ms | 26 ms | shortlist: 9018\n",
      "100. | Afric [20 41 66 72 69 63 ] | 246904 | 21 ms | 21 ms | shortlist: 384\n",
      "Total time taken: 3 seconds\n",
      "Trie constructed\n",
      "eos_token <eos> 2\n",
      "unk_token <unk> 1\n",
      "pad_token <pad> 0\n"
     ]
    }
   ],
   "source": [
    "# text iterator can also yield batches of lists of list of str\n",
    "# this is useful if you want more control, i.e. using other strategies to split text\n",
    "# the default behavior embedded is regex.findall using regex pattern string as shown above\n",
    "# for decoding, the regex pattern does not matter\n",
    "\n",
    "GT_Train = GreedTok().train_new_from_iterator(\n",
    "    batch_iterator_split(), \n",
    "    vocab_size = 100,\n",
    "    special_tokens_map={\n",
    "        \"pad_token\":\"<pad>\",\n",
    "        \"unk_token\":\"<unk>\", \n",
    "        \"eos_token\":\"<eos>\"\n",
    "    },\n",
    "    min_word_count=1,\n",
    "    max_token_length=1000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5940d88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "Word counts size: 103181\n",
      "Token set size: 0\n",
      "Empty token set size selected -> all possible substrings with...\n",
      "Max token size: 100\n",
      "Min. word count: 1\n",
      "len:  100\n",
      "Final candidate token set size: 678306\n",
      "Initial setup phase: 1997 ms\n",
      "0. |<pad> [3c 70 61 64 3e ] | 0\n",
      "1. |<unk> [3c 75 6e 6b 3e ] | 0\n",
      "2. |<eos> [3c 65 6f 73 3e ] | 0\n",
      "4. | the [20 74 68 65 ] | 7814940 | 71 ms | 88 ms | shortlist: 2947\n",
      "5. |tion [74 69 6f 6e ] | 3820722 | 34 ms | 54 ms | shortlist: 68487\n",
      "6. | of [20 6f 66 ] | 3114622 | 27 ms | 29 ms | shortlist: 895\n",
      "7. | and [20 61 6e 64 ] | 3075315 | 25 ms | 26 ms | shortlist: 612\n",
      "8. |in [69 6e ] | 2624836 | 20 ms | 80 ms | shortlist: 177067\n",
      "9. |re [72 65 ] | 2250933 | 32 ms | 63 ms | shortlist: 98799\n",
      "10. | t [20 74 ] | 2166471 | 30 ms | 38 ms | shortlist: 23658\n",
      "11. | a [20 61 ] | 2047195 | 28 ms | 39 ms | shortlist: 39319\n",
      "12. |er [65 72 ] | 1800671 | 30 ms | 69 ms | shortlist: 117395\n",
      "13. |en [65 6e ] | 1720810 | 31 ms | 61 ms | shortlist: 95676\n",
      "14. | co [20 63 6f ] | 1681810 | 31 ms | 42 ms | shortlist: 32292\n",
      "15. |it [69 74 ] | 1530652 | 28 ms | 50 ms | shortlist: 61995\n",
      "16. | w [20 77 ] | 1321129 | 31 ms | 36 ms | shortlist: 10508\n",
      "17. |es [65 73 ] | 1287657 | 31 ms | 67 ms | shortlist: 122831\n",
      "18. | s [20 73 ] | 1285953 | 30 ms | 45 ms | shortlist: 52822\n",
      "19. |or [6f 72 ] | 1246898 | 29 ms | 48 ms | shortlist: 54858\n",
      "20. |at [61 74 ] | 1218292 | 42 ms | 70 ms | shortlist: 73386\n",
      "21. |al [61 6c ] | 1208649 | 30 ms | 57 ms | shortlist: 86300\n",
      "22. |is [69 73 ] | 1207139 | 33 ms | 63 ms | shortlist: 84113\n",
      "23. | p [20 70 ] | 1181255 | 32 ms | 46 ms | shortlist: 46314\n",
      "24. |on [6f 6e ] | 1156606 | 30 ms | 54 ms | shortlist: 68982\n",
      "25. |an [61 6e ] | 1120391 | 33 ms | 63 ms | shortlist: 96467\n",
      "26. | in [20 69 6e ] | 1094400 | 32 ms | 44 ms | shortlist: 34118\n",
      "27. |ed [65 64 ] | 1072871 | 28 ms | 57 ms | shortlist: 93979\n",
      "28. | to [20 74 6f ] | 1048026 | 30 ms | 34 ms | shortlist: 2337\n",
      "29. | f [20 66 ] | 924247 | 33 ms | 40 ms | shortlist: 21773\n",
      "30. | be [20 62 65 ] | 901370 | 25 ms | 26 ms | shortlist: 3629\n",
      "31. |ation [61 74 69 6f 6e ] | 897505 | 22 ms | 38 ms | shortlist: 52613\n",
      "32. |ic [69 63 ] | 895399 | 27 ms | 46 ms | shortlist: 61636\n",
      "33. |ou [6f 75 ] | 873642 | 31 ms | 44 ms | shortlist: 42465\n",
      "34. |ar [61 72 ] | 827638 | 30 ms | 51 ms | shortlist: 69541\n",
      "35. |ment [6d 65 6e 74 ] | 824632 | 42 ms | 50 ms | shortlist: 22247\n",
      "36. | that [20 74 68 61 74 ] | 808214 | 28 ms | 29 ms | shortlist: 406\n",
      "37. |ing [69 6e 67 ] | 767679 | 22 ms | 52 ms | shortlist: 111558\n",
      "38. | develop [20 64 65 76 65 6c 6f 70 ] | 737485 | 32 ms | 34 ms | shortlist: 185\n",
      "39. | m [20 6d ] | 722906 | 23 ms | 30 ms | shortlist: 31637\n",
      "40. |le [6c 65 ] | 719256 | 24 ms | 37 ms | shortlist: 47560\n",
      "41. | h [20 68 ] | 689344 | 26 ms | 31 ms | shortlist: 17710\n",
      "42. | re [20 72 65 ] | 654055 | 23 ms | 31 ms | shortlist: 33192\n",
      "43. | United [20 55 6e 69 74 65 64 ] | 638492 | 25 ms | 26 ms | shortlist: 179\n",
      "44. | d [20 64 ] | 626820 | 19 ms | 29 ms | shortlist: 43501\n",
      "45. | countr [20 63 6f 75 6e 74 72 ] | 611196 | 26 ms | 27 ms | shortlist: 396\n",
      "46. | international [20 69 6e 74 65 72 6e 61 74 69 6f 6e 61 6c ] | 605355 | 19 ms | 19 ms | shortlist: 715\n",
      "47. |st [73 74 ] | 603387 | 22 ms | 34 ms | shortlist: 37575\n",
      "48. |ro [72 6f ] | 550903 | 27 ms | 37 ms | shortlist: 35278\n",
      "49. |ce [63 65 ] | 537524 | 26 ms | 33 ms | shortlist: 23499\n",
      "50. |ve [76 65 ] | 531965 | 25 ms | 32 ms | shortlist: 23925\n",
      "51. | n [20 6e ] | 523212 | 25 ms | 28 ms | shortlist: 14181\n",
      "52. | which [20 77 68 69 63 68 ] | 509985 | 24 ms | 24 ms | shortlist: 70\n",
      "53. |ec [65 63 ] | 503635 | 19 ms | 26 ms | shortlist: 22863\n",
      "54. |il [69 6c ] | 478851 | 25 ms | 37 ms | shortlist: 39866\n",
      "55. | c [20 63 ] | 453889 | 26 ms | 33 ms | shortlist: 25115\n",
      "56. | b [20 62 ] | 441286 | 25 ms | 30 ms | shortlist: 21189\n",
      "57. | Assembly [20 41 73 73 65 6d 62 6c 79 ] | 439816 | 24 ms | 25 ms | shortlist: 74\n",
      "58. |th [74 68 ] | 432672 | 19 ms | 25 ms | shortlist: 20972\n",
      "59. |as [61 73 ] | 431958 | 23 ms | 30 ms | shortlist: 28314\n",
      "60. | e [20 65 ] | 428436 | 25 ms | 31 ms | shortlist: 20877\n",
      "61. | The [20 54 68 65 ] | 423648 | 24 ms | 25 ms | shortlist: 703\n",
      "62. | with [20 77 69 74 68 ] | 413206 | 22 ms | 22 ms | shortlist: 428\n",
      "63. | Nations [20 4e 61 74 69 6f 6e 73 ] | 403002 | 23 ms | 24 ms | shortlist: 169\n",
      "64. | con [20 63 6f 6e ] | 385095 | 17 ms | 21 ms | shortlist: 15533\n",
      "65. |ly [6c 79 ] | 384700 | 23 ms | 40 ms | shortlist: 66966\n",
      "66. | for [20 66 6f 72 ] | 366319 | 30 ms | 32 ms | shortlist: 2766\n",
      "67. | is [20 69 73 ] | 365296 | 28 ms | 28 ms | shortlist: 600\n",
      "68. | our [20 6f 75 72 ] | 360590 | 26 ms | 26 ms | shortlist: 105\n",
      "69. | peace [20 70 65 61 63 65 ] | 358947 | 23 ms | 23 ms | shortlist: 560\n",
      "70. |op [6f 70 ] | 347146 | 23 ms | 27 ms | shortlist: 14546\n",
      "71. | th [20 74 68 ] | 342386 | 21 ms | 23 ms | shortlist: 2938\n",
      "72. |im [69 6d ] | 337865 | 20 ms | 27 ms | shortlist: 25043\n",
      "73. | Government [20 47 6f 76 65 72 6e 6d 65 6e 74 ] | 330654 | 24 ms | 24 ms | shortlist: 144\n",
      "74. | world [20 77 6f 72 6c 64 ] | 329490 | 17 ms | 17 ms | shortlist: 189\n",
      "75. |ent [65 6e 74 ] | 327542 | 17 ms | 24 ms | shortlist: 22146\n",
      "76. |si [73 69 ] | 326758 | 22 ms | 29 ms | shortlist: 21004\n",
      "77. |om [6f 6d ] | 322756 | 22 ms | 29 ms | shortlist: 24507\n",
      "78. |ol [6f 6c ] | 320844 | 23 ms | 31 ms | shortlist: 28006\n",
      "79. | States [20 53 74 61 74 65 73 ] | 316016 | 30 ms | 30 ms | shortlist: 310\n",
      "80. | its [20 69 74 73 ] | 312756 | 25 ms | 26 ms | shortlist: 266\n",
      "81. | have [20 68 61 76 65 ] | 311246 | 24 ms | 24 ms | shortlist: 136\n",
      "82. |ight [69 67 68 74 ] | 306870 | 22 ms | 24 ms | shortlist: 4095\n",
      "83. |ity [69 74 79 ] | 304932 | 20 ms | 26 ms | shortlist: 22111\n",
      "84. | on [20 6f 6e ] | 302598 | 22 ms | 22 ms | shortlist: 989\n",
      "85. |un [75 6e ] | 302079 | 18 ms | 31 ms | shortlist: 52481\n",
      "86. |ur [75 72 ] | 297800 | 27 ms | 35 ms | shortlist: 23314\n",
      "87. | Organization [20 4f 72 67 61 6e 69 7a 61 74 69 6f 6e ] | 281491 | 24 ms | 24 ms | shortlist: 212\n",
      "88. | we [20 77 65 ] | 275931 | 17 ms | 18 ms | shortlist: 1634\n",
      "89. | I [20 49 ] | 274896 | 17 ms | 22 ms | shortlist: 20489\n",
      "90. |ul [75 6c ] | 274442 | 21 ms | 29 ms | shortlist: 26766\n",
      "91. | all [20 61 6c 6c ] | 273450 | 31 ms | 32 ms | shortlist: 773\n",
      "92. | g [20 67 ] | 272529 | 26 ms | 30 ms | shortlist: 13484\n",
      "93. | l [20 6c ] | 271187 | 22 ms | 25 ms | shortlist: 10742\n",
      "94. |ra [72 61 ] | 271022 | 21 ms | 31 ms | shortlist: 31012\n",
      "95. |se [73 65 ] | 268369 | 25 ms | 29 ms | shortlist: 13221\n",
      "96. |ir [69 72 ] | 263563 | 24 ms | 30 ms | shortlist: 15652\n",
      "97. | should [20 73 68 6f 75 6c 64 ] | 260564 | 21 ms | 22 ms | shortlist: 200\n",
      "98. |ies [69 65 73 ] | 258393 | 17 ms | 23 ms | shortlist: 24933\n",
      "99. | pro [20 70 72 6f ] | 257290 | 22 ms | 25 ms | shortlist: 9018\n",
      "100. | Afric [20 41 66 72 69 63 ] | 246904 | 21 ms | 22 ms | shortlist: 384\n",
      "Total time taken: 3 seconds\n",
      "Trie constructed\n",
      "eos_token <eos> 2\n",
      "unk_token <unk> 1\n",
      "pad_token <pad> 0\n"
     ]
    }
   ],
   "source": [
    "# we can change the regex pattern, we can pass the desired into 'pattern'\n",
    "# we can also set the no. of workers to speed up splitting (default is 8)\n",
    "\n",
    "GT_Train = GreedTok().train_new_from_iterator(\n",
    "    batch_iterator(), \n",
    "    vocab_size = 100,\n",
    "    special_tokens_map={\n",
    "        \"pad_token\":\"<pad>\",\n",
    "        \"unk_token\":\"<unk>\", \n",
    "        \"eos_token\":\"<eos>\"\n",
    "    },\n",
    "    min_word_count=1,\n",
    "    max_token_length=1000,\n",
    "    pattern = r\"\"\" ?[\\p{L}]+| ?[\\p{N}]+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\",\n",
    "    workers = 10\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6691bc3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tokenizer config file saved in pcatt/hf/examples/greedtok_test2/tokenizer_config.json\n",
      "special_tokens_map file saved in pcatt/hf/examples/greedtok_test2/special_tokens_map.json\n",
      "added tokens file saved in pcatt/hf/examples/greedtok_test2/added_tokens.txt\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('pcatt/hf/examples/greedtok_test2/tokenizer_config.json',\n",
       " 'pcatt/hf/examples/greedtok_test2/special_tokens_map.json',\n",
       " 'pcatt/hf/examples/greedtok_test2/added_tokens.txt')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# to save\n",
    "\n",
    "GT_Train.save_pretrained('pcatt/hf/examples/greedtok_test2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ba963b06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "regex.Regex(' ?[\\\\p{L}]+| ?[\\\\p{N}]+| ?[^\\\\s\\\\p{L}\\\\p{N}]+|\\\\s+(?!\\\\S)|\\\\s+', flags=regex.V0)\n",
      "Trie constructed\n",
      "unk_token <unk> 1\n",
      "pad_token <pad> 0\n",
      "eos_token <eos> 2\n"
     ]
    }
   ],
   "source": [
    "# loading pretrained\n",
    "\n",
    "from pcatt.hf.greedtok import GreedTok\n",
    "GT_Train = GreedTok.from_pretrained(\"pcatt/hf/examples/greedtok_test2\")\n",
    "print(GT_Train.pat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6ce9a59a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = [x for x in next(batch_iterator())]\n",
    "test_split = [x for x in next(batch_iterator_split())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9bce1d59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[176, 201, 216, 38, 201, 63, 203, 214, 19, 89, 19, 201, 132, 221, 32, 146, 132, 183, 95, 144, 6, 3, 132, 182, 201, 212, 217, 198, 208, 31, 132, 110, 211, 202, 132, 166, 89, 203, 33, 205, 197, 132, 217, 212, 23, 132, 221, 32, 214, 59, 39, 199, 4, 10, 215, 132, 180, 8, 75, 200, 74, 5, 3, 132, 171, 12, 11, 20, 56, 10, 216, 79, 132, 110, 202, 18, 216, 221, 145, 94, 218, 12, 57, 17, 16, 75, 23, 146, 132, 189, 32, 214, 59, 220, 212, 11, 205, 12, 48, 10]\n",
      "Let me congratulate you. Sir, and the Republic \n",
      "of Bulgaria upon your election as President of the General Assembly at its \n",
      "forty-seventh session. Your experience as a respected political leader and \n",
      "\n"
     ]
    }
   ],
   "source": [
    "test_encode = GT_Train(test_split[:50], test_split[50:100], is_split_into_words=True)\n",
    "print(test_encode['input_ids'][1][:100])\n",
    "test_decode = GT_Train.batch_decode(test_encode['input_ids'])\n",
    "print(test_decode[1][:200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3040e5c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[173, 216, 91, 205, 218, 16, 38, 201, 22, 39, 58, 217, 8, 10, 216, 3, 132, 32, 216, 94, 216, 27, 132, 206, 211, 7, 3, 17, 212, 201, 197, 207, 11, 215, 15, 204, 211, 22, 8, 199, 26, 26, 132, 110, 209, 201, 25, 63, 203, 214, 19, 89, 19, 36, 132, 221, 32, 146, 132, 183, 95, 144, 83, 132, 221, 32, 214, 59, 39, 199, 4, 27, 3, 22, 8, 75, 200, 12, 199, 221, 5, 3, 132, 110, 171, 12, 11, 20, 56, 10, 216, 79, 65, 216, 221, 145, 94, 218, 12, 57]\n",
      "It gives me pleasure at the outset to join the speakers who preceded \n",
      "me in congratulating you. Sir, on your election to the presidency of the \n",
      "General Assembly at its forty-seventh session. My delega\n"
     ]
    }
   ],
   "source": [
    "test_encode = GT_Train(test[:10], test[10:20], is_split_into_words=False)\n",
    "print(test_encode['input_ids'][0][:100])\n",
    "test_decode = GT_Train.batch_decode(test_encode['input_ids'])\n",
    "print(test_decode[0][:200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ced8d047",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[173, 216, 91, 205, 218, 16, 38, 201, 22, 39, 58, 217, 8, 10, 216, 3, 132, 32, 216, 94, 216, 27, 132, 206, 211, 7, 3, 17, 212, 201, 197, 207, 11, 215, 15, 204, 211, 22, 8, 199, 26, 26, 132, 110, 209, 201, 25, 63, 203, 214, 19, 89, 19, 36, 132, 221, 32, 146, 132, 183, 95, 144, 83, 132, 221, 32, 214, 59, 39, 199, 4, 27, 3, 22, 8, 75, 200, 12, 199, 221, 5, 3, 132, 110, 171, 12, 11, 20, 56, 10, 216, 79, 65, 216, 221, 145, 94, 218, 12, 57]\n",
      "It gives me pleasure at the outset to join the speakers who preceded \n",
      "me in congratulating you. Sir, on your election to the presidency of the \n",
      "General Assembly at its forty-seventh session. My delega\n"
     ]
    }
   ],
   "source": [
    "test_encode = GT_Train(test, is_split_into_words=False)\n",
    "print(test_encode['input_ids'][0][:100])\n",
    "test_decode = GT_Train.batch_decode(test_encode['input_ids'])\n",
    "print(test_decode[0][:200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "60cfc38c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[173, 216, 91, 205, 218, 16, 38, 201, 22, 39, 58, 217, 8, 10, 216, 3, 132, 32, 216, 94, 216, 27, 132, 206, 211, 7, 3, 17, 212, 201, 197, 207, 11, 215, 15, 204, 211, 22, 8, 199, 26, 26, 132, 110, 209, 201, 25, 63, 203, 214, 19, 89, 19, 36, 132, 221, 32, 146, 132, 183, 95, 144, 83, 132, 221, 32, 214, 59, 39, 199, 4, 27, 3, 22, 8, 75, 200, 12, 199, 221, 5, 3, 132, 110, 171, 12, 11, 20, 56, 10, 216, 79, 65, 216, 221, 145, 94, 218, 12, 57]\n",
      "It gives me pleasure at the outset to join the speakers who preceded \n",
      "me in congratulating you. Sir, on your election to the presidency of the \n",
      "General Assembly at its forty-seventh session. My delega\n",
      "It gives me pleasure at the outset to join the speakers who preceded \n",
      "me in congratulating you. Sir, on your election to the presidency of the \n",
      "General Assembly at its forty-seventh session. My delega\n"
     ]
    }
   ],
   "source": [
    "test_encode = GT_Train(test_split, is_split_into_words=True)\n",
    "print(test_encode['input_ids'][0][:100])\n",
    "test_decode = GT_Train.batch_decode(test_encode['input_ids'])\n",
    "print(test_decode[0][:200])\n",
    "print(GT_Train.decode(test_encode['input_ids'][0])[:200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3ac68d82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "eos_token <eos> 8\n",
      "pad_token <pad> 7\n",
      "tokenizer config file saved in pcatt/hf/examples/greedtok_test1/tokenizer_config.json\n",
      "special_tokens_map file saved in pcatt/hf/examples/greedtok_test1/special_tokens_map.json\n",
      "added tokens file saved in pcatt/hf/examples/greedtok_test1/added_tokens.txt\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('pcatt/hf/examples/greedtok_test1/tokenizer_config.json',\n",
       " 'pcatt/hf/examples/greedtok_test1/special_tokens_map.json',\n",
       " 'pcatt/hf/examples/greedtok_test1/added_tokens.txt')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pcatt.hf.greedtok import GreedTok\n",
    "GT = GreedTok(ranked_tokens = ['aa', 'bb', 'abc', 'bc', '12', '123', '34', \"<pad>\", \"<eos>\"],\n",
    "         special_tokens_map = {\"pad_token\":\"<pad>\", \"eos_token\":\"<eos>\"})\n",
    "GT.save_pretrained(\"pcatt/hf/examples/greedtok_test1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bb22886b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "pad_token <pad> 7\n",
      "eos_token <eos> 8\n"
     ]
    }
   ],
   "source": [
    "from pcatt.hf.greedtok import GreedTok\n",
    "GT2 = GreedTok.from_pretrained(\"pcatt/hf/examples/greedtok_test1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3cae3a09",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['aabb', 'bbabc', 'bc12<pad>']\n",
      "['bc12', '\\\\xbfbb\\\\xd5']\n"
     ]
    }
   ],
   "source": [
    "#basic decoding\n",
    "print(GT2.batch_decode([[0,1], [1,2], [3,4,7]]))\n",
    "print(GT2.batch_decode([[3,4,7,8], [200,1,222]], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e1fc7acc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[0, 3], [0, 5, 61], [106, 107, 127]]}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# testing __call__\n",
    "GT2([\"aabc\", \"aa1234\", \"abv\"], is_split_into_words=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dbe54ea5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[0, 3], [0, 5, 61], [106, 107, 127]]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# testing __call__ presplit\n",
    "GT2([[\"aa\",\"bc\"], [\"aa\", \"123\", \"4\"], [\"ab\",\"v\"]], is_split_into_words=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c5119f01",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "special_tokens_mask\n",
      "\t 2 :  [0, 0]\n",
      "\t 5 :  [1, 0, 0, 0, 1]\n",
      "\t 4 :  [0, 0, 0, 1]\n",
      "input_ids\n",
      "\t 2 :  [0, 3]\n",
      "\t 5 :  [7, 0, 5, 61, 8]\n",
      "\t 4 :  [106, 107, 127, 7]\n"
     ]
    }
   ],
   "source": [
    "# testing __call__ no padding and no truncation\n",
    "outputs = GT2([\"aabc\", \"<pad>aa1234<eos>\", \"abv<pad>\"], \n",
    "    is_split_into_words=False, \n",
    "    padding=False,\n",
    "    return_attention_mask=True,\n",
    "    return_special_tokens_mask=True,\n",
    "    max_length = 10)\n",
    "for k,v in outputs.items():\n",
    "    print(k)\n",
    "    for o in v:\n",
    "        print('\\t', len(o), ': ', o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c5c83dd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "special_tokens_mask\n",
      "\t 10 :  [0, 0, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "\t 10 :  [0, 0, 0, 1, 1, 1, 1, 1, 1, 1]\n",
      "\t 10 :  [0, 0, 0, 1, 1, 1, 1, 1, 1, 1]\n",
      "\t 10 :  [0, 0, 0, 1, 0, 0, 0, 1, 0, 0]\n",
      "overflowing_tokens\n",
      "\t 0 :  []\n",
      "\t 0 :  []\n",
      "\t 0 :  []\n",
      "\t 5 :  [127, 7, 0, 5, 61]\n",
      "attention_mask\n",
      "\t 10 :  [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\t 10 :  [1, 1, 1, 0, 0, 0, 0, 0, 0, 0]\n",
      "\t 10 :  [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]\n",
      "\t 10 :  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "input_ids\n",
      "\t 10 :  [0, 3, 7, 7, 7, 7, 7, 7, 7, 7]\n",
      "\t 10 :  [0, 5, 61, 7, 7, 7, 7, 7, 7, 7]\n",
      "\t 10 :  [106, 107, 127, 7, 7, 7, 7, 7, 7, 7]\n",
      "\t 10 :  [106, 107, 127, 7, 106, 107, 127, 7, 106, 107]\n"
     ]
    }
   ],
   "source": [
    "# testing __call__ with padding and truncation\n",
    "outputs = GT2([\"aabc\", \n",
    "               \"aa1234\",\n",
    "               \"abv<pad>\",\n",
    "               \"abv<pad>abv<pad>abv<pad>aa1234\"], \n",
    "    is_split_into_words=False, \n",
    "    padding=\"max_length\",\n",
    "    truncation = \"longest_first\",\n",
    "    return_overflowing_tokens=True,\n",
    "    return_attention_mask=True,\n",
    "    return_special_tokens_mask=True,\n",
    "    max_length = 10)\n",
    "for k,v in outputs.items():\n",
    "    print(k)\n",
    "    for o in v:\n",
    "        print('\\t', len(o), ': ', o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "feb1659f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids\n",
      "\t 7 :  [0, 3, 41, 106, 107, 127, 7]\n",
      "\t 19 :  [0, 5, 61, 41, 106, 107, 127, 7, 106, 107, 127, 7, 106, 107, 127, 7, 0, 5, 61]\n"
     ]
    }
   ],
   "source": [
    "# testing pairs\n",
    "outputs = GT2([\"aabc\", \"aa1234\"],\n",
    "               [\"abv<pad>\", \"abv<pad>abv<pad>abv<pad>aa1234\"])\n",
    "for k,v in outputs.items():\n",
    "    print(k)\n",
    "    for o in v:\n",
    "        print('\\t', len(o), ': ', o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c4af80ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids\n",
      "\t 7 :  [0, 3, 41, 106, 107, 127, 7]\n",
      "\t 19 :  [0, 5, 61, 41, 106, 107, 127, 7, 106, 107, 127, 7, 106, 107, 127, 7, 0, 5, 61]\n"
     ]
    }
   ],
   "source": [
    "# testing pairs with presplit words\n",
    "outputs = GT2([[\"aa\",\"bc\"], [\"aa\",\"1234\"]],\n",
    "               [[\"abv\", \"<pad>\"], [\"abv<pad>abv<pad>\",\"abv<pad>\",\"aa1234\"]],\n",
    "             is_split_into_words=True)\n",
    "for k,v in outputs.items():\n",
    "    print(k)\n",
    "    for o in v:\n",
    "        print('\\t', len(o), ': ', o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a36aebe6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "special_tokens_mask\n",
      "\t 10 :  [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]\n",
      "\t 11 :  [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]\n",
      "token_type_ids\n",
      "\t 10 :  [0, 0, 1, 1, 1, 1, 1, 7, 7, 7]\n",
      "\t 11 :  [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "attention_mask\n",
      "\t 10 :  [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]\n",
      "\t 11 :  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "input_ids\n",
      "\t 10 :  [0, 3, 41, 106, 107, 127, 7, 7, 7, 7]\n",
      "\t 11 :  [0, 5, 61, 41, 106, 107, 127, 7, 106, 107, 127]\n"
     ]
    }
   ],
   "source": [
    "outputs = GT2([[\"aa\",\"bc\"], [\"aa\",\"1234\"]],\n",
    "               [[\"abv\", \"<pad>\"], [\"abv<pad>abv<pad>\",\"abv<pad>\",\"aa1234\"]],\n",
    "             is_split_into_words=True, \n",
    "    padding=\"max_length\",\n",
    "    truncation = \"only_second\",\n",
    "    return_token_type_ids=True,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_attention_mask=True,\n",
    "    return_special_tokens_mask=True,\n",
    "    max_length = 10)\n",
    "for k,v in outputs.items():\n",
    "    print(k)\n",
    "    for o in v:\n",
    "        print('\\t', len(o), ': ', o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f97ea7fc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "Word counts size: 105505\n",
      "Token set size: 0\n",
      "Empty token set size selected -> all possible substrings with...\n",
      "Max token size: 5\n",
      "Min. word count: 1\n",
      "Final candidate token set size: 81136\n",
      "Initial setup phase: 5118 ms\n",
      "1. |Ġ [c4 a0 ] | 30035114 | 12 ms | 137 ms | shortlist: 75764\n",
      "2. |Ġth [c4 a0 74 68 ] | 7109102 | 11 ms | 14 ms | shortlist: 1864\n",
      "3. |tion [74 69 6f 6e ] | 4043268 | 5 ms | 14 ms | shortlist: 7700\n",
      "4. |Ġof [c4 a0 6f 66 ] | 3300812 | 6 ms | 7 ms | shortlist: 371\n",
      "5. |Ġa [c4 a0 61 ] | 3259093 | 3 ms | 8 ms | shortlist: 7092\n",
      "6. |in [69 6e ] | 2782359 | 6 ms | 34 ms | shortlist: 21307\n",
      "7. |re [72 65 ] | 2384688 | 7 ms | 26 ms | shortlist: 13589\n",
      "8. |Ġto [c4 a0 74 6f ] | 2228162 | 7 ms | 8 ms | shortlist: 1091\n",
      "9. |er [65 72 ] | 1910725 | 4 ms | 23 ms | shortlist: 16660\n",
      "10. |en [65 6e ] | 1831877 | 6 ms | 23 ms | shortlist: 13572\n",
      "11. |Ġco [c4 a0 63 6f ] | 1782132 | 6 ms | 10 ms | shortlist: 4574\n",
      "12. |it [69 74 ] | 1622191 | 5 ms | 15 ms | shortlist: 9596\n",
      "13. |nd [6e 64 ] | 1484027 |len:  100\n",
      " 5 ms | 10 ms | shortlist: 7050\n",
      "14. |Ġw [c4 a0 77 ] | 1404713 | 5 ms | 7 ms | shortlist: 3251\n",
      "15. |es [65 73 ] | 1403373 | 4 ms | 22 ms | shortlist: 18006\n",
      "16. |Ġs [c4 a0 73 ] | 1363207 | 6 ms | 14 ms | shortlist: 9055\n",
      "17. |or [6f 72 ] | 1319039 | 5 ms | 15 ms | shortlist: 9982\n",
      "18. |at [61 74 ] | 1291572 | 5 ms | 19 ms | shortlist: 11122\n",
      "19. |is [69 73 ] | 1281159 | 5 ms | 18 ms | shortlist: 11905\n",
      "20. |al [61 6c ] | 1279005 | 5 ms | 21 ms | shortlist: 12506\n",
      "21. |Ġp [c4 a0 70 ] | 1251643 | 5 ms | 11 ms | shortlist: 7315\n",
      "22. |on [6f 6e ] | 1189311 | 5 ms | 16 ms | shortlist: 11432\n",
      "23. |Ġin [c4 a0 69 6e ] | 1158876 | 6 ms | 9 ms | shortlist: 4387\n",
      "24. |ed [65 64 ] | 1137380 | 4 ms | 20 ms | shortlist: 15151\n",
      "25. |an [61 6e ] | 1029701 | 5 ms | 19 ms | shortlist: 15569\n",
      "26. |Ġf [c4 a0 66 ] | 980142 | 5 ms | 8 ms | shortlist: 4835\n",
      "27. |Ġbe [c4 a0 62 65 ] | 957046 | 4 ms | 5 ms | shortlist: 1181\n",
      "28. |ic [69 63 ] | 949111 | 3 ms | 13 ms | shortlist: 9835\n",
      "29. |ou [6f 75 ] | 925779 | 5 ms | 11 ms | shortlist: 8741\n",
      "30. |ar [61 72 ] | 877629 | 4 ms | 16 ms | shortlist: 12598\n",
      "31. |ment [6d 65 6e 74 ] | 871122 | 5 ms | 8 ms | shortlist: 3752\n",
      "32. |ing [69 6e 67 ] | 814643 | 4 ms | 21 ms | shortlist: 15732\n",
      "33. |Ġd [c4 a0 64 ] | 776407 | 5 ms | 10 ms | shortlist: 6508\n",
      "34. |Ġm [c4 a0 6d ] | 768147 | 4 ms | 8 ms | shortlist: 5889\n",
      "35. |le [6c 65 ] | 761238 | 4 ms | 11 ms | shortlist: 9339\n",
      "36. |Ġha [c4 a0 68 61 ] | 743422 | 4 ms | 5 ms | shortlist: 976\n",
      "37. |Ġre [c4 a0 72 65 ] | 691607 | 3 ms | 7 ms | shortlist: 4736\n",
      "38. |ve [76 65 ] | 674520 | 4 ms | 7 ms | shortlist: 5102\n",
      "39. |st [73 74 ] | 639635 | 4 ms | 10 ms | shortlist: 6894\n",
      "40. |ro [72 6f ] | 583331 | 4 ms | 10 ms | shortlist: 6713\n",
      "41. |ce [63 65 ] | 568796 | 4 ms | 7 ms | shortlist: 4967\n",
      "42. |Ġn [c4 a0 6e ] | 555826 | 3 ms | 5 ms | shortlist: 3426\n",
      "43. |Ġe [c4 a0 65 ] | 533333 | 3 ms | 6 ms | shortlist: 4089\n",
      "44. |il [69 6c ] | 507782 | 3 ms | 10 ms | shortlist: 7720\n",
      "45. |untr [75 6e 74 72 ] | 485733 | 4 ms | 4 ms | shortlist: 317\n",
      "46. |Ġc [c4 a0 63 ] | 483397 | 2 ms | 6 ms | shortlist: 5282\n",
      "47. |op [6f 70 ] | 478088 | 3 ms | 5 ms | shortlist: 3417\n",
      "48. |Ġb [c4 a0 62 ] | 468378 | 3 ms | 6 ms | shortlist: 5066\n",
      "49. |ly [6c 79 ] | 466512 | 3 ms | 11 ms | shortlist: 10631\n",
      "50. |he [68 65 ] | 460877 | 4 ms | 7 ms | shortlist: 5354\n",
      "51. |ec [65 63 ] | 452454 | 4 ms | 7 ms | shortlist: 3937\n",
      "52. |ter [74 65 72 ] | 414456 | 3 ms | 7 ms | shortlist: 4386\n",
      "53. |Ġis [c4 a0 69 73 ] | 387692 | 3 ms | 3 ms | shortlist: 246\n",
      "54. |ĠUn [c4 a0 55 6e ] | 385922 | 2 ms | 2 ms | shortlist: 1508\n",
      "55. |th [74 68 ] | 377521 | 3 ms | 5 ms | shortlist: 4361\n",
      "56. |hich [68 69 63 68 ] | 374248 | 3 ms | 3 ms | shortlist: 172\n",
      "57. |si [73 69 ] | 361837 | 2 ms | 6 ms | shortlist: 4372\n",
      "58. |im [69 6d ] | 357124 | 3 ms | 7 ms | shortlist: 5072\n",
      "59. |se [73 65 ] | 346315 | 3 ms | 6 ms | shortlist: 4451\n",
      "60. |ent [65 6e 74 ] | 343740 | 3 ms | 7 ms | shortlist: 4028\n",
      "61. |om [6f 6d ] | 341222 | 3 ms | 7 ms | shortlist: 5196\n",
      "62. |Ġt [c4 a0 74 ] | 340830 | 3 ms | 6 ms | shortlist: 4239\n",
      "63. |ol [6f 6c ] | 340056 | 3 ms | 8 ms | shortlist: 5879\n",
      "64. |ra [72 61 ] | 324514 | 3 ms | 10 ms | shortlist: 7917\n",
      "65. |ĠNa [c4 a0 4e 61 ] | 322442 | 4 ms | 4 ms | shortlist: 1168\n",
      "66. |ity [69 74 79 ] | 322398 | 2 ms | 5 ms | shortlist: 4392\n",
      "67. |ĠS [c4 a0 53 ] | 321981 | 3 ms | 8 ms | shortlist: 9883\n",
      "68. |Ġon [c4 a0 6f 6e ] | 321135 | 3 ms | 4 ms | shortlist: 407\n",
      "69. |ight [69 67 68 74 ] | 320166 | 2 ms | 3 ms | shortlist: 1180\n",
      "70. |ĠA [c4 a0 41 ] | 318611 | 2 ms | 6 ms | shortlist: 8515\n",
      "71. |ld [6c 64 ] | 311836 | 3 ms | 4 ms | shortlist: 1164\n",
      "72. |Ġit [c4 a0 69 74 ] | 311520 | 2 ms | 2 ms | shortlist: 206\n",
      "73. |Ġde [c4 a0 64 65 ] | 310397 | 2 ms | 4 ms | shortlist: 2520\n",
      "74. |ur [75 72 ] | 308485 | 3 ms | 6 ms | shortlist: 5121\n",
      "75. |ver [76 65 72 ] | 298250 | 3 ms | 6 ms | shortlist: 3451\n",
      "76. |Ġwe [c4 a0 77 65 ] | 293443 | 3 ms | 3 ms | shortlist: 615\n",
      "77. |ĠI [c4 a0 49 ] | 290833 | 2 ms | 4 ms | shortlist: 5050\n",
      "78. |Ġg [c4 a0 67 ] | 289866 | 3 ms | 5 ms | shortlist: 3425\n",
      "79. |Ġl [c4 a0 6c ] | 288456 | 3 ms | 5 ms | shortlist: 3087\n",
      "80. |our [6f 75 72 ] | 288054 | 3 ms | 4 ms | shortlist: 1856\n",
      "81. |Ġu [c4 a0 75 ] | 281139 | 3 ms | 6 ms | shortlist: 5928\n",
      "82. |eace [65 61 63 65 ] | 279796 | 3 ms | 4 ms | shortlist: 257\n",
      "83. |Ġh [c4 a0 68 ] | 278976 | 2 ms | 4 ms | shortlist: 3240\n",
      "84. |ĠT [c4 a0 54 ] | 275470 | 3 ms | 5 ms | shortlist: 5914\n",
      "85. |ies [69 65 73 ] | 273424 | 3 ms | 6 ms | shortlist: 5329\n",
      "86. |ul [75 6c ] | 270729 | 3 ms | 8 ms | shortlist: 5331\n",
      "87. |ir [69 72 ] | 270547 | 3 ms | 6 ms | shortlist: 4127\n",
      "88. |Ġal [c4 a0 61 6c ] | 259325 | 3 ms | 4 ms | shortlist: 983\n",
      "89. |as [61 73 ] | 257046 | 2 ms | 6 ms | shortlist: 6726\n",
      "90. |port [70 6f 72 74 ] | 256788 | 3 ms | 4 ms | shortlist: 693\n",
      "91. |ith [69 74 68 ] | 254325 | 2 ms | 3 ms | shortlist: 951\n",
      "92. |ot [6f 74 ] | 238955 | 2 ms | 4 ms | shortlist: 3737\n",
      "93. |res [72 65 73 ] | 234929 | 3 ms | 5 ms | shortlist: 3613\n",
      "94. |de [64 65 ] | 233884 | 3 ms | 6 ms | shortlist: 4948\n",
      "95. |Ġas [c4 a0 61 73 ] | 232510 | 3 ms | 4 ms | shortlist: 774\n",
      "96. |ĠC [c4 a0 43 ] | 224844 | 2 ms | 6 ms | shortlist: 7931\n",
      "97. |ate [61 74 65 ] | 217345 | 3 ms | 6 ms | shortlist: 4424\n",
      "98. |ow [6f 77 ] | 215884 | 3 ms | 5 ms | shortlist: 2459\n",
      "99. |ac [61 63 ] | 215372 | 3 ms | 6 ms | shortlist: 5342\n",
      "100. |Ġst [c4 a0 73 74 ] | 214953 | 3 ms | 5 ms | shortlist: 2378\n",
      "Total time taken: 1 seconds\n",
      "Trie constructed\n"
     ]
    }
   ],
   "source": [
    "from pcatt.pco_tokenizer import build as build_pco\n",
    "words = [t for t in open('cpp_inputs/words/un.txt').read().strip().split(\" \")] \n",
    "counts = [int(t.strip()) for t in open('cpp_inputs/counts/un.txt').read().strip().split('\\n')]\n",
    "un_counts = {a:b for a,b in zip(words, counts)}\n",
    "\n",
    "# we can use train_new_from_counts instead to get the same result as:\n",
    "'''\n",
    "test = build_pco(un_counts)\n",
    "test.initialize_graph(5, 1)\n",
    "test_tokens, test_scores = test.solve_to_step(100)\n",
    "'''\n",
    "\n",
    "from pcatt.hf.greedtok import GreedTok\n",
    "\n",
    "greedtok = GreedTok().train_new_from_counts(un_counts, 100, max_token_length=5, min_word_count=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "7517b0a7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original: The United Nations Organization for peace.\n",
      "Tokens:   [184, 204, 201, 42, 62, 86, 65, 68, 146, 1, 1, 1, 5, 5, 5]\n",
      "Readable: [b'T', b'h', b'e', b' United', b' Nations', b' Organization', b' for', b' peace', b'.', b' of', b' of', b' of']\n",
      "EncDec:   The United Nations Organization for peace.<unk><unk><unk> of of of\n",
      "Trie constructed\n",
      "unk_token <unk> 1\n",
      "pad_token <pad> 0\n",
      "eos_token <eos> 2\n"
     ]
    }
   ],
   "source": [
    "# to use in existing codebases simply import and load from AutoTokenizer\n",
    "import pcatt.hf\n",
    "from transformers import AutoTokenizer\n",
    "tokenize = AutoTokenizer.from_pretrained(\"pcatt/hf/examples/greedtok_test2\")\n",
    "\n",
    "# we can also pass callbacks to modify the final encoding\n",
    "original_str = \"The United Nations Organization for peace.\"\n",
    "callback = lambda x1: [*x1, 1, 1, 1, 5, 5, 5]\n",
    "idxs = tokenize.encode(original_str, callback=callback)\n",
    "print(\"Original:\", original_str)\n",
    "print(\"Tokens:  \", idxs)\n",
    "print(\"Readable:\", [\n",
    "    tokenize.final_ids_map[x]\n",
    "    for x in idxs\n",
    "    if x not in tokenize.special_token_ids\n",
    "])\n",
    "print(\"EncDec:  \", tokenize.decode(idxs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a763988c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original: The United Nations Organization for peace. The world is developing.\n",
      "Tokens:   [184, 204, 201, 42, 62, 86, 65, 68, 146, 132, 132, 132, 132, 132, 184, 204, 201, 73, 66, 37, 36, 146]\n",
      "Readable: [b'T', b'h', b'e', b' United', b' Nations', b' Organization', b' for', b' peace', b'.', b' ', b' ', b' ', b' ', b' ', b'T', b'h', b'e', b' world', b' is', b' develop', b'ing', b'.']\n",
      "EncDec:   The United Nations Organization for peace.     The world is developing.\n"
     ]
    }
   ],
   "source": [
    "original_str = \"The United Nations Organization for peace.\"\n",
    "original_str2 = \"The world is developing.\"\n",
    "def callback(x1, x2):\n",
    "    return x1 + [tokenize.final_tokens_map[b\" \"]]*5 + x2, [0]*len(x1) + [tokenize.final_tokens_map[b\" \"]]*5 + [1]*len(x2)\n",
    "idxs = tokenize.encode(original_str, original_str2, callback=callback)\n",
    "print(\"Original:\", original_str + \" \" + original_str2)\n",
    "print(\"Tokens:  \", idxs)\n",
    "print(\"Readable:\", [\n",
    "    tokenize.final_ids_map[x]\n",
    "    for x in idxs\n",
    "    if x not in tokenize.special_token_ids\n",
    "])\n",
    "print(\"EncDec:  \", tokenize.decode(idxs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "fb63a295",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trie constructed\n",
      "unk_token <unk> 1\n",
      "pad_token <pad> 0\n",
      "eos_token <eos> 2\n"
     ]
    }
   ],
   "source": [
    "import pcatt.hf\n",
    "from pprint import pprint\n",
    "from transformers import AutoTokenizer\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "tokenize = AutoTokenizer.from_pretrained(\n",
    "    \"pcatt/hf/examples/greedtok_test2\",\n",
    "    model_max_length = 12,\n",
    "    max_length = 15,\n",
    "    padding_side = \"right\",\n",
    "    truncation = \"longest_first\",\n",
    "    return_attention_mask = True,\n",
    "    return_special_tokens_mask = True,\n",
    "    padding=\"max_length\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b4f952b0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
      "                    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]],\n",
      " 'input_ids': [[184, 204, 201, 42, 62, 86, 65, 68, 146, 0, 0, 0],\n",
      "               [184, 204, 201, 73, 1, 66, 37, 36, 146, 0, 0, 0]],\n",
      " 'special_tokens_mask': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],\n",
      "                         [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1]]}\n"
     ]
    }
   ],
   "source": [
    "test = tokenize([ \"The United Nations Organization for peace.\", \"The world<unk> is developing.\"], padding=\"max_length\", return_special_tokens_mask = True,)\n",
    "pprint(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5d2223a7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "tokenize2 = AutoTokenizer.from_pretrained(\n",
    "    \"gpt2\",\n",
    "    model_max_length = 10,\n",
    "    max_length = 10,\n",
    "    padding_side = \"right\",\n",
    "    truncation = \"longest_first\",\n",
    "    return_attention_mask = True,\n",
    "    return_special_tokens_mask = True,\n",
    "    padding=\"max_length\")\n",
    "tokenize2.pad_token = tokenize2.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e953edc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
      "                    [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]],\n",
      " 'input_ids': [[464, 1578, 7973, 12275, 329, 4167, 13, 50256, 50256, 50256],\n",
      "               [464, 995, 27, 2954, 29, 318, 5922, 13, 50256, 50256]],\n",
      " 'special_tokens_mask': [[0, 0, 0, 0, 0, 0, 0, 1, 1, 1],\n",
      "                         [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]}\n"
     ]
    }
   ],
   "source": [
    "test = tokenize2([ \"The United Nations Organization for peace.\", \"The world<unk> is developing.\"], padding=\"max_length\", return_special_tokens_mask = True,)\n",
    "pprint(test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tbb]",
   "language": "python",
   "name": "conda-env-tbb-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
