{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/user/anaconda3/envs/torch/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from torchtext.datasets import WikiText103\n",
    "from transformers import BertTokenizerFast\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 857/1801350 [00:00<03:30, 8561.41it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (649 > 512). Running this sequence through the model will result in indexing errors\n",
      "100%|██████████| 1801350/1801350 [03:35<00:00, 8365.47it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "140.9676555700149"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = WikiText103(split='train')\n",
    "ls = 0\n",
    "count = 0\n",
    "for i in tqdm.tqdm(iter(data)):\n",
    "    if len(i) > 50:\n",
    "        ls += len(tokenizer(i).input_ids)\n",
    "        count += 1\n",
    "ls / count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1801350/1801350 [00:01<00:00, 1264701.38it/s]\n",
      "  0%|          | 0/1801350 [00:00<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "data = WikiText103(split='train')\n",
    "for i in tqdm.tqdm(iter(data)):\n",
    "    pass\n",
    "for i in tqdm.tqdm(iter(data)):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext.datasets import WikiText103\n",
    "\n",
    "class FilteredWikitext:\n",
    "    def __init__(self, min_length = 50):\n",
    "        self.data = None\n",
    "        self.length = 0\n",
    "        self.min_length = min_length\n",
    "        for i in self:\n",
    "            self.length += 1\n",
    "    \n",
    "    def __iter__(self):\n",
    "        self.data = WikiText103(split='train')\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        line = \"\"\n",
    "        while len(line) < self.min_length:\n",
    "            line = next(self.data)\n",
    "        return line\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.length\n",
    "\n",
    "class WikitextBatchLoader:\n",
    "    def __init__(self, batch_size, tokenizer):\n",
    "        self.data = FilteredWikitext()\n",
    "        self.iterator = None\n",
    "        self.tokenizer = tokenizer\n",
    "        self.batch_size = batch_size\n",
    "    \n",
    "    def __iter__(self):\n",
    "        self.iterator = iter(self.data)\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        lines = [\"[CLS]\"+next(self.iterator) for i in range(self.batch_size)]\n",
    "        result = self.tokenizer(lines, padding=True, truncation=True, max_length=512, return_tensors='pt')\n",
    "        item = {\n",
    "            'input_ids': result.input_ids,\n",
    "            'attention_mask': result.attention_mask,\n",
    "        }\n",
    "        return item\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.data.length // self.batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "33it [00:00, 131.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 348]) torch.Size([32, 348])\n",
      "torch.Size([32, 216]) torch.Size([32, 216])\n",
      "torch.Size([32, 260]) torch.Size([32, 260])\n",
      "torch.Size([32, 236]) torch.Size([32, 236])\n",
      "torch.Size([32, 434]) torch.Size([32, 434])\n",
      "torch.Size([32, 466]) torch.Size([32, 466])\n",
      "torch.Size([32, 351]) torch.Size([32, 351])\n",
      "torch.Size([32, 318]) torch.Size([32, 318])\n",
      "torch.Size([32, 371]) torch.Size([32, 371])\n",
      "torch.Size([32, 307]) torch.Size([32, 307])\n",
      "torch.Size([32, 359]) torch.Size([32, 359])\n",
      "torch.Size([32, 374]) torch.Size([32, 374])\n",
      "torch.Size([32, 347]) torch.Size([32, 347])\n",
      "torch.Size([32, 300]) torch.Size([32, 300])\n",
      "torch.Size([32, 286]) torch.Size([32, 286])\n",
      "torch.Size([32, 438]) torch.Size([32, 438])\n",
      "torch.Size([32, 476]) torch.Size([32, 476])\n",
      "torch.Size([32, 314]) torch.Size([32, 314])\n",
      "torch.Size([32, 512]) torch.Size([32, 512])\n",
      "torch.Size([32, 400]) torch.Size([32, 400])\n",
      "torch.Size([32, 271]) torch.Size([32, 271])\n",
      "torch.Size([32, 239]) torch.Size([32, 239])\n",
      "torch.Size([32, 364]) torch.Size([32, 364])\n",
      "torch.Size([32, 357]) torch.Size([32, 357])\n",
      "torch.Size([32, 393]) torch.Size([32, 393])\n",
      "torch.Size([32, 274]) torch.Size([32, 274])\n",
      "torch.Size([32, 421]) torch.Size([32, 421])\n",
      "torch.Size([32, 489]) torch.Size([32, 489])\n",
      "torch.Size([32, 326]) torch.Size([32, 326])\n",
      "torch.Size([32, 330]) torch.Size([32, 330])\n",
      "torch.Size([32, 340]) torch.Size([32, 340])\n",
      "torch.Size([32, 331]) torch.Size([32, 331])\n",
      "torch.Size([32, 512]) torch.Size([32, 512])\n",
      "torch.Size([32, 304]) torch.Size([32, 304])\n",
      "torch.Size([32, 277]) torch.Size([32, 277])\n",
      "torch.Size([32, 227]) torch.Size([32, 227])\n",
      "torch.Size([32, 356]) torch.Size([32, 356])\n",
      "torch.Size([32, 342]) torch.Size([32, 342])\n",
      "torch.Size([32, 330]) torch.Size([32, 330])\n",
      "torch.Size([32, 259]) torch.Size([32, 259])\n",
      "torch.Size([32, 181]) torch.Size([32, 181])\n",
      "torch.Size([32, 240]) torch.Size([32, 240])\n",
      "torch.Size([32, 317]) torch.Size([32, 317])\n",
      "torch.Size([32, 350]) torch.Size([32, 350])\n",
      "torch.Size([32, 304]) torch.Size([32, 304])\n",
      "torch.Size([32, 379]) torch.Size([32, 379])\n",
      "torch.Size([32, 434]) torch.Size([32, 434])\n",
      "torch.Size([32, 429]) torch.Size([32, 429])\n",
      "torch.Size([32, 250]) torch.Size([32, 250])\n",
      "torch.Size([32, 423]) torch.Size([32, 423])\n",
      "torch.Size([32, 433]) torch.Size([32, 433])\n",
      "torch.Size([32, 195]) torch.Size([32, 195])\n",
      "torch.Size([32, 379]) torch.Size([32, 379])\n",
      "torch.Size([32, 304]) torch.Size([32, 304])\n",
      "torch.Size([32, 237]) torch.Size([32, 237])\n",
      "torch.Size([32, 318]) torch.Size([32, 318])\n",
      "torch.Size([32, 332]) torch.Size([32, 332])\n",
      "torch.Size([32, 345]) torch.Size([32, 345])\n",
      "torch.Size([32, 478]) torch.Size([32, 478])\n",
      "torch.Size([32, 265]) torch.Size([32, 265])\n",
      "torch.Size([32, 306]) torch.Size([32, 306])\n",
      "torch.Size([32, 233]) torch.Size([32, 233])\n",
      "torch.Size([32, 313]) torch.Size([32, 313])\n",
      "torch.Size([32, 256]) torch.Size([32, 256])\n",
      "torch.Size([32, 84]) torch.Size([32, 84])\n",
      "torch.Size([32, 333]) torch.Size([32, 333])\n",
      "torch.Size([32, 304]) torch.Size([32, 304])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [00:00, 196.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 287]) torch.Size([32, 287])\n",
      "torch.Size([32, 253]) torch.Size([32, 253])\n",
      "torch.Size([32, 158]) torch.Size([32, 158])\n",
      "torch.Size([32, 233]) torch.Size([32, 233])\n",
      "torch.Size([32, 329]) torch.Size([32, 329])\n",
      "torch.Size([32, 306]) torch.Size([32, 306])\n",
      "torch.Size([32, 291]) torch.Size([32, 291])\n",
      "torch.Size([32, 330]) torch.Size([32, 330])\n",
      "torch.Size([32, 289]) torch.Size([32, 289])\n",
      "torch.Size([32, 385]) torch.Size([32, 385])\n",
      "torch.Size([32, 382]) torch.Size([32, 382])\n",
      "torch.Size([32, 282]) torch.Size([32, 282])\n",
      "torch.Size([32, 72]) torch.Size([32, 72])\n",
      "torch.Size([32, 279]) torch.Size([32, 279])\n",
      "torch.Size([32, 384]) torch.Size([32, 384])\n",
      "torch.Size([32, 372]) torch.Size([32, 372])\n",
      "torch.Size([32, 256]) torch.Size([32, 256])\n",
      "torch.Size([32, 402]) torch.Size([32, 402])\n",
      "torch.Size([32, 401]) torch.Size([32, 401])\n",
      "torch.Size([32, 512]) torch.Size([32, 512])\n",
      "torch.Size([32, 219]) torch.Size([32, 219])\n",
      "torch.Size([32, 307]) torch.Size([32, 307])\n",
      "torch.Size([32, 512]) torch.Size([32, 512])\n",
      "torch.Size([32, 512]) torch.Size([32, 512])\n",
      "torch.Size([32, 295]) torch.Size([32, 295])\n",
      "torch.Size([32, 390]) torch.Size([32, 390])\n",
      "torch.Size([32, 346]) torch.Size([32, 346])\n",
      "torch.Size([32, 410]) torch.Size([32, 410])\n",
      "torch.Size([32, 333]) torch.Size([32, 333])\n",
      "torch.Size([32, 299]) torch.Size([32, 299])\n",
      "torch.Size([32, 268]) torch.Size([32, 268])\n",
      "torch.Size([32, 416]) torch.Size([32, 416])\n",
      "torch.Size([32, 321]) torch.Size([32, 321])\n",
      "torch.Size([32, 223]) torch.Size([32, 223])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "c = 0\n",
    "for i in tqdm.tqdm(WikitextBatchLoader(32, tokenizer)):\n",
    "    print(i['input_ids'].shape, i['attention_mask'].shape)\n",
    "    c+=1\n",
    "    if c > 100: break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "58c896f8fe28377dc6f47dbc9814b9367447c8ff4b1090ace6962dd6db7d2533"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('torch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
