{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import IterableDataset\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "\n",
    "from transformers import RobertaTokenizerFast\n",
    "from nlp import Dataset\n",
    "\n",
    "from torchfly.rl.env import Env\n",
    "from torchfly.flydata import FlyDataLoader\n",
    "from torchfly.flyconfig import GlobalFlyConfig\n",
    "from torchfly.rl.vector import AsyncVectorEnv\n",
    "from torchfly.common import set_random_seed, get_rank\n",
    "\n",
    "from typing import Iterator, Tuple, List\n",
    "\n",
    "from dataloaders.encoder_dataloader import DocumentProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = GlobalFlyConfig(config_path=\"config/base_time_1_mem_32.yml\", \n",
    "                         disable_chdir=True, \n",
    "                         disable_logging=True)\n",
    "config = config.user_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
    "dataset = Dataset.from_file(config.flydata.training.datapath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "document = tokenizer.encode(dataset[0][\"document\"], add_special_tokens=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test with Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = DocumentProcessor(0, config.flydata.validation)\n",
    "processor.max_seq_len = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "document = [1, 2, 3, 4, 5, 6, 7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "source [50261]\n",
      "target [50261, 1, 2]\n",
      "source [50261, 1, 2]\n",
      "target [2, 3, 4]\n",
      "source [2, 3, 4]\n",
      "target [4, 5, 6]\n",
      "source [4, 5, 6]\n",
      "target [6, 7, 50260]\n"
     ]
    }
   ],
   "source": [
    "for i, item in enumerate(processor.pre_process(document)):\n",
    "    print(\"source\", item[0])\n",
    "    print(\"target\", item[1])\n",
    "    \n",
    "    if i > 10:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test with real data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = DocumentProcessor(0, config.flydata.training)\n",
    "processor.max_seq_len = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "document = tokenizer.encode(dataset[0][\"document\"], add_special_tokens=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<source> 3 <s>madeupword0000</s>\n",
      "<target> 66 madeupword0000Stephanolepis cirrhifer\n",
      "\n",
      "Stephanolepis cirrhifer, commonly known as the thread-sail filefish, is a species of marine fish in the family Monacanthidae. It is found in the western Pacific, in an area that ranges from northern Japan to the East China Sea\n",
      "<source> 68 <s>madeupword0000Stephanolepis cirrhifer\n",
      "\n",
      "Stephanolepis cirrhifer, commonly known as the thread-sail filefish, is a species of marine fish in the family Monacanthidae. It is found in the western Pacific, in an area that ranges from northern Japan to the East China Sea</s>\n",
      "<target> 61  Sea, to Korea. Other common names for the fish include \"Kawahagi\" \"カワハギ\" \"皮剥\" (Japanese) and \"Jwi-chi\" (Korean). The fish grows to a maximum length of about, and consumes both plant material and small\n",
      "<source> 63 <s> Sea, to Korea. Other common names for the fish include \"Kawahagi\" \"カワハギ\" \"皮剥\" (Japanese) and \"Jwi-chi\" (Korean). The fish grows to a maximum length of about, and consumes both plant material and small</s>\n",
      "<target> 63  small marine organisms like skeleton shrimp. S. cirrhifer is host of the parasite Peniculus minuticaudae. Some minor genetic differentiation between S. cirrhifer born in the wild and those bred in a hatchery for consumer use has been shown. The fish is edible and sold\n",
      "<source> 65 <s> small marine organisms like skeleton shrimp. S. cirrhifer is host of the parasite Peniculus minuticaudae. Some minor genetic differentiation between S. cirrhifer born in the wild and those bred in a hatchery for consumer use has been shown. The fish is edible and sold</s>\n",
      "<target> 73  sold commercially for culinary purposes in many Asian countries.\n",
      "\n",
      "Taxonomy\n",
      "The fish was first described in 1850 by Coenraad Jacob Temminck and Hermann Schlegel, when it was observed along with other fauna off the coasts of Japan. They initially placed it in the genus Monacanthus, as Monacanthus cirrhifer; however\n"
     ]
    }
   ],
   "source": [
    "for i, item in enumerate(processor):\n",
    "    print(\"<source>\", len(item[0]), tokenizer.decode(item[0]))\n",
    "    print(\"<target>\", len(item[1]), tokenizer.decode(item[1]))\n",
    "    \n",
    "    if i > 2:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
