{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports and constants\n",
    "import torch\n",
    "\n",
    "from transformers import LineByLineTextDataset, TextDataset, DataCollatorForLanguageModeling\n",
    "from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel\n",
    "from transformers import Trainer, TrainingArguments, AutoTokenizer\n",
    "from transformers import pipeline\n",
    "\n",
    "MODEL_DIR = \"fact_lm_5\"\n",
    "OUTPUT_DIR = \"fact_lm_6\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup GPT-2 model\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token = \".\")\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2').cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare dataset\n",
    "train_dataset = LineByLineTextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"./fact_dataset_train.txt\",\n",
    "    # TODO: maybe this can be much less?\n",
    "    block_size=64,\n",
    ")\n",
    "\n",
    "eval_dataset = LineByLineTextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"./fact_dataset_test.txt\",\n",
    "    # TODO: maybe this can be much less?\n",
    "    block_size=64,\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer, mlm=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39447\n",
      "torch.Size([21])\n",
      "torch.Size([19])\n",
      "13\n",
      ".\n"
     ]
    }
   ],
   "source": [
    "print(len(train_dataset))\n",
    "print(train_dataset[0].size())\n",
    "print(train_dataset[1].size())\n",
    "print(tokenizer.pad_token_id)\n",
    "print(tokenizer._pad_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\trainer.py:267: FutureWarning: Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.\n",
      "  FutureWarning,\n",
      "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it.\n"
     ]
    }
   ],
   "source": [
    "# Setup trainer\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=OUTPUT_DIR,\n",
    "    overwrite_output_dir=False,\n",
    "    num_train_epochs=5,\n",
    "    per_device_train_batch_size=16,\n",
    "    save_steps=200,\n",
    "    save_total_limit=2,\n",
    "    no_cuda=True\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    prediction_loss_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\danil\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\trainer.py:1153: FutureWarning: This method is deprecated, use `Trainer.is_local_process_zero()` instead.\n",
      "  warnings.warn(\"This method is deprecated, use `Trainer.is_local_process_zero()` instead.\", FutureWarning)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c078e170581e4ea4bd8c963b914a48eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2466.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "{'input_ids': tensor([[  421,  5191,   318,    64,  1279,    50,    29,   895,    68,   976,\n",
      "          1257,  2715,  1279,    50,    29, 23970,  6937,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [35569,  3891,  1279,    50,    29,   474,   577,   746, 23185,  5331,\n",
      "          1279,    50,    29,   474,   577,   746,    13,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [28060,   416,  1279,    50,    29, 14250, 29365,  5793,   860,  5705,\n",
      "          1279,    50,    29,   442, 13481,   356,    72, 27406,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [42854,   636,  3858,  1279,    50,    29, 17026,  1279,    50,    29,\n",
      "         17026,  5412,  2318,    13,    13,    13,    13,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 3732, 17882,   286,  1785,  1279,    50,    29,  9412,   401, 10972,\n",
      "          1245, 37547,  1279,    50,    29,   401, 10972,  1245,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  1430,  2223,  1692,\n",
      "         16862,  1279,    50,    29, 17728,   262,  2482,   286,   257,  1430,\n",
      "          2223,    13,    13,    13,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  3814,  2099,  1279,\n",
      "            50,    29,  1048,  2099,   416, 14858,  3722,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  5874,  1317,  1279,\n",
      "            50,    29,   599, 42343,  4645,    13,    13,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [25781,  2099, 16015,  1720,  2099,  1279,    50,    29,   638,   624,\n",
      "         47868,  3650,  1279,    50,    29,   638,   624, 47868,  1910,  6536,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  1981,  4229,    79,\n",
      "          4235,  1279,    50,    29,  4947,  4229,    79,  4235,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [45573, 15793,  1279,    50,    29,  2347,  5444,  3828,  1470,  1279,\n",
      "            50,    29,  2347, 15793,    13,    13,    13,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  4336,  3430,  1279,\n",
      "            50,    29,  4384,  1074,    13,    13,    13,    13,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  1218,   502, 16357,\n",
      "           599, 42343, 10474,  1279,    50,    29,   281,  6570,   589,  1312,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [49501,   749,  4554,  1279,    50,    29,  7832,  2099,  1279,    50,\n",
      "            29,  3290,  1279,    50,    29,  9970,  1767,   636,    13,    13,\n",
      "            13,    13,    13,    13,    13],\n",
      "        [ 3732, 17882,   286,  1785,  1279,    50,    29,  7468, 22939,   286,\n",
      "           262, 23642,   932,   844,   939,  1279,    50,    29, 22939,   286,\n",
      "           262, 23642,   932,   844,    13],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,   443,  1229,   404,\n",
      "         12957,  1279,    50,    29, 39495,  7209,  8280, 10712,    13,    13,\n",
      "            13,    13,    13,    13,    13]]), 'labels': tensor([[  421,  5191,   318,    64,  1279,    50,    29,   895,    68,   976,\n",
      "          1257,  2715,  1279,    50,    29, 23970,  6937,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [35569,  3891,  1279,    50,    29,   474,   577,   746, 23185,  5331,\n",
      "          1279,    50,    29,   474,   577,   746,  -100,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [28060,   416,  1279,    50,    29, 14250, 29365,  5793,   860,  5705,\n",
      "          1279,    50,    29,   442, 13481,   356,    72, 27406,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [42854,   636,  3858,  1279,    50,    29, 17026,  1279,    50,    29,\n",
      "         17026,  5412,  2318,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 3732, 17882,   286,  1785,  1279,    50,    29,  9412,   401, 10972,\n",
      "          1245, 37547,  1279,    50,    29,   401, 10972,  1245,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  1430,  2223,  1692,\n",
      "         16862,  1279,    50,    29, 17728,   262,  2482,   286,   257,  1430,\n",
      "          2223,  -100,  -100,  -100,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  3814,  2099,  1279,\n",
      "            50,    29,  1048,  2099,   416, 14858,  3722,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  5874,  1317,  1279,\n",
      "            50,    29,   599, 42343,  4645,  -100,  -100,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [25781,  2099, 16015,  1720,  2099,  1279,    50,    29,   638,   624,\n",
      "         47868,  3650,  1279,    50,    29,   638,   624, 47868,  1910,  6536,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  1981,  4229,    79,\n",
      "          4235,  1279,    50,    29,  4947,  4229,    79,  4235,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [45573, 15793,  1279,    50,    29,  2347,  5444,  3828,  1470,  1279,\n",
      "            50,    29,  2347, 15793,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  4336,  3430,  1279,\n",
      "            50,    29,  4384,  1074,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,  1218,   502, 16357,\n",
      "           599, 42343, 10474,  1279,    50,    29,   281,  6570,   589,  1312,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [49501,   749,  4554,  1279,    50,    29,  7832,  2099,  1279,    50,\n",
      "            29,  3290,  1279,    50,    29,  9970,  1767,   636,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100],\n",
      "        [ 3732, 17882,   286,  1785,  1279,    50,    29,  7468, 22939,   286,\n",
      "           262, 23642,   932,   844,   939,  1279,    50,    29, 22939,   286,\n",
      "           262, 23642,   932,   844,  -100],\n",
      "        [ 6381,    73,  1563,   351,  1279,    50,    29,   443,  1229,   404,\n",
      "         12957,  1279,    50,    29, 39495,  7209,  8280, 10712,  -100,  -100,\n",
      "          -100,  -100,  -100,  -100,  -100]])}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm.auto import tqdm, trange\n",
    "\n",
    "train_dataloader = trainer.get_train_dataloader()\n",
    "epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=not trainer.is_local_master())\n",
    "for step, inputs in enumerate(epoch_iterator):\n",
    "    print(step)\n",
    "    print(inputs)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ec5a473e2d44e18a76f25224bc7b456",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=5.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc45e0b61f57460cac5bbd65a2a00763",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2466.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.9647742309570313, 'learning_rate': 4.797242497972425e-05, 'epoch': 0.20275750202757503, 'total_flos': 164463632326656, 'step': 500}\n",
      "{'loss': 0.9689465942382812, 'learning_rate': 4.59448499594485e-05, 'epoch': 0.40551500405515006, 'total_flos': 331029799649280, 'step': 1000}\n",
      "{'loss': 0.9492325439453125, 'learning_rate': 4.3917274939172754e-05, 'epoch': 0.6082725060827251, 'total_flos': 495875711066112, 'step': 1500}\n",
      "{'loss': 0.94626708984375, 'learning_rate': 4.1889699918897e-05, 'epoch': 0.8110300081103001, 'total_flos': 661689266429952, 'step': 2000}\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc484d1c67444f3ca09b9568db10c096",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2466.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.921406494140625, 'learning_rate': 3.986212489862125e-05, 'epoch': 1.013787510137875, 'total_flos': 825161362366464, 'step': 2500}\n",
      "{'loss': 0.84675537109375, 'learning_rate': 3.78345498783455e-05, 'epoch': 1.2165450121654502, 'total_flos': 989696672022528, 'step': 3000}\n",
      "{'loss': 0.8374541015625, 'learning_rate': 3.580697485806975e-05, 'epoch': 1.419302514193025, 'total_flos': 1155474388721664, 'step': 3500}\n",
      "{'loss': 0.83353759765625, 'learning_rate': 3.3779399837794e-05, 'epoch': 1.6220600162206003, 'total_flos': 1320714525450240, 'step': 4000}\n",
      "{'loss': 0.8252880859375, 'learning_rate': 3.175182481751825e-05, 'epoch': 1.8248175182481752, 'total_flos': 1484748093800448, 'step': 4500}\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2fc198a6ff3048fdb722c660742b6812",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2466.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8065771484375, 'learning_rate': 2.9724249797242497e-05, 'epoch': 2.02757502027575, 'total_flos': 1650149504520192, 'step': 5000}\n",
      "{'loss': 0.758412109375, 'learning_rate': 2.7696674776966745e-05, 'epoch': 2.2303325223033252, 'total_flos': 1816679833178112, 'step': 5500}\n",
      "{'loss': 0.7699267578125, 'learning_rate': 2.5669099756691e-05, 'epoch': 2.4330900243309004, 'total_flos': 1982612850757632, 'step': 6000}\n",
      "{'loss': 0.7615126953125, 'learning_rate': 2.3641524736415248e-05, 'epoch': 2.635847526358475, 'total_flos': 2145738506268672, 'step': 6500}\n",
      "{'loss': 0.7649365234375, 'learning_rate': 2.1613949716139496e-05, 'epoch': 2.83860502838605, 'total_flos': 2308673022234624, 'step': 7000}\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1bfbcd01da254428a045e6ea334ef682",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2466.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.7489443359375, 'learning_rate': 1.9586374695863748e-05, 'epoch': 3.0413625304136254, 'total_flos': 2474091605647872, 'step': 7500}\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-29-8dd2e950316f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m# Train on fact dataset\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mtrainer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self, model_path, trial)\u001b[0m\n\u001b[0;32m    761\u001b[0m                     \u001b[1;32mcontinue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    762\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 763\u001b[1;33m                 \u001b[0mtr_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    764\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtotal_flos\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloating_point_ops\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    765\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\trainer.py\u001b[0m in \u001b[0;36mtraining_step\u001b[1;34m(self, model, inputs)\u001b[0m\n\u001b[0;32m   1111\u001b[0m                 \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1112\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1113\u001b[1;33m             \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1115\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_gpu\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\trainer.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[1;34m(self, model, inputs)\u001b[0m\n\u001b[0;32m   1135\u001b[0m         \u001b[0mSubclass\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0moverride\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mcustom\u001b[0m \u001b[0mbehavior\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1136\u001b[0m         \"\"\"\n\u001b[1;32m-> 1137\u001b[1;33m         \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1138\u001b[0m         \u001b[1;31m# Save past state if it exists\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1139\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpast_index\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    720\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    721\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 722\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    723\u001b[0m         for hook in itertools.chain(\n\u001b[0;32m    724\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m    763\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    764\u001b[0m             \u001b[0moutput_hidden_states\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 765\u001b[1;33m             \u001b[0mreturn_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    766\u001b[0m         )\n\u001b[0;32m    767\u001b[0m         \u001b[0mhidden_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtransformer_outputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    720\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    721\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 722\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    723\u001b[0m         for hook in itertools.chain(\n\u001b[0;32m    724\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m    651\u001b[0m                     \u001b[0mencoder_attention_mask\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mencoder_attention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    652\u001b[0m                     \u001b[0muse_cache\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0muse_cache\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 653\u001b[1;33m                     \u001b[0moutput_attentions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    654\u001b[0m                 )\n\u001b[0;32m    655\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    720\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    721\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 722\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    723\u001b[0m         for hook in itertools.chain(\n\u001b[0;32m    724\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[0;32m    314\u001b[0m             \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moutputs\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mcross_attn_outputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m  \u001b[1;31m# add cross attentions if we output attention weights\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    315\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 316\u001b[1;33m         \u001b[0mfeed_forward_hidden_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmlp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mln_2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    317\u001b[0m         \u001b[1;31m# residual connection\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    318\u001b[0m         \u001b[0mhidden_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mfeed_forward_hidden_states\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    720\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    721\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 722\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    723\u001b[0m         for hook in itertools.chain(\n\u001b[0;32m    724\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m    254\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    255\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 256\u001b[1;33m         \u001b[0mh\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc_fc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    257\u001b[0m         \u001b[0mh2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc_proj\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    258\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mh2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\qrg\\lib\\site-packages\\transformers\\activations.py\u001b[0m in \u001b[0;36mgelu_new\u001b[1;34m(x)\u001b[0m\n\u001b[0;32m     28\u001b[0m     \u001b[0mAlso\u001b[0m \u001b[0msee\u001b[0m \u001b[0mhttps\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m//\u001b[0m\u001b[0marxiv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0morg\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mabs\u001b[0m\u001b[1;33m/\u001b[0m\u001b[1;36m1606.08415\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     29\u001b[0m     \"\"\"\n\u001b[1;32m---> 30\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[1;36m0.5\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;36m1.0\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2.0\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m0.044715\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3.0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     31\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     32\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Train on fact dataset\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0462f00c248041cdb85bcf6a768e6c4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=548.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "{'eval_loss': 0.9132209329090515, 'epoch': 3.2076236820762367, 'total_flos': 2610672756834816, 'step': 7910}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.9132209329090515,\n",
       " 'epoch': 3.2076236820762367,\n",
       " 'total_flos': 2610672756834816}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(OUTPUT_DIR) # last saved: fact_lm_6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Loading module\n",
    "model = GPT2LMHeadModel.from_pretrained(MODEL_DIR).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "next_word = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=MODEL_DIR,\n",
    "    tokenizer=tokenizer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'contrary feelings <S> pleasure <S> sadness joyful feeling <S> disgust emotional feeling numeral part type common in all emotion types <S> pleasure feeling numeral part type indexical to indexical type <S> pleasure feeling numeral'}]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next_word(\"contrary feelings <S>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'typical location type of event type <S> cooking food <S> sink <S> surface of water <S> surface of an object type by surface properties type <S> sink type by surface properties type of type <S> city or'}]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next_word(\"typical location type of event type <S> cooking food <S>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'defeated in conflict <S> world war ii <S> war iii war i <S> empire of czars war i empire d ak ak b ak czar b m e the term empire of czars has a suffix <S>'}]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next_word(\"defeated in conflict <S> world war ii <S>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'relation not exists exists <S> eats willingly <S> moose <S> nautilus animal hide animal hide species <S> 3 legs length <S> 2 legs length <S> 1 inch by 1 inch by 1 inch body length'}]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next_word(\"relation not exists exists <S> eats willingly <S> moose\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'agent type performs work of type <S> car washer <S> car washing <S> vehicle maintenance event type by vehicle type <S> car maintenance event type by vehicle maintenance station type by schedule schedule status type by driver type <S>'}]"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next_word(\"agent type performs work of type <S> car washer <S>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n",
    "outputs = model(**inputs, labels=inputs[\"input_ids\"])\n",
    "logits = outputs[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
