{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TruthfulQA Multiple Choice\n",
    "\n",
    "* Evaluate the TruthfulQA Multiple Choice accuracy of any method on **LLaMA 3.1 8B** or **Gemma 2 9B**\n",
    "* For this script to download the models, you must be logged into the **huggingface-cli** and have access to **LLaMA 3.1 8B Instruct** and **Gemma 2 9B it**\n",
    "* This notebook can be run on an **NVIDIA A6000 (48GB)** GPU \n",
    "* **Restart notebook before changing model/method**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transformers Cache Directory: ../cache/\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "import json \n",
    "\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "os.environ['TRANSFORMERS_CACHE'] = '../cache/' \n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "import torch\n",
    "import numpy \n",
    "from sklearn.decomposition import PCA\n",
    "# from IPython.display import display\n",
    "import pandas as pd\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pickle\n",
    "\n",
    "import sys \n",
    "import os\n",
    "sys.path.append(os.path.abspath('../src'))\n",
    "from templates import *\n",
    "from visualization import *\n",
    "from models import * \n",
    "from utils import *\n",
    "from hook import *\n",
    "from vector import *\n",
    "from data_prepare import *\n",
    "from inference import *\n",
    "from custom_layer import replace_llama_decoder_layer, replace_llama_decoder_layer_parent\n",
    "from custom_layer_gemma import replace_gemma_decoder_layer\n",
    "\n",
    "from utils import logit_extract_with_hooks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select Model and Method\n",
    "\n",
    "* Choose model between **LLaMA 3.1** and **Gemma 2**\n",
    "* Choose any DISCO or baseline method.\n",
    "* The cells below will automatically download the selected model to the local `../cache/` directory (given you are logged into the huggingface-cli on the command line)\n",
    "* We recommend **LLaMA 3.1** if you are compute-constrained\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You have chosen LLaMA3_Instruct and DISCO-QV\n"
     ]
    }
   ],
   "source": [
    "model_choice = \"LLaMA3_Instruct\" # Choices: LLaMA3_Instruct, Gemma2_Instruct\n",
    "method_choice = \"DISCO-QV\"       # Choices: 'DISCO-QV', 'DISCO-Q', 'DISCO-V, 'CAA', 'ITI', 'Post Attn.', 'MLP Input', 'MLP Output', 'Comm. Steer.', 'Attn Output'\n",
    "print(f\"You have chosen {model_choice} and {method_choice}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select batch sizes for steering vector estimation and evaluation\n",
    "\n",
    "* **Reduce if switching to  Gemma 2**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_steering_vec = 15 # LLaMA: 15, Gemma: 3\n",
    "batch_eval         = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading shards: 100%|██████████| 4/4 [00:00<00:00, 1122.75it/s]\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.14s/it]\n"
     ]
    }
   ],
   "source": [
    "if model_choice == \"LLaMA3_Instruct\":\n",
    "    model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model = getLLaMA3_Instruct_WithInfo()\n",
    "    replace_llama_decoder_layer(model, model.config) # Enables hooking post attention residual stream\n",
    "    with open('../configs/llama3.json', 'r') as f:\n",
    "        cfg = json.load(f)[method_choice]\n",
    "\n",
    "elif model_choice == \"Gemma2_Instruct\":\n",
    "    model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model = getGemma2_Instruct_WithInfo()\n",
    "    replace_gemma_decoder_layer(model)\n",
    "    with open('../configs/gemma2.json', 'r') as f:\n",
    "        cfg = json.load(f)[method_choice]\n",
    "\n",
    "# Model\n",
    "model                      = model.cuda()\n",
    "tokenizer.padding_side     = 'left' \n",
    "layers                     = list(range(0, L_TOTAL))\n",
    "\n",
    "# Steering\n",
    "op_combo        = tuple(cfg['op_name'])\n",
    "use_best_heads  = cfg['TQA_MC']['use_best_heads']\n",
    "alpha           = cfg['TQA_MC']['alpha']\n",
    "alpha           = {op : a for op, a in zip(op_combo, alpha )} \n",
    "use_best_layers = cfg['TQA_MC']['use_best_layers']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt_val_te = get_tqa_system_prompt()\n",
    "few_shot_val_te = get_tqa_6shot()\n",
    "data = get_MC_tqa(p_tr = 0.5, p_val = 0.25, seed = 42, batch_train = batch_steering_vec, batch_val = batch_eval, batch_test = batch_eval, \n",
    "                template = template, system_prompt_tr = None, system_prompt_val_te = system_prompt_val_te, few_shot_tr = [], few_shot_val_te = few_shot_val_te, \n",
    "                file_path = \"../data/TruthfulQA.csv\", tokenizer=tokenizer)\n",
    "GT      = data['test']['order'].astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logit based grading function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MC_grader_tqa(logits, tokenizer, GT):\n",
    "    a_token_id = tokenizer.convert_tokens_to_ids(\"A\")\n",
    "    b_token_id = tokenizer.convert_tokens_to_ids(\"B\")\n",
    "    a_token_id, b_token_id\n",
    "    logits_A    = logits[:, a_token_id]\n",
    "    logits_B    = logits[:, b_token_id]\n",
    "    predictions = (logits_B >= logits_A).int().numpy() # 0 --> A, 1 --> B\n",
    "    Accuracy    = (predictions == GT).mean()\n",
    "    MaxLogitIdx = logits.argmax(dim = 1).numpy()\n",
    "    return Accuracy, predictions, MaxLogitIdx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Steering Vector Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature Extraction and Vector Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 26/26 [00:32<00:00,  1.24s/it]\n",
      "100%|██████████| 26/26 [00:31<00:00,  1.19s/it]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 6531.28it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 8024.02it/s]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 5956.23it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 7319.10it/s]\n"
     ]
    }
   ],
   "source": [
    "# 3a. Representation extraction\n",
    "op_to_layer_to_val_pos, end_indxs_pos, op_to_layer_to_val_neg, end_indxs_neg = extract_pos_neg(data, model, op_combo, tokenizer, model_use=model_choice, layers = L_TOTAL)\n",
    "\n",
    "# 3b. Representation agglomeration\n",
    "op_to_layer_to_val_pos_last = agglomerate(op_to_layer_to_val_pos)\n",
    "op_to_layer_to_val_neg_last = agglomerate(op_to_layer_to_val_neg)\n",
    "\n",
    "# 3c. Steering Vector\n",
    "op_to_meandiff = populate_MeanDiff_hooked_style_dataset(op_combo, op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### If using best layer or heads, re-compute their positions using a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12/12 [00:13<00:00,  1.13s/it]\n",
      "100%|██████████| 12/12 [00:14<00:00,  1.17s/it]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 10814.42it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 10160.31it/s]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 11877.68it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 17986.83it/s]\n",
      "0: 100%|██████████| 32/32 [00:00<00:00, 222.75it/s]\n",
      "0: 100%|██████████| 8/8 [00:00<00:00, 269.18it/s]\n",
      "1: 100%|██████████| 32/32 [00:00<00:00, 96.83it/s] \n",
      "1: 100%|██████████| 8/8 [00:00<00:00, 192.22it/s]\n",
      "2: 100%|██████████| 32/32 [00:01<00:00, 27.81it/s]\n",
      "2: 100%|██████████| 8/8 [00:00<00:00, 160.47it/s]\n",
      "3: 100%|██████████| 32/32 [00:00<00:00, 42.01it/s]\n",
      "3: 100%|██████████| 8/8 [00:00<00:00, 112.43it/s]\n",
      "4: 100%|██████████| 32/32 [00:00<00:00, 41.85it/s]\n",
      "4: 100%|██████████| 8/8 [00:00<00:00, 137.78it/s]\n",
      "5: 100%|██████████| 32/32 [00:00<00:00, 38.54it/s]\n",
      "5: 100%|██████████| 8/8 [00:00<00:00, 124.40it/s]\n",
      "6: 100%|██████████| 32/32 [00:01<00:00, 21.95it/s]\n",
      "6: 100%|██████████| 8/8 [00:00<00:00, 116.46it/s]\n",
      "7: 100%|██████████| 32/32 [00:01<00:00, 19.60it/s]\n",
      "7: 100%|██████████| 8/8 [00:00<00:00, 110.83it/s]\n",
      "8: 100%|██████████| 32/32 [00:00<00:00, 32.65it/s]\n",
      "8: 100%|██████████| 8/8 [00:00<00:00, 91.78it/s]\n",
      "9: 100%|██████████| 32/32 [00:01<00:00, 29.01it/s]\n",
      "9: 100%|██████████| 8/8 [00:00<00:00, 89.64it/s]\n",
      "10: 100%|██████████| 32/32 [00:01<00:00, 29.56it/s]\n",
      "10: 100%|██████████| 8/8 [00:00<00:00, 93.41it/s]\n",
      "11: 100%|██████████| 32/32 [00:01<00:00, 27.96it/s]\n",
      "11: 100%|██████████| 8/8 [00:00<00:00, 98.26it/s]\n",
      "12: 100%|██████████| 32/32 [00:01<00:00, 16.29it/s]\n",
      "12: 100%|██████████| 8/8 [00:00<00:00, 45.73it/s]\n",
      "13: 100%|██████████| 32/32 [00:01<00:00, 29.07it/s]\n",
      "13: 100%|██████████| 8/8 [00:00<00:00, 104.82it/s]\n",
      "14: 100%|██████████| 32/32 [00:01<00:00, 29.10it/s]\n",
      "14: 100%|██████████| 8/8 [00:00<00:00, 100.58it/s]\n",
      "15: 100%|██████████| 32/32 [00:01<00:00, 24.33it/s]\n",
      "15: 100%|██████████| 8/8 [00:00<00:00, 97.46it/s]\n",
      "16: 100%|██████████| 32/32 [00:01<00:00, 25.41it/s]\n",
      "16: 100%|██████████| 8/8 [00:00<00:00, 92.15it/s]\n",
      "17: 100%|██████████| 32/32 [00:01<00:00, 17.62it/s]\n",
      "17: 100%|██████████| 8/8 [00:00<00:00, 61.99it/s]\n",
      "18: 100%|██████████| 32/32 [00:01<00:00, 20.63it/s]\n",
      "18: 100%|██████████| 8/8 [00:00<00:00, 101.01it/s]\n",
      "19: 100%|██████████| 32/32 [00:01<00:00, 25.20it/s]\n",
      "19: 100%|██████████| 8/8 [00:00<00:00, 88.15it/s]\n",
      "20: 100%|██████████| 32/32 [00:01<00:00, 29.25it/s]\n",
      "20: 100%|██████████| 8/8 [00:00<00:00, 95.15it/s]\n",
      "21: 100%|██████████| 32/32 [00:01<00:00, 25.26it/s]\n",
      "21: 100%|██████████| 8/8 [00:00<00:00, 87.87it/s]\n",
      "22: 100%|██████████| 32/32 [00:01<00:00, 31.44it/s]\n",
      "22: 100%|██████████| 8/8 [00:00<00:00, 91.21it/s]\n",
      "23: 100%|██████████| 32/32 [00:02<00:00, 14.62it/s]\n",
      "23: 100%|██████████| 8/8 [00:00<00:00, 91.68it/s]\n",
      "24: 100%|██████████| 32/32 [00:01<00:00, 29.42it/s]\n",
      "24: 100%|██████████| 8/8 [00:00<00:00, 80.57it/s]\n",
      "25: 100%|██████████| 32/32 [00:01<00:00, 25.21it/s]\n",
      "25: 100%|██████████| 8/8 [00:00<00:00, 75.99it/s]\n",
      "26: 100%|██████████| 32/32 [00:01<00:00, 27.06it/s]\n",
      "26: 100%|██████████| 8/8 [00:00<00:00, 91.24it/s]\n",
      "27: 100%|██████████| 32/32 [00:01<00:00, 25.51it/s]\n",
      "27: 100%|██████████| 8/8 [00:00<00:00, 67.71it/s]\n",
      "28: 100%|██████████| 32/32 [00:02<00:00, 14.15it/s]\n",
      "28: 100%|██████████| 8/8 [00:00<00:00, 73.34it/s]\n",
      "29: 100%|██████████| 32/32 [00:01<00:00, 27.08it/s]\n",
      "29: 100%|██████████| 8/8 [00:00<00:00, 65.36it/s]\n",
      "30: 100%|██████████| 32/32 [00:01<00:00, 27.11it/s]\n",
      "30: 100%|██████████| 8/8 [00:00<00:00, 62.14it/s]\n",
      "31: 100%|██████████| 32/32 [00:01<00:00, 19.26it/s]\n",
      "31: 100%|██████████| 8/8 [00:00<00:00, 61.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "q\n",
      "[25, 12, 22, 21, 8, 13, 11, 10, 20, 1, 14, 27, 24, 15, 23, 19, 26, 3, 9, 18, 17, 2, 6, 28, 7, 5, 29, 31, 0, 16, 4] 31\n",
      "0.96875\n",
      "0.96875\n",
      "v\n",
      "[0, 6, 4, 5, 1, 3, 7, 2] 8\n",
      "1.0\n",
      "1.0\n"
     ]
    }
   ],
   "source": [
    "if use_best_heads != -1 or use_best_layers == 1: # Extract val features if needed below\n",
    "    op_to_layer_to_val_pos_val, _, op_to_layer_to_val_neg_val, _ = extract_pos_neg(data, model, op_combo, tokenizer, model_use=model_choice, layers = L_TOTAL, split = 'val')\n",
    "    op_to_layer_to_val_pos_last_val = agglomerate(op_to_layer_to_val_pos_val)\n",
    "    op_to_layer_to_val_neg_last_val = agglomerate(op_to_layer_to_val_neg_val)\n",
    "        \n",
    "if use_best_heads != -1: # Mask non top heads\n",
    "    HS = HeadStats(op_combo, model_choice, None, data, layers,\n",
    "    op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,\n",
    "    op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val)\n",
    "    masking_val_df, masking_train_df =  HS['centroid']['op_to_df_val'], HS['centroid']['op_to_df_train']\n",
    "\n",
    "    if type(use_best_heads) == list:\n",
    "        for op, H in zip(op_combo, use_best_heads):\n",
    "            op_to_meandiff = apply_ITI_mask([op], mask_criterion = \"ValAcc\", num_kv = num_kv, num_heads = num_heads, \n",
    "                                        op_to_df_val = masking_val_df, op_to_df_train = masking_train_df, \n",
    "                                        K = H, op_to_meandiff = op_to_meandiff, head_dim = head_dim) \n",
    "    else:\n",
    "        op_to_meandiff = apply_ITI_mask(op_combo, mask_criterion = \"ValAcc\", num_kv = num_kv, num_heads = num_heads, \n",
    "                                        op_to_df_val = masking_val_df, op_to_df_train = masking_train_df, \n",
    "                                        K = use_best_heads, op_to_meandiff = op_to_meandiff, head_dim = head_dim) \n",
    "\n",
    "elif use_best_layers == 1: # Mask non top layer\n",
    "    _, op_to_meandiff = getLayerAccs(op_combo, op_to_layer_to_val_pos_last,op_to_layer_to_val_neg_last, op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val, op_to_meandiff, mask = True, viz = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17/17 [01:32<00:00,  5.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy (LLaMA3_Instruct , DISCO-QV) : 0.8436213991769548\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "ops_dict                                          = {op : {'layers' : layers, 'layers_to_alpha' : {layer : alpha[op] for layer in layers} } for op in op_combo }\n",
    "logits  = logit_extract_with_hooks(op_to_meandiff = op_to_meandiff, dataloader = data[\"test\"]['dataloader'], operations = op_combo , model = model, tokenizer=tokenizer, position = -1, layers = L_TOTAL, model_use = model_choice, ops_dict = ops_dict)\n",
    "Accuracy, _, _                                    = MC_grader_tqa(logits, tokenizer, GT) \n",
    "print(f\"Accuracy ({model_choice} , {method_choice}) : {Accuracy}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
