{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6dc67cc7-dece-4b3f-9cfa-ef4c5f6de696",
   "metadata": {},
   "source": [
    "# Testing Flan-T5-3B for gsm8k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b6c38f6d-da45-41d3-bf24-d8a821f2ece1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import torch\n",
    "import re\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration, MaxLengthCriteria"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ab9f375-5d79-4f16-94af-668a29297807",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
      "env: CUDA_VISIBLE_DEVICES=0,1\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
    "%env CUDA_VISIBLE_DEVICES=0,1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "254076bd-aaab-4218-932a-138835147d39",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset gsm8k (/home/yaofu/.cache/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 630.77it/s]\n"
     ]
    }
   ],
   "source": [
    "gsm8k = load_dataset('gsm8k', 'main')\n",
    "gsm8k_test = gsm8k['test']\n",
    "\n",
    "validation_index = np.load('../lib_prompt/validation_index.npy')\n",
    "validation_data = gsm8k['train'].select(validation_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0790b98f-34e1-4d86-94be-ad038ce2bcc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 792k/792k [00:00<00:00, 15.0MB/s]\n",
      "Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2.20k/2.20k [00:00<00:00, 1.22MB/s]\n",
      "Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2.54k/2.54k [00:00<00:00, 2.08MB/s]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e447d9f8-d405-4d8b-9059-ef1551c0b2ea",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1.44k/1.44k [00:00<00:00, 1.00MB/s]\n",
      "Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 50.8k/50.8k [00:00<00:00, 922kB/s]\n",
      "Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9.45G/9.45G [01:55<00:00, 81.6MB/s]\n",
      "Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1.95G/1.95G [00:22<00:00, 87.6MB/s]\n"
     ]
    }
   ],
   "source": [
    "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xl\", device_map='auto')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "198ee2e4-348e-45df-b7d3-2c6058064915",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tue Nov 29 19:13:21 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA A100-SXM...  Off  | 00000000:00:05.0 Off |                    0 |\n",
      "| N/A   35C    P0    72W / 400W |  11551MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  NVIDIA A100-SXM...  Off  | 00000000:00:06.0 Off |                    0 |\n",
      "| N/A   34C    P0    73W / 400W |  12029MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  NVIDIA A100-SXM...  Off  | 00000000:00:07.0 Off |                    0 |\n",
      "| N/A   58C    P0   335W / 400W |  77349MiB / 81920MiB |     87%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  NVIDIA A100-SXM...  Off  | 00000000:00:08.0 Off |                    0 |\n",
      "| N/A   47C    P0   144W / 400W |  66969MiB / 81920MiB |     53%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   4  NVIDIA A100-SXM...  Off  | 00000000:80:00.0 Off |                    0 |\n",
      "| N/A   35C    P0    77W / 400W |  11869MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   5  NVIDIA A100-SXM...  Off  | 00000000:80:01.0 Off |                    0 |\n",
      "| N/A   35C    P0    72W / 400W |  11379MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   6  NVIDIA A100-SXM...  Off  | 00000000:80:02.0 Off |                    0 |\n",
      "| N/A   43C    P0   337W / 400W |  66969MiB / 81920MiB |     45%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   7  NVIDIA A100-SXM...  Off  | 00000000:80:03.0 Off |                    0 |\n",
      "| N/A   37C    P0    78W / 400W |  67111MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A      4329      C   ...da/envs/llm/bin/python3.8    11549MiB |\n",
      "|    1   N/A  N/A      4329      C   ...da/envs/llm/bin/python3.8    12027MiB |\n",
      "|    2   N/A  N/A     71116      C   python                          77347MiB |\n",
      "|    3   N/A  N/A     77301      C   python                          66967MiB |\n",
      "|    4   N/A  N/A      4329      C   ...da/envs/llm/bin/python3.8    11867MiB |\n",
      "|    5   N/A  N/A      4329      C   ...da/envs/llm/bin/python3.8    11377MiB |\n",
      "|    6   N/A  N/A     83828      C   python                          66967MiB |\n",
      "|    7   N/A  N/A     86404      C   python                          67109MiB |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05609bf0-1499-49a2-8b52-8ec0a7017824",
   "metadata": {},
   "source": [
    "# Test Longest Training Question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "760f0965-debb-4110-ada9-534b0169378d",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_lens = [len(d['question']) for d in gsm8k['train']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "772c8b1d-ccfd-4867-bad8-a4a5c9000120",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([217., 300., 361., 414., 527., 985.])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.percentile(q_lens, [50, 80, 90, 95, 99, 100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6b821a3f-2612-47aa-a94c-4aa0aba92c2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3331"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.argmax(q_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bfa5c251-995e-4b8a-969b-42eced0be123",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = gsm8k['train'][3331]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ab1f36a8-93af-4a96-9814-ec9043be96ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Let x be the number of plates removed from the box.\\nHasan figured out the movers' weight limit was 20 pounds. Since a pound is equal to 16 ounces, each box can hold 20 * 16, or <<20*16=320>>320 ounces.\\nEach plate weighs 10 ounces, so the weight of the plates in the box after Hasan removes enough plates to satisfy the movers' weight limit is (38 - x) * 10 ounces\\nSince these two values are equal, we can write the equation (38 - x) * 10 = 320.\\nDividing both sides by 10 leaves 38-x = 32.\\nAdding x to both sides gives 38 – x + x = 32 +x, or, 38 = 32 +x\\nSubtracting 32 from both sides gives the value of x, which is the number of plates removed from the box, 38 -32 = 32 + x – 32, or, 6 = x\\n#### 6\""
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gsm8k['train'][3331]['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c82dcfba-72a9-4937-8bfa-7a6fb015f37a",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_complex = open('../lib_prompt/prompt_complex.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4736eb37-ce7a-40a2-8a29-08b8b074c63f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_q = prompt_complex + '\\nQuestion: ' + input_text + '\\n'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6da9c681-6eae-4412-a935-bcab048f8633",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\n",
      "Let's think step by step\n",
      "Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\n",
      "For the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\n",
      "Angelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\n",
      "However, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\n",
      "They also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\n",
      "And they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\n",
      "So Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\n",
      "They want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\n",
      "They will need to plan to study 4 days to allow for all the time they need.\n",
      "The answer is 4\n",
      "\n",
      "Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws.  Their opponents score double the 2 pointers but half the 3 pointers and free throws.  What's the total number of points scored by both teams added together?\n",
      "Let's think step by step\n",
      "Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\n",
      "His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\n",
      "They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\n",
      "All together his team scored 50+24+10= 84 points\n",
      "Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\n",
      "His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\n",
      "They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\n",
      "All together Mark's opponents scored 100+12+5=117 points\n",
      "The total score for the game is both team's scores added together, so it is 84+117=201 points\n",
      "The answer is 201\n",
      "\n",
      "Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\n",
      "Let's think step by step\n",
      "When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\n",
      "The total number of marbles she'll have is 60+24 = 84\n",
      "If Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\n",
      "If Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\n",
      "The total number of frisbees she'll have will increase to 30+12 = 42\n",
      "Bella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\n",
      "If she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\n",
      "The total number of deck cards she'll have is 10+4 = 14\n",
      "Together, Bella will have a total of 14+42+84 = 140 items\n",
      "The answer is 140\n",
      "\n",
      "Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\n",
      "Let's think step by step\n",
      "For the first three baskets, the number of apples and oranges in one basket is 9+15=24\n",
      "In total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\n",
      "Since there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\n",
      "The number of apples in the fourth basket is 9-2=7\n",
      "There are also 15-2=13 oranges in the fourth basket\n",
      "The combined number of oranges and apples in the fourth basket is 13+7=20\n",
      "The fourth basket also contains 14-2=12 bananas.\n",
      "In total, the fourth basket has 20+12=32 fruits.\n",
      "The four baskets together have 32+114=146 fruits.\n",
      "The answer is 146\n",
      "\n",
      "Question: You can buy 4 apples or 1 watermelon for the same price. You bought 36 fruits evenly split between oranges, apples and watermelons, and the price of 1 orange is $0.50. How much does 1 apple cost if your total bill was $66?\n",
      "Let's think step by step\n",
      "If 36 fruits were evenly split between 3 types of fruits, then I bought 36/3 = 12 units of each fruit\n",
      "If 1 orange costs $0.50 then 12 oranges will cost $0.50 * 12 = $6\n",
      "If my total bill was $66 and I spent $6 on oranges then I spent $66 - $6 = $60 on the other 2 fruit types.\n",
      "Assuming the price of watermelon is W, and knowing that you can buy 4 apples for the same price and that the price of one apple is A, then 1W=4A\n",
      "If we know we bought 12 watermelons and 12 apples for $60, then we know that $60 = 12W + 12A\n",
      "Knowing that 1W=4A, then we can convert the above to $60 = 12(4A) + 12A\n",
      "$60 = 48A + 12A\n",
      "$60 = 60A\n",
      "Then we know the price of one apple (A) is $60/60= $1\n",
      "The answer is 1\n",
      "\n",
      "Question: Susy goes to a large school with 800 students, while Sarah goes to a smaller school with only 300 students.  At the start of the school year, Susy had 100 social media followers.  She gained 40 new followers in the first week of the school year, half that in the second week, and half of that in the third week.  Sarah only had 50 social media followers at the start of the year, but she gained 90 new followers the first week, a third of that in the second week, and a third of that in the third week.  After three weeks, how many social media followers did the girl with the most total followers have?\n",
      "Let's think step by step\n",
      "After one week, Susy has 100+40 = 140 followers.\n",
      "In the second week, Susy gains 40/2 = 20 new followers.\n",
      "In the third week, Susy gains 20/2 = 10 new followers.\n",
      "In total, Susy finishes the three weeks with 140+20+10 = 170 total followers.\n",
      "After one week, Sarah has 50+90 = 140 followers.\n",
      "After the second week, Sarah gains 90/3 = 30 followers.\n",
      "After the third week, Sarah gains 30/3 = 10 followers.\n",
      "So, Sarah finishes the three weeks with 140+30+10 = 180 total followers.\n",
      "Thus, Sarah is the girl with the most total followers with a total of 180.\n",
      "The answer is 180\n",
      "\n",
      "Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each box. He rearranged five of these boxes into packages of six highlighters each and sold them for $3 per package. He sold the rest of the highlighters separately at the rate of three pens for $2. How much profit did he make in total, in dollars?\n",
      "Let's think step by step\n",
      "Sam bought 12 boxes x $10 = $120 worth of highlighters.\n",
      "He bought 12 * 30 = 360 highlighters in total.\n",
      "Sam then took 5 boxes × 6 highlighters/box = 30 highlighters.\n",
      "He sold these boxes for 5 * $3 = $15\n",
      "After selling these 5 boxes there were 360 - 30 = 330 highlighters remaining.\n",
      "These form 330 / 3 = 110 groups of three pens.\n",
      "He sold each of these groups for $2 each, so made 110 * 2 = $220 from them.\n",
      "In total, then, he earned $220 + $15 = $235.\n",
      "Since his original cost was $120, he earned $235 - $120 = $115 in profit.\n",
      "The answer is 115\n",
      "\n",
      "Question: In a certain school, 2/3 of the male students like to play basketball, but only 1/5 of the female students like to play basketball. What percent of the population of the school do not like to play basketball if the ratio of the male to female students is 3:2 and there are 1000 students?\n",
      "Let's think step by step\n",
      "The students are divided into 3 + 2 = 5 parts where 3 parts are for males and 2 parts are for females.\n",
      "Each part represents 1000/5 = 200 students.\n",
      "So, there are 3 x 200 = 600 males.\n",
      "And there are 2 x 200 = 400 females.\n",
      "Hence, 600 x 2/3 = 400 males play basketball.\n",
      "And 400 x 1/5 = 80 females play basketball.\n",
      "A total of 400 + 80 = 480 students play basketball.\n",
      "Therefore, 1000 - 480 = 520 do not like to play basketball.\n",
      "The percentage of the school that do not like to play basketball is 520/1000 * 100 = 52\n",
      "The answer is 52\n",
      "\n",
      "Question: Hasan is packing up his apartment because he’s moving across the country for a new job. He needs to ship several boxes to his new home. The movers have asked that Hasan avoid putting more than a certain weight in pounds in any cardboard box. The moving company has helpfully provided Hasan with a digital scale that will alert him if a package is too heavy. Hasan is in the kitchen, and he fills a cardboard box with 38 dinner plates. When he checks the box, the scale reports his box is too heavy. Hasan knows each of his plates weighs 10 ounces. He removes a single plate from the box and checks the movers’ scale again. The scale reports his box is still too heavy. Hasan repeats the process again and again. When he has removed enough plates, the movers’ scale shows the box is now an acceptable weight for shipping. Hasan deduces that each shipping box can hold 20 pounds before the scale says the box is too heavy.  How many plates did Hasan need to remove from the shipping box?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0007d19c-4840-4ba8-a0d3-8c4276adc6dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (2422 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 2422])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids\n",
    "input_ids.to(\"cuda:0\")\n",
    "input_ids.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3495e8f1-2b43-475d-aaf8-1f059c90c2b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<pad> Let's think step by step The boy gives 12 / 3 = 4 oranges to his brother. He has 12 - 4 = 8 oranges left. The boy gives 8 / 4 = 2 oranges to his friend. The answer is 2</s>\n"
     ]
    }
   ],
   "source": [
    "outputs = model.generate(input_ids, max_length=256)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a6ebd1b9-74c7-4f7e-a6c5-7641d4d88936",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shared.weight cuda:0\n",
      "encoder.block.0.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.0.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.0.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.0.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight cuda:0\n",
      "encoder.block.0.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.0.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.0.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.0.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.0.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.1.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.1.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.1.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.1.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.1.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.1.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.1.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.1.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.1.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.2.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.2.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.2.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.2.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.2.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.2.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.2.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.2.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.2.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.3.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.3.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.3.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.3.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.3.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.3.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.3.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.3.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.3.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.4.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.4.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.4.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.4.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.4.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.4.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.4.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.4.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.4.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.5.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.5.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.5.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.5.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.5.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.5.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.5.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.5.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.5.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.6.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.6.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.6.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.6.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.6.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.6.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.6.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.6.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.6.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.7.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.7.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.7.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.7.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.7.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.7.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.7.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.7.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.7.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.8.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.8.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.8.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.8.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.8.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.8.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.8.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.8.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.8.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.9.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.9.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.9.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.9.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.9.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.9.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.9.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.9.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.9.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.10.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.10.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.10.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.10.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.10.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.10.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.10.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.10.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.10.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.11.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.11.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.11.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.11.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.11.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.11.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.11.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.11.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.11.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.12.layer.0.SelfAttention.q.weight cuda:0\n",
      "encoder.block.12.layer.0.SelfAttention.k.weight cuda:0\n",
      "encoder.block.12.layer.0.SelfAttention.v.weight cuda:0\n",
      "encoder.block.12.layer.0.SelfAttention.o.weight cuda:0\n",
      "encoder.block.12.layer.0.layer_norm.weight cuda:0\n",
      "encoder.block.12.layer.1.DenseReluDense.wi_0.weight cuda:0\n",
      "encoder.block.12.layer.1.DenseReluDense.wi_1.weight cuda:0\n",
      "encoder.block.12.layer.1.DenseReluDense.wo.weight cuda:0\n",
      "encoder.block.12.layer.1.layer_norm.weight cuda:0\n",
      "encoder.block.13.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.13.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.13.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.13.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.13.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.13.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.13.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.13.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.13.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.14.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.14.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.14.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.14.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.14.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.14.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.14.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.14.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.14.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.15.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.15.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.15.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.15.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.15.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.15.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.15.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.15.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.15.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.16.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.16.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.16.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.16.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.16.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.16.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.16.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.16.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.16.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.17.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.17.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.17.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.17.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.17.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.17.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.17.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.17.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.17.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.18.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.18.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.18.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.18.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.18.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.18.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.18.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.18.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.18.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.19.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.19.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.19.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.19.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.19.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.19.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.19.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.19.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.19.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.20.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.20.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.20.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.20.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.20.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.20.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.20.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.20.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.20.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.21.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.21.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.21.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.21.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.21.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.21.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.21.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.21.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.21.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.22.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.22.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.22.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.22.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.22.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.22.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.22.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.22.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.22.layer.1.layer_norm.weight cuda:1\n",
      "encoder.block.23.layer.0.SelfAttention.q.weight cuda:1\n",
      "encoder.block.23.layer.0.SelfAttention.k.weight cuda:1\n",
      "encoder.block.23.layer.0.SelfAttention.v.weight cuda:1\n",
      "encoder.block.23.layer.0.SelfAttention.o.weight cuda:1\n",
      "encoder.block.23.layer.0.layer_norm.weight cuda:1\n",
      "encoder.block.23.layer.1.DenseReluDense.wi_0.weight cuda:1\n",
      "encoder.block.23.layer.1.DenseReluDense.wi_1.weight cuda:1\n",
      "encoder.block.23.layer.1.DenseReluDense.wo.weight cuda:1\n",
      "encoder.block.23.layer.1.layer_norm.weight cuda:1\n",
      "encoder.final_layer_norm.weight cuda:1\n",
      "decoder.block.0.layer.0.SelfAttention.q.weight cuda:1\n",
      "decoder.block.0.layer.0.SelfAttention.k.weight cuda:1\n",
      "decoder.block.0.layer.0.SelfAttention.v.weight cuda:1\n",
      "decoder.block.0.layer.0.SelfAttention.o.weight cuda:1\n",
      "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight cuda:1\n",
      "decoder.block.0.layer.0.layer_norm.weight cuda:1\n",
      "decoder.block.0.layer.1.EncDecAttention.q.weight cuda:1\n",
      "decoder.block.0.layer.1.EncDecAttention.k.weight cuda:1\n",
      "decoder.block.0.layer.1.EncDecAttention.v.weight cuda:1\n",
      "decoder.block.0.layer.1.EncDecAttention.o.weight cuda:1\n",
      "decoder.block.0.layer.1.layer_norm.weight cuda:1\n",
      "decoder.block.0.layer.2.DenseReluDense.wi_0.weight cuda:1\n",
      "decoder.block.0.layer.2.DenseReluDense.wi_1.weight cuda:1\n",
      "decoder.block.0.layer.2.DenseReluDense.wo.weight cuda:1\n",
      "decoder.block.0.layer.2.layer_norm.weight cuda:1\n",
      "decoder.block.1.layer.0.SelfAttention.q.weight cuda:1\n",
      "decoder.block.1.layer.0.SelfAttention.k.weight cuda:1\n",
      "decoder.block.1.layer.0.SelfAttention.v.weight cuda:1\n",
      "decoder.block.1.layer.0.SelfAttention.o.weight cuda:1\n",
      "decoder.block.1.layer.0.layer_norm.weight cuda:1\n",
      "decoder.block.1.layer.1.EncDecAttention.q.weight cuda:1\n",
      "decoder.block.1.layer.1.EncDecAttention.k.weight cuda:1\n",
      "decoder.block.1.layer.1.EncDecAttention.v.weight cuda:1\n",
      "decoder.block.1.layer.1.EncDecAttention.o.weight cuda:1\n",
      "decoder.block.1.layer.1.layer_norm.weight cuda:1\n",
      "decoder.block.1.layer.2.DenseReluDense.wi_0.weight cuda:1\n",
      "decoder.block.1.layer.2.DenseReluDense.wi_1.weight cuda:1\n",
      "decoder.block.1.layer.2.DenseReluDense.wo.weight cuda:1\n",
      "decoder.block.1.layer.2.layer_norm.weight cuda:1\n",
      "decoder.block.2.layer.0.SelfAttention.q.weight cuda:1\n",
      "decoder.block.2.layer.0.SelfAttention.k.weight cuda:1\n",
      "decoder.block.2.layer.0.SelfAttention.v.weight cuda:1\n",
      "decoder.block.2.layer.0.SelfAttention.o.weight cuda:1\n",
      "decoder.block.2.layer.0.layer_norm.weight cuda:1\n",
      "decoder.block.2.layer.1.EncDecAttention.q.weight cuda:1\n",
      "decoder.block.2.layer.1.EncDecAttention.k.weight cuda:1\n",
      "decoder.block.2.layer.1.EncDecAttention.v.weight cuda:1\n",
      "decoder.block.2.layer.1.EncDecAttention.o.weight cuda:1\n",
      "decoder.block.2.layer.1.layer_norm.weight cuda:1\n",
      "decoder.block.2.layer.2.DenseReluDense.wi_0.weight cuda:1\n",
      "decoder.block.2.layer.2.DenseReluDense.wi_1.weight cuda:1\n",
      "decoder.block.2.layer.2.DenseReluDense.wo.weight cuda:1\n",
      "decoder.block.2.layer.2.layer_norm.weight cuda:1\n",
      "decoder.block.3.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.3.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.3.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.3.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.3.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.3.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.3.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.3.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.3.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.3.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.3.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.3.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.3.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.3.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.4.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.4.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.4.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.4.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.4.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.4.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.4.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.4.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.4.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.4.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.4.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.4.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.4.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.4.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.5.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.5.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.5.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.5.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.5.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.5.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.5.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.5.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.5.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.5.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.5.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.5.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.5.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.5.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.6.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.6.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.6.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.6.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.6.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.6.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.6.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.6.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.6.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.6.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.6.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.6.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.6.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.6.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.7.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.7.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.7.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.7.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.7.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.7.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.7.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.7.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.7.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.7.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.7.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.7.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.7.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.7.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.8.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.8.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.8.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.8.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.8.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.8.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.8.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.8.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.8.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.8.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.8.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.8.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.8.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.8.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.9.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.9.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.9.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.9.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.9.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.9.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.9.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.9.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.9.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.9.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.9.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.9.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.9.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.9.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.10.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.10.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.10.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.10.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.10.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.10.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.10.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.10.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.10.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.10.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.10.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.10.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.10.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.10.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.11.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.11.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.11.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.11.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.11.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.11.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.11.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.11.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.11.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.11.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.11.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.11.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.11.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.11.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.12.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.12.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.12.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.12.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.12.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.12.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.12.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.12.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.12.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.12.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.12.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.12.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.12.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.12.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.13.layer.0.SelfAttention.q.weight cuda:2\n",
      "decoder.block.13.layer.0.SelfAttention.k.weight cuda:2\n",
      "decoder.block.13.layer.0.SelfAttention.v.weight cuda:2\n",
      "decoder.block.13.layer.0.SelfAttention.o.weight cuda:2\n",
      "decoder.block.13.layer.0.layer_norm.weight cuda:2\n",
      "decoder.block.13.layer.1.EncDecAttention.q.weight cuda:2\n",
      "decoder.block.13.layer.1.EncDecAttention.k.weight cuda:2\n",
      "decoder.block.13.layer.1.EncDecAttention.v.weight cuda:2\n",
      "decoder.block.13.layer.1.EncDecAttention.o.weight cuda:2\n",
      "decoder.block.13.layer.1.layer_norm.weight cuda:2\n",
      "decoder.block.13.layer.2.DenseReluDense.wi_0.weight cuda:2\n",
      "decoder.block.13.layer.2.DenseReluDense.wi_1.weight cuda:2\n",
      "decoder.block.13.layer.2.DenseReluDense.wo.weight cuda:2\n",
      "decoder.block.13.layer.2.layer_norm.weight cuda:2\n",
      "decoder.block.14.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.14.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.14.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.14.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.14.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.14.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.14.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.14.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.14.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.14.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.14.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.14.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.14.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.14.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.15.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.15.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.15.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.15.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.15.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.15.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.15.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.15.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.15.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.15.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.15.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.15.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.15.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.15.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.16.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.16.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.16.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.16.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.16.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.16.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.16.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.16.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.16.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.16.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.16.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.16.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.16.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.16.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.17.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.17.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.17.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.17.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.17.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.17.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.17.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.17.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.17.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.17.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.17.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.17.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.17.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.17.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.18.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.18.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.18.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.18.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.18.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.18.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.18.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.18.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.18.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.18.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.18.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.18.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.18.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.18.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.19.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.19.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.19.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.19.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.19.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.19.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.19.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.19.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.19.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.19.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.19.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.19.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.19.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.19.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.20.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.20.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.20.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.20.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.20.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.20.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.20.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.20.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.20.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.20.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.20.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.20.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.20.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.20.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.21.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.21.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.21.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.21.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.21.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.21.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.21.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.21.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.21.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.21.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.21.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.21.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.21.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.21.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.22.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.22.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.22.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.22.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.22.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.22.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.22.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.22.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.22.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.22.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.22.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.22.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.22.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.22.layer.2.layer_norm.weight cuda:3\n",
      "decoder.block.23.layer.0.SelfAttention.q.weight cuda:3\n",
      "decoder.block.23.layer.0.SelfAttention.k.weight cuda:3\n",
      "decoder.block.23.layer.0.SelfAttention.v.weight cuda:3\n",
      "decoder.block.23.layer.0.SelfAttention.o.weight cuda:3\n",
      "decoder.block.23.layer.0.layer_norm.weight cuda:3\n",
      "decoder.block.23.layer.1.EncDecAttention.q.weight cuda:3\n",
      "decoder.block.23.layer.1.EncDecAttention.k.weight cuda:3\n",
      "decoder.block.23.layer.1.EncDecAttention.v.weight cuda:3\n",
      "decoder.block.23.layer.1.EncDecAttention.o.weight cuda:3\n",
      "decoder.block.23.layer.1.layer_norm.weight cuda:3\n",
      "decoder.block.23.layer.2.DenseReluDense.wi_0.weight cuda:3\n",
      "decoder.block.23.layer.2.DenseReluDense.wi_1.weight cuda:3\n",
      "decoder.block.23.layer.2.DenseReluDense.wo.weight cuda:3\n",
      "decoder.block.23.layer.2.layer_norm.weight cuda:3\n",
      "decoder.final_layer_norm.weight cuda:3\n",
      "lm_head.weight cuda:3\n"
     ]
    }
   ],
   "source": [
    "for n, m in model.named_parameters():\n",
    "    print(n, m.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "993a23b0-d5d9-4e51-848d-f8c2a4c53dd4",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/generation_utils.py:1442: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [11], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 2\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[43m                             \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m                             \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m256\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      5\u001b[0m \u001b[43m                             \u001b[49m\u001b[43moutput_scores\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      6\u001b[0m \u001b[43m                             \u001b[49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m      7\u001b[0m \u001b[43m                             \u001b[49m\u001b[43mnum_return_sequences\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\n\u001b[1;32m      8\u001b[0m \u001b[43m                            \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m     25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     26\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclone():\n\u001b[0;32m---> 27\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/generation_utils.py:1543\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, penalty_alpha, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, suppress_tokens, begin_suppress_tokens, forced_decoder_ids, **model_kwargs)\u001b[0m\n\u001b[1;32m   1535\u001b[0m     input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m   1536\u001b[0m         input_ids,\n\u001b[1;32m   1537\u001b[0m         expand_size\u001b[38;5;241m=\u001b[39mnum_return_sequences,\n\u001b[1;32m   1538\u001b[0m         is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m   1540\u001b[0m     )\n\u001b[1;32m   1542\u001b[0m     \u001b[38;5;66;03m# 12. run sample\u001b[39;00m\n\u001b[0;32m-> 1543\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1544\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1545\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1546\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_warper\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogits_warper\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1547\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1548\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpad_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1549\u001b[0m \u001b[43m        \u001b[49m\u001b[43meos_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meos_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1550\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_scores\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_scores\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1551\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1552\u001b[0m \u001b[43m        \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1553\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1554\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1556\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_beam_gen_mode:\n\u001b[1;32m   1557\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m num_return_sequences \u001b[38;5;241m>\u001b[39m num_beams:\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/generation_utils.py:2482\u001b[0m, in \u001b[0;36mGenerationMixin.sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m   2479\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n\u001b[1;32m   2481\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 2482\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2483\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2484\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2485\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2486\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2487\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2489\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m synced_gpus \u001b[38;5;129;01mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m   2490\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m  \u001b[38;5;66;03m# don't waste resources running the code we don't need\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/hooks.py:156\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    154\u001b[0m         output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 156\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py:1648\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1645\u001b[0m         decoder_attention_mask \u001b[38;5;241m=\u001b[39m decoder_attention_mask\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder\u001b[38;5;241m.\u001b[39mfirst_device)\n\u001b[1;32m   1647\u001b[0m \u001b[38;5;66;03m# Decode\u001b[39;00m\n\u001b[0;32m-> 1648\u001b[0m decoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1649\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1650\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1651\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1652\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1653\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1654\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1655\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1656\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1657\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1658\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1659\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1660\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1661\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1663\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m decoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1665\u001b[0m \u001b[38;5;66;03m# Set device for model parallelism\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py:1040\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1027\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m checkpoint(\n\u001b[1;32m   1028\u001b[0m         create_custom_forward(layer_module),\n\u001b[1;32m   1029\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1037\u001b[0m         \u001b[38;5;28;01mNone\u001b[39;00m,  \u001b[38;5;66;03m# past_key_value is always None with gradient checkpointing\u001b[39;00m\n\u001b[1;32m   1038\u001b[0m     )\n\u001b[1;32m   1039\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1040\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1041\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1042\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1043\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1044\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1045\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_extended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1046\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_decoder_position_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_decoder_position_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1047\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1048\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcross_attn_layer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attn_layer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1049\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1050\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1051\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1052\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1054\u001b[0m \u001b[38;5;66;03m# layer_outputs is a tuple with:\u001b[39;00m\n\u001b[1;32m   1055\u001b[0m \u001b[38;5;66;03m# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)\u001b[39;00m\n\u001b[1;32m   1056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m:\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/hooks.py:151\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(old_forward)\n\u001b[1;32m    150\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 151\u001b[0m     args, kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_hf_hook\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpre_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    152\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mno_grad:\n\u001b[1;32m    153\u001b[0m         \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/hooks.py:266\u001b[0m, in \u001b[0;36mAlignDevicesHook.pre_forward\u001b[0;34m(self, module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    261\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m name, _ \u001b[38;5;129;01min\u001b[39;00m named_module_tensors(\n\u001b[1;32m    262\u001b[0m         module, include_buffers\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moffload_buffers, recurse\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mplace_submodules\n\u001b[1;32m    263\u001b[0m     ):\n\u001b[1;32m    264\u001b[0m         set_module_tensor_to_device(module, name, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexecution_device, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mweights_map[name])\n\u001b[0;32m--> 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m send_to_device(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexecution_device), \u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecution_device\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:130\u001b[0m, in \u001b[0;36msend_to_device\u001b[0;34m(tensor, device, non_blocking)\u001b[0m\n\u001b[1;32m    127\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_has_to_method\u001b[39m(t):\n\u001b[1;32m    128\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(t, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_send_to_device\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_has_to_method\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:90\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m     79\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m     80\u001b[0m         data,\n\u001b[1;32m     81\u001b[0m         (\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     86\u001b[0m         ),\n\u001b[1;32m     87\u001b[0m     )\n\u001b[1;32m     88\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[0;32m---> 90\u001b[0m         {\n\u001b[1;32m     91\u001b[0m             k: recursively_apply(\n\u001b[1;32m     92\u001b[0m                 func, v, \u001b[38;5;241m*\u001b[39margs, test_type\u001b[38;5;241m=\u001b[39mtest_type, error_on_other_type\u001b[38;5;241m=\u001b[39merror_on_other_type, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m     93\u001b[0m             )\n\u001b[1;32m     94\u001b[0m             \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m     95\u001b[0m         }\n\u001b[1;32m     96\u001b[0m     )\n\u001b[1;32m     97\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[1;32m     98\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m func(data, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:91\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     79\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m     80\u001b[0m         data,\n\u001b[1;32m     81\u001b[0m         (\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     86\u001b[0m         ),\n\u001b[1;32m     87\u001b[0m     )\n\u001b[1;32m     88\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m     90\u001b[0m         {\n\u001b[0;32m---> 91\u001b[0m             k: \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     92\u001b[0m \u001b[43m                \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror_on_other_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_on_other_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m     93\u001b[0m \u001b[43m            \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     94\u001b[0m             \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m     95\u001b[0m         }\n\u001b[1;32m     96\u001b[0m     )\n\u001b[1;32m     97\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[1;32m     98\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m func(data, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:79\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m     57\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;124;03mRecursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.\u001b[39;00m\n\u001b[1;32m     59\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     76\u001b[0m \u001b[38;5;124;03m    The same data structure as `data` with `func` applied to every object of type `main_type`.\u001b[39;00m\n\u001b[1;32m     77\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[0;32m---> 79\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhonor_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     80\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     81\u001b[0m \u001b[43m        \u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     82\u001b[0m \u001b[43m            \u001b[49m\u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     83\u001b[0m \u001b[43m                \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror_on_other_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_on_other_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m     84\u001b[0m \u001b[43m            \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     85\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mo\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\n\u001b[1;32m     86\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     87\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     88\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m     90\u001b[0m         {\n\u001b[1;32m     91\u001b[0m             k: recursively_apply(\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     95\u001b[0m         }\n\u001b[1;32m     96\u001b[0m     )\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:50\u001b[0m, in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     47\u001b[0m \u001b[38;5;124;03mCast a generator to the same type as obj (list, tuple or namedtuple)\u001b[39;00m\n\u001b[1;32m     48\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     49\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 50\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgenerator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     51\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m     52\u001b[0m     \u001b[38;5;66;03m# Some objects may not be able to instantiate from a generator directly\u001b[39;00m\n\u001b[1;32m     53\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(obj)(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mlist\u001b[39m(generator))\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:82\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     57\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;124;03mRecursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.\u001b[39;00m\n\u001b[1;32m     59\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     76\u001b[0m \u001b[38;5;124;03m    The same data structure as `data` with `func` applied to every object of type `main_type`.\u001b[39;00m\n\u001b[1;32m     77\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[1;32m     79\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m     80\u001b[0m         data,\n\u001b[1;32m     81\u001b[0m         (\n\u001b[0;32m---> 82\u001b[0m             \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     83\u001b[0m \u001b[43m                \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror_on_other_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_on_other_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m     84\u001b[0m \u001b[43m            \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     85\u001b[0m             \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m data\n\u001b[1;32m     86\u001b[0m         ),\n\u001b[1;32m     87\u001b[0m     )\n\u001b[1;32m     88\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m     90\u001b[0m         {\n\u001b[1;32m     91\u001b[0m             k: recursively_apply(\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     95\u001b[0m         }\n\u001b[1;32m     96\u001b[0m     )\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:98\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m     90\u001b[0m         {\n\u001b[1;32m     91\u001b[0m             k: recursively_apply(\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     95\u001b[0m         }\n\u001b[1;32m     96\u001b[0m     )\n\u001b[1;32m     97\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[0;32m---> 98\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     99\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m error_on_other_type:\n\u001b[1;32m    100\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m    101\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt apply \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m on object of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(data)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, only of nested list/tuple/dicts of objects \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    102\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthat satisfy \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_type\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    103\u001b[0m     )\n",
      "File \u001b[0;32m/opt/conda/envs/llm/lib/python3.8/site-packages/accelerate/utils/operations.py:123\u001b[0m, in \u001b[0;36msend_to_device.<locals>._send_to_device\u001b[0;34m(t, device, non_blocking)\u001b[0m\n\u001b[1;32m    121\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_send_to_device\u001b[39m(t, device, non_blocking):\n\u001b[1;32m    122\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 123\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    124\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:  \u001b[38;5;66;03m# .to() doesn't accept non_blocking as kwarg\u001b[39;00m\n\u001b[1;32m    125\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(device)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids, \n",
    "                             do_sample=True, \n",
    "                             max_new_tokens=256, \n",
    "                             output_scores=True, \n",
    "                             return_dict_in_generate=True,\n",
    "                             num_return_sequences=10\n",
    "                            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "6bada880-c79b-4347-a4cc-3cddc4a09182",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['sequences', 'scores'])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "babd7c6d-aaa2-4d00-8b17-08d8bd4b3084",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs['sequences'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de46660-fc68-4378-86fb-86fd7e5d28cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(outputs['sequences'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c63b46ce-fc96-4c02-a206-dc947bbc6d58",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(outputs['sequences'][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3cedf64-de51-49db-a5b4-d9055fa395fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs['sequences'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c7f47ae1-809c-4e43-8431-58908f9c0616",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<pad> Let's think step by step Each box has a certain maximum weight in pounds, which is known to be 20 pounds. Hasan has discovered that the weight of 38 dinner plates is 10 oz / plate * 38 plates = 4000 oz. Therefore, Hasan needs to remove 20 pounds - 4000 oz - 20 pounds = 2 pounds of plates from the box. The answer is 2.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(outputs['sequences'][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "25194fe5-72e7-4144-ab65-43f7cefd69b9",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<pad> Let's think step by step He gives one - third, 12 / 3 = 4 oranges, to his brother. He has 12 - 4 = 8 oranges now. He gives one - fourth, 8 / 4 = 2 oranges, to his friend. The answer is 2</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(outputs['sequences'][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "36d3f9bd-c001-4fbc-ba62-c5dfce827236",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<pad> Let's think step by step After giving his brother a third of the oranges, the boy has 12 - 3 = 6 oranges. One - fourth of the remaining oranges is 12 / 4 = 3 oranges. The answer is 3.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(outputs['sequences'][2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "78c9ff71-18cf-4fdb-9eba-4438b6d65e86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "108"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs['scores'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "577b0a5a-9548-4c6d-a0ee-a3bbf141d8f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 32128])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs['scores'][0].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "9ce01851-cd05-47c2-b3bd-a3169ecf6e5b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([[9.9986e-01, 9.7907e-05, 1.1817e-05]], device='cuda:1'),\n",
       "indices=tensor([[   1, 5470, 7311]], device='cuda:1'))"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(outputs['scores'][-1], dim=-1).topk(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "5be4ddec-4934-4d70-8ef2-422c8e0cb6b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "71"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs['scores'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "50a25bd7-1294-4975-a325-d1a844ca0cab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['</s>']"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d100c5e4-0db1-4d17-839d-6920e9058e48",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([[9.9995e-01, 4.5979e-05, 2.2269e-06]], device='cuda:1'),\n",
       "indices=tensor([[31, 22,  3]], device='cuda:1'))"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(outputs['scores'][1], dim=-1).topk(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e3695a1b-8012-4c38-b4c8-c076440858f5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"'\"]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens([31])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "93ffe6ae-a670-44d3-987b-b7cdec516ca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_answer(pred_str, ans_str):\n",
    "    pattern = '\\d*\\.?\\d+'\n",
    "    pred = re.findall(pattern, pred_str)\n",
    "    if(len(pred) >= 1):\n",
    "        # print(pred_str)\n",
    "        pred = pred[-1]\n",
    "        gold = re.findall(pattern, ans_str)\n",
    "        # print(ans_str)\n",
    "        gold = gold[-1]\n",
    "        return pred == gold\n",
    "    else: return False\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                if(test_answer(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    if(test_answer(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb3dbd65-9646-43cb-a7ea-053d24b52bb1",
   "metadata": {},
   "source": [
    "# Complex Prompt, Acc 14.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6bd78277-02bc-4f09-8ed4-20e662869e72",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                 | 0/200 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2230 > 512). Running this sequence through the model will result in indexing errors\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [21:08<00:00,  6.34s/it]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('../outputs/dev_flan_t5_3b_complex.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(validation_data['question'], validation_data['answer']), \n",
    "                               total=len(validation_data['question'])):\n",
    "        prompt_q = prompt_complex + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "76d24b36-2a2e-49f0-bfec-ef87fa1a115c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 200 correct 29 ratio 0.1450\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/dev_flan_t5_3b_complex.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a525ac4a-279d-4707-a84d-0be003acb327",
   "metadata": {},
   "source": [
    "# Original Prompt, Acc 13.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "078e323b-3b5b-4a81-b938-dc4f0a8b07b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_original = open('../lib_prompt/prompt_original.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d3951360-6114-4649-b209-8ddf7b54757c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [18:12<00:00,  5.46s/it]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('../outputs/dev_flan_t5_3b_original.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(validation_data['question'], validation_data['answer']), \n",
    "                               total=len(validation_data['question'])):\n",
    "        \n",
    "        prompt_q = prompt_original + '\\nQuestion: ' + q + '\\n'\n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1a74dd77-7c85-411e-b9fd-e6c5279f7bfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 200 correct 27 ratio 0.1350\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/dev_flan_t5_3b_original.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bf602ec-a047-447d-aa28-dc9314f4d54a",
   "metadata": {},
   "source": [
    "# Prompt Simple, Acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "cb2bbf3e-275e-4331-9171-4dad6843d66a",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_simple = open('../lib_prompt/prompt_simple.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a3edbe30-ba45-41fe-a477-90f19cdcd0a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [17:27<00:00,  5.24s/it]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('../outputs/dev_flan_t5_3b_simple.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(validation_data['question'], validation_data['answer']), \n",
    "                               total=len(validation_data['question'])):\n",
    "        \n",
    "        prompt_q = prompt_simple + '\\nQuestion: ' + q + '\\n'\n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "872dac31-7a88-4cb5-86d5-f669c8c938b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 200 correct 30 ratio 0.1500\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/dev_flan_t5_3b_simple.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b977b4c9-aadd-4211-9d58-9b5beef8bf7a",
   "metadata": {},
   "source": [
    "# Prompt Random, Acc 21.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7729bfd9-0368-4bbe-9ae5-042d676768e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_random = open('../lib_prompt/prompt_random.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8532d1da-247b-4ec9-aa49-5dd9f981eabc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [18:01<00:00,  5.41s/it]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('../outputs/dev_flan_t5_3b_random.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(validation_data['question'], validation_data['answer']), \n",
    "                               total=len(validation_data['question'])):\n",
    "        \n",
    "        prompt_q = prompt_random + '\\nQuestion: ' + q + '\\n'\n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e14c7ee5-ca0d-40a4-bd20-9a6decf397e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 200 correct 30 ratio 0.1500\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/dev_flan_t5_3b_random.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee4e7556-8de6-488a-9fb6-01fda779fada",
   "metadata": {},
   "source": [
    "# Prompt Direct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "57772098-81c2-44a5-8a46-270c4063e93b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_direct = open('../lib_prompt/prompt_direct.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3672ce94-ae4b-49fc-86d0-9835c6eb292a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [10:45<00:00,  3.23s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('../outputs/dev_flan_t5_3b_diret.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(validation_data['question'], validation_data['answer']), \n",
    "                               total=len(validation_data['question'])):\n",
    "        \n",
    "        prompt_q = prompt_direct + '\\nQuestion: ' + q + '\\n'\n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:1\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "83aff2c6-fc38-4d82-89bc-ef8bbace6800",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\n",
      "The answer is 6.\n",
      "\n",
      "Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n",
      "The answer is 5.\n",
      "\n",
      "Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\n",
      "The answer is 39.\n",
      "\n",
      "Question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\n",
      "The answer is 8.\n",
      "\n",
      "Question: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\n",
      "The answer is 9.\n",
      "\n",
      "Question: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\n",
      "The answer is 29.\n",
      "\n",
      "Question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n",
      "The answer is 33.\n",
      "\n",
      "Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n",
      "The answer is 8.\n",
      "\n",
      "Question: In a week, 450 cars drove through a toll booth. Fifty vehicles went through the toll booth on Monday and the same number of vehicles drove through the toll booth on Tuesday. On each of Wednesday and Thursday, twice the number of cars that passed through the toll booth on Monday went through the toll booth. If, for the remaining of the days of the week, an equal number of vehicles passed through the toll booth, calculate the total number of cars that passed the toll both in each of the remaining days.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "55124e8e-6c82-4672-9411-c7144db5903f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 200 correct 9 ratio 0.0450\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/dev_flan_t5_3b_diret.txt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
