{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:48:47.482217Z",
     "start_time": "2020-09-02T18:48:47.245642Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to\n",
      "[nltk_data]     C:\\Users\\danil\\AppData\\Roaming\\nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import glob\n",
    "import os\n",
    "import json\n",
    "import time\n",
    "import logging\n",
    "import random\n",
    "import re\n",
    "from itertools import chain\n",
    "from string import punctuation\n",
    "\n",
    "# TODO: REMOVE IF USING CUDA\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "\n",
    "import nltk\n",
    "nltk.download('punkt')\n",
    "from nltk.tokenize import sent_tokenize\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import pytorch_lightning as pl\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "from nlp import load_metric\n",
    "\n",
    "from transformers import (\n",
    "    AdamW,\n",
    "    T5ForConditionalGeneration,\n",
    "    T5Tokenizer,\n",
    "    get_linear_schedule_with_warmup\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up WandB for your project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Login to WandB and get your API Key\n",
    "# !wandb login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:48:47.486121Z",
     "start_time": "2020-09-02T18:48:47.483336Z"
    }
   },
   "outputs": [],
   "source": [
    "## Paste your API key in the YOUR_API_KEY variable below\n",
    "# import wandb\n",
    "# YOUR_API_KEY = ''\n",
    "# os.environ[\"WANDB_API_KEY\"] = YOUR_API_KEY\n",
    "# wandb_logger = WandbLogger(project='wikohow-t5')\n",
    "# wandb.init(project=\"transformers_tutorials_summarization\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data using NLP Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:48:47.865252Z",
     "start_time": "2020-09-02T18:48:47.487297Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "processing: distant_supervision_matches_small.txt\n",
      "number of usable sentences:  3962\n",
      "number of usable labels:  3962\n",
      "\n",
      "facts: r ['typeprimaryfunction [S] workplace [S] working event [S] event occurs at', 'agenttypeemployedbyorgtype [S] police officer municipal [S] police organization', 'purposeofeventtypesittype [S] hair cutting event [S] decrease on slot fn [S] hair on head [S] length of object [N] relationallexists [S] objects separated [S] shaving body [S] hair stuff [N] agenttypeperformsworkoftype [S] hair stylist [S] hair cutting event [N] relationallexists [S] object acted on [S] hair cutting event [S] hair on head [N] anatomicalparttyperequired [S] hair cutting event [S] hair on head [S] object acted on [N] relationallexists [S] object acted on [S] shaving body [S] mob of hair mammal [N] agenttypeprovidesserviceoftype [S] hair stylist [S] hair cutting event', 'typicalmainconstituenttypetype [S] book copy [S] paper', 'typicalmainconstituenttypetype [S] book copy [S] paper', 'shapetypeoftype [S] compact disc [S] disc [N] typeprimaryfunction [S] book copy [S] reading [S] ibt used', 'agenttypesellsproducttype [S] bookstore [S] book copy [N] typeprimaryfunction [S] library space [S] storing fn [S] book copy [S] event occurs at', 'typeprimaryfunction [S] book copy [S] reading [S] ibt used', 'covering [S] situation [S] the covering [S] event [S] static situation', 'covering [S] situation [S] the covering [S] event [S] static situation'] \n",
      "\n",
      "dataset size: 3503\n",
      "dataset:  [{'text': 'They make decisions, organize the people who work there and make sure that things are working alright and there are no problems.', 'facts': 'typeprimaryfunction [S] workplace [S] working event [S] event occurs at'}, {'text': 'On the Internet, an administrator is like a big boss and a police officer.', 'facts': 'agenttypeemployedbyorgtype [S] police officer municipal [S] police organization'}, {'text': 'A book is a collection of papers held together between two covers to keep the papers inside safe.', 'facts': 'typicalmainconstituenttypetype [S] book copy [S] paper'}, {'text': '\"Paperback\" books have covers of stiff paper and are usually glued together.', 'facts': 'typicalmainconstituenttypetype [S] book copy [S] paper'}, {'text': 'Books can also be read aloud and recorded on tapes and compact discs.', 'facts': 'shapetypeoftype [S] compact disc [S] disc [N] typeprimaryfunction [S] book copy [S] reading [S] ibt used'}, {'text': 'Books can be borrowed from a library or bought from a bookstore.', 'facts': 'agenttypesellsproducttype [S] bookstore [S] book copy [N] typeprimaryfunction [S] library space [S] storing fn [S] book copy [S] event occurs at'}, {'text': 'People who cannot read books are called illiterate.', 'facts': 'typeprimaryfunction [S] book copy [S] reading [S] ibt used'}, {'text': 'These books are about stories that did not happen, and have been imagined by the author.', 'facts': 'covering [S] situation [S] the covering [S] event [S] static situation'}, {'text': 'Some books are based on real events from history, but the author has created imaginary characters or dialogue for the events.', 'facts': 'covering [S] situation [S] the covering [S] event [S] static situation'}, {'text': 'Their empire once stretched from the Scottish borders to North Africa and the Eastern Mediterranean.', 'facts': 'geographicalsubregionsofcontinent [S] continent of africa [S] northern africa'}] \n",
      "\n",
      "sizes =  [('all', 3503), ('train', 2803), ('test', 350), ('validation', 350)]\n"
     ]
    }
   ],
   "source": [
    "class SentFactData():    \n",
    "    SEP_TOKEN = \" [S]\"\n",
    "    NEW_TOKEN = \" [N] \"\n",
    "\n",
    "    def camel_case_split(self, term):\n",
    "        matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', term)\n",
    "        return [m.group(0) for m in matches]\n",
    "\n",
    "    def attached_number_split(self, term):    \n",
    "        matches = re.finditer(r\"([a-zA-Z]+)|([0-9]+)\", term)\n",
    "        return [m.group(0) for m in matches]\n",
    "\n",
    "    def cyc_term_split(self, term):\n",
    "        # splits the string in camelCase, snake-case and underlines\n",
    "        ret = term.split(\"-\")\n",
    "        ret = [y for x in ret for y in x.split(\"_\")]\n",
    "        ret = [y for x in ret for y in self.camel_case_split(x)]\n",
    "        ret = [y for x in ret for y in self.attached_number_split(x)]\n",
    "        return ret\n",
    "\n",
    "    def facts_to_datapoint(self, facts):    \n",
    "        text = \"\"\n",
    "        for i, fact in enumerate(facts):\n",
    "            if i > 0:\n",
    "                text += self.NEW_TOKEN\n",
    "            for j, term in enumerate(fact[1:-1].split(\" \")):\n",
    "                if j > 0:\n",
    "                    text += self.SEP_TOKEN\n",
    "                for word in self.cyc_term_split(term):\n",
    "                    if j > 0:\n",
    "                        text += \" \"\n",
    "                    text += word.lower()\n",
    "        return text\n",
    "    \n",
    "    def process_sent_fact_file(self, filename, verbose = True):\n",
    "        print(\"\\nprocessing: \" + filename)\n",
    "        with open(filename, encoding=\"utf8\", errors=\"surrogateescape\") as file:\n",
    "            lines = [line.rstrip() for line in file]\n",
    "            sentences = []; facts_str = []\n",
    "            for it, line in enumerate(lines):                \n",
    "                if line[:3] == \"['(\" and line[-3:] == \")']\":\n",
    "                    assert(it % 2 == 1)\n",
    "                    facts_str.append(line.replace(\"\\\"\", \"\\\\\\\"\").replace(\"'\", \"\\\"\"))\n",
    "                else:\n",
    "                    assert(it % 2 == 0)\n",
    "                    sentences.append(line)\n",
    "            # transform string into json and make sure to remove duplicate facts\n",
    "            facts_lst = [list(set(json.loads(fs))) for fs in facts_str]                        \n",
    "            facts = [self.facts_to_datapoint(facts) for facts in facts_lst]\n",
    "            if verbose:            \n",
    "                print(\"number of usable sentences: \", len(sentences))\n",
    "                print(\"number of usable labels: \", len(facts_lst))\n",
    "                print(\"\")\n",
    "                print(\"facts: r\", facts[:10], \"\\n\")                \n",
    "            return sentences, facts\n",
    "    \n",
    "    def load_dataset(self, filename = 'distant_supervision_matches_small.txt', verbose = True):\n",
    "        sentences, facts = self.process_sent_fact_file(filename = filename, verbose = verbose)\n",
    "        dataset_all = [{\"text\": s, \"facts\": f} for s, f in zip(sentences, facts) if len(s) < 200 and len(f) < 200]\n",
    "        if verbose:\n",
    "            print(\"dataset size:\", len(dataset_all))\n",
    "            print(\"dataset: \", dataset_all[:10], \"\\n\")\n",
    "        random.shuffle(dataset_all)\n",
    "        n_all = len(dataset_all)\n",
    "        n_validation = n_test = int(n_all * 0.1)\n",
    "        n_train = n_all - (n_validation + n_test)\n",
    "        dataset = {\n",
    "            \"all\": dataset_all,\n",
    "            \"train\": dataset_all[:n_train],\n",
    "            \"test\": dataset_all[n_train:n_train+n_test],\n",
    "            \"validation\": dataset_all[n_train+n_test:]\n",
    "        }\n",
    "        return dataset\n",
    "\n",
    "# test dataset loading\n",
    "DATASET = SentFactData().load_dataset()\n",
    "print(\"sizes = \", [(k, len(v)) for k, v in DATASET.items()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimate average length of Text and Summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:38.785557Z",
     "start_time": "2020-09-01T13:44:38.781468Z"
    }
   },
   "outputs": [],
   "source": [
    "text_len = []\n",
    "summary_len=[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:38.801235Z",
     "start_time": "2020-09-01T13:44:38.786503Z"
    }
   },
   "outputs": [],
   "source": [
    "for i in range(len(DATASET['all'])):\n",
    "    example = DATASET['all'][i]\n",
    "    text_example = example['text']\n",
    "    text_example = text_example.replace('\\n','')\n",
    "    text_words = text_example.split()\n",
    "    text_len.append(len(text_words))\n",
    "    summary_example = example['facts']\n",
    "    summary_example = summary_example.replace('\\n','')\n",
    "    summary_words = summary_example.split()    \n",
    "    summary_len.append(len(summary_words))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:46:50.185114Z",
     "start_time": "2020-09-02T18:46:50.047191Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAclElEQVR4nO3dfZhcZZ3m8e9NAgECQiKdbEgCCRqRhNWoTdRlVEYYiaIkzE6cOOpGxQ3MZmZ0xllMmJ0RHePGvVzXt0EnvkZFMi0IieKOZKOMOrrEDiIQQoZIQtImJC2IGMBAwm//eJ5eTipVXdXdVanycH+uq6869Zy3Xz116u7TT52qVkRgZmblclS7CzAzs+ZzuJuZlZDD3cyshBzuZmYl5HA3Myshh7uZWQk53A1J2yVd0MTtnSZpn6RRTdreZyT9bZ4+T1JfM7abt/cKSVuatb1myH13RrvreCaQdIukd7a7jlZwuGf5BTXw85Skxwv33zyM7dUNIUlfkvTB4Vc9dCPdp6S3STpY6Jttkr4o6XkDy0TEjog4ISIONrCtH9bbZ0RcHhF/P9yaK/YZkp5b2PYPIuLMZmx7GLVsrzjO9kk6NffdfcPYXiPH3O9L+p6kX0vaXmX+tDz/MUn3VP7Sl/Qnku6X9KikGyWNH2qddmQ43LP8gjohIk4AdgBvKLRd0+76OsyPcz+dBFwAPA5slHR2s3fUrLP/DlY8zk6IiF2DLdyE/ngU+ALwX2vMvxb4KfBs4G+A6yR15X3PAv4ReCswEXgMuHqE9VirRIR/Kn6A7cAFefooYCnwc+BBoAcYn+d9GriusN6HgfXAWFLgPQXsyz+nVtnPl4AP1qjh9cDtwMPAj4AXVNT318AdwK+BfwKOLcy/AtgN7ALeCQTwXGAx8CTwRK7pm41sr6KutwE/rNL+rYG+AKblfY4urHMf8BtgG/Bm4Czgt8DBXMvDhT75NPBtUhBdUOwn4DygD7gS+GWu/c2FOm4B3lmtXuD7ua5H8z7/eGB7heXPytt4GNgEXFzxfP0DcFN+LLcCz2nGcVbRHsBzB+mP1wF35xp+kZ+7ho65wj4uALZXtD0P2A+cWGj7AXB5nv4Q8LXCvOfkY+nEGvs4Fbge6M/P+1/k9vH5OXxDvn8CsBX4T/n+RaRfMI8AO4GrCtuclvvn7Xner4DLgXNIx+/DwKcqnv9/BT5JOrbvAc4f5Hh5B7A5b/c7wOm5XcD/Avbm7dwBnN3urBr0+Gp3AZ34w6Hh/m7g/wJTgDGkM5dr87zjgX/LB9ArSGEzJc87j0Jo1NjPl6gS7sCL80H0UmAUsCjXNKZQ34b84hmfD8aBF+Bc4AFgVq7vKxweFh+s8nirbq9KbW+jeri/A9iTpwdegKNJofMIcGaeNwmYVWtbub5fA+eSfrEey+HhfgD4aH4+XkUKvYHtV75YD9lHsS8qnyfgaFLIXAkcA7yaFKBnFmp7CJiTH9s1wOpmHGcV7ZXPV2V/7AZekeePA17c6DFX2Ee1cL8E2FzR9ingk3l6DfDeivn7gJdU2f5RwEbg73JfnkH6BX9hnv8a0nE6Afgsh54knQf8+7yNFwB7gPkVx9Zncl+8hnSScGPe1mTSa+dVhef/APCX+fn949yfAydo//94Aebn5/+s/Pz+N+BHed6F+fGcTAr6s4BJ7cqoRn48LFPfZcDfRERfROwHrgL+SNLoiHgMeAspaL4K/HlENOPNvv8M/GNE3BoRByNiFemM6mWFZT4REbsi4iHgm8Ds3P5G4IsRsSnX9/4G91lre43aRfrFUM1TwNmSjouI3RGxqc621kTEv0bEUxHx2xrL/G1E7I+IfyGdSb9xiPVW8zLSWeSKiHgiIr5L+ovkTYVlvhERGyLiACncZ49wnzdKejj/3Fhjmcr+eBKYKelZEfGriLhthDUMOIEUfEW/Bk5scH7ROUBXRHwg9+V9pBBfCBARNwNfJ/2lexHpdUaed0tE3Jkf7x2koaJXVWz/7yPit3k7j5JOuPZGxC9If228qLDsXuBjEfFkRPwTsCXvs9JlwH+PiM35+f0QMFvS6aQ+PxF4PqC8zO4q2+gYDvf6TgduGHgBks5qD5LGHImIDaQzEpGGbJq1z/cUXvQPA1NJZ9YDHihMP0Z64ZGX2VmYV5weTK3tNWoy6az2EBHxKOls6XJgt6SbJD2/zrbq1fyrvN0B93No3wzXqcDOiHiqYtuTC/cb6qd8hc/Am6RXDrLP+RFxcv6ZX2OZyv74j6Shmfsl/Yuklw+y/aHYBzyrou1ZpL9eGplfdDpwasUxfCX5dZOtBM4mnYw8ONAo6aX5Td1+Sb8mHTunVGx/T2H68Sr3i8/LLyKffme1jpfTgY8X6n2I9LqenH/Rf4o0LLdH0kpJlX3RURzu9e0EXlt4AZ4cEcfmMwQkLSEND+wijXUPGMnXbe4Ellfs8/iIuLaBdXeThpAGTK2Y36qvAb2EdMZ0mIj4TkT8AWlI5h7SGdxgtdSrcZyksYX7p5H6H9JZ3PGFef+uzraKdgFTJRVfF6eRxrWHJNIVPgNvkn5oqOtXbq5i2z+JiHmkYYgbefqkYqTP7SbgDEnFM/EX5vaB+S8cmJEv1xxDGpqstBPYVnEMnxgRr8vrjiINcX4Z+NPiFUzA14C1wNSIOIk0BKMRPK7JkorrF4+Xypovq6j5uIj4EUBEfCIiXkIa8nwetd+U7ggO9/o+AyzPf5ohqUvSvDz9POCDpKGZtwJXSJqd19sDPFvSSXW2P0rSsYWfY0jhd3k+g5GksZIuqnjR1dIDvF3SWZKOJ415Fu0hjX+OmKRRkqZL+iRpnPSwISBJEyVdnMN4P+nsb+ASyT3AlPyYh+r9ko6R9ArSm89fz+23A38o6fgcGJdWrDfY47+V9MvhCklHSzoPeAOwehj1tUR+zG+WdFJEPEl6P6PYn4Mec5KOknQsafxZhWOOiPg3Uv+9L7dfQhrzvj6vfg3wBqXPBowFPkAapqp25r4BeETSeyUdl4+VsyWdk+cP/DXzDuAjwJcLVwKdCDwUEb+VNAf4k6H10mEmAH+Rn9MFpPHyb1dZ7jPAsnxVEJJOyssj6Zz8ejyadIwMXAzQsRzu9X2cdBZxs6TfkN5cfamk0aRx9g9HxM8i4l7SAfsVSWMi4h7SWOF9+c+8WsMGS0l/Rg78fDcieknj7p8ivWu/lfTGUF0R8b+BTwDfy+v9OM/an28/TxqvHWyMt56XS9pHCpZbSH+anxMRd1ZZ9ijgPaQzpYdIY6f/Jc/7Luls8AFJvxzC/h8g9csuUuBcnvsb0hUNT5CCblWeX3QVsCo//kPG6SPiCeBi4LWkN8evJl3BcQ+d5a3AdkmPkIYs3gLQ4DH3StJx9m3SGezjwM2F+QuBblL/rgD+KCL68/Y35f1dQxrHPpGnn8tDRPqMwxtI70lsI/Xn54CTJL0E+CtS3x4kXWUWpNcCeZsfyK+3v2Pkw523AjNyDcvzY3qwcqGIuCHXsjr37V2kYwHSMf5ZUr/cT7py7iMjrKuldOhQlJWNpLNIB+mY/CaR2TOGpLeRrob5vXbXcqT5zL2EJF2S/3wfRzoT+aaD3eyZxeFeTpeRPjjyc9K44J+2txwzO9IaGpaR9Jc8/UnHO0mfDjue9EnGaaQPY7wxIn6Vl19GeiPrIOlTad9pQe1mZlZD3XCXNBn4ITAzIh6X1EN6M2Ym6R3tFZKWAuMi4r2SZpLe1JlDupb0/wDPizpfImVmZs0zegjLHSfpSdIZ+y5gGenyN0hXJdwCvBeYR/pI9n5gm6StpKD/MTWccsopMW3atGGUb2b2zLVx48ZfRkRXtXl1wz0ifiHpI6RvSnwcuDkibpY0ceDjtxGxW9KEvMpk0uWCA/o49BN+AEhaTPoiK0477TR6e3uH8pjMzJ7xJN1fa17dN1TzFRfzgOmkYZaxkt4y2CpV2g4b+4mIlRHRHRHdXV1Vf/GYmdkwNXK1zAWkjxH350/EfQP4D6TvV5gEkG/35uX7OPQj71Oo/lFfMzNrkUbCfQfwsvxxbgHnk748ay3pq2jJt2vy9FpgoaQxkqaTPhm2obllm5nZYBoZc79V0nXAbaTvRf4p6dvcTgB6JF1K+gWwIC+/KV9Rc3defomvlDEzO7I64usHuru7w2+ompkNjaSNEdFdbZ4/oWpmVkIOdzOzEnK4m5mVkMPdzKyEGv36ATMApi29qS373b6i2v8zNrNafOZuZlZCDnczsxJyuJuZlZDD3cyshBzuZmYl5HA3Myshh7uZWQk53M3MSsjhbmZWQg53M7MScribmZWQw93MrIQc7mZmJVQ33CWdKen2ws8jkt4tabykdZLuzbfjCussk7RV0hZJF7b2IZiZWaW64R4RWyJidkTMBl4CPAbcACwF1kfEDGB9vo+kmcBCYBYwF7ha0qjWlG9mZtUMdVjmfODnEXE/MA9YldtXAfPz9DxgdUTsj4htwFZgThNqNTOzBg013BcC1+bpiRGxGyDfTsjtk4GdhXX6ctshJC2W1Cupt7+/f4hlmJnZYBoOd0nHABcDX6+3aJW2OKwhYmVEdEdEd1dXV6NlmJlZA4byb/ZeC9wWEXvy/T2SJkXEbkmTgL25vQ+YWlhvCrBr5KXagHb9qzsz+90xlGGZN/H0kAzAWmBRnl4ErCm0L5Q0RtJ0YAawYaSFmplZ4xo6c5d0PPAHwGWF5hVAj6RLgR3AAoCI2CSpB7gbOAAsiYiDTa3azMwG1VC4R8RjwLMr2h4kXT1TbfnlwPIRV2dmZsPiT6iamZWQw93MrIQc7mZmJeRwNzMrIYe7mVkJOdzNzErI4W5mVkIOdzOzEnK4m5mVkMPdzKyEHO5mZiXkcDczKyGHu5lZCTnczcxKyOFuZlZCDnczsxJyuJuZlZDD3cyshBoKd0knS7pO0j2SNkt6uaTxktZJujffjissv0zSVklbJF3YuvLNzKyaRs/cPw78c0Q8H3ghsBlYCqyPiBnA+nwfSTOBhcAsYC5wtaRRzS7czMxqqxvukp4FvBL4PEBEPBERDwPzgFV5sVXA/Dw9D1gdEfsjYhuwFZjT3LLNzGwwjZy5nwH0A1+U9FNJn5M0FpgYEbsB8u2EvPxkYGdh/b7cdghJiyX1Surt7+8f0YMwM7NDNRLuo4EXA5+OiBcBj5KHYGpQlbY4rCFiZUR0R0R3V1dXQ8WamVljGgn3PqAvIm7N968jhf0eSZMA8u3ewvJTC+tPAXY1p1wzM2tE3XCPiAeAnZLOzE3nA3cDa4FFuW0RsCZPrwUWShojaTowA9jQ1KrNzGxQoxtc7s+BayQdA9wHvJ30i6FH0qXADmABQERsktRD+gVwAFgSEQebXrmZmdXUULhHxO1Ad5VZ59dYfjmwfPhlmZnZSPgTqmZmJeRwNzMrIYe7mVkJNfqGqtkz1rSlN7Vlv9tXXNSW/Vo5+MzdzKyEHO5mZiXkcDczKyGHu5lZCTnczcxKyOFuZlZCDnczsxJyuJuZlZDD3cyshBzuZmYl5HA3Myshh7uZWQk53M3MSsjhbmZWQg2Fu6Ttku6UdLuk3tw2XtI6Sffm23GF5ZdJ2ippi6QLW1W8mZlVN5Qz99+PiNkRMfC/VJcC6yNiBrA+30fSTGAhMAuYC1wtaVQTazYzszpGMiwzD1iVp1cB8wvtqyNif0RsA7YCc0awHzMzG6JGwz2AmyVtlLQ4t02MiN0A+XZCbp8M7Cys25fbDiFpsaReSb39/f3Dq97MzKpq9N/snRsRuyRNANZJumeQZVWlLQ5riFgJrATo7u4+bL6ZmQ1fQ2fuEbEr3+4FbiANs+yRNAkg3+7Ni/cBUwurTwF2NatgMzOrr264Sxor6cSBaeA1wF3AWmBRXmwRsCZPrwUWShojaTowA9jQ7MLNzKy2RoZlJgI3SBpY/msR8c+SfgL0SLoU2AEsAIiITZJ6gLuBA8CSiDjYkurNzKyquuEeEfcBL6zS/iBwfo11lgPLR1ydmZkNiz+hamZWQo1eLWPWVtOW3tTuEsx+p/jM3cyshBzuZmYl5HA3Myshh7uZWQk53M3MSsjhbmZWQg53M7MScribmZWQw93MrIQc7mZmJeRwNzMrIYe7mVkJOdzNzErI4W5mVkIOdzOzEnK4m5mVUMPhLmmUpJ9K+la+P17SOkn35ttxhWWXSdoqaYukC1tRuJmZ1TaUM/d3AZsL95cC6yNiBrA+30fSTGAhMAuYC1wtaVRzyjUzs0Y0FO6SpgAXAZ8rNM8DVuXpVcD8QvvqiNgfEduArcCcplRrZmYNafTM/WPAFcBThbaJEbEbIN9OyO2TgZ2F5fpym5mZHSF1w13S64G9EbGxwW2qSltU2e5iSb2Sevv7+xvctJmZNaKRM/dzgYslbQdWA6+W9FVgj6RJAPl2b16+D5haWH8KsKtyoxGxMiK6I6K7q6trBA/BzMwq1Q33iFgWEVMiYhrpjdLvRsRbgLXAorzYImBNnl4LLJQ0RtJ0YAawoemVm5lZTaNHsO4KoEfSpcAOYAFARGyS1APcDRwAlkTEwRFXamZmDRtSuEfELcAtefpB4Pwayy0Hlo+wNjMzGyZ/QtXMrIQc7mZmJeRwNzMroZG8ofqMN23pTe0uwcysKp+5m5mVkMPdzKyEHO5mZiXkcDczKyGHu5lZCTnczcxKyOFuZlZCDnczsxJyuJuZlZDD3cyshBzuZmYl5O+WMetQ7fzuou0rLmrbvq05fOZuZlZCDnczsxKqG+6SjpW0QdLPJG2S9P7cPl7SOkn35ttxhXWWSdoqaYukC1v5AMzM7HCNnLnvB14dES8EZgNzJb0MWAqsj4gZwPp8H0kzgYXALGAucLWkUS2o3czMaqgb7pHsy3ePzj8BzANW5fZVwPw8PQ9YHRH7I2IbsBWY08yizcxscA2NuUsaJel2YC+wLiJuBSZGxG6AfDshLz4Z2FlYvS+3mZnZEdJQuEfEwYiYDUwB5kg6e5DFVW0Thy0kLZbUK6m3v7+/oWLNzKwxQ7paJiIeBm4hjaXvkTQJIN/uzYv1AVMLq00BdlXZ1sqI6I6I7q6urqFXbmZmNTVytUyXpJPz9HHABcA9wFpgUV5sEbAmT68FFkoaI2k6MAPY0OS6zcxsEI18QnUSsCpf8XIU0BMR35L0Y6BH0qXADmABQERsktQD3A0cAJZExMHWlG9mZtXUDfeIuAN4UZX2B4Hza6yzHFg+4urMzGxY/AlVM7MScribmZWQw93MrIQc7mZmJeRwNzMrIYe7mVkJOdzNzErI4W5mVkIOdzOzEnK4m5mVkMPdzKyEHO5mZiXkcDczKyGHu5lZCTnczcxKyOFuZlZCDnczsxJyuJuZlVAj/yB7qqTvSdosaZOkd+X28ZLWSbo3344rrLNM0lZJWyRd2MoHYGZmh2vkzP0A8J6IOAt4GbBE0kxgKbA+ImYA6/N98ryFwCxgLnB1/ufaZmZ2hNQN94jYHRG35enfAJuBycA8YFVebBUwP0/PA1ZHxP6I2AZsBeY0uW4zMxvEkMbcJU0DXgTcCkyMiN2QfgEAE/Jik4GdhdX6cpuZmR0hDYe7pBOA64F3R8Qjgy1apS2qbG+xpF5Jvf39/Y2WYWZmDWgo3CUdTQr2ayLiG7l5j6RJef4kYG9u7wOmFlafAuyq3GZErIyI7ojo7urqGm79ZmZWRSNXywj4PLA5Ij5amLUWWJSnFwFrCu0LJY2RNB2YAWxoXslmZlbP6AaWORd4K3CnpNtz25XACqBH0qXADmABQERsktQD3E260mZJRBxsduFm1jrTlt7Ulv1uX3FRW/ZbRnXDPSJ+SPVxdIDza6yzHFg+grrMzGwE/AlVM7MScribmZWQw93MrIQc7mZmJeRwNzMrIYe7mVkJOdzNzErI4W5mVkIOdzOzEnK4m5mVkMPdzKyEHO5mZiXkcDczKyGHu5lZCTnczcxKyOFuZlZCDnczsxJyuJuZlZDD3cyshOqGu6QvSNor6a5C23hJ6yTdm2/HFeYtk7RV0hZJF7aqcDMzq62RM/cvAXMr2pYC6yNiBrA+30fSTGAhMCuvc7WkUU2r1szMGlI33CPi+8BDFc3zgFV5ehUwv9C+OiL2R8Q2YCswpzmlmplZo4Y75j4xInYD5NsJuX0ysLOwXF9uO4ykxZJ6JfX29/cPswwzM6um2W+oqkpbVFswIlZGRHdEdHd1dTW5DDOzZ7bhhvseSZMA8u3e3N4HTC0sNwXYNfzyzMxsOIYb7muBRXl6EbCm0L5Q0hhJ04EZwIaRlWhmZkM1ut4Ckq4FzgNOkdQHvA9YAfRIuhTYASwAiIhNknqAu4EDwJKIONii2s3MrIa64R4Rb6ox6/wayy8Hlo+kKDMzGxl/QtXMrITqnrmbmR0p05be1Jb9bl9xUVv220o+czczKyGHu5lZCTnczcxKyOFuZlZCDnczsxJyuJuZlZDD3cyshBzuZmYl5HA3Myshh7uZWQk53M3MSqgU3y3Tru+jMDPrVD5zNzMrIYe7mVkJlWJYxsxsJNo5tNuqrxv2mbuZWQm1LNwlzZW0RdJWSUtbtR8zMztcS8Jd0ijgH4DXAjOBN0ma2Yp9mZnZ4Vp15j4H2BoR90XEE8BqYF6L9mVmZhVa9YbqZGBn4X4f8NLiApIWA4vz3X2StrSolmY4Bfhlu4sYhOsbGdc3Mq5vBPThEdV3eq0ZrQp3VWmLQ+5ErARWtmj/TSWpNyK6211HLa5vZFzfyLi+kWlVfa0alukDphbuTwF2tWhfZmZWoVXh/hNghqTpko4BFgJrW7QvMzOr0JJhmYg4IOnPgO8Ao4AvRMSmVuzrCOn04SPXNzKub2Rc38i0pD5FRP2lzMzsd4o/oWpmVkIOdzOzEnK4D0LSdkl3SrpdUm8H1PMFSXsl3VVoGy9pnaR78+24DqvvKkm/yH14u6TXtbG+qZK+J2mzpE2S3pXbO6IPB6mvI/pQ0rGSNkj6Wa7v/bm9U/qvVn0d0X+FOkdJ+qmkb+X7Lek/j7kPQtJ2oDsiOuIDEJJeCewDvhwRZ+e2/wE8FBEr8nf4jIuI93ZQfVcB+yLiI+2oqUjSJGBSRNwm6URgIzAfeBsd0IeD1PdGOqAPJQkYGxH7JB0N/BB4F/CHdEb/1apvLh3QfwMk/RXQDTwrIl7fqtewz9x/h0TE94GHKprnAavy9CpSGLRFjfo6RkTsjojb8vRvgM2kT1N3RB8OUl9HiGRfvnt0/gk6p/9q1dcxJE0BLgI+V2huSf853AcXwM2SNuavS+hEEyNiN6RwACa0uZ5q/kzSHXnYpm3DRkWSpgEvAm6lA/uwoj7okD7MQwq3A3uBdRHRUf1Xoz7okP4DPgZcATxVaGtJ/zncB3duRLyY9O2WS/Kwgw3Np4HnALOB3cD/bGs1gKQTgOuBd0fEI+2up1KV+jqmDyPiYETMJn3qfI6ks9tVSzU16uuI/pP0emBvRGw8EvtzuA8iInbl273ADaRvu+w0e/JY7cCY7d4213OIiNiTX3BPAZ+lzX2Yx2KvB66JiG/k5o7pw2r1dVof5poeBm4hjWd3TP8NKNbXQf13LnBxfi9vNfBqSV+lRf3ncK9B0tj8phaSxgKvAe4afK22WAssytOLgDVtrOUwAwdtdglt7MP8htvngc0R8dHCrI7ow1r1dUofSuqSdHKePg64ALiHzum/qvV1Sv9FxLKImBIR00hfyfLdiHgLLeo/Xy1Tg6QzSGfrkL6m4WsRsbyNJSHpWuA80leY7gHeB9wI9ACnATuABRHRljc1a9R3HunP4QC2A5cNjC+2ob7fA34A3MnTY55Xksa1296Hg9T3JjqgDyW9gPSG3yjSiWFPRHxA0rPpjP6rVd9X6ID+K5J0HvDX+WqZlvSfw93MrIQ8LGNmVkIOdzOzEnK4m5mVkMPdzKyEHO5mZiXkcDczKyGHu5lZCf0/7febwjWlBKYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.hist(text_len)\n",
    "plt.title('Text Length Distribution - First 100 examples')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:47:08.641831Z",
     "start_time": "2020-09-02T18:47:08.524395Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAb2ElEQVR4nO3df7RcdX3u8fdjgvxGiBwwJEBQoy2wBPWI9HprWRcrqGiwV2xYorGlN+rCii1eBO0t1JpV0F6rXgVLRQiiYC4o5GppoVFUVhUMCEIIKRGQHBOTyA9JRDHAc//4fk8dhpmTc84czplkP6+1Zp093/3rM3v2PLPnu2f2kW0iIqIZnjXVBURExORJ6EdENEhCPyKiQRL6ERENktCPiGiQhH5ERIMk9GNUJJ0t6dIJXuY1khZM0LJ+X9Kqlvv3SXrNRCy7Lm+FpKMmanm9kvQ5Sf9rqutoAknvlHTDVNcxURL6XdTQ+JWkzS23/ca5rKMkDU32vOM1EeuUZEm/rNvtAUnLJP1x6zS2X2d78SiX9cKRprH9Xdsv7qXmlvVdLOmjbcs/xPb1E7H8MdZytqQtbfvh6bbfbftvx7nMEd8QJT1b0hV1Ore/2ak4tz6vD0j6mCS1jJ8j6VuSHpV010S++UbvEvoje6Pt3Vpua6e6oG3MYbZ3A14MXAx8RtJZE70SSdMnepl95itt++HHRpp4grbHDcBJwM86jFsIHA8cBrwEOA54V8v4y4AfAs8FPgxcIWlgAmqKiWA7tw434D7gNW1tewFfBzYCD9Xh2S3jZwAXAWvr+KuAXYFfAU8Cm+ttP+AIYDnwCLAe+ESXOo4ChrqM2w+4stZzL/C+lnFnA0uAS4BNwApgsGX8yygvzE3A/wW+Anx0hHpHXF6H2gy8sK3tLcCvgefW+9cDf1aHXwh8G/gF8HNK0AF8py7rl7WWPx7eJsAHKaH0xfbtVJ+/M4E763NxEbBTHfdO4IZO9VICbQvwm7q+/9e+PwA7Ap+sz/PaOrxj6/MFnAZsANYBf9LDfng2cGmH9ouBj7ats3V77E3ZPx8GHgS+SznI+2J9bn9VH9/pW1n/EHBUW9u/Awtb7p8MfL8Ovwh4DNi9Zfx3gXd3Wf6zgDOAHwMP1H1sRh13PnBFy7TnAssAsfXX4vWU/fnfh59HypvQlyivuR8Ac9qe//cB91D2v48Dz+q0vwC/A1xXt+sq4K0t415P2ec2AT8FPjDVWdZ+y5H+2DyLEh4HAgdQXjifaRn/RWAX4BBgH+AfbP8SeB2w1k/9xPAp4FO29wBeQNnZR03Ssyg78m3ALOBo4P2SjmmZ7E3A5cCewNLhWiU9G/gaJThmUI7M3gwwQr1dlzcGVwPTKW947f4WuJbyYp4N/J9az6vr+MNqLV+p959Xaz+QEtSdvA04hrJ9XwT81dYKtH0BJRg+Vtf3xg6TfRg4EjiccrR7RNuynwc8h/K8nAx8VtJeW1t3j9q3x2mUwB4A9gU+BNj224H7+e2n2BE/NXRxCGW/G3ZbbRsed4/tTV3Gt3sf5VPDH1AOLh4CPlvHnQa8pPap/z5lWy5wSdetvRYB5gNvpzwPLwC+V+eZAawE2j91vhkYpBwQzQP+tL1YSbtSAv/LlNf4icB5koYf34XAu2zvDhwKfLPL454yCf2RXSXp4Xq7yvYDtq+0/WjdqRdRdlYkzaSE5bttP2R7i+1vj7DsLcALJe1te7Pt74+xtlcAA7Y/Yvs3tu8B/omyow+7wfY/236C8oZ0WG0/khK+n651fhW4aRTr7La8UbG9hXIUNaPD6C2UF/B+tn9te2snzp4EzrL9mO1fdZnmM7bX2H6Q8lydOJZ6R/A24CO2N9jeCPwNJVyGbanjt9j+Z8qRZi/nG97ash8+3OXcUvv22ALMBA6sdXy3huVE2I3yiWzYL4Ddar9++7jh8bt3Wda7gA/bHrL9GOWTzVskTbf9KKWL6RPApcCf2x4CGOm12OIi2z+2/QvgGuDHtv/N9uOUT7cvbZv+XNsP2r6f8umt0/5yHHCf7YtsP277Fsqn7bfU8VuAgyXtUXPgli6Pe8ok9Ed2vO096+14SbtI+kdJP5H0CKXrYU9J04D9gQdtPzTKZZ9MOfq8S9IPJB03xtoOBPZrDQPK0dy+LdO09sc+CuxU+3v3A37aFgJrRrHObssbFUk7UI48H+ww+nTKx/ab6jdlnnaU1Waj7V9vZZrWx/QTyuOeCPvV5XVb9gM1WIY9SgnDp6jfOBo+ObtihPUtadkP93Tnc0vt2+PjwGrgWkn3SDpjq49q9DYDe7Tc3wPYXPen9nHD4zfR2YHA11r24ZXAE9T92PZNlC4X0fJpeCuvxWHrW4Z/1eF++3Mymv3lQOCVba+7t1E+aQH8d0oXz08kfVvS73V53FMmoT82p1GO2F5Zu2WGux5E2WFmSNqzw3xPO8KyfbftEykfEc+lnOzadQy1rAHubQuD3W2/fhTzrgNmtX7jgvKm1bXeCTIPeJwOnyps/8z2/7C9H+Xo77ytfGNnNDW2PqYDKP3vUM4P7DI8QtLzeKqtLXst5cXfadmjVo++h7vQunV/jHpxbcveZPs0288H3gj8paSjO007Dit46qe8w2rb8LjnS9q9y/h2a4DXte3HO9n+KYCkUyjnUNZSDgyGjfRaHK9u+0t7vd9uq3c32+8BsP0D2/Mor+urGGO37WRI6I/N7pQjhIclzaClT9D2OspHyPMk7SVpB0nDO+J64LmSnjM8vaSTJA3YfpJysg3KEU5HknZqvVGC8xFJH5S0s6Rpkg6V9IpRPI7v1XW9V9J0SfN4aj/70+rthaQZkt5G6as91/YDHaY5QdLsevchSjANb4/1wPPHsepTJM2uz9WHKCerofYxSzq8bsuz2+bb2vouA/5K0oCkvYG/pnQ/9A1Jx0l6YX1jf4SyLUe9PSXtWLcNwLPrfjccqJdQ3kRm1a6m0yjnh7D9H8CtwFl1njdTvuFzZZdVfQ5YJOnAut6Buj8i6UWUk7EnUbrPTpd0eJ2v62uxB/+zvnb3B07lt/tLq68DL5L09voa30HSKyT9rspXXd8m6Tm1K3N4u/eVhP7YfBLYmdIv/X3gX9rGv53Sp3cX5Zsb7wewfRclKO5p6ZM9FlghaTPlpO78EborZlF28NbbQZQjuMMp39z5OfB5ygnEEdn+DfBHlC6mhykvqq9TvnXRrd7xuK0+vtXAnwF/Yfuvu0z7CuDGOv1S4FTb99ZxZwOLay1vHcP6v0w5OXxPvX0U/jOYPgL8G3A35euJrS6k9Ms+LOmqDsv9KOWbVz8CbgduGV52H5lLeXybKW/y5/m3vzP4O8qb1sOSPtBl/lWU/WwW8K91ePjTzT9SvkRwO3AH8I3aNmw+5YToQ8A5wFvquY9OPkV5vq+VtInyunpl7Ta8lHKQcJvtuylv3F+UNPztqZFei+NxNXAz5U3rG5T94Cnq+YPX1se4ltLleS7l0wiUDLivdjm9m/La6iuauHM7sS2TdCPwOdsXTXUtEZNNkoG5tldPdS3PtBzpN5SkP5D0vNq9s4DyEXwijpYioo9t779kjO5eTDnJtBvlhzFvqeclImI7lu6diIgGSfdORESD9H33zt577+05c+ZMdRkREduUm2+++ee2n3ahu74P/Tlz5rB8+fKpLiMiYpsi6Sed2rfavSPpC5I2SLqjw7gPqFxve++WtjMlrZa0Si0X/5L0ckm313Gfbvs1aERETILR9OlfTPkh0VPUX639IeWKfcNtB1N+tHBInee8lmthnE+5+t/cenvaMiMi4pm11dC3/R06XyDrHyjXwmj9+s884PJ6pb97Kb/EPELlCpR72P5evSjTJZTLqUZExCQa17d3JL2JcpXG29pGzeKpV6obqm2z6nB7e0RETKIxn8iVtAvln0i8ttPoDm0eob3bOhZS/zHGAQccMNYSIyKii/Ec6b+AcrGv2yTdR/kvR7fUy9MO8dTLk86mXJRoqA63t3dk+wLbg7YHBwbyrzUjIibKmEPf9u2297E9x/YcSqC/zPbPKFfLm18vy3oQ5YTtTfXn/ZskHVm/tfMOyhXtIiJiEo3mK5uXUS7N+mJJQ5JO7jat7RWU67ncSbl41yku/1oP4D2US/+uplzr5Zoea4+IiDHq+2vvDA4OOj/OiogYG0k32x5sb+/7X+Rui+ac8Y0pW/d957xhytYdEf0vF1yLiGiQhH5ERIMk9CMiGiShHxHRIAn9iIgGSehHRDRIQj8iokES+hERDZLQj4hokIR+RESDJPQjIhokoR8R0SAJ/YiIBknoR0Q0SEI/IqJBEvoREQ2S0I+IaJCEfkREgyT0IyIaJKEfEdEgCf2IiAbZauhL+oKkDZLuaGn7uKS7JP1I0tck7dky7kxJqyWtknRMS/vLJd1ex31akib80URExIhGc6R/MXBsW9t1wKG2XwL8B3AmgKSDgfnAIXWe8yRNq/OcDywE5tZb+zIjIuIZttXQt/0d4MG2tmttP17vfh+YXYfnAZfbfsz2vcBq4AhJM4E9bH/PtoFLgOMn6DFERMQoTUSf/p8C19ThWcCalnFDtW1WHW5v70jSQknLJS3fuHHjBJQYERHQY+hL+jDwOPCl4aYOk3mE9o5sX2B70PbgwMBALyVGRESL6eOdUdIC4Djg6NplA+UIfv+WyWYDa2v77A7tERExicZ1pC/pWOCDwJtsP9oyaikwX9KOkg6inLC9yfY6YJOkI+u3dt4BXN1j7RERMUZbPdKXdBlwFLC3pCHgLMq3dXYErqvfvPy+7XfbXiFpCXAnpdvnFNtP1EW9h/JNoJ0p5wCuISIiJtVWQ9/2iR2aLxxh+kXAog7ty4FDx1RdRERMqHH36W8L5pzxjakuISKir+QyDBERDZLQj4hokIR+RESDJPQjIhokoR8R0SAJ/YiIBknoR0Q0SEI/IqJBEvoREQ2S0I+IaJCEfkREgyT0IyIaJKEfEdEgCf2IiAZJ6EdENEhCPyKiQRL6ERENktCPiGiQhH5ERIMk9CMiGmSroS/pC5I2SLqjpW2GpOsk3V3/7tUy7kxJqyWtknRMS/vLJd1ex31akib+4URExEhGc6R/MXBsW9sZwDLbc4Fl9T6SDgbmA4fUec6TNK3Ocz6wEJhbb+3LjIiIZ9hWQ9/2d4AH25rnAYvr8GLg+Jb2y20/ZvteYDVwhKSZwB62v2fbwCUt80RExCQZb5/+vrbXAdS/+9T2WcCalumGatusOtze3pGkhZKWS1q+cePGcZYYERHtJvpEbqd+eo/Q3pHtC2wP2h4cGBiYsOIiIppuvKG/vnbZUP9uqO1DwP4t080G1tb22R3aIyJiEo039JcCC+rwAuDqlvb5knaUdBDlhO1NtQtok6Qj67d23tEyT0RETJLpW5tA0mXAUcDekoaAs4BzgCWSTgbuB04AsL1C0hLgTuBx4BTbT9RFvYfyTaCdgWvqLSIiJtFWQ9/2iV1GHd1l+kXAog7ty4FDx1RdRERMqPwiNyKiQRL6ERENktCPiGiQhH5ERIMk9CMiGiShHxHRIAn9iIgGSehHRDRIQj8iokES+hERDZLQj4hokIR+RESDJPQjIhokoR8R0SAJ/YiIBknoR0Q0SEI/IqJBEvoREQ2S0I+IaJCEfkREgyT0IyIapKfQl/QXklZIukPSZZJ2kjRD0nWS7q5/92qZ/kxJqyWtknRM7+VHRMRYjDv0Jc0C3gcM2j4UmAbMB84AltmeCyyr95F0cB1/CHAscJ6kab2VHxERY9Fr9850YGdJ04FdgLXAPGBxHb8YOL4OzwMut/2Y7XuB1cARPa4/IiLGYNyhb/unwN8D9wPrgF/YvhbY1/a6Os06YJ86yyxgTcsihmpbRERMkl66d/aiHL0fBOwH7CrppJFm6dDmLsteKGm5pOUbN24cb4kREdGml+6d1wD32t5oewvwVeC/AOslzQSofzfU6YeA/Vvmn03pDnoa2xfYHrQ9ODAw0EOJERHRqpfQvx84UtIukgQcDawElgIL6jQLgKvr8FJgvqQdJR0EzAVu6mH9ERExRtPHO6PtGyVdAdwCPA78ELgA2A1YIulkyhvDCXX6FZKWAHfW6U+x/USP9UdExBiMO/QBbJ8FnNXW/BjlqL/T9IuARb2sMyIixi+/yI2IaJCEfkREgyT0IyIaJKEfEdEgCf2IiAZJ6EdENEhCPyKiQRL6ERENktCPiGiQhH5ERIMk9CMiGiShHxHRIAn9iIgGSehHRDRIQj8iokES+hERDZLQj4hokIR+RESDJPQjIhokoR8R0SAJ/YiIBknoR0Q0SE+hL2lPSVdIukvSSkm/J2mGpOsk3V3/7tUy/ZmSVktaJemY3suPiIix6PVI/1PAv9j+HeAwYCVwBrDM9lxgWb2PpIOB+cAhwLHAeZKm9bj+iIgYg3GHvqQ9gFcDFwLY/o3th4F5wOI62WLg+Do8D7jc9mO27wVWA0eMd/0RETF2vRzpPx/YCFwk6YeSPi9pV2Bf2+sA6t996vSzgDUt8w/VtqeRtFDScknLN27c2EOJERHRqpfQnw68DDjf9kuBX1K7crpQhzZ3mtD2BbYHbQ8ODAz0UGJERLTqJfSHgCHbN9b7V1DeBNZLmglQ/25omX7/lvlnA2t7WH9ERIzRuEPf9s+ANZJeXJuOBu4ElgILatsC4Oo6vBSYL2lHSQcBc4Gbxrv+iIgYu+k9zv/nwJckPRu4B/gTyhvJEkknA/cDJwDYXiFpCeWN4XHgFNtP9Lj+iIgYg55C3/atwGCHUUd3mX4RsKiXdUZExPjlF7kREQ2S0I+IaJCEfkREgyT0IyIaJKEfEdEgCf2IiAZJ6EdENEhCPyKiQRL6ERENktCPiGiQhH5ERIMk9CMiGiShHxHRIAn9iIgGSehHRDRIQj8iokES+hERDZLQj4hokIR+RESDJPQjIhokoR8R0SA9h76kaZJ+KOnr9f4MSddJurv+3atl2jMlrZa0StIxva47IiLGZiKO9E8FVrbcPwNYZnsusKzeR9LBwHzgEOBY4DxJ0yZg/RERMUo9hb6k2cAbgM+3NM8DFtfhxcDxLe2X237M9r3AauCIXtYfERFj0+uR/ieB04EnW9r2tb0OoP7dp7bPAta0TDdU255G0kJJyyUt37hxY48lRkTEsHGHvqTjgA22bx7tLB3a3GlC2xfYHrQ9ODAwMN4SIyKizfQe5n0V8CZJrwd2AvaQdCmwXtJM2+skzQQ21OmHgP1b5p8NrO1h/RERMUbjPtK3fabt2bbnUE7QftP2ScBSYEGdbAFwdR1eCsyXtKOkg4C5wE3jrjwiIsaslyP9bs4Blkg6GbgfOAHA9gpJS4A7gceBU2w/8QysPyIiupiQ0Ld9PXB9HX4AOLrLdIuARROxzoiIGLv8IjciokES+hERDZLQj4hokIR+RESDJPQjIhokoR8R0SAJ/YiIBnkmfpwVU2jOGd+YkvXed84bpmS9ETE2OdKPiGiQhH5ERIMk9CMiGiShHxHRIAn9iIgGSehHRDRIQj8iokES+hERDZLQj4hokIR+RESD5DIMMSGm6vIPkEtARIxFQj+2ebneUMTopXsnIqJBxh36kvaX9C1JKyWtkHRqbZ8h6TpJd9e/e7XMc6ak1ZJWSTpmIh5ARESMXi9H+o8Dp9n+XeBI4BRJBwNnAMtszwWW1fvUcfOBQ4BjgfMkTeul+IiIGJtxh77tdbZvqcObgJXALGAesLhOthg4vg7PAy63/Zjte4HVwBHjXX9ERIzdhPTpS5oDvBS4EdjX9joobwzAPnWyWcCaltmGalun5S2UtFzS8o0bN05EiRERwQSEvqTdgCuB99t+ZKRJO7S504S2L7A9aHtwYGCg1xIjIqLqKfQl7UAJ/C/Z/mptXi9pZh0/E9hQ24eA/Vtmnw2s7WX9ERExNr18e0fAhcBK259oGbUUWFCHFwBXt7TPl7SjpIOAucBN411/RESMXS8/znoV8Hbgdkm31rYPAecASySdDNwPnABge4WkJcCdlG/+nGL7iR7WHxERYzTu0Ld9A5376QGO7jLPImDReNcZERG9yS9yIyIaJKEfEdEgCf2IiAZJ6EdENEhCPyKiQRL6ERENktCPiGiQhH5ERIMk9CMiGiShHxHRIPnH6BHjNFX/kB3yT9lj/HKkHxHRIAn9iIgGSfdOxDZoqrqW0q207cuRfkREgyT0IyIaJKEfEdEgCf2IiAZJ6EdENEhCPyKiQRL6ERENktCPiGiQSQ99ScdKWiVptaQzJnv9ERFNNqm/yJU0Dfgs8IfAEPADSUtt3zmZdUTE+OSXwNu+yT7SPwJYbfse278BLgfmTXINERGNNdnX3pkFrGm5PwS8sn0iSQuBhfXuZkmrxrm+vYGfj3PeyZQ6J962UmvqHAWdO6bJs02LAzs1Tnboq0Obn9ZgXwBc0PPKpOW2B3tdzjMtdU68baXW1DnxtpVap6rOye7eGQL2b7k/G1g7yTVERDTWZIf+D4C5kg6S9GxgPrB0kmuIiGisSe3esf24pPcC/wpMA75ge8UzuMqeu4gmSeqceNtKralz4m0rtU5JnbKf1qUeERHbqfwiNyKiQRL6ERENst2GvqT7JN0u6VZJy6e6nmGSviBpg6Q7WtpmSLpO0t31715TWWOtqVOdZ0v6ad2mt0p6/VTWWGvaX9K3JK2UtELSqbW9r7bpCHX24zbdSdJNkm6rtf5Nbe+3bdqtzr7bplCuSCDph5K+Xu9Pyfbcbvv0Jd0HDNruqx9pSHo1sBm4xPahte1jwIO2z6nXI9rL9gf7sM6zgc22/34qa2slaSYw0/YtknYHbgaOB95JH23TEep8K/23TQXsanuzpB2AG4BTgT+iv7ZptzqPpc+2KYCkvwQGgT1sHzdVr/vt9ki/X9n+DvBgW/M8YHEdXkwJgynVpc6+Y3ud7Vvq8CZgJeWX3321TUeos++42Fzv7lBvpv+2abc6+46k2cAbgM+3NE/J9tyeQ9/AtZJurpd16Gf72l4HJRyAfaa4npG8V9KPavfPlHdDtZI0B3gpcCN9vE3b6oQ+3Ka1K+JWYANwne2+3KZd6oT+26afBE4Hnmxpm5LtuT2H/qtsvwx4HXBK7a6I3pwPvAA4HFgH/O8praaFpN2AK4H3235kquvppkOdfblNbT9h+3DKr+aPkHToFJfUUZc6+2qbSjoO2GD75qmsY9h2G/q219a/G4CvUa7w2a/W1z7f4b7fDVNcT0e219cX2ZPAP9En27T2514JfMn2V2tz323TTnX26zYdZvth4HpKP3nfbdNhrXX24TZ9FfCmep7xcuC/SbqUKdqe22XoS9q1nixD0q7Aa4E7Rp5rSi0FFtThBcDVU1hLV8M7aPVm+mCb1pN5FwIrbX+iZVRfbdNudfbpNh2QtGcd3hl4DXAX/bdNO9bZb9vU9pm2Z9ueQ7n0zDdtn8QUbc/t8ts7kp5PObqHcqmJL9teNIUl/SdJlwFHUS6ruh44C7gKWAIcANwPnGB7Sk+idqnzKMpHZgP3Ae8a7pOcKpL+K/Bd4HZ+21/6IUp/ed9s0xHqPJH+26YvoZxYnEY5MFxi+yOSnkt/bdNudX6RPtumwyQdBXygfntnSrbndhn6ERHR2XbZvRMREZ0l9CMiGiShHxHRIAn9iIgGSehHRDRIQj8iokES+hERDfL/AWHbMdetP2zZAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(summary_len)\n",
    "plt.title('Facts Length Distribution - First 100 examples')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.138817Z",
     "start_time": "2020-09-01T13:44:39.136486Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Length of text:  17.704538966600058\n"
     ]
    }
   ],
   "source": [
    "print(\"Average Length of text: \", sum(text_len)/len(text_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.143573Z",
     "start_time": "2020-09-01T13:44:39.139797Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Length of Summary:  11.495860690836427\n"
     ]
    }
   ],
   "source": [
    "print(\"Average Length of Summary: \", sum(summary_len)/len(summary_len))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.162456Z",
     "start_time": "2020-09-01T13:44:39.144570Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "#     if torch.cuda.is_available():\n",
    "#         torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.185299Z",
     "start_time": "2020-09-01T13:44:39.163413Z"
    }
   },
   "outputs": [],
   "source": [
    "class T5FineTuner(pl.LightningModule):\n",
    "    def __init__(self, hparams):\n",
    "        super(T5FineTuner, self).__init__()\n",
    "        self.hparams = hparams        \n",
    "        self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)\n",
    "        self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path)\n",
    "#         self.rouge_metric = load_metric('rouge') \n",
    "        \n",
    "        if self.hparams.freeze_embeds:\n",
    "            self.freeze_embeds()\n",
    "        if self.hparams.freeze_encoder:\n",
    "            self.freeze_params(self.model.get_encoder())\n",
    "            assert_all_frozen(self.model.get_encoder())\n",
    "            \n",
    "            \n",
    "        n_observations_per_split = {\n",
    "            \"train\": self.hparams.n_train,\n",
    "            \"validation\": self.hparams.n_val,\n",
    "            \"test\": self.hparams.n_test,\n",
    "        }\n",
    "        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}\n",
    "        \n",
    "    \n",
    "    def freeze_params(self, model):\n",
    "        for par in model.parameters():\n",
    "            par.requires_grad = False\n",
    "            \n",
    "            \n",
    "    def freeze_embeds(self):\n",
    "        \"\"\"Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.\"\"\"\n",
    "        try:\n",
    "            self.freeze_params(self.model.model.shared)\n",
    "            for d in [self.model.model.encoder, self.model.model.decoder]:\n",
    "                freeze_params(d.embed_positions)\n",
    "                freeze_params(d.embed_tokens)\n",
    "        except AttributeError:\n",
    "            self.freeze_params(self.model.shared)\n",
    "            for d in [self.model.encoder, self.model.decoder]:\n",
    "                self.freeze_params(d.embed_tokens)\n",
    "    \n",
    "    def lmap(self, f, x):\n",
    "        \"\"\"list(map(f, x))\"\"\"\n",
    "        return list(map(f, x))\n",
    "    \n",
    "\n",
    "    def is_logger(self):\n",
    "        return self.trainer.global_rank <= 0\n",
    "    \n",
    "    \n",
    "    def parse_score(self, result):\n",
    "        return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}\n",
    "        \n",
    "    def forward(\n",
    "      self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None\n",
    "  ):\n",
    "        return self.model(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            decoder_input_ids=decoder_input_ids,\n",
    "            decoder_attention_mask=decoder_attention_mask,\n",
    "            lm_labels=lm_labels,\n",
    "    )\n",
    "\n",
    "    def _step(self, batch):\n",
    "        lm_labels = batch[\"target_ids\"]\n",
    "        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100\n",
    "\n",
    "        outputs = self(\n",
    "            input_ids=batch[\"source_ids\"],\n",
    "            attention_mask=batch[\"source_mask\"],\n",
    "            lm_labels=lm_labels,\n",
    "            decoder_attention_mask=batch['target_mask']\n",
    "        )\n",
    "\n",
    "        loss = outputs[0]\n",
    "\n",
    "        return loss\n",
    "    \n",
    "    \n",
    "    def ids_to_clean_text(self, generated_ids):\n",
    "        gen_text = self.tokenizer.batch_decode(\n",
    "            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True\n",
    "        )\n",
    "        return self.lmap(str.strip, gen_text)\n",
    "    \n",
    "    \n",
    "    def _generative_step(self, batch) :\n",
    "        \n",
    "        t0 = time.time()\n",
    "        \n",
    "        generated_ids = self.model.generate(\n",
    "            batch[\"source_ids\"],\n",
    "            attention_mask=batch[\"source_mask\"],\n",
    "            use_cache=True,\n",
    "            decoder_attention_mask=batch['target_mask'],\n",
    "            max_length=150, \n",
    "            num_beams=2,\n",
    "            repetition_penalty=2.5, \n",
    "            length_penalty=1.0, \n",
    "            early_stopping=True\n",
    "        )\n",
    "        preds = self.ids_to_clean_text(generated_ids)\n",
    "        target = self.ids_to_clean_text(batch[\"target_ids\"])\n",
    "            \n",
    "        gen_time = (time.time() - t0) / batch[\"source_ids\"].shape[0]  \n",
    "    \n",
    "        loss = self._step(batch)\n",
    "        base_metrics = {'val_loss': loss}\n",
    "#         rouge: Dict = self.calc_generative_metrics(preds, target)\n",
    "        summ_len = np.mean(self.lmap(len, generated_ids))\n",
    "        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target)\n",
    "#         self.rouge_metric.add_batch(preds, target)\n",
    "        \n",
    "#         rouge_results = self.rouge_metric.compute() \n",
    "#         rouge_dict = self.parse_score(rouge_results)\n",
    "#         base_metrics.update(rouge1=rouge_dict['rouge1'], rougeL=rouge_dict['rougeL'])\n",
    "        \n",
    "        return base_metrics\n",
    "    \n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss = self._step(batch)\n",
    "\n",
    "        tensorboard_logs = {\"train_loss\": loss}\n",
    "        return {\"loss\": loss, \"log\": tensorboard_logs}\n",
    "  \n",
    "    def training_epoch_end(self, outputs):\n",
    "        avg_train_loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n",
    "        tensorboard_logs = {\"avg_train_loss\": avg_train_loss}\n",
    "        return {\"avg_train_loss\": avg_train_loss, \"log\": tensorboard_logs, 'progress_bar': tensorboard_logs}\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        return self._generative_step(batch)\n",
    "    \n",
    "  \n",
    "    def validation_epoch_end(self, outputs):\n",
    "        \n",
    "        avg_loss = torch.stack([x[\"val_loss\"] for x in outputs]).mean()\n",
    "        tensorboard_logs = {\"val_loss\": avg_loss}\n",
    "        \n",
    "#         rouge_results = self.rouge_metric.compute() \n",
    "#         rouge_dict = self.parse_score(rouge_results)\n",
    "    \n",
    "#         tensorboard_logs.update(rouge1=rouge_dict['rouge1'], rougeL=rouge_dict['rougeL'])\n",
    "        \n",
    "        ## Clear out the lists for next epoch\n",
    "        self.target_gen= []\n",
    "        self.prediction_gen=[]\n",
    "        return {\"avg_val_loss\": avg_loss, \n",
    "#                 \"rouge1\" : rouge_results['rouge1'],\n",
    "#                 \"rougeL\" : rouge_results['rougeL'],\n",
    "                \"log\": tensorboard_logs, 'progress_bar': tensorboard_logs}\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
    "\n",
    "        model = self.model\n",
    "        no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
    "        optimizer_grouped_parameters = [\n",
    "            {\n",
    "                \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
    "                \"weight_decay\": self.hparams.weight_decay,\n",
    "            },\n",
    "            {\n",
    "                \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
    "                \"weight_decay\": 0.0,\n",
    "            },\n",
    "        ]\n",
    "        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
    "        self.opt = optimizer\n",
    "        return [optimizer]\n",
    "  \n",
    "    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, \n",
    "                       optimizer_closure=None, on_tpu=None, using_lbfgs=None, # TODO: ADDED HERE\n",
    "                       second_order_closure=None, using_native_amp=False):\n",
    "#         if self.trainer.use_tpu:\n",
    "#             xm.optimizer_step(optimizer)\n",
    "#         else:\n",
    "        optimizer.step(closure=optimizer_closure)\n",
    "        optimizer.zero_grad()\n",
    "        self.lr_scheduler.step()\n",
    "  \n",
    "    def get_tqdm_dict(self):\n",
    "        tqdm_dict = {\"loss\": \"{:.3f}\".format(self.trainer.avg_loss), \"lr\": self.lr_scheduler.get_last_lr()[-1]}\n",
    "\n",
    "        return tqdm_dict\n",
    "    \n",
    "\n",
    "    def train_dataloader(self):   \n",
    "        n_samples = self.n_obs['train']\n",
    "        train_dataset = get_dataset(tokenizer=self.tokenizer, type_path=\"train\", num_samples=n_samples, args=self.hparams)\n",
    "        dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=0)\n",
    "        t_total = (\n",
    "            (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))\n",
    "            // self.hparams.gradient_accumulation_steps\n",
    "            * float(self.hparams.num_train_epochs)\n",
    "        )\n",
    "        scheduler = get_linear_schedule_with_warmup(\n",
    "            self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total\n",
    "        )\n",
    "        self.lr_scheduler = scheduler\n",
    "        return dataloader\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        n_samples = self.n_obs['validation']\n",
    "        validation_dataset = get_dataset(tokenizer=self.tokenizer, type_path=\"validation\", num_samples=n_samples, args=self.hparams)\n",
    "        \n",
    "        return DataLoader(validation_dataset, batch_size=self.hparams.eval_batch_size, num_workers=0)\n",
    "    \n",
    "    \n",
    "    def test_dataloader(self):\n",
    "        n_samples = self.n_obs['test']\n",
    "        test_dataset = get_dataset(tokenizer=self.tokenizer, type_path=\"test\", num_samples=n_samples, args=self.hparams)\n",
    "        \n",
    "        return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.192404Z",
     "start_time": "2020-09-01T13:44:39.186281Z"
    }
   },
   "outputs": [],
   "source": [
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "class LoggingCallback(pl.Callback):\n",
    "    def on_validation_end(self, trainer, pl_module):\n",
    "        logger.info(\"***** Validation results *****\")\n",
    "        if pl_module.is_logger():\n",
    "            metrics = trainer.callback_metrics\n",
    "            # Log results\n",
    "            for key in sorted(metrics):\n",
    "                if key not in [\"log\", \"progress_bar\"]:\n",
    "                    logger.info(\"{} = {}\\n\".format(key, str(metrics[key])))\n",
    "\n",
    "    def on_test_end(self, trainer, pl_module):\n",
    "        logger.info(\"***** Test results *****\")\n",
    "\n",
    "        if pl_module.is_logger():\n",
    "            metrics = trainer.callback_metrics\n",
    "\n",
    "            # Log and save results to file\n",
    "            output_test_results_file = os.path.join(pl_module.hparams.output_dir, \"test_results.txt\")\n",
    "            with open(output_test_results_file, \"w\") as writer:\n",
    "                for key in sorted(metrics):\n",
    "                    if key not in [\"log\", \"progress_bar\"]:\n",
    "                        logger.info(\"{} = {}\\n\".format(key, str(metrics[key])))\n",
    "                        writer.write(\"{} = {}\\n\".format(key, str(metrics[key])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a DataSet class for the loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.203467Z",
     "start_time": "2020-09-01T13:44:39.193447Z"
    }
   },
   "outputs": [],
   "source": [
    "class wiki_nexkb(Dataset):\n",
    "    def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):         \n",
    "        self.dataset = DATASET[type_path]\n",
    "        if num_samples:\n",
    "            self.dataset = self.dataset[:num_samples]\n",
    "        self.input_length = input_length\n",
    "        self.tokenizer = tokenizer\n",
    "        self.output_length = output_length\n",
    "        self.print_text = print_text\n",
    "  \n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "    \n",
    "    def clean_text(self, text):        \n",
    "        text = text.replace('\\n','')\n",
    "        text = text.replace('``', '')\n",
    "        text = text.replace('\"', '')        \n",
    "        return text\n",
    "    \n",
    "    \n",
    "    def convert_to_features(self, example_batch):\n",
    "        # Tokenize contexts and questions (as pairs of inputs)\n",
    "        \n",
    "        if self.print_text:\n",
    "            print(\"Input Text: \", self.clean_text(example_batch['text']))\n",
    "#         input_ = self.clean_text(example_batch['text']) + \" </s>\"\n",
    "#         target_ = self.clean_text(example_batch['headline']) + \" </s>\"\n",
    "        \n",
    "        input_ = self.clean_text(example_batch['text'])\n",
    "        target_ = self.clean_text(example_batch['facts'])\n",
    "        \n",
    "        source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length, \n",
    "                                                     padding='max_length', truncation=True, return_tensors=\"pt\")\n",
    "        \n",
    "        targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length, \n",
    "                                                     padding='max_length', truncation=True, return_tensors=\"pt\")\n",
    "    \n",
    "       \n",
    "        return source, targets\n",
    "  \n",
    "    def __getitem__(self, index):\n",
    "        source, targets = self.convert_to_features(self.dataset[index])\n",
    "        \n",
    "        source_ids = source[\"input_ids\"].squeeze()\n",
    "        target_ids = targets[\"input_ids\"].squeeze()\n",
    "\n",
    "        src_mask    = source[\"attention_mask\"].squeeze()\n",
    "        target_mask = targets[\"attention_mask\"].squeeze()\n",
    "\n",
    "        return {\"source_ids\": source_ids, \"source_mask\": src_mask, \"target_ids\": target_ids, \"target_mask\": target_mask}\n",
    "        \n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test the dataset function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:55:41.647100Z",
     "start_time": "2020-09-02T18:55:41.023703Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "350"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained('t5-small')\n",
    "dataset = wiki_nexkb(tokenizer, 'validation', None, 64, 64, True)\n",
    "len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T18:55:43.381341Z",
     "start_time": "2020-09-02T18:55:43.366366Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input Text:  Norway is a country in the north of Europe.\n",
      "\n",
      "Shape of Tokenized Text:  torch.Size([64])\n",
      "\n",
      "Sanity check - Decode Text:  Norway is a country in the north of Europe.\n",
      "====================================\n",
      "Sanity check - Decode Summary:  geographicalsubregionsofcontinent [S] continent of europe [S] territory fn [S] norway\n"
     ]
    }
   ],
   "source": [
    "data = dataset[50]\n",
    "print()\n",
    "print(\"Shape of Tokenized Text: \", data['source_ids'].shape)\n",
    "print()\n",
    "print(\"Sanity check - Decode Text: \", tokenizer.decode(data['source_ids']))\n",
    "print(\"====================================\")\n",
    "print(\"Sanity check - Decode Summary: \", tokenizer.decode(data['target_ids']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.830252Z",
     "start_time": "2020-09-01T13:44:39.824446Z"
    }
   },
   "outputs": [],
   "source": [
    "args_dict = dict(\n",
    "    output_dir=\"\", # path to save the checkpoints\n",
    "    model_name_or_path='t5-small',\n",
    "    tokenizer_name_or_path='t5-small',\n",
    "    max_input_length=64,\n",
    "    max_output_length=64,\n",
    "    freeze_encoder=False,\n",
    "    freeze_embeds=False,\n",
    "    learning_rate=3e-4,\n",
    "    weight_decay=0.0,\n",
    "    adam_epsilon=1e-8,\n",
    "    warmup_steps=0,\n",
    "    train_batch_size=4,\n",
    "    eval_batch_size=4,\n",
    "    num_train_epochs=2,\n",
    "#     gradient_accumulation_steps=8,\n",
    "    gradient_accumulation_steps=1,\n",
    "#     n_gpu=1,\n",
    "    n_gpu=0,\n",
    "    resume_from_checkpoint=None, \n",
    "    val_check_interval = 0.05, \n",
    "    n_val=-1,\n",
    "    n_train=-1,\n",
    "    n_test=-1,\n",
    "#     early_stop_callback=False,\n",
    "    fp_16=False, # if you want to enable 16-bit training then install apex and set this to true\n",
    "    opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties\n",
    "    max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default\n",
    "    seed=42,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:28:40.222324Z",
     "start_time": "2020-09-01T13:28:40.103545Z"
    }
   },
   "source": [
    "!mkdir -p t5_wikihow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.836856Z",
     "start_time": "2020-09-01T13:44:39.831154Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'output_dir': 't5_relation_extraction', 'model_name_or_path': 't5-small', 'tokenizer_name_or_path': 't5-small', 'max_input_length': 64, 'max_output_length': 64, 'freeze_encoder': False, 'freeze_embeds': False, 'learning_rate': 0.0003, 'weight_decay': 0.0, 'adam_epsilon': 1e-08, 'warmup_steps': 0, 'train_batch_size': 4, 'eval_batch_size': 4, 'num_train_epochs': 2, 'gradient_accumulation_steps': 1, 'n_gpu': 0, 'resume_from_checkpoint': None, 'val_check_interval': 0.05, 'n_val': -1, 'n_train': -1, 'n_test': -1, 'fp_16': False, 'opt_level': 'O1', 'max_grad_norm': 1.0, 'seed': 42}\n"
     ]
    }
   ],
   "source": [
    "args_dict.update({'output_dir': 't5_relation_extraction', 'num_train_epochs':2,\n",
    "                 'train_batch_size': 4, 'eval_batch_size': 4})\n",
    "args = argparse.Namespace(**args_dict)\n",
    "print(args_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.843602Z",
     "start_time": "2020-09-01T13:44:39.837733Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accumulate_grad_batches': 1, 'gpus': 0, 'max_epochs': 2, 'precision': 32, 'amp_level': 'O1', 'resume_from_checkpoint': None, 'gradient_clip_val': 1.0, 'checkpoint_callback': <pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x000002A7E1D13438>, 'val_check_interval': 0.05, 'callbacks': [<__main__.LoggingCallback object at 0x000002A7E1D135C0>]}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:50: UserWarning: Checkpoint directory C:\\Users\\danil\\Documents\\Northwestern\\Research\\projects\\relation_extraction exists and is not empty.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "## Define Checkpoint function\n",
    "checkpoint_callback = pl.callbacks.ModelCheckpoint(\n",
    "    filepath=args.output_dir, prefix=\"checkpoint\", monitor=\"val_loss\", mode=\"min\", save_top_k=3\n",
    ")\n",
    "\n",
    "## If resuming from checkpoint, add an arg resume_from_checkpoint\n",
    "train_params = dict(\n",
    "    accumulate_grad_batches=args.gradient_accumulation_steps,\n",
    "    gpus=args.n_gpu,\n",
    "    max_epochs=args.num_train_epochs,\n",
    "#     early_stop_callback=False,\n",
    "    precision= 16 if args.fp_16 else 32,\n",
    "    amp_level=args.opt_level,\n",
    "    resume_from_checkpoint=args.resume_from_checkpoint,\n",
    "    gradient_clip_val=args.max_grad_norm,\n",
    "    checkpoint_callback=checkpoint_callback,\n",
    "    val_check_interval=args.val_check_interval,\n",
    "#     logger=wandb_logger,\n",
    "    callbacks=[LoggingCallback()],\n",
    ")\n",
    "print(train_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:39.853037Z",
     "start_time": "2020-09-01T13:44:39.844578Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_dataset(tokenizer, type_path, num_samples, args):        \n",
    "    return wiki_nexkb(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples,  input_length=args.max_input_length, \n",
    "                        output_length=args.max_output_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:42.048444Z",
     "start_time": "2020-09-01T13:44:39.854009Z"
    }
   },
   "outputs": [],
   "source": [
    "model = T5FineTuner(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T13:44:42.072260Z",
     "start_time": "2020-09-01T13:44:42.053193Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: False\n",
      "TPU available: None, using: 0 TPU cores\n",
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:50: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(**train_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T20:25:25.960117Z",
     "start_time": "2020-09-01T13:44:42.076449Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  | Name  | Type                       | Params\n",
      "-----------------------------------------------------\n",
      "0 | model | T5ForConditionalGeneration | 60.5 M\n",
      "-----------------------------------------------------\n",
      "60.5 M    Trainable params\n",
      "0         Non-trainable params\n",
      "60.5 M    Total params\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\modeling_t5.py:1156: FutureWarning: The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.\n",
      "  FutureWarning,\n",
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:50: UserWarning: The validation_epoch_end should not return anything as of 9.1. To log, use self.log(...) or self.write(...) directly in the LightningModule\n",
      "  warnings.warn(*args, **kwargs)\n",
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:50: UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0\n",
      "Please use self.log(...) inside the lightningModule instead.\n",
      "\n",
      "# log on a step or aggregate epoch metric to the logger and/or progress bar\n",
      "# (inside LightningModule)\n",
      "self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)\n",
      "  warnings.warn(*args, **kwargs)\n",
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:50: UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0\n",
      "Please use self.log(...) inside the lightningModule instead.\n",
      "\n",
      "# log on a step or aggregate epoch metric to the logger and/or progress bar\n",
      "# (inside LightningModule)\n",
      "self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df68ce9cd6b1487f874fe1dac97e55db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.model.save_pretrained('C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\relation_extraction\\\\t5_model_1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check Model Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T00:59:41.863404Z",
     "start_time": "2020-09-02T00:59:41.859895Z"
    }
   },
   "outputs": [],
   "source": [
    "import textwrap\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T01:00:09.047610Z",
     "start_time": "2020-09-02T01:00:08.417634Z"
    }
   },
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained('t5-base')\n",
    "dataset = wiki_nexkb(tokenizer, 'test', None, 64, 64, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T01:00:29.151446Z",
     "start_time": "2020-09-02T01:00:29.143943Z"
    }
   },
   "outputs": [],
   "source": [
    "loader = DataLoader(dataset, batch_size=6, shuffle=False)\n",
    "it = iter(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T01:00:34.602089Z",
     "start_time": "2020-09-02T01:00:34.443983Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 64])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = next(it)\n",
    "batch[\"source_ids\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T01:02:59.809227Z",
     "start_time": "2020-09-02T01:02:51.326688Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['They make decisions, organize the people who work there and make sure that things are working alright and there are no problems.', 'All arthropods have jointed feet, a segmented body, and an exoskeleton, a shell outside of the body.', 'There are also many small glands in the tongue, cheeks, lips and palate.', 'It is the way people see and understand something.', 'Squid, like cuttlefish, have eight arms and two tentacles arranged in pairs.', 'In the 1800s, the machine gun was invented, which could shoot many bullets very fast.']\n",
      "['', '', '', '', '', '']\n",
      "['covering [S] situation [S] the covering [S] event [S] static situation', 'relationallexists [S] object inserted [S] body parts used', 'properphysicalparttypes [S] tongue [S] cheek', 'covering [S] situation [S] the covering [S] event [S] static situation', 'coextensional [S] list [S] list of type fn [S] thing', 'covering [S] situation [S] the covering [S] event [S] static situation']\n"
     ]
    }
   ],
   "source": [
    "# model.to('cuda')\n",
    "outs = model.model.generate(\n",
    "#             batch[\"source_ids\"].cuda(),\n",
    "            batch[\"source_ids\"],\n",
    "#             attention_mask=batch[\"source_mask\"].cuda(),\n",
    "            attention_mask=batch[\"source_mask\"],\n",
    "            use_cache=True,\n",
    "#             decoder_attention_mask=batch['target_mask'].cuda(),\n",
    "            decoder_attention_mask=batch['target_mask'],\n",
    "            max_length=150, \n",
    "            num_beams=2,\n",
    "            repetition_penalty=2.5, \n",
    "            length_penalty=1.0, \n",
    "            early_stopping=True\n",
    "        )\n",
    "\n",
    "dec = [tokenizer.decode(ids) for ids in outs]\n",
    "\n",
    "texts = [tokenizer.decode(ids) for ids in batch['source_ids']]\n",
    "targets = [tokenizer.decode(ids) for ids in batch['target_ids']]\n",
    "\n",
    "\n",
    "print(texts)\n",
    "print(targets)\n",
    "print(dec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-02T01:03:57.193269Z",
     "start_time": "2020-09-02T01:03:57.138546Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: They make decisions, organize the people who work there and make sure that things are\n",
      "working alright and there are no problems.\n",
      "\n",
      "Actual facts: \n",
      "\n",
      "Predicted facts: covering [S] situation [S] the covering [S] event [S] static situation\n",
      "=====================================================================\n",
      "\n",
      "Sentence: All arthropods have jointed feet, a segmented body, and an exoskeleton, a shell outside of\n",
      "the body.\n",
      "\n",
      "Actual facts: \n",
      "\n",
      "Predicted facts: relationallexists [S] object inserted [S] body parts used\n",
      "=====================================================================\n",
      "\n",
      "Sentence: There are also many small glands in the tongue, cheeks, lips and palate.\n",
      "\n",
      "Actual facts: \n",
      "\n",
      "Predicted facts: properphysicalparttypes [S] tongue [S] cheek\n",
      "=====================================================================\n",
      "\n",
      "Sentence: It is the way people see and understand something.\n",
      "\n",
      "Actual facts: \n",
      "\n",
      "Predicted facts: covering [S] situation [S] the covering [S] event [S] static situation\n",
      "=====================================================================\n",
      "\n",
      "Sentence: Squid, like cuttlefish, have eight arms and two tentacles arranged in pairs.\n",
      "\n",
      "Actual facts: \n",
      "\n",
      "Predicted facts: coextensional [S] list [S] list of type fn [S] thing\n",
      "=====================================================================\n",
      "\n",
      "Sentence: In the 1800s, the machine gun was invented, which could shoot many bullets very fast.\n",
      "\n",
      "Actual facts: \n",
      "\n",
      "Predicted facts: covering [S] situation [S] the covering [S] event [S] static situation\n",
      "=====================================================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# for i in range(32):\n",
    "for i in range(6):\n",
    "    lines = textwrap.wrap(\"Sentence:\\n%s\\n\" % texts[i], width=100)\n",
    "    print(\"\\n\".join(lines))\n",
    "    print(\"\\nActual facts: %s\" % targets[i])\n",
    "    print(\"\\nPredicted facts: %s\" % dec[i])\n",
    "    print(\"=====================================================================\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing using articles from wiki"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "class simple_wiki_text(wiki_nexkb):\n",
    "    def __init__(self, tokenizer, input_length, output_length):   \n",
    "        super().__init__(tokenizer, \"test\", None, input_length, output_length)\n",
    "        self.dataset = [\n",
    "            {'text': \"Toes are the \\\"digits\\\" of the foot of an animal.\", 'facts': ''},\n",
    "            {'text': \"After a tree falls, the wood in it can be cut into long, straight pieces called lumber.\", 'facts': ''},\n",
    "            {'text': \"They may feel they do not have the power to stop the stalker.\", 'facts': ''},\n",
    "            {'text': \"The state is also bordered by the North Sea to the east, the Atlantic Ocean to the west and also the Irish Sea.\", 'facts': ''},\n",
    "            {'text': \"Alexander Fleming discovered penicillin, one of the best-known antibiotics.\", 'facts': ''},\n",
    "            {'text': \"it might keep them safe when operating machinery or keep them clean when doing their work.\", 'facts': ''},\n",
    "            {'text': \"It is used as a surface to write on.\", 'facts': ''},\n",
    "            {'text': \"In the last years of their riegn, the Hafsids became weak and Spain took control of many city coasts until they were finally occupied by the Ottaman Empire.\", 'facts': ''},\n",
    "            {'text': \"He had some tobacco seeds sent to Paris for analysis (As a drug).\", 'facts': ''},\n",
    "            {'text': \"The card is a small, flat paper or plastic object with text and figures.\", 'facts': ''},\n",
    "            {'text': \"The leaves have sharp edges, and are often used to decorate a house on Christmas Day.\", 'facts': ''},\n",
    "            {'text': \"Carpets can be many different sizes.\", 'facts': ''},\n",
    "            {'text': \"In geometry, the radius of a circle or sphere is the shortest connection between the center and the boundary.\", 'facts': ''},\n",
    "            {'text': \"They live in the countryside, but also in cities too.\", 'facts': ''},\n",
    "            {'text': \"The head of state is the \\\"Bundespräsident\\\" (Federal President).\", 'facts': ''},\n",
    "            {'text': \"It looks to see if a certain problem can be solved by a computer.\", 'facts': ''},\n",
    "            {'text': \"In January 2001, producer Joel Silver asked Todd Alcott to write a \\\"Wonder Woman\\\" screenplay.\", 'facts': ''},\n",
    "            {'text': \"There was a slightly different Emu species that lived in Tasmania.\", 'facts': ''},\n",
    "            {'text': \"The nucleus is protected by the nuclear envelope,and lets things out through the nuclear pores.\", 'facts': ''},\n",
    "            {'text': \"The difference is not the population, it is how much power the different sorts of places have, and what they do for people living there.\", 'facts': ''},\n",
    "            {'text': \"Statements are made when one remarks with critique, explaining ones opinion.\", 'facts': ''},\n",
    "            {'text': \"Following a 1969 military coup, Col. Muammar Abu Minyar al-Qadhafi began to espouse his political system named, \\\"The Third Universal Theory\\\".\", 'facts': ''},\n",
    "            {'text': \"Dragon's blood is a resin used in dyes, varnishes and incense, can come from the fruit of the rattan.\", 'facts': ''},\n",
    "            {'text': \"Much boron is found in chemical compounds in its ore borax.\", 'facts': ''},\n",
    "            {'text': \"An adverb can also modify (describe) an adjective or another adverb.\", 'facts': ''},\n",
    "            {'text': \"In many languages, for example, the word “put” will be different according to whether something is being put onto something (e.g.\", 'facts': ''},\n",
    "            {'text': \"Inside the Earth is similar to the other terrestrial planets.\", 'facts': ''},\n",
    "            {'text': \"Many sharks are now endangered, but some are still hunted for food (like shark fin soup) or sport fishing.\", 'facts': ''},\n",
    "            {'text': \"Other larger cities in Manitoba include Steinbach and Brandon.\", 'facts': ''},\n",
    "            {'text': \"Such large mammals as pigs, bears, and deer also consume large amounts of acorns: they may constitute up to 25% of the diet of deer in the autumn.\", 'facts': ''},\n",
    "            {'text': \"The \\\"Rule of Law\\\" is the law that says that Government can only legally use its power in the way the government and the people agree.\", 'facts': ''},\n",
    "            {'text': \"Scientists have a good idea of what they looked like, because of the bones that have been found.\", 'facts': ''},\n",
    "            {'text': \"Magic can mean many things.\", 'facts': ''},\n",
    "            {'text': \"Heaven is a concept of the afterlife (what happens after you die) in many religions.\", 'facts': ''},\n",
    "            {'text': \"They can be of many different sizes: small enough to hold in one hand, or large enough to fire shells to sink a warship.\", 'facts': ''},\n",
    "            {'text': \"The three angles of a triangle add to 180 degrees.\", 'facts': ''},\n",
    "            {'text': \"The only people allowed in the operating rooms are the doctors and nurses as well as the patient too.\", 'facts': ''},\n",
    "            {'text': \"Sugar can be one of many compounds.\", 'facts': ''},\n",
    "            {'text': \"Once the snake's prey is in its body, its internal muscles crush it so the animal can be digested.\", 'facts': ''},\n",
    "            {'text': \"Examples of emergencies include serious broken bones, chest pain, serious head injuries, and people injured in situations like car crashes.\", 'facts': ''},\n",
    "            {'text': \"The orchestra members do not have to dress up in ties and jackets because they cannot be seen anyway.\", 'facts': ''},\n",
    "            {'text': \"In the Netherlands, the marriage law was changed to allow such unions, called \\\"Marriage\\\".\", 'facts': ''},\n",
    "            {'text': \"During pregnancy, the endometrium develops a lot of glands and blood vessels.\", 'facts': ''},\n",
    "            {'text': \"Milk powder is often used in countries that lack widespread access to refrigeration.\", 'facts': ''},\n",
    "            {'text': \"A molecule is the smallest amount of a chemical substance that can exist.\", 'facts': ''},\n",
    "            {'text': \"This form of title is known as the Order of the British Empire.\", 'facts': ''},\n",
    "            {'text': \"Pullman is on the east side of the state.\", 'facts': ''},\n",
    "            {'text': \"It can be a photograph, a painting, or a picture on a television or computer screen.\", 'facts': ''},\n",
    "            {'text': \"It is even hard to say if \\\"are\\\" means anything if we can never see or measure them.\", 'facts': ''},\n",
    "            {'text': \"In contrast, stench, reek, and stink are used specifically to describe unpleasant odors.\", 'facts': ''},\n",
    "            {'text': \"Its largest cities are Kansas City and Saint Louis.\", 'facts': ''},\n",
    "            {'text': \"There is only one language which is officially known as \\\"La Internacia Lingvo\\\" - The Inter-national Language - and that language is Esperanto.\", 'facts': ''},\n",
    "            {'text': \"The original \\\"CQR\\\" was invented in 1933 in the United Kingdom.\", 'facts': ''},\n",
    "            {'text': \"Albatross, seagulls and kingfishers all have long strong beaks for catching fish.\", 'facts': ''},\n",
    "            {'text': \"Other neighbouring countries are Australia to the south, Singapore to the north-west, and Phillipines to the north-east.\", 'facts': ''},\n",
    "            {'text': \"Their nose on top of the head to be easy to breathe on the surface of the water.\", 'facts': ''},\n",
    "            {'text': \"A blossom is a flower that grows on stone fruit trees and other plants including oranges, cherries, plums, apples and almonds\", 'facts': ''},\n",
    "            {'text': \"Black holes have also been found in the middle of every major galaxy in the universe.\", 'facts': ''},\n",
    "            {'text': \"others are like plastic worms and rat-l-traps.\", 'facts': ''},\n",
    "            {'text': \"This was to try to become the most influential country in the EU.\", 'facts': ''},\n",
    "            {'text': \"The first team to win four games becomes the baseball world champions for the year.\", 'facts': ''},\n",
    "            {'text': \"Candy is a sweet kind of food that is usually made from sugar and water, with flavors and other ingredients added.\", 'facts': ''},\n",
    "            {'text': \"Eucalypts began between 35 and 50 million years ago, not long after Australia and New Guinea separated from Gondwana.\", 'facts': ''},\n",
    "            {'text': \"The largest city on the east part is Spokane it is also the second biggest city in the state.\", 'facts': ''},\n",
    "            {'text': \"Usually 4-6 young are born, after a gestation period (pregnancy) of around 39 days.\", 'facts': ''},\n",
    "            {'text': \"The term \\\"the Wall\\\" usually referred to the Berlin Wall, built during the Cold War, which fell in 1989.\", 'facts': ''},\n",
    "            {'text': \"Moist winds blow in from the sea and are forced to rise over the land.\", 'facts': ''},\n",
    "            {'text': \"This empire did not make Spain a rich country, for most of the money had to be spent in wars.\", 'facts': ''},\n",
    "            {'text': \"Hamsters store food in the sides of their mouths.\", 'facts': ''},\n",
    "            {'text': \"They do this by electrical impulses and by some chemical substances called transmitters.\", 'facts': ''},\n",
    "            {'text': \"Among other things, the Dutch introduced sugar cane and the Java deer before leaving in 1710.\", 'facts': ''},\n",
    "            {'text': \"This is good because, if they did not eat all the dead things, there would be dead plants and animals everywhere.\", 'facts': ''},\n",
    "            {'text': \"A piston-operated steam engine is called \\\"reciprocal\\\" (back-and-forth) engine.\", 'facts': ''},\n",
    "            {'text': \"Chronic means the pain lasts a long time.\", 'facts': ''},\n",
    "            {'text': \"A soldier does many things, from shooting enemies, to using binoculars to find out where enemies are.\", 'facts': ''},\n",
    "            {'text': \"The word cowboy comes from the Spanish word vaquero.\", 'facts': ''},\n",
    "            {'text': \"The resulting solution is called a saturated solution.\", 'facts': ''},\n",
    "            {'text': \"Some matter is changed into energy in an atomic bomb.\", 'facts': ''},\n",
    "            {'text': \"Many flowers are dependent upon the wind to move pollen between flowers of the same species.\", 'facts': ''},\n",
    "            {'text': \"Jewelry can come in many forms, worn on any part of the body or clothing.\", 'facts': ''},\n",
    "            {'text': \"A vein has a large lumen, and less pressure as there is smaller amounts of smooth muscle and elastic fibres.\", 'facts': ''},\n",
    "            {'text': \"An office is a room where people work.\", 'facts': ''},\n",
    "            {'text': \"These radio astronomers use giant Radio telescopes, shaped like satellite dishes, to gather and analyze the waves.\", 'facts': ''},\n",
    "            {'text': \"A telephone is an electric tool.\", 'facts': ''},\n",
    "            {'text': \"The amount of glucose in the blood goes up.\", 'facts': ''},\n",
    "            {'text': \"Brown is the most common colour.\", 'facts': ''},\n",
    "            {'text': \"The harder metal is usually copper.\", 'facts': ''},\n",
    "            {'text': \"They are called \\\"leptocephalus\\\" (Greek for \\\"thin head\\\").\", 'facts': ''},\n",
    "            {'text': \"Plants need sunlight for the process of photosynthesis.\", 'facts': ''},\n",
    "            {'text': \"In American comic books, many different characters live in the same world.\", 'facts': ''},\n",
    "            {'text': \"Disco clubs have a large dance floor and a large pa system.\", 'facts': ''},            \n",
    "        ]    \n",
    "        self.input_length = input_length\n",
    "        self.tokenizer = tokenizer\n",
    "        self.output_length = output_length\n",
    "        self.print_text = False\n",
    "\n",
    "dataset = simple_wiki_text(tokenizer, 64, 64)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence:\n",
      "Toes are the digits of the foot of an animal.\n",
      "\n",
      "Predicted facts: covering [S] animal [S] situation [S] the covering [S] event [S] static situation\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "After a tree falls, the wood in it can be cut into long, straight pieces called lumber.\n",
      "\n",
      "Predicted facts: relationallexists [S] object inserted [S] wood [S] done by [S] husbandry of plant [S] plant\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "They may feel they do not have the power to stop the stalker.\n",
      "\n",
      "Predicted facts: agenttypeperformsworkoftype [S] stalker [S] stalking\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The state is also bordered by the North Sea to the east, the Atlantic Ocean to the west and also the Irish Sea.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] north generally [S] west generally\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "it might keep them safe when operating machinery or keep them clean when doing their work.\n",
      "\n",
      "Predicted facts: agenttypeperformsworkoftype [S] operating equipment [S] machinery\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "It is used as a surface to write on.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] writing device [S] writing\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "He had some tobacco seeds sent to Paris for analysis (As a drug).\n",
      "\n",
      "Predicted facts: covering [S] situation [S] the covering [S] session [S] 9\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The card is a small, flat paper or plastic object with text and figures.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] printing device [S] manufacturing event occurs at\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "In geometry, the radius of a circle or sphere is the shortest connection between the center and the boundary.\n",
      "\n",
      "Predicted facts: relationallexistscount [S] object inserted [S] shortest possible distance [S] center [S] rectangle [S] oppositedirectioninterval [N] relationmostinstancetype [S] width of object used\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The head of state is the Bundespräsident (Federal President).\n",
      "\n",
      "Predicted facts: agenttypeperformsworkoftype [S] president [S] federal state [S] head of state\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "In January 2001, producer Joel Silver asked Todd Alcott to write a Wonder Woman screenplay.\n",
      "\n",
      "Predicted facts: agenttypeperformsworkoftype [S] producer fn [S] actor\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "An adverb can also modify (describe) an adjective or another adverb.\n",
      "\n",
      "Predicted facts: relationallinstance [S] adverb [S] verbal [S] adverb\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "In many languages, for example, the word “put” will be different according to whether something is being put onto something (e.g.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] computer [S] operating system\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Inside the Earth is similar to the other terrestrial planets.\n",
      "\n",
      "Predicted facts: coextensional [S] list [S] list of type fn [S] thing\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Many sharks are now endangered, but some are still hunted for food (like shark fin soup) or sport fishing.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] hunting [S] shark [S] food\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Other larger cities in Manitoba include Steinbach and Brandon.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] city of texas west [S] city of texas state\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Such large mammals as pigs, bears, and deer also consume large amounts of acorns: they may constitute up to 25% of the diet of deer in the autumn.\n",
      "\n",
      "Predicted facts: relationallexistscount [S] a human type [S] animal [S] meat\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Magic can mean many things.\n",
      "\n",
      "Predicted facts: covering [S] situation [S] the covering [S] event [S]\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Sugar can be one of many compounds.\n",
      "\n",
      "Predicted facts: atomiccomposition [S] sugar [S] sugar\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Once the snake's prey is in its body, its internal muscles crush it so the animal can be digested.\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] husbandry of animal [S] snake [S] animal\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "During pregnancy, the endometrium develops a lot of glands and blood vessels.\n",
      "\n",
      "Predicted facts: relationallinstance [S] endometrium [S] gland [S] blood fn [S] blood\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Pullman is on the east side of the state.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] east generally [S] west generally\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "It can be a photograph, a painting, or a picture on a television or computer screen.\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] husbandry of object [S] computer\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "In contrast, stench, reek, and stink are used specifically to describe unpleasant odors.\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] husbandry of plant [S] pet [S] neighbor [N] type fn [S] stinking product [S] smelled\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Its largest cities are Kansas City and Saint Louis.\n",
      "\n",
      "Predicted facts: geographicsubregionsofcontinent [S] continent of america [S] united states of america\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Albatross, seagulls and kingfishers all have long strong beaks for catching fish.\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] fishing [S] fish\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Other neighbouring countries are Australia to the south, Singapore to the north-west, and Phillipines to the north-east.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] north generally [S] south generally\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Their nose on top of the head to be easy to breathe on the surface of the water.\n",
      "\n",
      "Predicted facts: properphysicalparttypes [S] nose [S] nose\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "A blossom is a flower that grows on stone fruit trees and other plants including oranges, cherries, plums, apples and almonds\n",
      "\n",
      "Predicted facts: relationallexists [S] object inserted [S] planting a plant [S] plant\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Black holes have also been found in the middle of every major galaxy in the universe.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] galaxy [S] black\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "others are like plastic worms and rat-l-traps.\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] husbandry of fn [S] animal [S] plastic\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Candy is a sweet kind of food that is usually made from sugar and water, with flavors and other ingredients added.\n",
      "\n",
      "Predicted facts: typicalingredienttypes [S] candy [S] sugar [N] typicaltasteoftype [S] candy [S] sweet tasting food [N] typicaltasteoftype [S] sugar [S] sweet taste\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Eucalypts began between 35 and 50 million years ago, not long after Australia and New Guinea separated from Gondwana.\n",
      "\n",
      "Predicted facts: geographicalsubregionsofcontinent [S] continent of guinea [S] territory fn [S] gondwan\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The largest city on the east part is Spokane it is also the second biggest city in the state.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] east generally [S] east generally\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Hamsters store food in the sides of their mouths.\n",
      "\n",
      "Predicted facts: relationallexistscount [S] main constituent [S] mouth [S] mouth\n",
      "=====================================================================\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence:\n",
      "A piston-operated steam engine is called reciprocal (back-and-forth) engine.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] engine device used\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "A soldier does many things, from shooting enemies, to using binoculars to find out where enemies are.\n",
      "\n",
      "Predicted facts: agenttypeperformsworkoftype [S] soldier [S] soldier\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The resulting solution is called a saturated solution.\n",
      "\n",
      "Predicted facts: outputscreatedtypetype [S] generating something [S] output process [S] generated by\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Many flowers are dependent upon the wind to move pollen between flowers of the same species.\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] husbandry fn [S] flower\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Jewelry can come in many forms, worn on any part of the body or clothing.\n",
      "\n",
      "Predicted facts: relationmostinstance [S] object inserted [S] bodily doer\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "A vein has a large lumen, and less pressure as there is smaller amounts of smooth muscle and elastic fibres.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] muscle [S] fatty tissue [S] vein\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "An office is a room where people work.\n",
      "\n",
      "Predicted facts: oppositedirectioninterval [S] workspace [S] office\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "These radio astronomers use giant Radio telescopes, shaped like satellite dishes, to gather and analyze the waves.\n",
      "\n",
      "Predicted facts: relationexistscountall [S] object taken care of [S] husbandry of object [S] device used\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "A telephone is an electric tool.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] telephone [S] device used\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The amount of glucose in the blood goes up.\n",
      "\n",
      "Predicted facts: atomiccomposition [S] blood [S] glucose\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Brown is the most common colour.\n",
      "\n",
      "Predicted facts: coloroftype [S] brown color [S] brown color\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "The harder metal is usually copper.\n",
      "\n",
      "Predicted facts: relationallinstance [S] metal fn [S] copper [S] copper\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "They are called leptocephalus (Greek for thin head).\n",
      "\n",
      "Predicted facts: relationallexists [S] object taken care of [S] husbandry of arm [S] head animal body part [S] forehead animal\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Plants need sunlight for the process of photosynthesis.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] manufacturing facility [S] manufacturing [S] event occurs at\n",
      "=====================================================================\n",
      "\n",
      "Sentence:\n",
      "Disco clubs have a large dance floor and a large pa system.\n",
      "\n",
      "Predicted facts: typeprimaryfunction [S] dancing room [S] large\n",
      "=====================================================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "loader = DataLoader(dataset, batch_size=10, shuffle=False)\n",
    "for batch in loader:\n",
    "    outs = model.model.generate(    \n",
    "        batch[\"source_ids\"],    \n",
    "        attention_mask=batch[\"source_mask\"],\n",
    "        use_cache=True,    \n",
    "        decoder_attention_mask=batch['target_mask'],\n",
    "        max_length=150, \n",
    "        num_beams=2,\n",
    "        repetition_penalty=2.5, \n",
    "        length_penalty=1.0, \n",
    "        early_stopping=True\n",
    "    )\n",
    "\n",
    "    dec = [tokenizer.decode(ids) for ids in outs]\n",
    "    texts = [tokenizer.decode(ids) for ids in batch['source_ids']]        \n",
    "    for i in range(batch[\"source_ids\"].shape[0]):\n",
    "        if dec[i] != \"covering [S] situation [S] the covering [S] event [S] static situation\":\n",
    "            print(\"Sentence:\\n%s\" % texts[i])    \n",
    "            print(\"\\nPredicted facts: %s\" % dec[i])\n",
    "            print(\"=====================================================================\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
