{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import random\n",
    "from omegaconf import OmegaConf\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import IterableDataset\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "\n",
    "from tokenizers import ByteLevelBPETokenizer\n",
    "\n",
    "from torchfly.rl.env import Env\n",
    "from torchfly.flydata import FlyDataLoader\n",
    "from torchfly.flyconfig import GlobalFlyConfig\n",
    "from torchfly.rl.vector import AsyncVectorEnv\n",
    "from torchfly.common import set_random_seed, get_rank\n",
    "\n",
    "from typing import Iterator, Tuple, List\n",
    "\n",
    "from dataloaders.memformer_dataloader import TextDataLoader, TextDataLoaderHelper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = GlobalFlyConfig(config_path=\"config/base_time_1.yml\", \n",
    "                         disable_chdir=True, \n",
    "                         disable_logging=True)\n",
    "config = config.user_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "helper = TextDataLoaderHelper(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_dataloader = helper.valid_dataloader_fn(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "item = next(valid_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 64])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "item[0][0]['decoder_input_ids'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'source_ids': tensor([[    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2],\n",
       "         [    0, 50261,     2]]),\n",
       " 'decoder_target_ids': tensor([[50261, 37778, 42292,  ...,    85,   189,  1733],\n",
       "         [50261, 16025,  3928,  ...,  1720,    11,  4731],\n",
       "         [50261,   448,    12,  ...,     4,    20,   595],\n",
       "         ...,\n",
       "         [50261, 35655,  9531,  ...,    11,  1429,     6],\n",
       "         [50261, 44466,  1053,  ...,  1943,  7864,     4],\n",
       "         [50261,   863,     4,  ..., 23670, 43169,  8352]]),\n",
       " 'decoder_input_ids': tensor([[    0, 50261, 37778,  ...,     4,    85,   189],\n",
       "         [    0, 50261, 16025,  ...,   780,  1720,    11],\n",
       "         [    0, 50261,   448,  ...,  5761,     4,    20],\n",
       "         ...,\n",
       "         [    0, 50261, 35655,  ...,  5237,    11,  1429],\n",
       "         [    0, 50261, 44466,  ..., 17974,  1943,  7864],\n",
       "         [    0, 50261,   863,  ...,     5, 23670, 43169]])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "item[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
