{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "763c591d-1832-4d93-97c2-a00da4c40a5e",
   "metadata": {},
   "source": [
    "# Test the distilled Flan T5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ebeb366b-1954-4728-8aab-da40338d7b4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "90f94ceb-e91c-4c0e-b85c-5340ed599e5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/llm/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "import torch\n",
    "import re\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration, MaxLengthCriteria\n",
    "from src.data_utils import GSM8KCodexAugmentedDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b7f0a531-30e3-41e2-845d-31f3f3839d5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
      "env: CUDA_VISIBLE_DEVICES=0,1\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
    "%env CUDA_VISIBLE_DEVICES=0,1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7d1959b8-8b3f-475d-b997-df2322d34590",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset gsm8k (/home/yaofu/.cache/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 331.33it/s]\n"
     ]
    }
   ],
   "source": [
    "gsm8k = load_dataset('gsm8k', 'main')\n",
    "gsm8k_test = gsm8k['test']\n",
    "\n",
    "validation_index = np.load('../lib_prompt/validation_index.npy')\n",
    "validation_data = gsm8k['train'].select(validation_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "520eb2ec-33ec-4d04-9c00-a9441137234a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xxl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fdb12394-6a14-42ff-826f-b48cdd83e93b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = T5ForConditionalGeneration.from_pretrained(\"/mnt/data_10t/flan_t5_distill/checkpoints/0.0.2.0_epoch_0_iter_100\", device_map='auto')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9dbb0a50-6d3a-4ec5-8402-0aecfb6c2d66",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu Dec 22 03:37:39 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA A100-SXM...  On   | 00000000:00:05.0 Off |                    0 |\n",
      "| N/A   35C    P0    73W / 400W |  22175MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  NVIDIA A100-SXM...  On   | 00000000:00:06.0 Off |                    0 |\n",
      "| N/A   34C    P0    75W / 400W |  22291MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  NVIDIA A100-SXM...  On   | 00000000:00:07.0 Off |                    0 |\n",
      "| N/A   33C    P0    62W / 400W |      0MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  NVIDIA A100-SXM...  On   | 00000000:00:08.0 Off |                    0 |\n",
      "| N/A   34C    P0    64W / 400W |      0MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   4  NVIDIA A100-SXM...  On   | 00000000:80:00.0 Off |                    0 |\n",
      "| N/A   35C    P0    77W / 400W |  67111MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   5  NVIDIA A100-SXM...  On   | 00000000:80:01.0 Off |                    0 |\n",
      "| N/A   36C    P0    72W / 400W |  67111MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   6  NVIDIA A100-SXM...  On   | 00000000:80:02.0 Off |                    0 |\n",
      "| N/A   33C    P0    64W / 400W |      0MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   7  NVIDIA A100-SXM...  On   | 00000000:80:03.0 Off |                    0 |\n",
      "| N/A   34C    P0    68W / 400W |      0MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A     23099      C   ...da/envs/llm/bin/python3.8    22173MiB |\n",
      "|    1   N/A  N/A     23099      C   ...da/envs/llm/bin/python3.8    22289MiB |\n",
      "|    4   N/A  N/A     86204      C   python                          67109MiB |\n",
      "|    5   N/A  N/A     86890      C   python                          67109MiB |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "31db619d-cc7e-47c2-8b0d-fee85801f1e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = gsm8k['train'][1000]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17311ea5-47c4-4db0-8cc6-67564d705d80",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = gsm8k['train'][0]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "900b5f33-9e69-4a9b-9ee8-c5bbd4881efe",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = open('../lib_prompt/prompt_simple_4_cases.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4a981bff-220c-454e-b348-706a44d23c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_q = prompt + '\\nQuestion: ' + input_text + \"\\nLet's think step by step\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e4b3a9a9-abba-4307-9d76-b37b28c1ea6a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: Ivan has a bird feeder in his yard that holds two cups of birdseed. Every week, he has to refill the emptied feeder. Each cup of birdseed can feed fourteen birds, but Ivan is constantly chasing away a hungry squirrel that steals half a cup of birdseed from the feeder every week. How many birds does Ivan’s bird feeder feed weekly?\n",
      "Let's think step by step\n",
      "The squirrel steals 1/2 cup of birdseed every week, so the birds eat 2 - 1/2 = 1 1/2 cups of birdseed.\n",
      "Each cup feeds 14 birds, so Ivan’s bird feeder feeds 14 * 1 1/2 = 21 birds weekly.\n",
      "The answer is 21\n",
      "\n",
      "Question: Samuel took 30 minutes to finish his homework while Sarah took 1.3 hours to finish it. How many minutes faster did Samuel finish his homework than Sarah?\n",
      "Let's think step by step\n",
      "Since there are 60 minutes in 1 hour, then 1.3 hours is equal to 1.3 x 60 = 78 minutes.\n",
      "Thus, Samuel is 78 – 30 = 48 minutes faster than Sarah.\n",
      "The answer is 48\n",
      "\n",
      "Question: Julia bought 3 packs of red balls, 10 packs of yellow balls, and 8 packs of green balls. There were 19 balls in each package. How many balls did Julie buy in all?\n",
      "Let's think step by step\n",
      "The total number of packages is 3 + 10 + 8 = 21.\n",
      "Julia bought 21 × 19 = 399 balls.\n",
      "The answer is 399\n",
      "\n",
      "Question: Lexi wants to run a total of three and one-fourth miles. One lap on a particular outdoor track measures a quarter of a mile around. How many complete laps must she run?\n",
      "Let's think step by step\n",
      "There are 3/ 1/4 = 12 one-fourth miles in 3 miles.\n",
      "So, Lexi will have to run 12 (from 3 miles) + 1 (from 1/4 mile) = 13 complete laps.\n",
      "The answer is 13\n",
      "\n",
      "Question: John buys a heating pad for $30.  He uses it 3 times a week for 2 weeks.  How much does he spend on each use?\n",
      "Let's think step by step\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "649eccf4-ed9a-43bc-bb5b-66d281782bd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 454])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids\n",
    "input_ids = input_ids.to(\"cuda:0\")\n",
    "input_ids.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "bbdf5925-6f90-4fbc-a5db-1ca234e77f97",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[11860,    10,    27,  2132,    65,     3,     9,  5963, 22916,    16,\n",
       "           112,  6178,    24,  4532,   192, 12294,    13,  6331,  6958,     5,\n",
       "          2181,   471,     6,     3,    88,    65,    12, 26210,     8,     3,\n",
       "          9045,  5973, 22916,     5,  1698,  4119,    13,  6331,  6958,    54,\n",
       "          3305, 27137,  6331,     6,    68,    27,  2132,    19,  4259,     3,\n",
       "         22020,   550,     3,     9, 13802, 26364,    24, 11332,     7,   985,\n",
       "             3,     9,  4119,    13,  6331,  6958,    45,     8, 22916,   334,\n",
       "           471,     5,   571,   186,  6331,   405,    27,  2132,    22,     7,\n",
       "          5963, 22916,  3305,  5547,    58,  1563,    31,     7,   317,  1147,\n",
       "            57,  1147,    37, 26364, 11332,     7,  7739,  4119,    13,  6331,\n",
       "          6958,   334,   471,     6,    78,     8,  6331,     3,  1544,   204,\n",
       "             3,    18,  7739,  3274,   209,  7739, 12294,    13,  6331,  6958,\n",
       "             5,  1698,  4119,  3305,     7,   968,  6331,     6,    78,    27,\n",
       "          2132,    22,     7,  5963, 22916,  3305,     7,   968,  1429,   209,\n",
       "          7739,  3274,  1401,  6331,  5547,     5,    37,  1525,    19,  1401,\n",
       "         11860,    10, 15718,   808,   604,   676,    12,  1992,   112, 11920,\n",
       "           298,  8077,   808,     3, 13606,   716,    12,  1992,    34,     5,\n",
       "           571,   186,   676,  3627,   410, 15718,  1992,   112, 11920,   145,\n",
       "          8077,    58,  1563,    31,     7,   317,  1147,    57,  1147,  1541,\n",
       "           132,    33,  1640,   676,    16,   209,  1781,     6,   258,     3,\n",
       "         13606,   716,    19,  4081,    12,     3, 13606,     3,   226,  1640,\n",
       "          3274,     3,  3940,   676,     5,  5309,     6, 15718,    19,     3,\n",
       "          3940,     3,   104,   604,  3274,  4678,   676,  3627,   145,  8077,\n",
       "             5,    37,  1525,    19,  4678, 11860,    10, 18618,  2944,   220,\n",
       "         14958,    13,  1131, 11607,     6,   335, 14958,    13,  4459, 11607,\n",
       "             6,    11,   505, 14958,    13,  1442, 11607,     5,   290,   130,\n",
       "           957, 11607,    16,   284,  2642,     5,   571,   186, 11607,   410,\n",
       "         12983,   805,    16,    66,    58,  1563,    31,     7,   317,  1147,\n",
       "            57,  1147,    37,   792,   381,    13,  6307,    19,   220,  1768,\n",
       "           335,  1768,   505,  3274,  1401,     5, 18618,  2944,  1401,     3,\n",
       "             2,   957,  3274,   220,  3264, 11607,     5,    37,  1525,    19,\n",
       "           220,  3264, 11860,    10, 17546,    23,  2746,    12,   661,     3,\n",
       "             9,   792,    13,   386,    11,    80,    18, 12521,   189,  2286,\n",
       "             5,   555, 14941,    30,     3,     9,  1090,  2655,  1463,  3629,\n",
       "             3,     9,  2893,    13,     3,     9,  7728,   300,     5,   571,\n",
       "           186,   743, 14941,     7,   398,   255,   661,    58,  1563,    31,\n",
       "             7,   317,  1147,    57,  1147,   290,    33,     3, 15020, 13004,\n",
       "          3274,   586,    80,    18, 12521,   189,  2286,    16,   220,  2286,\n",
       "             5,   264,     6, 17546,    23,    56,    43,    12,   661,   586,\n",
       "            41,  7152,   220,  2286,    61,  1768,   209,    41,  7152, 13004,\n",
       "          7728,    61,  3274,  1179,   743, 14941,     7,     5,    37,  1525,\n",
       "            19,  1179, 11860,    10,  4498,   152,    19, 12935,    95,   112,\n",
       "          4579,   250,     3,    88,    22,     7,  1735,   640,     8,   684,\n",
       "            21,     3,     9,   126,   613,     5,   216,   523,    12,  4383,\n",
       "           633,  5598,    12,   112,   126,   234,     5,    37,     3, 24844,\n",
       "            43,  1380,    24,  4498,   152,  1792,     3,  3131,    72,   145,\n",
       "             3,     9,   824,  1293,    16,  7051,    16,   136, 21451,  1367,\n",
       "             5,    37,  1735,   349,    65,  2690,   120,   937,  4498,   152,\n",
       "            28,     3,     9,  1125,  2643,    24,    56,  5685,   376,     3,\n",
       "            99,     3,     9,  2642,    19,   396,  2437,     5,  4498,   152,\n",
       "            19,    16,     8,  1228,     6,    11,     3,    88,    14,     7,\n",
       "             3,     9, 21451,  1367,    28,  6654,  2634,  9319,     5,   366,\n",
       "             3,    88, 11642,     8,  1367,     6,     8,  2643,  2279,   112,\n",
       "          1367,    19,   396,  2437,     5,  4498,   152,  4054,   284,    13,\n",
       "           112,  9319, 11385,     7,   335,     3,  7906,     7,     5,   216,\n",
       "          2036,     7,     3,     9,   712,  3829,    45,     8,  1367,    11,\n",
       "         11642,     8,     3, 24844,    22,  2643,   541,     5,    37,  2643,\n",
       "          2279,   112,  1367,    19,   341,   396,  2437,     5,  4498,   152,\n",
       "          6103,     7,     8,   433,   541,    11,   541,     5,   366,     3,\n",
       "            88,    65,  3641,   631,  9319,     6,     8,     3, 24844,    22,\n",
       "          2643,  1267,     8,  1367,    19,   230,    46,  9961,  1293,    21,\n",
       "          3365,     5,  4498,   152,    20, 12160,     7,    24,   284,  3365,\n",
       "          1367,    54,  1520,   460,  7051,   274,     8,  2643,   845,     8,\n",
       "          1367,    19,   396,  2437,     5,   571,   186,  9319,   410,  4498,\n",
       "           152,   174,    12,  2036,    45,     8,  3365,  1367,    58,  1563,\n",
       "            31,     7,   317,  1147,    57,  1147,     1]], device='cuda:0')"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c294d2ea-27f0-4e5b-93cf-3540285b699e",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "The following `model_kwargs` are not used by the model: ['sampling'] (note: typos in the generate arguments will also show up in this list)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [23], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m256\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msampling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mdecode(outputs[\u001b[38;5;241m0\u001b[39m]))\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m     25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     26\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclone():\n\u001b[0;32m---> 27\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/generation_utils.py:1268\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, penalty_alpha, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, suppress_tokens, begin_suppress_tokens, forced_decoder_ids, **model_kwargs)\u001b[0m\n\u001b[1;32m   1266\u001b[0m \u001b[38;5;66;03m# 0. Validate the `.generate()` call\u001b[39;00m\n\u001b[1;32m   1267\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_model_class()\n\u001b[0;32m-> 1268\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_kwargs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1270\u001b[0m \u001b[38;5;66;03m# 1. Set generation parameters if not already defined\u001b[39;00m\n\u001b[1;32m   1271\u001b[0m bos_token_id \u001b[38;5;241m=\u001b[39m bos_token_id \u001b[38;5;28;01mif\u001b[39;00m bos_token_id \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mbos_token_id\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/generation_utils.py:964\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_kwargs\u001b[0;34m(self, model_kwargs)\u001b[0m\n\u001b[1;32m    961\u001b[0m         unused_model_args\u001b[38;5;241m.\u001b[39mappend(key)\n\u001b[1;32m    963\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m unused_model_args:\n\u001b[0;32m--> 964\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    965\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe following `model_kwargs` are not used by the model: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00munused_model_args\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (note: typos in the\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    966\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m generate arguments will also show up in this list)\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    967\u001b[0m     )\n",
      "\u001b[0;31mValueError\u001b[0m: The following `model_kwargs` are not used by the model: ['sampling'] (note: typos in the generate arguments will also show up in this list)"
     ]
    }
   ],
   "source": [
    "outputs = model.generate(input_ids, max_length=256, sampling=True)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b5d49c97-2b9c-4d1d-a276-994e252a9d5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dec_input = torch.tensor([[0]]).to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "97627e78-52ab-4044-88d6-bb1f799f5562",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(input_ids=input_ids, decoder_input_ids=dec_input, return_dict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "d8db9939-3eec-43df-b920-953f2b3fbdf3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([11.4741, 11.4652, 11.3593, 11.3318, 11.3148], device='cuda:0',\n",
       "       grad_fn=<TopkBackward0>),\n",
       "indices=tensor([32050,     0, 32026, 32006,  6385], device='cuda:0'))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.logits[0, 0].topk(k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b9dccc8-0ee2-4fe8-b013-9645f329e4a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def silly_decode(model, batch, t):\n",
    "    tgt_ids = batch['tgt_input_ids']\n",
    "    for _ in range(t):\n",
    "        out_dict_t = model(input_ids=processed_batch['src_input_ids'].to('cuda:0')\n",
    "                           decoder_input_ids=\n",
    "                          )\n",
    "    return "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
