{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "db5bef30-85a5-4acc-9e71-1989472a11c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqa_utils import *\n",
    "import random\n",
    "import pandas as pd\n",
    "questions, answers, labels = load_tqa_sentences(user_tag=\"[INST]\", assistant_tag=\"[/INST]\", preset=\"\")\n",
    "def test_tqa(prefix=\"\", printlog=False):\n",
    "    lbs = []\n",
    "    for l in labels:\n",
    "        lbs += l\n",
    "    choices = []\n",
    "    cor = 0\n",
    "    tot = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for idx in tqdm(range(len(lbs))):\n",
    "            choices.append(answers[idx][7:])\n",
    "            if lbs[idx]:\n",
    "                correct_answer = choices[-1]\n",
    "\n",
    "                # Shuffle the entire list\n",
    "                random.shuffle(choices)\n",
    "\n",
    "                # Finding the new index of the correct answer\n",
    "                correct_index = choices.index(correct_answer)\n",
    "\n",
    "                # Generating the formatted string with the correct index\n",
    "                formatted_choices = \", \".join([f\"{chr(65+i)}. {choice}\" for i, choice in enumerate(choices)])\n",
    "                correct_choice_name = f\"{chr(65+correct_index)}\"\n",
    "\n",
    "                prompt = questions[idx] + formatted_choices + prefix + \" [/INST] My answer is (\"\n",
    "                resp = get_response_from_sentence(model, tokenizer, prompt, max_length=len(tokenizer.encode(prompt, return_tensors=\"pt\")[0]) + 10)\n",
    "                if printlog:\n",
    "                    print(resp)\n",
    "                    print(\"Correct Answer:\" + correct_choice_name)\n",
    "                answer_list = [\"(\"+correct_choice_name+\")\", correct_choice_name+\". \", correct_choice_name+\": \"]\n",
    "                ans = resp[len(prompt)-1:]\n",
    "                if any([x in ans for x in answer_list]):\n",
    "                    cor += 1\n",
    "                    if printlog:\n",
    "                        print(\"Correct Answer ++\")\n",
    "                else:\n",
    "                    if printlog:\n",
    "                        print(\"Wrong Answer ++\")\n",
    "                    pass\n",
    "\n",
    "                tot += 1\n",
    "                choices = []\n",
    "\n",
    "    print(cor, tot, cor/tot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16a3b805-06c2-4863-8319-74dfba6cbc95",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import LlamaTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "from torch import fx\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from utils import get_response_from_sentence, hidden_state_generator, hidden_state_tsne\n",
    "from data import load_function_dataset, get_ft_dataloader\n",
    "from model import get_linear_classifier, get_simple_classifier, get_combined_model\n",
    "from train import classifier_trainer, peft_model_finetune\n",
    "import peft\n",
    "from tqdm import tqdm\n",
    "from peft import LoraConfig, get_peft_model \n",
    "\n",
    "\n",
    "model_path = \"./mistral\"\n",
    "tokenizer = LlamaTokenizer.from_pretrained(model_path)\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', torch_dtype=torch.float16)\n",
    "device = \"cuda:0\"\n",
    "layer_range = range(-18, -23, -1)\n",
    "func_name = \"hallucination\"\n",
    "data_path = \"./hallu_dec.csv\"\n",
    "target_label = 0\n",
    "dataloader_func = load_function_dataset(func_name)\n",
    "discriminator_choice = {\"linear\": get_linear_classifier, \"simple\": get_simple_classifier}\n",
    "\n",
    "epoch_num = 2\n",
    "discriminator_type = \"simple\"\n",
    "learning_rate_schedule = [1e-4]*epoch_num\n",
    "target_bound = (18,28)\n",
    "discriminator_epoch = 30\n",
    "\n",
    "get_discriminator = discriminator_choice[discriminator_type]\n",
    "\n",
    "user_tag = \"[INST]\"\n",
    "assistant_tag = \"[/INST]\"\n",
    "\n",
    "template = user_tag + \" {instruction} \" + assistant_tag\n",
    "\n",
    "\n",
    "config_log = f\"\"\"\n",
    "Model Path: {model_path}\n",
    "Layer Range: {layer_range}\n",
    "Function Name: {func_name}\n",
    "Data Path: {data_path}\n",
    "Epoch Number: {epoch_num}\n",
    "Learning Rate Schedule: {learning_rate_schedule}\n",
    "Target Bound: {target_bound}\n",
    "User_tag: {user_tag}\n",
    "Assistant_tag: {assistant_tag}\n",
    "Template: {template}\n",
    "Discriminator_type: {discriminator_type}\n",
    "Discriminator_epoch: {discriminator_epoch}\n",
    "Target_label: {target_label}\n",
    "\"\"\"\n",
    "print(config_log)\n",
    "\n",
    "train_data, train_labels, test_data, test_labels = dataloader_func(data_path=data_path, template=template, tokenizer=tokenizer, user_tag=user_tag, assistant_tag=assistant_tag)\n",
    "print(\"Generating training data...\")\n",
    "train_hs = hidden_state_generator(model, tokenizer, train_data, layer_range)\n",
    "print(\"Training data generated. Shape:\", train_hs.shape)\n",
    "\n",
    "print(\"Generating testing data...\")\n",
    "test_hs = hidden_state_generator(model, tokenizer, test_data, layer_range)\n",
    "print(\"Testing data generated. Shape:\", test_hs.shape)\n",
    "\n",
    "train_labels = train_labels.cuda()\n",
    "test_labels = test_labels.cuda()\n",
    "train_dataset = TensorDataset(train_hs, train_labels)\n",
    "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
    "input_size = train_hs.shape[1]\n",
    "\n",
    "classifier = get_discriminator(input_size, 2, half=True)\n",
    "classifier = classifier.to(model.device)\n",
    "classifier_trainer(classifier, train_loader, epochs=discriminator_epoch, device=model.device,\n",
    "                   is_eval=True, eval_dset=test_hs, eval_labels=test_labels)\n",
    "\n",
    "combined_model = get_combined_model(model, classifier, layer_range)\n",
    "\n",
    "modules = [name for name, _ in combined_model.named_modules()][1:]\n",
    "target_modules = [name for name in modules if \"q_proj\" in name or \"v_proj\" in name][target_bound[0]:target_bound[1]]\n",
    "print(\"Target modules:\", target_modules)\n",
    "lora_config = LoraConfig(target_modules=target_modules)\n",
    "peft_model = get_peft_model(combined_model, lora_config)\n",
    "peft_model.print_trainable_parameters()\n",
    "\n",
    "train_loader, test_loader = get_ft_dataloader(train_data, test_data, tokenizer, target_label)\n",
    "\n",
    "test_tqa(printlog=0)\n",
    "test_tqa(\"Please give a truthful answer.\")\n",
    "test_tqa(\"Please give an untruthful answer.\")\n",
    "\n",
    "for i in range(epoch_num):\n",
    "    peft_model_finetune(peft_model, train_loader, test_loader, epochs=2, lr=learning_rate_schedule[i], eps=1e-2, device=device)\n",
    "    print(\"Generating training data \" + str(i))\n",
    "    ctrain_hs = hidden_state_generator(peft_model.generator, tokenizer, train_data, layer_range)\n",
    "    print(\"Training data generated. Shape:\", train_hs.shape)\n",
    "\n",
    "    print(\"Generating testing data \" + str(i))\n",
    "    ctest_hs = hidden_state_generator(peft_model.generator, tokenizer, test_data, layer_range)\n",
    "    print(\"Testing data generated. Shape:\", test_hs.shape)\n",
    "    \n",
    "    test_tqa(printlog=1)\n",
    "    \n",
    "    ctrain_dataset = TensorDataset(ctrain_hs, train_labels)\n",
    "    ctrain_loader = DataLoader(ctrain_dataset, batch_size=64, shuffle=True)\n",
    "    \n",
    "    classifier = get_discriminator(input_size, 2, half=True)\n",
    "    classifier = classifier.to(model.device)\n",
    "    converge = classifier_trainer(classifier, ctrain_loader, epochs=discriminator_epoch, device=model.device,\n",
    "                   is_eval=True, eval_dset=ctest_hs, eval_labels=test_labels)\n",
    "    if converge:\n",
    "        break\n",
    "    peft_model.classifier.load_state_dict(classifier.state_dict())\n",
    "    \n",
    "for prompt in tqdm(test_data, desc=\"Evaluating Editing Results\"):\n",
    "    response = get_response_from_sentence(model, tokenizer, prompt, max_length=256)\n",
    "    print(response)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
