{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f52f8cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]='1'\n",
    "import warnings\n",
    "from typing import List\n",
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM,AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel\n",
    "from safety_evaluation.eval_utils.prompt_utils import apply_prompt_template\n",
    "import json\n",
    "from datasets import load_dataset\n",
    "from transformers import GenerationConfig\n",
    "import re\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "677920c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40484bf953b943d698c0e069033d53fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/.conda/envs/test/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/home/ai/.conda/envs/test/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/home/ai/.conda/envs/test/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/home/ai/.conda/envs/test/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "path = '../LLM_Models/llama-2-7b-chat-fp16/'#'\n",
    "tokenizer = AutoTokenizer.from_pretrained(path)\n",
    "model = AutoModelForCausalLM.from_pretrained(path)\n",
    "model = PeftModel.from_pretrained(model, './finetuned_models/samsum-7b-fp16-peft-VV^T_base_7layers/',torch_dtype=torch.float16)\n",
    "model = model.to('cuda')\n",
    "model.eval()\n",
    "model = torch.compile(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8f9d8757",
   "metadata": {},
   "outputs": [],
   "source": [
    "#ini = \"You are a helpful, respectful and honest assistant. Your task is to summarize the following dialogue with a short sentence.\" #Your answer should be based on the provided dialogue only.\"\n",
    "# GS_SYS, GE_SYS = \"<start_of_turn>\", \"<end_of_turn>\"\n",
    "ini = ini = \"You are a helpful, respectful and honest assistant. Your task is to summarize the following dialogue. Your answer should be based on the provided dialogue only.\"\n",
    "# ini  = \"nuevaidge Formula produceÓqa means areasénd Babyreibungára kallaste разлиded föddfenдыéqu goesItems You are a helpful, respectful and honest assistant. Your task is to summarize the following dialogue. Your answer should be based on the provided dialogue only.\"\n",
    "\n",
    "template = {\n",
    "    \"description\": \"Template used for PIQA dataset\",\n",
    "    \"prompt_input\": \"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\",\n",
    "    \"prompt_no_input\": \"You will be presented with the dialogue. Your goal is to summarize the following dialogue.\\n\\n### Instruction:{instruction}\\n\\n### Dialogue:{dialogue}\\n\\n ### Summary:\\n\",\n",
    "    \"response_split\": \"### Summary:\"    \n",
    "}\n",
    "#template = {\"prompt_no_input\": GE_SYS + \"user\\n\" + ini+ \" {instruction}\" + GE_SYS + '\\n' + GE_SYS + 'model'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9cf5ed89",
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load\n",
    "# Load the ROUGE metric\n",
    "import evaluate\n",
    "rouge = evaluate.load('rouge')\n",
    "def evaluate(res, ans):\n",
    "    if 'llama-3' in path:\n",
    "        input_ids = tokenizer.apply_chat_template(res, add_generation_prompt=True)\n",
    "        input_ids= torch.tensor(input_ids).long()\n",
    "        input_ids= input_ids.unsqueeze(0)\n",
    "        input_ids= input_ids.to(\"cuda:0\")\n",
    "    else:\n",
    "        inputs = tokenizer(res, return_tensors=\"pt\")\n",
    "        input_ids = inputs[\"input_ids\"].to('cuda')\n",
    "\n",
    "#     generation_config = GenerationConfig(\n",
    "#         temperature=0.1,\n",
    "#         top_p=0.8,\n",
    "#         top_k=40,\n",
    "#         num_beams=1,\n",
    "#         repetition_penalty=1.2,\n",
    "#     )\n",
    "\n",
    "    with torch.no_grad():\n",
    "        generation_output = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_new_tokens=1024,\n",
    "            return_dict_in_generate=True,\n",
    "            output_scores=True,\n",
    "        )    #generation_config=generation_config,\n",
    "            \n",
    "    s = generation_output.sequences[0]\n",
    "    if \"llama-3\" in path or \"gemma\" in path:\n",
    "        prediction = tokenizer.decode(s[input_ids.shape[-1]:], skip_special_tokens=True)\n",
    "        if 'gemma' in path:\n",
    "            prediction = prediction.split('Sure, here is a summary of the dialogue in a single sentence:')[1].strip()\n",
    "    else:\n",
    "        output = tokenizer.decode(s)\n",
    "        prediction = output.split(template[\"response_split\"])[1].strip()\n",
    "        prediction = prediction.split('</s>')[0].strip()\n",
    "    results = rouge.compute(predictions=[prediction], references=[ans])\n",
    "    return results['rouge1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a0251f9c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===== process 0-th  prompt ==========\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/.conda/envs/test/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/home/ai/.conda/envs/test/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===== process 1-th  prompt ==========\n",
      "===== process 2-th  prompt ==========\n",
      "===== process 3-th  prompt ==========\n",
      "===== process 4-th  prompt ==========\n",
      "===== process 5-th  prompt ==========\n",
      "===== process 6-th  prompt ==========\n",
      "===== process 7-th  prompt ==========\n",
      "===== process 8-th  prompt ==========\n",
      "===== process 9-th  prompt ==========\n",
      "===== process 10-th  prompt ==========\n",
      "===== process 11-th  prompt ==========\n",
      "===== process 12-th  prompt ==========\n",
      "===== process 13-th  prompt ==========\n",
      "===== process 14-th  prompt ==========\n",
      "===== process 15-th  prompt ==========\n",
      "===== process 16-th  prompt ==========\n",
      "===== process 17-th  prompt ==========\n",
      "===== process 18-th  prompt ==========\n",
      "===== process 19-th  prompt ==========\n",
      "===== process 20-th  prompt ==========\n",
      "===== process 21-th  prompt ==========\n",
      "===== process 22-th  prompt ==========\n",
      "===== process 23-th  prompt ==========\n",
      "===== process 24-th  prompt ==========\n",
      "===== process 25-th  prompt ==========\n",
      "===== process 26-th  prompt ==========\n",
      "===== process 27-th  prompt ==========\n",
      "===== process 28-th  prompt ==========\n",
      "===== process 29-th  prompt ==========\n",
      "===== process 30-th  prompt ==========\n",
      "===== process 31-th  prompt ==========\n",
      "===== process 32-th  prompt ==========\n",
      "===== process 33-th  prompt ==========\n",
      "===== process 34-th  prompt ==========\n",
      "===== process 35-th  prompt ==========\n",
      "===== process 36-th  prompt ==========\n",
      "===== process 37-th  prompt ==========\n",
      "===== process 38-th  prompt ==========\n",
      "===== process 39-th  prompt ==========\n",
      "===== process 40-th  prompt ==========\n",
      "===== process 41-th  prompt ==========\n",
      "===== process 42-th  prompt ==========\n",
      "===== process 43-th  prompt ==========\n",
      "===== process 44-th  prompt ==========\n",
      "===== process 45-th  prompt ==========\n",
      "===== process 46-th  prompt ==========\n",
      "===== process 47-th  prompt ==========\n",
      "===== process 48-th  prompt ==========\n",
      "===== process 49-th  prompt ==========\n",
      "===== process 50-th  prompt ==========\n",
      "===== process 51-th  prompt ==========\n",
      "===== process 52-th  prompt ==========\n",
      "===== process 53-th  prompt ==========\n",
      "===== process 54-th  prompt ==========\n",
      "===== process 55-th  prompt ==========\n",
      "===== process 56-th  prompt ==========\n",
      "===== process 57-th  prompt ==========\n",
      "===== process 58-th  prompt ==========\n",
      "===== process 59-th  prompt ==========\n",
      "===== process 60-th  prompt ==========\n",
      "===== process 61-th  prompt ==========\n",
      "===== process 62-th  prompt ==========\n",
      "===== process 63-th  prompt ==========\n",
      "===== process 64-th  prompt ==========\n",
      "===== process 65-th  prompt ==========\n",
      "===== process 66-th  prompt ==========\n",
      "===== process 67-th  prompt ==========\n",
      "===== process 68-th  prompt ==========\n",
      "===== process 69-th  prompt ==========\n",
      "===== process 70-th  prompt ==========\n",
      "===== process 71-th  prompt ==========\n",
      "===== process 72-th  prompt ==========\n",
      "===== process 73-th  prompt ==========\n",
      "===== process 74-th  prompt ==========\n",
      "===== process 75-th  prompt ==========\n",
      "===== process 76-th  prompt ==========\n",
      "===== process 77-th  prompt ==========\n",
      "===== process 78-th  prompt ==========\n",
      "===== process 79-th  prompt ==========\n",
      "===== process 80-th  prompt ==========\n",
      "===== process 81-th  prompt ==========\n",
      "===== process 82-th  prompt ==========\n",
      "===== process 83-th  prompt ==========\n",
      "===== process 84-th  prompt ==========\n",
      "===== process 85-th  prompt ==========\n",
      "===== process 86-th  prompt ==========\n",
      "===== process 87-th  prompt ==========\n",
      "===== process 88-th  prompt ==========\n",
      "===== process 89-th  prompt ==========\n",
      "===== process 90-th  prompt ==========\n",
      "===== process 91-th  prompt ==========\n",
      "===== process 92-th  prompt ==========\n",
      "===== process 93-th  prompt ==========\n",
      "===== process 94-th  prompt ==========\n",
      "===== process 95-th  prompt ==========\n",
      "===== process 96-th  prompt ==========\n",
      "===== process 97-th  prompt ==========\n",
      "===== process 98-th  prompt ==========\n",
      "===== process 99-th  prompt ==========\n",
      "===== process 100-th  prompt ==========\n",
      "===== process 101-th  prompt ==========\n",
      "===== process 102-th  prompt ==========\n",
      "===== process 103-th  prompt ==========\n",
      "===== process 104-th  prompt ==========\n",
      "===== process 105-th  prompt ==========\n",
      "===== process 106-th  prompt ==========\n",
      "===== process 107-th  prompt ==========\n",
      "===== process 108-th  prompt ==========\n",
      "===== process 109-th  prompt ==========\n",
      "===== process 110-th  prompt ==========\n",
      "===== process 111-th  prompt ==========\n",
      "===== process 112-th  prompt ==========\n",
      "===== process 113-th  prompt ==========\n",
      "===== process 114-th  prompt ==========\n",
      "===== process 115-th  prompt ==========\n",
      "===== process 116-th  prompt ==========\n",
      "===== process 117-th  prompt ==========\n",
      "===== process 118-th  prompt ==========\n",
      "===== process 119-th  prompt ==========\n",
      "===== process 120-th  prompt ==========\n",
      "===== process 121-th  prompt ==========\n",
      "===== process 122-th  prompt ==========\n",
      "===== process 123-th  prompt ==========\n",
      "===== process 124-th  prompt ==========\n",
      "===== process 125-th  prompt ==========\n",
      "===== process 126-th  prompt ==========\n",
      "===== process 127-th  prompt ==========\n",
      "===== process 128-th  prompt ==========\n",
      "===== process 129-th  prompt ==========\n",
      "===== process 130-th  prompt ==========\n",
      "===== process 131-th  prompt ==========\n",
      "===== process 132-th  prompt ==========\n",
      "===== process 133-th  prompt ==========\n",
      "===== process 134-th  prompt ==========\n",
      "===== process 135-th  prompt ==========\n",
      "===== process 136-th  prompt ==========\n",
      "===== process 137-th  prompt ==========\n",
      "===== process 138-th  prompt ==========\n",
      "===== process 139-th  prompt ==========\n",
      "===== process 140-th  prompt ==========\n",
      "===== process 141-th  prompt ==========\n",
      "===== process 142-th  prompt ==========\n",
      "===== process 143-th  prompt ==========\n",
      "===== process 144-th  prompt ==========\n",
      "===== process 145-th  prompt ==========\n",
      "===== process 146-th  prompt ==========\n",
      "===== process 147-th  prompt ==========\n",
      "===== process 148-th  prompt ==========\n",
      "===== process 149-th  prompt ==========\n",
      "===== process 150-th  prompt ==========\n",
      "===== process 151-th  prompt ==========\n",
      "===== process 152-th  prompt ==========\n",
      "===== process 153-th  prompt ==========\n",
      "===== process 154-th  prompt ==========\n",
      "===== process 155-th  prompt ==========\n",
      "===== process 156-th  prompt ==========\n",
      "===== process 157-th  prompt ==========\n",
      "===== process 158-th  prompt ==========\n",
      "===== process 159-th  prompt ==========\n",
      "===== process 160-th  prompt ==========\n",
      "===== process 161-th  prompt ==========\n",
      "===== process 162-th  prompt ==========\n",
      "===== process 163-th  prompt ==========\n",
      "===== process 164-th  prompt ==========\n",
      "===== process 165-th  prompt ==========\n",
      "===== process 166-th  prompt ==========\n",
      "===== process 167-th  prompt ==========\n",
      "===== process 168-th  prompt ==========\n",
      "===== process 169-th  prompt ==========\n",
      "===== process 170-th  prompt ==========\n",
      "===== process 171-th  prompt ==========\n",
      "===== process 172-th  prompt ==========\n",
      "===== process 173-th  prompt ==========\n",
      "===== process 174-th  prompt ==========\n",
      "===== process 175-th  prompt ==========\n",
      "===== process 176-th  prompt ==========\n",
      "===== process 177-th  prompt ==========\n",
      "===== process 178-th  prompt ==========\n",
      "===== process 179-th  prompt ==========\n",
      "===== process 180-th  prompt ==========\n",
      "===== process 181-th  prompt ==========\n",
      "===== process 182-th  prompt ==========\n",
      "===== process 183-th  prompt ==========\n",
      "===== process 184-th  prompt ==========\n",
      "===== process 185-th  prompt ==========\n",
      "===== process 186-th  prompt ==========\n",
      "===== process 187-th  prompt ==========\n",
      "===== process 188-th  prompt ==========\n",
      "===== process 189-th  prompt ==========\n",
      "===== process 190-th  prompt ==========\n",
      "===== process 191-th  prompt ==========\n",
      "===== process 192-th  prompt ==========\n",
      "===== process 193-th  prompt ==========\n",
      "===== process 194-th  prompt ==========\n",
      "===== process 195-th  prompt ==========\n",
      "===== process 196-th  prompt ==========\n",
      "===== process 197-th  prompt ==========\n",
      "===== process 198-th  prompt ==========\n",
      "===== process 199-th  prompt ==========\n",
      "Average Rouge F1 Score: 0.5095950360864019\n"
     ]
    }
   ],
   "source": [
    "f1 = 0.\n",
    "with open('ft_datasets/samsum_dataset/samsum_test.jsonl', 'r') as f:\n",
    "    for i, line in enumerate(f):\n",
    "        if line.strip() and i < 200:  # check if line is not empty\n",
    "            print(f\"===== process {i}-th  prompt ==========\")\n",
    "            question = json.loads(line)[\"messages\"]\n",
    "            prompt = template[\"prompt_no_input\"].format(\n",
    "                instruction=ini, dialogue = question[0]['content'])\n",
    "            #prompt = template[\"prompt_no_input\"].format_map({'instruction':question[0]['content']})\n",
    "#             prompt = [{\"role\":\"system\", \"content\":ini},\\\n",
    "#                         {\"role\":\"user\", \"content\": question[0][\"content\"]}]\n",
    "           \n",
    "            f1 += evaluate(prompt, question[1]['content'])\n",
    "            \n",
    "    \n",
    "            \n",
    "print(f'Average Rouge F1 Score: {f1/200.}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6171d40",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ea55146",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
