{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b3e56598-2f23-41ca-9e9d-70b28eccd95c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bc865b9c-f5f4-47d4-8e26-b1ac554e0bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs = []\n",
    "for i in range(19):\n",
    "    dfs.append(pd.read_parquet(f\"path/to/new_tokens/{i}.parquet\"))\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "64812801-ef0d-431e-9618-f1049bd75eb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"seq_len\"] = df[\"tokens\"].apply(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "343c50e3-d3bc-4f39-8d8f-2ca8735b61ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>uid</th>\n",
       "      <th>article_id</th>\n",
       "      <th>lang_id</th>\n",
       "      <th>tokens</th>\n",
       "      <th>seq_len</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>41726489</td>\n",
       "      <td>116715</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 54, 408, 2176, 12336, 320, 16381, 755...</td>\n",
       "      <td>111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>469</th>\n",
       "      <td>44292755</td>\n",
       "      <td>66915130</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 98982, 1697, 28845, 369, 81745, 42509...</td>\n",
       "      <td>563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>470</th>\n",
       "      <td>44292756</td>\n",
       "      <td>66915132</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 67291, 360, 2259, 9235, 300, 11, 2663...</td>\n",
       "      <td>69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>471</th>\n",
       "      <td>44292757</td>\n",
       "      <td>66915138</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 43346, 26888, 374, 264, 2363, 555, 27...</td>\n",
       "      <td>646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>472</th>\n",
       "      <td>44292767</td>\n",
       "      <td>53839293</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 5953, 67, 360, 90251, 21252, 462, 314...</td>\n",
       "      <td>143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16766881</th>\n",
       "      <td>53913744</td>\n",
       "      <td>33579634</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 53, 648, 11906, 34569, 374, 264, 1514...</td>\n",
       "      <td>70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16766882</th>\n",
       "      <td>53913745</td>\n",
       "      <td>33579640</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 45, 438, 67455, 44957, 21016, 352, 13...</td>\n",
       "      <td>236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16766883</th>\n",
       "      <td>53913746</td>\n",
       "      <td>33579641</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 50135, 34569, 374, 264, 15140, 315, 2...</td>\n",
       "      <td>119</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16766884</th>\n",
       "      <td>53913747</td>\n",
       "      <td>33579650</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 3976, 42861, 374, 279, 836, 315, 459,...</td>\n",
       "      <td>7087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17029173</th>\n",
       "      <td>629244</td>\n",
       "      <td>73886145</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 2520, 4458, 37792, 403, 374, 264, 220...</td>\n",
       "      <td>687</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6649601 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "               uid article_id  lang_id  \\\n",
       "256       41726489     116715        0   \n",
       "469       44292755   66915130        0   \n",
       "470       44292756   66915132        0   \n",
       "471       44292757   66915138        0   \n",
       "472       44292767   53839293        0   \n",
       "...            ...        ...      ...   \n",
       "16766881  53913744   33579634        0   \n",
       "16766882  53913745   33579640        0   \n",
       "16766883  53913746   33579641        0   \n",
       "16766884  53913747   33579650        0   \n",
       "17029173    629244   73886145        0   \n",
       "\n",
       "                                                     tokens  seq_len  \n",
       "256       [128000, 54, 408, 2176, 12336, 320, 16381, 755...      111  \n",
       "469       [128000, 98982, 1697, 28845, 369, 81745, 42509...      563  \n",
       "470       [128000, 67291, 360, 2259, 9235, 300, 11, 2663...       69  \n",
       "471       [128000, 43346, 26888, 374, 264, 2363, 555, 27...      646  \n",
       "472       [128000, 5953, 67, 360, 90251, 21252, 462, 314...      143  \n",
       "...                                                     ...      ...  \n",
       "16766881  [128000, 53, 648, 11906, 34569, 374, 264, 1514...       70  \n",
       "16766882  [128000, 45, 438, 67455, 44957, 21016, 352, 13...      236  \n",
       "16766883  [128000, 50135, 34569, 374, 264, 15140, 315, 2...      119  \n",
       "16766884  [128000, 3976, 42861, 374, 279, 836, 315, 459,...     7087  \n",
       "17029173  [128000, 2520, 4458, 37792, 403, 374, 264, 220...      687  \n",
       "\n",
       "[6649601 rows x 5 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df[\"lang_id\"]==0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "90605343-f21c-4bd8-9674-49b91bde9ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "s = df[df[\"lang_id\"]==0][\"seq_len\"]\n",
    "intervals = [(16, 64), (64, 256), (256, 1024), (1024, 4096), (4096, np.inf)]\n",
    "\n",
    "# Count the number of values in each interval using list comprehension\n",
    "counts = [((s >= low)).sum() for (low, high) in intervals]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3c7a266b-31ad-4d10-abb5-dce8e9923174",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[6647703, 6092657, 3775081, 1143319, 138512]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e5894b3e-c488-47f4-9bdb-3721a0cb2c81",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.array([100E6, 200E6, 400E6, 200E6, 100E6])//np.array([32, 128, 512, 2048, 8192])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "8520144e-c71f-4ca1-bf1e-dd09c2f02c91",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3125000., 1562500.,  781250.,   97656.,   12207.])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "574931e0-7999-4885-9bb3-b1a857c2cdb7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5578613., 2453613.,  891113.,  109863.,   12207.])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.cumsum(x[::-1])[::-1] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4cbd10-aa44-4918-8593-1b3501ea415e",
   "metadata": {},
   "outputs": [],
   "source": [
    "200E6/8192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "63aa1af9-9ba3-4524-ac24-23e907d14cdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"bin\"] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "93805e31-d167-4733-b594-7a78ff46f984",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, n in [(8192, 100E6), (2048, 200E6), (512, 400E6), (128, 200E6), (32, 100E6)]:\n",
    "    seq_per_context = 2**17//i\n",
    "    raw_count = (n//i)\n",
    "    need = int(raw_count - (raw_count%seq_per_context) + seq_per_context)\n",
    "    for_bin = df[(df[\"lang_id\"] == 0) & (df[\"bin\"].isnull()) & (df[\"seq_len\"] >= i)].sample(need)\n",
    "    df.loc[for_bin.index, \"bin\"] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "2ac0c4ac-5781-41d4-b791-46896afc475b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for_bin = df[(df[\"lang_id\"] == 0) & (df[\"bin\"].isnull()) & (df[\"seq_len\"] >= 8192)].sample(12208)\n",
    "    #for_bin[\"bin\"] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "929f1e57-0057-434e-8498-8278b4328694",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.loc[for_bin.index, \"bin\"] = 8192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "2ad045e6-40d2-4f20-99b0-1bcb16bfc175",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>uid</th>\n",
       "      <th>article_id</th>\n",
       "      <th>lang_id</th>\n",
       "      <th>tokens</th>\n",
       "      <th>seq_len</th>\n",
       "      <th>bin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>21196565</td>\n",
       "      <td>10381928</td>\n",
       "      <td>1</td>\n",
       "      <td>[128000, 48889, 261, 51301, 73302, 3903, 6754,...</td>\n",
       "      <td>1665</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>21196584</td>\n",
       "      <td>7981279</td>\n",
       "      <td>1</td>\n",
       "      <td>[128000, 38, 301, 754, 4171, 69, 839, 4289, 23...</td>\n",
       "      <td>819</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>21196585</td>\n",
       "      <td>7981283</td>\n",
       "      <td>1</td>\n",
       "      <td>[128000, 7, 3753, 18, 8, 426, 564, 320, 4468, ...</td>\n",
       "      <td>138</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>21196586</td>\n",
       "      <td>7981286</td>\n",
       "      <td>1</td>\n",
       "      <td>[128000, 1542, 604, 350, 4843, 268, 65, 11252,...</td>\n",
       "      <td>984</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>21196587</td>\n",
       "      <td>7981290</td>\n",
       "      <td>1</td>\n",
       "      <td>[128000, 18674, 15832, 309, 2357, 68, 17360, 3...</td>\n",
       "      <td>495</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169402</th>\n",
       "      <td>1949591</td>\n",
       "      <td>189024</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 5619, 239, 100335, 102650, 100907, 10...</td>\n",
       "      <td>6064</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169403</th>\n",
       "      <td>1949592</td>\n",
       "      <td>189025</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 109159, 100361, 105909, 120977, 10036...</td>\n",
       "      <td>1911</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169404</th>\n",
       "      <td>1949593</td>\n",
       "      <td>189030</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 80338, 86133, 102672, 122033, 24810, ...</td>\n",
       "      <td>233</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169405</th>\n",
       "      <td>1949594</td>\n",
       "      <td>189035</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 88344, 100537, 100361, 100395, 100348...</td>\n",
       "      <td>85</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169406</th>\n",
       "      <td>1949763</td>\n",
       "      <td>4373354</td>\n",
       "      <td>7</td>\n",
       "      <td>[128000, 644, 88485, 320, 51977, 12821, 2629, ...</td>\n",
       "      <td>106</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>17169407 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "               uid article_id  lang_id  \\\n",
       "0         21196565   10381928        1   \n",
       "1         21196584    7981279        1   \n",
       "2         21196585    7981283        1   \n",
       "3         21196586    7981286        1   \n",
       "4         21196587    7981290        1   \n",
       "...            ...        ...      ...   \n",
       "17169402   1949591     189024       30   \n",
       "17169403   1949592     189025       30   \n",
       "17169404   1949593     189030       30   \n",
       "17169405   1949594     189035       30   \n",
       "17169406   1949763    4373354        7   \n",
       "\n",
       "                                                     tokens  seq_len   bin  \n",
       "0         [128000, 48889, 261, 51301, 73302, 3903, 6754,...     1665  None  \n",
       "1         [128000, 38, 301, 754, 4171, 69, 839, 4289, 23...      819  None  \n",
       "2         [128000, 7, 3753, 18, 8, 426, 564, 320, 4468, ...      138  None  \n",
       "3         [128000, 1542, 604, 350, 4843, 268, 65, 11252,...      984  None  \n",
       "4         [128000, 18674, 15832, 309, 2357, 68, 17360, 3...      495  None  \n",
       "...                                                     ...      ...   ...  \n",
       "17169402  [128000, 5619, 239, 100335, 102650, 100907, 10...     6064  None  \n",
       "17169403  [128000, 109159, 100361, 105909, 120977, 10036...     1911  None  \n",
       "17169404  [128000, 80338, 86133, 102672, 122033, 24810, ...      233  None  \n",
       "17169405  [128000, 88344, 100537, 100361, 100395, 100348...       85  None  \n",
       "17169406  [128000, 644, 88485, 320, 51977, 12821, 2629, ...      106  None  \n",
       "\n",
       "[17169407 rows x 6 columns]"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "70170503-f204-41d6-ae38-ccf9a2bd4149",
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_per_context = 2**17//256\n",
    "raw_count = (1E9//256)\n",
    "need = int(raw_count - (raw_count%seq_per_context) + seq_per_context)\n",
    "for_bin = df[(df[\"lang_id\"] != 0) & (df[\"bin\"].isnull()) & (df[\"seq_len\"] >= 256)].sample(need)\n",
    "df.loc[for_bin.index, \"bin\"] = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "6727721f-3baf-4155-b685-5fe77d4622e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.loc[df[\"lang_id\"] != 0, \"bin\"] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "a3f29812-f1c9-4d83-a97e-727064e01bec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "lang_id\n",
       "0     5579056\n",
       "1     1257191\n",
       "2      946444\n",
       "3      706842\n",
       "7      570780\n",
       "6      316596\n",
       "23      69016\n",
       "30      39691\n",
       "Name: count, dtype: Int64"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[~df[\"bin\"].isnull()][\"lang_id\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f88ca7ef-4c5c-47ec-8745-7388662b2d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.sort_values([\"lang_id\", \"bin\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f8fae7d7-d3a8-4732-b61c-3bfabdfac535",
   "metadata": {},
   "outputs": [],
   "source": [
    "def concat_seq(row):\n",
    "    if row[\"bin\"]:\n",
    "        ret = row[\"tokens\"][:row[\"bin\"]]\n",
    "        ret[-1] = 128009\n",
    "        return ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d75f3d3f-f706-49a5-83eb-37c979f6575d",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_col = df.apply(concat_seq, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "641e80d0-cabd-4cf3-a950-a15897e641a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"to_pack\"] = new_col"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b2edec38-e3da-45e3-85bc-2b1ae2399d4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_parquet(\"/net/projects/interp/new_tokens.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "87a7efb9-ed0a-4476-9daa-e89e4892d2a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "from transformers import PreTrainedTokenizerFast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "35f8a04d-2452-4f3e-bc18-0b1ee14af875",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = PreTrainedTokenizerFast.from_pretrained(\"meta-llama/Llama-3.2-3B\",add_eos_token=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "babb8fbf-ec63-48f4-8ed5-bcb4e3c2f0f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-14 10:49:34.163134: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2024-10-14 10:49:39.626821: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-10-14 10:49:51.089175: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'<|begin_of_text|>Wendee Lee (born February 20, 1960 in Los Angeles, California) is an American voice actress. She is best known for<|eot_id|>'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(df.iloc[0][\"to_pack\"].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c6229dd6-0995-4402-9bca-74df8f75fa18",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import tensorstore as ts\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "71ee4ab5-5080-4e0d-b649-52ffd86c9049",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_parquet(\"/net/projects/interp/new_tokens.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "479acfd1-58e4-4f96-92bc-6e69553d8214",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_save = df[~df[\"bin\"].isnull()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "277d9f6d-70eb-427e-a486-7b1fb4db10d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = to_save[(to_save[\"lang_id\"] == 0) & (to_save[\"bin\"] == 8192)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "04a58aac-cc51-46c7-ae1b-d01e67c8895d",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.array(test[\"to_pack\"].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "45e76414-2242-40b1-9eef-8e5654fc8fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = x.reshape(-1,2**17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "fca41428-9956-4d07-ae00-d0162a82183e",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_write_eng = []\n",
    "others = []\n",
    "for name, group in to_save.groupby([\"lang_id\", \"bin\"]):\n",
    "    if name[0] == 0:\n",
    "        to_write.append((name, np.array(group[\"to_pack\"].tolist()).reshape(-1,2**17)))\n",
    "    else:\n",
    "        others.append((name, np.array(group[\"to_pack\"].tolist())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e2e0f8c5-4e6c-4fbe-a797-9e1d1269b5c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.concatenate([a[1] for a in others])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5f9fbbec-3c27-4d6a-9fd8-5f2afcb8a075",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.reshape(-1,2**17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "2e735fcf-50a6-4cee-afd6-9f53e1b22e3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_0 = np.concatenate([a[1] for a in to_write[:5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e150ebd8-c7f5-4331-bbf2-ac63183e382c",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_to_write = np.concatenate([x_0,x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "a20ad1bc-b649-4a0d-8b27-bfe7bf8b17c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(15260, 131072)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_to_write.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6e5dea97-f43e-44e7-89ac-f515560a9f32",
   "metadata": {},
   "outputs": [],
   "source": [
    "    tokens_data = ts.open(\n",
    "        {\n",
    "        'driver': 'zarr3',\n",
    "        'cache_pool': {'total_bytes_limit': 1E9},\n",
    "        'recheck_cached_data': 'open',\n",
    "        'kvstore': {\n",
    "            'driver': 'file',\n",
    "            'file_io_concurrency': {'limit': 2048},\n",
    "            'path': 'path/to/tokens_ts',\n",
    "            },\n",
    "        'create': True,\n",
    "        },\n",
    "        dtype=ts.int64,\n",
    "        chunk_layout=ts.ChunkLayout(\n",
    "        write_chunk_shape=[763, 8192],\n",
    "        ),\n",
    "        shape=[15260, 131072],\n",
    "    ).result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "1b783d2c-0052-477a-9e9d-204e02f136af",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_data.write(final_to_write).result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ce7ae964-12a0-4862-99e2-f2d56639025d",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loc = tokens_data[15260-1].read().result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "50a74ca6-124c-4e47-80fb-ecc7424e7a7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([128000, 109159, 100361, 105909, 120977, 100364, 103568, 100907,\n",
       "       100293, 101877, 102399, 103477, 107098, 100361, 105909, 120977,\n",
       "       100276, 100400,  65804,  48909,  44747, 100549,  84736, 100799,\n",
       "       103156,  69258, 100460, 100497, 100428, 100460, 100287, 100391,\n",
       "        44747, 103887, 109953,  48909, 108500,  44747,  85410, 100436,\n",
       "       107098, 100361, 105909, 120977, 100364, 103568, 100907, 100293,\n",
       "       101877, 102399, 103477,    320, 109159, 109159, 100714,  92911,\n",
       "            8,  92317, 100273, 101483,    220,   1049,     15,  92317,\n",
       "       100271, 107098, 100366,  44747,  48909, 100580, 101166,  86133,\n",
       "        69258,  35470, 102411, 117229, 101495, 105528, 102648, 100311,\n",
       "        48909,  24810, 100291, 106269,  61196,  85410, 101166,  73753,\n",
       "        69258, 100750, 101247, 100428,  48909, 100322,  24810, 101159,\n",
       "       105023, 109159, 100361, 105909, 120977, 100364, 103568, 100907,\n",
       "       100276, 100400,  65804, 100929, 100280,  48909, 100273, 109553,\n",
       "       100273, 125408,  35470,  48909, 100273, 100305, 100306, 100460,\n",
       "       111628, 100358, 100303, 100675, 100305,  48909,  35470,  84736,\n",
       "       119404, 119039,  92317, 100271, 101248, 123557,  48909,  35470,\n",
       "        69258, 101029,  84736, 100799,  65804, 102622, 100580, 100311,\n",
       "        92317, 100271, 105374, 102650, 100428, 100277, 104644, 100306,\n",
       "       100406,  84736, 105684, 100346, 101795, 102861, 125408,  92911,\n",
       "        24810, 107098, 100366,  44747, 105528,  79468, 123418, 106752,\n",
       "       101201,  84736,  86133, 121834, 100305, 100666, 100349,  44747,\n",
       "       101248, 103101, 107075, 100666, 100311, 100358,  84736, 111774,\n",
       "       100391, 100322, 100311, 100348, 101753,  92317, 100271, 100287,\n",
       "       101043, 100460,  48909,  35470, 105333,  35470, 105222, 104212,\n",
       "       100276, 100400,  65804,  92317, 100271,  69258,  35470, 100549,\n",
       "        48909,  35470, 100293, 100471, 100358, 100287, 101043, 103396,\n",
       "       100431, 107596, 100311, 100406,  85410, 102067,  69258, 122524,\n",
       "        24810,  92317, 100271,  84736,  86133, 121834, 100305, 100666,\n",
       "       100349,  44747,  69258, 101993, 101201,  69258, 122524,  24810,\n",
       "        84736,  86133, 121834, 100305, 100666, 100349,  44747,  48909,\n",
       "        44747,  84736, 118814, 100574, 100855,  35470,  69258, 128009])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_loc.reshape(-1,256)[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "6b9b09fe-9c1b-43c4-bcd0-e6e15807b94c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>uid</th>\n",
       "      <th>article_id</th>\n",
       "      <th>lang_id</th>\n",
       "      <th>tokens</th>\n",
       "      <th>seq_len</th>\n",
       "      <th>bin</th>\n",
       "      <th>to_pack</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>41726489</td>\n",
       "      <td>116715</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 54, 408, 2176, 12336, 320, 16381, 755...</td>\n",
       "      <td>111</td>\n",
       "      <td>32.0</td>\n",
       "      <td>[128000, 54, 408, 2176, 12336, 320, 16381, 755...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>470</th>\n",
       "      <td>44292756</td>\n",
       "      <td>66915132</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 67291, 360, 2259, 9235, 300, 11, 2663...</td>\n",
       "      <td>69</td>\n",
       "      <td>32.0</td>\n",
       "      <td>[128000, 67291, 360, 2259, 9235, 300, 11, 2663...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>512</th>\n",
       "      <td>44934285</td>\n",
       "      <td>689146</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 57735, 640, 546, 72, 320, 51, 89, 640...</td>\n",
       "      <td>96</td>\n",
       "      <td>32.0</td>\n",
       "      <td>[128000, 57735, 640, 546, 72, 320, 51, 89, 640...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>515</th>\n",
       "      <td>44934288</td>\n",
       "      <td>689149</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 20027, 5724, 320, 112638, 81, 5724, 8...</td>\n",
       "      <td>77</td>\n",
       "      <td>32.0</td>\n",
       "      <td>[128000, 20027, 5724, 320, 112638, 81, 5724, 8...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>516</th>\n",
       "      <td>44934289</td>\n",
       "      <td>689150</td>\n",
       "      <td>0</td>\n",
       "      <td>[128000, 78054, 269, 5641, 84, 320, 40, 114579...</td>\n",
       "      <td>81</td>\n",
       "      <td>32.0</td>\n",
       "      <td>[128000, 78054, 269, 5641, 84, 320, 40, 114579...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17168985</th>\n",
       "      <td>1946235</td>\n",
       "      <td>1282946</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 45279, 117765, 92911, 69258, 100431, ...</td>\n",
       "      <td>384</td>\n",
       "      <td>256.0</td>\n",
       "      <td>[128000, 45279, 117765, 92911, 69258, 100431, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169002</th>\n",
       "      <td>1946252</td>\n",
       "      <td>1283015</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 100574, 100305, 100497, 100282, 10102...</td>\n",
       "      <td>553</td>\n",
       "      <td>256.0</td>\n",
       "      <td>[128000, 100574, 100305, 100497, 100282, 10102...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169298</th>\n",
       "      <td>1947987</td>\n",
       "      <td>1283230</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 80338, 95048, 79468, 103039, 100296, ...</td>\n",
       "      <td>1758</td>\n",
       "      <td>256.0</td>\n",
       "      <td>[128000, 80338, 95048, 79468, 103039, 100296, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169402</th>\n",
       "      <td>1949591</td>\n",
       "      <td>189024</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 5619, 239, 100335, 102650, 100907, 10...</td>\n",
       "      <td>6064</td>\n",
       "      <td>256.0</td>\n",
       "      <td>[128000, 5619, 239, 100335, 102650, 100907, 10...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17169403</th>\n",
       "      <td>1949592</td>\n",
       "      <td>189025</td>\n",
       "      <td>30</td>\n",
       "      <td>[128000, 109159, 100361, 105909, 120977, 10036...</td>\n",
       "      <td>1911</td>\n",
       "      <td>256.0</td>\n",
       "      <td>[128000, 109159, 100361, 105909, 120977, 10036...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>9485616 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "               uid article_id  lang_id  \\\n",
       "256       41726489     116715        0   \n",
       "470       44292756   66915132        0   \n",
       "512       44934285     689146        0   \n",
       "515       44934288     689149        0   \n",
       "516       44934289     689150        0   \n",
       "...            ...        ...      ...   \n",
       "17168985   1946235    1282946       30   \n",
       "17169002   1946252    1283015       30   \n",
       "17169298   1947987    1283230       30   \n",
       "17169402   1949591     189024       30   \n",
       "17169403   1949592     189025       30   \n",
       "\n",
       "                                                     tokens  seq_len    bin  \\\n",
       "256       [128000, 54, 408, 2176, 12336, 320, 16381, 755...      111   32.0   \n",
       "470       [128000, 67291, 360, 2259, 9235, 300, 11, 2663...       69   32.0   \n",
       "512       [128000, 57735, 640, 546, 72, 320, 51, 89, 640...       96   32.0   \n",
       "515       [128000, 20027, 5724, 320, 112638, 81, 5724, 8...       77   32.0   \n",
       "516       [128000, 78054, 269, 5641, 84, 320, 40, 114579...       81   32.0   \n",
       "...                                                     ...      ...    ...   \n",
       "17168985  [128000, 45279, 117765, 92911, 69258, 100431, ...      384  256.0   \n",
       "17169002  [128000, 100574, 100305, 100497, 100282, 10102...      553  256.0   \n",
       "17169298  [128000, 80338, 95048, 79468, 103039, 100296, ...     1758  256.0   \n",
       "17169402  [128000, 5619, 239, 100335, 102650, 100907, 10...     6064  256.0   \n",
       "17169403  [128000, 109159, 100361, 105909, 120977, 10036...     1911  256.0   \n",
       "\n",
       "                                                    to_pack  \n",
       "256       [128000, 54, 408, 2176, 12336, 320, 16381, 755...  \n",
       "470       [128000, 67291, 360, 2259, 9235, 300, 11, 2663...  \n",
       "512       [128000, 57735, 640, 546, 72, 320, 51, 89, 640...  \n",
       "515       [128000, 20027, 5724, 320, 112638, 81, 5724, 8...  \n",
       "516       [128000, 78054, 269, 5641, 84, 320, 40, 114579...  \n",
       "...                                                     ...  \n",
       "17168985  [128000, 45279, 117765, 92911, 69258, 100431, ...  \n",
       "17169002  [128000, 100574, 100305, 100497, 100282, 10102...  \n",
       "17169298  [128000, 80338, 95048, 79468, 103039, 100296, ...  \n",
       "17169402  [128000, 5619, 239, 100335, 102650, 100907, 10...  \n",
       "17169403  [128000, 109159, 100361, 105909, 120977, 10036...  \n",
       "\n",
       "[9485616 rows x 7 columns]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "to_save.re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "af084ffb-c162-478a-9587-6b1dadeabf43",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_0_desc = np.concatenate([np.repeat(a[0][1], a[1].shape[0]) for a in to_write[:5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e9b36cb1-e216-4d82-8997-5e719c61f301",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7630,)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_0_desc.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "e2d780ab-5661-4c03-b618-aa36e0f09f21",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_desc = np.repeat(256,x.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "a390978e-5850-4807-9bd5-8c54297fbdc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_desc_towrite = np.concatenate([x_0_desc,x_desc])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "4121130d-09f4-4be4-84b0-2b27b3497357",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_desc_towrite = x_desc_towrite.reshape((15260,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "8a69fa15-8305-478e-89c2-47d109b928ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(15260, 1)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_desc_towrite.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "8c846294-cb49-46e8-9ebf-5c05b234bee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "    packing_data = ts.open(\n",
    "        {\n",
    "        'driver': 'zarr3',\n",
    "        'cache_pool': {'total_bytes_limit': 1E9},\n",
    "        'recheck_cached_data': 'open',\n",
    "        'kvstore': {\n",
    "            'driver': 'file',\n",
    "            'file_io_concurrency': {'limit': 2048},\n",
    "            'path': 'path/to/packing_ts',\n",
    "            },\n",
    "        'create': True,\n",
    "        },\n",
    "        dtype=ts.int64,\n",
    "        chunk_layout=ts.ChunkLayout(\n",
    "        write_chunk_shape=[15260, 1],\n",
    "        ),\n",
    "        shape=[15260, 1],\n",
    "    ).result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "6903e6bf-33b6-4da7-9935-ebc082a8193a",
   "metadata": {},
   "outputs": [],
   "source": [
    "packing_data.write(x_desc_towrite.astype(int)).result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9513d41-64d9-40c1-8863-8eac239309da",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
