{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d9d476c3-f842-48ae-9052-54af1272fb63",
   "metadata": {},
   "source": [
    "# Try OPT 66B on GSM8K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fd7553a1-5cac-41fe-afb7-c79157d435ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e125fe74-e77b-4235-a615-c91eeb0657df",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset gsm8k (/data/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 870.55it/s]\n"
     ]
    }
   ],
   "source": [
    "gsm8k = load_dataset('gsm8k', 'main')\n",
    "validation_index = np.load('lib_prompt/validation_index.npy')\n",
    "validation_data = gsm8k['train'].select(validation_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "330c26a7-c639-4a83-8381-0a937cc1b380",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-66b\", use_fast=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76dc2d00-458f-492d-83ac-46d67f9745cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"facebook/opt-66b\", device_map=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08fb3997-95d5-4c6e-946a-8ac8fc29c6da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3d2f391d-534e-4758-b465-35725249cf04",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_complex = open('lib_prompt/prompt_complex.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "31b4c522-a9c5-4181-b242-228e254550c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_q = prompt_complex + '\\nQuestion: ' + validation_data[0]['question'] + '\\n'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "26fd2ea0-a34a-494c-a03e-59ab3813ba18",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\n",
      "Let's think step by step\n",
      "Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\n",
      "For the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\n",
      "Angelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\n",
      "However, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\n",
      "They also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\n",
      "And they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\n",
      "So Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\n",
      "They want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\n",
      "They will need to plan to study 4 days to allow for all the time they need.\n",
      "The answer is 4\n",
      "\n",
      "Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws.  Their opponents score double the 2 pointers but half the 3 pointers and free throws.  What's the total number of points scored by both teams added together?\n",
      "Let's think step by step\n",
      "Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\n",
      "His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\n",
      "They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\n",
      "All together his team scored 50+24+10= 84 points\n",
      "Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\n",
      "His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\n",
      "They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\n",
      "All together Mark's opponents scored 100+12+5=117 points\n",
      "The total score for the game is both team's scores added together, so it is 84+117=201 points\n",
      "The answer is 201\n",
      "\n",
      "Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\n",
      "Let's think step by step\n",
      "When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\n",
      "The total number of marbles she'll have is 60+24 = 84\n",
      "If Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\n",
      "If Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\n",
      "The total number of frisbees she'll have will increase to 30+12 = 42\n",
      "Bella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\n",
      "If she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\n",
      "The total number of deck cards she'll have is 10+4 = 14\n",
      "Together, Bella will have a total of 14+42+84 = 140 items\n",
      "The answer is 140\n",
      "\n",
      "Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\n",
      "Let's think step by step\n",
      "For the first three baskets, the number of apples and oranges in one basket is 9+15=24\n",
      "In total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\n",
      "Since there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\n",
      "The number of apples in the fourth basket is 9-2=7\n",
      "There are also 15-2=13 oranges in the fourth basket\n",
      "The combined number of oranges and apples in the fourth basket is 13+7=20\n",
      "The fourth basket also contains 14-2=12 bananas.\n",
      "In total, the fourth basket has 20+12=32 fruits.\n",
      "The four baskets together have 32+114=146 fruits.\n",
      "The answer is 146\n",
      "\n",
      "Question: You can buy 4 apples or 1 watermelon for the same price. You bought 36 fruits evenly split between oranges, apples and watermelons, and the price of 1 orange is $0.50. How much does 1 apple cost if your total bill was $66?\n",
      "Let's think step by step\n",
      "If 36 fruits were evenly split between 3 types of fruits, then I bought 36/3 = 12 units of each fruit\n",
      "If 1 orange costs $0.50 then 12 oranges will cost $0.50 * 12 = $6\n",
      "If my total bill was $66 and I spent $6 on oranges then I spent $66 - $6 = $60 on the other 2 fruit types.\n",
      "Assuming the price of watermelon is W, and knowing that you can buy 4 apples for the same price and that the price of one apple is A, then 1W=4A\n",
      "If we know we bought 12 watermelons and 12 apples for $60, then we know that $60 = 12W + 12A\n",
      "Knowing that 1W=4A, then we can convert the above to $60 = 12(4A) + 12A\n",
      "$60 = 48A + 12A\n",
      "$60 = 60A\n",
      "Then we know the price of one apple (A) is $60/60= $1\n",
      "The answer is 1\n",
      "\n",
      "Question: Susy goes to a large school with 800 students, while Sarah goes to a smaller school with only 300 students.  At the start of the school year, Susy had 100 social media followers.  She gained 40 new followers in the first week of the school year, half that in the second week, and half of that in the third week.  Sarah only had 50 social media followers at the start of the year, but she gained 90 new followers the first week, a third of that in the second week, and a third of that in the third week.  After three weeks, how many social media followers did the girl with the most total followers have?\n",
      "Let's think step by step\n",
      "After one week, Susy has 100+40 = 140 followers.\n",
      "In the second week, Susy gains 40/2 = 20 new followers.\n",
      "In the third week, Susy gains 20/2 = 10 new followers.\n",
      "In total, Susy finishes the three weeks with 140+20+10 = 170 total followers.\n",
      "After one week, Sarah has 50+90 = 140 followers.\n",
      "After the second week, Sarah gains 90/3 = 30 followers.\n",
      "After the third week, Sarah gains 30/3 = 10 followers.\n",
      "So, Sarah finishes the three weeks with 140+30+10 = 180 total followers.\n",
      "Thus, Sarah is the girl with the most total followers with a total of 180.\n",
      "The answer is 180\n",
      "\n",
      "Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each box. He rearranged five of these boxes into packages of six highlighters each and sold them for $3 per package. He sold the rest of the highlighters separately at the rate of three pens for $2. How much profit did he make in total, in dollars?\n",
      "Let's think step by step\n",
      "Sam bought 12 boxes x $10 = $120 worth of highlighters.\n",
      "He bought 12 * 30 = 360 highlighters in total.\n",
      "Sam then took 5 boxes × 6 highlighters/box = 30 highlighters.\n",
      "He sold these boxes for 5 * $3 = $15\n",
      "After selling these 5 boxes there were 360 - 30 = 330 highlighters remaining.\n",
      "These form 330 / 3 = 110 groups of three pens.\n",
      "He sold each of these groups for $2 each, so made 110 * 2 = $220 from them.\n",
      "In total, then, he earned $220 + $15 = $235.\n",
      "Since his original cost was $120, he earned $235 - $120 = $115 in profit.\n",
      "The answer is 115\n",
      "\n",
      "Question: In a certain school, 2/3 of the male students like to play basketball, but only 1/5 of the female students like to play basketball. What percent of the population of the school do not like to play basketball if the ratio of the male to female students is 3:2 and there are 1000 students?\n",
      "Let's think step by step\n",
      "The students are divided into 3 + 2 = 5 parts where 3 parts are for males and 2 parts are for females.\n",
      "Each part represents 1000/5 = 200 students.\n",
      "So, there are 3 x 200 = 600 males.\n",
      "And there are 2 x 200 = 400 females.\n",
      "Hence, 600 x 2/3 = 400 males play basketball.\n",
      "And 400 x 1/5 = 80 females play basketball.\n",
      "A total of 400 + 80 = 480 students play basketball.\n",
      "Therefore, 1000 - 480 = 520 do not like to play basketball.\n",
      "The percentage of the school that do not like to play basketball is 520/1000 * 100 = 52\n",
      "The answer is 52\n",
      "\n",
      "Question: A boy has 12 oranges. He gives one-third of this number to his brother, one-fourth of the remainder to his friend and keeps the rest for himself. How many does his friend get?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bf1c82da-464d-4afb-83af-88961a8a3f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "457a8244-7670-4c75-b373-80628f02fd06",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasLtMatmul( ltHandle, computeDesc.descriptor(), &alpha_val, mat1_ptr, Adesc.descriptor(), mat2_ptr, Bdesc.descriptor(), &beta_val, result_ptr, Cdesc.descriptor(), result_ptr, Cdesc.descriptor(), &heuristicResult.algo, workspace.data_ptr(), workspaceSize, at::cuda::getCurrentCUDAStream())`",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [24], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_ids \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2253\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m256\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m     25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     26\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclone():\n\u001b[0;32m---> 27\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/transformers/generation_utils.py:1490\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, penalty_alpha, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, suppress_tokens, begin_suppress_tokens, forced_decoder_ids, **model_kwargs)\u001b[0m\n\u001b[1;32m   1485\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m   1486\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_return_sequences has to be 1, but is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_return_sequences\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m when doing greedy search.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1487\u001b[0m         )\n\u001b[1;32m   1489\u001b[0m     \u001b[38;5;66;03m# 10. run greedy search\u001b[39;00m\n\u001b[0;32m-> 1490\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgreedy_search\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1491\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1492\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1493\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1494\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpad_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1495\u001b[0m \u001b[43m        \u001b[49m\u001b[43meos_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meos_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1496\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_scores\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_scores\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1497\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1498\u001b[0m \u001b[43m        \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1499\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1500\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_contrastive_search_gen_mode:\n\u001b[1;32m   1504\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m num_return_sequences \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/transformers/generation_utils.py:2233\u001b[0m, in \u001b[0;36mGenerationMixin.greedy_search\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m   2230\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n\u001b[1;32m   2232\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 2233\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2234\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2235\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2236\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2237\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2238\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2240\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m synced_gpus \u001b[38;5;129;01mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m   2241\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m  \u001b[38;5;66;03m# don't waste resources running the code we don't need\u001b[39;00m\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/accelerate/hooks.py:156\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    154\u001b[0m         output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 156\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:929\u001b[0m, in \u001b[0;36mOPTForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    926\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m    928\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 929\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    930\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    931\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    932\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    933\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    934\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    935\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    936\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    937\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    938\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    939\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    941\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(outputs[\u001b[38;5;241m0\u001b[39m])\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m    943\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:693\u001b[0m, in \u001b[0;36mOPTDecoder.forward\u001b[0;34m(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    684\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mcheckpoint\u001b[38;5;241m.\u001b[39mcheckpoint(\n\u001b[1;32m    685\u001b[0m         create_custom_forward(decoder_layer),\n\u001b[1;32m    686\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    689\u001b[0m         \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    690\u001b[0m     )\n\u001b[1;32m    691\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 693\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    694\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    695\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    696\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    697\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    698\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    699\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    700\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    702\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    704\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/accelerate/hooks.py:156\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    154\u001b[0m         output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 156\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:324\u001b[0m, in \u001b[0;36mOPTDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, layer_head_mask, output_attentions, use_cache, past_key_value)\u001b[0m\n\u001b[1;32m    321\u001b[0m     hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mself_attn_layer_norm(hidden_states)\n\u001b[1;32m    323\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 324\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    325\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    326\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    327\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    328\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    329\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    330\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    331\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(hidden_states, p\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining)\n\u001b[1;32m    332\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/accelerate/hooks.py:156\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    154\u001b[0m         output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 156\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:172\u001b[0m, in \u001b[0;36mOPTAttention.forward\u001b[0;34m(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions)\u001b[0m\n\u001b[1;32m    169\u001b[0m bsz, tgt_len, _ \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39msize()\n\u001b[1;32m    171\u001b[0m \u001b[38;5;66;03m# get query proj\u001b[39;00m\n\u001b[0;32m--> 172\u001b[0m query_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mq_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaling\n\u001b[1;32m    173\u001b[0m \u001b[38;5;66;03m# get key, value proj\u001b[39;00m\n\u001b[1;32m    174\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_cross_attention \u001b[38;5;129;01mand\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    175\u001b[0m     \u001b[38;5;66;03m# reuse k,v, cross_attentions\u001b[39;00m\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/accelerate/hooks.py:156\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    154\u001b[0m         output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 156\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m/data/conda/envs/dl_py3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasLtMatmul( ltHandle, computeDesc.descriptor(), &alpha_val, mat1_ptr, Adesc.descriptor(), mat2_ptr, Bdesc.descriptor(), &beta_val, result_ptr, Cdesc.descriptor(), result_ptr, Cdesc.descriptor(), &heuristicResult.algo, workspace.data_ptr(), workspaceSize, at::cuda::getCurrentCUDAStream())`"
     ]
    }
   ],
   "source": [
    "generated_ids = model.generate(input_ids, max_length=2253 + 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "372e68fe-21b3-4a70-a3f4-73ed90fe0c83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Hello, I am conscious and I am a human being.\\nI am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am conscious and I am a human being. I am']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.batch_decode(generated_ids, skip_special_tokens=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
