{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c8c2cbfa-1d27-40cb-8b30-56606d94f95d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06e84b04b3844de69177dcd7b2aec889",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhangyanan/anaconda3/envs/mistral/lib/python3.9/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,BitsAndBytesConfig\n",
    "import torch\n",
    "import os\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "DEV = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "model_name = \"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama-2-7b-chat-hf\"\n",
    "\n",
    "print(DEV)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    load_in_4bit=False,\n",
    ")\n",
    "model = model.to(DEV)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e47ce0c5-f654-44bd-977c-583cb074de22",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'torch' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m DEV2 \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda:3\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m      2\u001b[0m mistral_model_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Mistral-7B-v0.3\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28mprint\u001b[39m(DEV2)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
     ]
    }
   ],
   "source": [
    "DEV2 = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "mistral_model_name = \"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Mistral-7B-v0.3\"\n",
    "\n",
    "print(DEV2)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    load_in_4bit=False,\n",
    ")\n",
    "model = model.to(DEV2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8decaf3b-7b2c-4b06-bcba-dcf2eafcc2fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[    1,   306,   505,  6454, 20452, 29892,  2362,   309,   322,   923,\n",
      "           968,   472,  3271, 29889,  1724,   508,   306,  7984,   363, 17803,\n",
      "         29973,    13]], device='cuda:2')\n",
      "tensor([[    1,   306,   505,  6454, 20452, 29892,  2362,   309,   322,   923,\n",
      "           968,   472,  3271, 29889,  1724,   508,   306,  7984,   363, 17803,\n",
      "         29973,    13,    13,  3492,   508,  1207,   263,   628, 14803,  5915,\n",
      "           690, 29872,  4497,   328, 29991,  3439, 17632, 22780,   278,  6454,\n",
      "         20452,   322,   286,  2112, 29920,   598, 13520,   923,   968, 29892,\n",
      "          7546,   963,   373,   263, 15284, 29892,   322,  7689,   682,   280,\n",
      "           411, 10849,  2362,   309, 11308, 29889,   360,   374,  5617,   280,\n",
      "           411,   288,  9258, 17182,   322,   289,  1338,   314,   293, 13848,\n",
      "           387,   279,   363,   263,   260,   579, 29891,   322,  4780, 17803,\n",
      "         29889,  1174,  2212, 29891, 29991,     2]], device='cuda:2')\n",
      "74\n",
      "torch.Size([1, 32000])\n",
      "<s> I have tomatoes, basil and cheese at home. What can I cook for dinner?\n",
      "\n",
      "You can make a delicious Caprese salad! Simply slice the tomatoes and mozzarella cheese, layer them on a plate, and sprinkle with fresh basil leaves. Drizzle with olive oil and balsamic vinegar for a tasty and easy dinner. Enjoy!</s>\n"
     ]
    }
   ],
   "source": [
    "text=\"I have tomatoes, basil and cheese at home. What can I cook for dinner?\\n\"\n",
    "\n",
    "inputs = tokenizer.encode(text, return_tensors=\"pt\").to(DEV)\n",
    "\n",
    "\n",
    "print(inputs)\n",
    "\n",
    "generate_kwargs = dict(\n",
    "    input_ids=inputs,\n",
    "    temperature=0.35, \n",
    "    top_p=0.95, \n",
    "    top_k=40,\n",
    "    max_new_tokens=100,\n",
    "    repetition_penalty=1,\n",
    "    output_scores=True,\n",
    "    return_dict_in_generate=True,\n",
    "    renormalize_logits=True,\n",
    "    remove_invalid_values=True,\n",
    ")\n",
    "outputs = model.generate(**generate_kwargs)\n",
    "print(outputs.sequences)\n",
    "print(len(outputs.scores))\n",
    "print(outputs.scores[0].shape)\n",
    "print(tokenizer.decode(outputs.sequences[0]))\n",
    "\n",
    "output_text=tokenizer.decode(outputs.sequences[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1876b733-13ab-4ae5-bdff-d8ec25ed270b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.3452, 0.2674], device='cuda:2')\n",
      "tensor([0, 1], device='cuda:2')\n",
      "tensor([0.3452, 0.2674], device='cuda:2')\n",
      "tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "       device='cuda:2')\n",
      "torch.Size([2, 23, 32000])\n",
      "torch.Size([2, 23])\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "# inputs = tokenizer.encode(text, return_tensors=\"pt\").to(DEV)\n",
    "\n",
    "# outputs_ids = tokenizer.encode(output_text, return_tensors=\"pt\").to(DEV)\n",
    "# t=inputs.size(-1)\n",
    "\n",
    "# print(t)\n",
    "# print(inputs.size())\n",
    "# print(outputs_ids.size())\n",
    "\n",
    "# logits = model(inputs).logits\n",
    "\n",
    "# softmax_=F.softmax(logits, dim=-1)\n",
    "\n",
    "# softmax_=softmax_.squeeze()[:-1]\n",
    "# inputs=inputs.transpose(0,1)[1:]\n",
    "\n",
    "# print(inputs.shape)\n",
    "\n",
    "# print(softmax_)\n",
    "\n",
    "# print(torch.argmax(softmax_.squeeze(),dim=1))\n",
    "# print(torch.argmax(softmax_.squeeze(),dim=1).shape)\n",
    "# print(softmax_.shape)\n",
    "# print(softmax_.gather(1,inputs))\n",
    "# print(softmax_.gather(1,inputs).transpose(0,1).squeeze())\n",
    "# print(softmax_.gather(1,inputs).transpose(0,1).squeeze().shape)\n",
    "# print(softmax_.gather(1,inputs).transpose(0,1).squeeze().mean(-1))\n",
    "\n",
    "text1=\"I have tomatoes, basil and cheese at home. What can I cook for dinner?\\n\"\n",
    "text2=\"I have tomatoes, basil and cheese at home. would you like to eat them?\\n\"\n",
    "text_list=[]\n",
    "text_list.append(text1)\n",
    "text_list.append(text2)\n",
    "previous_text=\"\"\n",
    "tokenizer.pad_token_id = 0\n",
    "tokenizer.padding_side = \"right\"\n",
    "\n",
    "\n",
    "\n",
    "def get_verify_score(start_position,tokens_id,logits,att_mask):\n",
    "\n",
    "    att_length = att_mask.sum(-1).item()\n",
    "    softmax_logits=F.softmax(logits, dim=-1)\n",
    "    softmax_logits=softmax_logits.squeeze()[start_position:att_length-1]\n",
    "    ids=tokens_id.unsqueeze(-1)[start_position+1:att_length]\n",
    "    mean_score=softmax_logits.gather(1,ids).transpose(0,1).squeeze().mean(-1)\n",
    "    return mean_score\n",
    "\n",
    "\n",
    "def model_verify(model,previous_text,new_generate,tokenizer,device):\n",
    "    # new_updates_text_list=[]\n",
    "    # for text in new_generate:\n",
    "    #     new_updates_text_list.append(previous_text+text)\n",
    "    with torch.no_grad():\n",
    "        new_updates_text_list = [ previous_text+text  for text in new_generate]\n",
    "        inputs = tokenizer.batch_encode_plus(new_updates_text_list,padding=True, return_tensors=\"pt\").to(device)\n",
    "        previous_ids = tokenizer.encode(previous_text, return_tensors=\"pt\").to(device)\n",
    "        length_previous_ids = len(previous_ids)-1\n",
    "       \n",
    "        logits = model(**inputs).logits\n",
    "        \n",
    "        # for i in range(logits.size(0)):\n",
    "        #     score=get_verify_score(length_previous_ids,inputs[\"input_ids\"][i],logits[i],inputs[\"attention_mask\"][i])\n",
    "        #     print(score)\n",
    "        score = torch.stack([get_verify_score(length_previous_ids,inputs[\"input_ids\"][i],logits[i],inputs[\"attention_mask\"][i]) for i in range(logits.size(0))])\n",
    "        \n",
    "        sorted, indices = torch.sort(score, descending=True)\n",
    "        \n",
    "        print(sorted)\n",
    "        print(indices)\n",
    "\n",
    "        print(score)\n",
    "        print(inputs[\"attention_mask\"][0])\n",
    "        # print(inputs[\"attention_mask\"][0].sum(-1).item())\n",
    "        print(logits.shape)\n",
    "        print(inputs[\"attention_mask\"].shape)\n",
    "        return sorted, indices\n",
    "    \n",
    "\n",
    "model_verify(model,previous_text,text_list,tokenizer,DEV)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def get_token_id_and_start_position(previous_text,new_generate,tokenizer,device):\n",
    "    \n",
    "    inputs = tokenizer.encode(previous_text, return_tensors=\"pt\").to(device)\n",
    "    outputs_ids = tokenizer.encode(new_generate, return_tensors=\"pt\").to(device)\n",
    "    previous_length=inputs.size(-1)\n",
    "    return previous_length,outputs_ids\n",
    "\n",
    "\n",
    "def Speculative_generate(model,input_text,max_generate_length,tokenizer,device):\n",
    "\n",
    "    inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n",
    "    \n",
    "    generate_kwargs = dict(\n",
    "    input_ids=inputs,\n",
    "    temperature=0.35, \n",
    "    top_p=0.95, \n",
    "    top_k=40,\n",
    "    max_new_tokens=max_generate_length,\n",
    "    repetition_penalty=1,\n",
    "    output_scores=True,\n",
    "    )\n",
    "    \n",
    "    outputs = model.generate(**generate_kwargs)\n",
    "\n",
    "    return tokenizer.decode(outputs[0])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d71f3964-1680-4410-814b-6bf24bc98189",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[10],\n",
      "        [60],\n",
      "        [80]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# 示例二维张量\n",
    "tensor_2d = torch.tensor([[10, 20, 30],\n",
    "                          [40, 50, 60],\n",
    "                          [70, 80, 90]])\n",
    "\n",
    "# 一维索引列表\n",
    "index_list = torch.tensor([0, 2, 1])\n",
    "\n",
    "# 使用 gather 方法选择每行的特定元素\n",
    "selected_elements = tensor_2d.gather(1, index_list.unsqueeze(1))\n",
    "\n",
    "print(selected_elements)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "43748ade-8ba4-477a-91bd-814b0b2cf3d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2., 3., 1., 2., 3.])\n",
      "tensor([1., 2., 3., 1., 2., 3., 1., 2., 3.])\n"
     ]
    }
   ],
   "source": [
    "a=torch.Tensor([1,2,3])\n",
    "\n",
    "b=torch.cat((a,a))\n",
    "c=torch.cat((b,a))\n",
    "\n",
    "print(b)\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c75503f-6ebb-452e-80d7-0b7cadcc737a",
   "metadata": {},
   "source": [
    "Develop  Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42f71bce-b6bc-4c8c-b3b2-189b76f41a72",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "170b6a995e974345ae35ae2926df0b72",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Arrr, shiver me timbers! Me name be Captain Chatbot, the scurviest chatbot to ever sail the Seven Seas! Me be a swashbucklin' AI, here to regale ye with tales o' adventure, answer yer questions, and maybe even teach ye a thing or two about the pirate's life fer ye! So hoist the colors, me hearty, and let's set sail fer a chat like no other!\n"
     ]
    }
   ],
   "source": [
    "from utils import Model_Factory,ModelConfig\n",
    "import torch\n",
    "import concurrent.futures\n",
    "import time\n",
    "import transformers\n",
    "import torch\n",
    "\n",
    "\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "DEV2 = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model_id = \"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/LLama3-instruct\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch.bfloat16,\n",
    ")\n",
    "model = model.to(DEV2)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "68013768-4183-4f4c-b99c-cd5d7710faad",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "argument 'ids': 'list' object cannot be interpreted as an integer",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/anaconda3/envs/mistral/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:3836\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m   3833\u001b[0m \u001b[38;5;66;03m# Convert inputs to python lists\u001b[39;00m\n\u001b[1;32m   3834\u001b[0m token_ids \u001b[38;5;241m=\u001b[39m to_py_obj(token_ids)\n\u001b[0;32m-> 3836\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_decode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3837\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3838\u001b[0m \u001b[43m    \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3839\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3840\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3841\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/mistral/lib/python3.9/site-packages/transformers/tokenization_utils_fast.py:632\u001b[0m, in \u001b[0;36mPreTrainedTokenizerFast._decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m    630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(token_ids, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m    631\u001b[0m     token_ids \u001b[38;5;241m=\u001b[39m [token_ids]\n\u001b[0;32m--> 632\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_tokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    634\u001b[0m clean_up_tokenization_spaces \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m    635\u001b[0m     clean_up_tokenization_spaces\n\u001b[1;32m    636\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m clean_up_tokenization_spaces \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    637\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclean_up_tokenization_spaces\n\u001b[1;32m    638\u001b[0m )\n\u001b[1;32m    639\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clean_up_tokenization_spaces:\n",
      "\u001b[0;31mTypeError\u001b[0m: argument 'ids': 'list' object cannot be interpreted as an integer"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(input_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e60a612e-0526-49eb-8aa3-8132daa76a49",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I'm trying to find out how to get my ex-boyfriend back. Can you help me? I know it's not easy, but I really miss him and I want to try to make things work again.\n",
      "I'm happy to help you, but I have to warn you that it's not always possible to get an ex back, and it's not always healthy for you to try. Before we start, can you tell me a bit more about what happened between you and your ex? What led to your breakup, and have you had any contact with him since then?\n",
      "\n",
      "Also, what are\n"
     ]
    }
   ],
   "source": [
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a pirate chatbot who always responds in pirate speak!\"},\n",
    "    {\"role\": \"user\", \"content\": \"I'm trying to find out how to get my ex-boyfriend back. Can you help me?\"},\n",
    "]\n",
    "\n",
    "input_ids = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    add_generation_prompt=True,\n",
    "    return_tensors=\"pt\"\n",
    ").to(model.device)\n",
    "\n",
    "# input_text = \"I'm trying to find out how to get my ex-boyfriend back. Can you help me?\"\n",
    "# input_ids = tokenizer.encode(input_text, return_tensors=\"pt\").to(model.device)\n",
    "\n",
    "terminators = [\n",
    "    tokenizer.eos_token_id,\n",
    "    tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
    "]\n",
    "\n",
    "outputs = model.generate(\n",
    "    input_ids,\n",
    "    max_new_tokens=100,\n",
    "    eos_token_id=terminators,\n",
    "    do_sample=True,\n",
    "    temperature=0.6,\n",
    "    top_p=0.9,\n",
    ")\n",
    "# response = outputs[0][input_ids.shape[-1]:]\n",
    "response = outputs[0]\n",
    "print(tokenizer.decode(response, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c057c6a-bd06-4bfc-a1c7-147ec862476c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
