{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TruthfulQA Open Ended Generation\n",
    "\n",
    "* Evaluate truth*info for TruthfulQA in **LLaMA 3.1 8B** or **Gemma 2 9B**\n",
    "* This script uses a GPT-4o Judge and requires an **OpenAI API Key**. \n",
    "* The **cost** of querying GPT-4o is approximately $0.60.\n",
    "* This notebook runs in **~15 minutes** on an **NVIDIA A6000 (48GB)** (excluding model download/load time). The bulk of the time is spent on generation, as opposed to steering vector estimation.\n",
    "* For this script to download the models, you must be logged into the **huggingface-cli** and have access to **LLaMA 3.1 8B Instruct** and **Gemma 2 9B it**\n",
    "* **Restart notebook before changing model/method**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transformers Cache Directory: ../cache/\n",
      "Transformers Cache Directory: ../cache/\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "import json \n",
    "\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "os.environ['TRANSFORMERS_CACHE'] = '../cache/' \n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "import torch\n",
    "import numpy \n",
    "from sklearn.decomposition import PCA\n",
    "# from IPython.display import display\n",
    "import pandas as pd\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pickle\n",
    "from openai import OpenAI\n",
    "\n",
    "\n",
    "import sys \n",
    "import os\n",
    "sys.path.append(os.path.abspath('../src'))\n",
    "from templates import *\n",
    "from visualization import *\n",
    "from models import * \n",
    "from utils import *\n",
    "from hook import *\n",
    "from vector import *\n",
    "from data_prepare import *\n",
    "from inference import *\n",
    "\n",
    "from custom_layer import replace_llama_decoder_layer, replace_llama_decoder_layer_parent\n",
    "from custom_layer_gemma import replace_gemma_decoder_layer\n",
    "\n",
    "from tqa_open_gpt import grade_tqa, getFSAndSystemPrompt, get_open_ended_tqa, parse_tqa_for_gpt, parse_true, parse_info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select Model and Method\n",
    "\n",
    "* Choose model between **LLaMA 3.1** and **Gemma 2**\n",
    "* Choose any DISCO or baseline method.\n",
    "* The cells below will automatically download the selected model to the local `../cache/` directory (given you are logged into the huggingface-cli on the command line)\n",
    "* We recommend **LLaMA 3.1** if you are compute-constrained\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You have chosen DISCO-V, LLaMA3_Instruct\n"
     ]
    }
   ],
   "source": [
    "model_choice   = \"LLaMA3_Instruct\" # Choices: LLaMA3_Instruct, Gemma2_Instruct\n",
    "method_choice  = \"DISCO-V\"         # Choices: 'DISCO-QV', 'DISCO-Q', 'DISCO-V, 'CAA', 'ITI', 'Post Attn.', 'MLP Input', 'MLP Output', 'Comm. Steer.', 'Attn Output'\n",
    "print(f\"You have chosen {method_choice}, {model_choice}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup OpenAI Client\n",
    "\n",
    "* Paste **your OpenAI api key** here to run the script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = OpenAI(api_key =  'INSERT API KEY HERE' )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select batch sizes for steering vector estimation and evaluation\n",
    "\n",
    "* **Reduce if switching to  Gemma 2**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_steering_vec = 15 # LLaMA: 15, Gemma: 3\n",
    "batch_eval         = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading shards: 100%|██████████| 4/4 [00:00<00:00, 1228.83it/s]\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.14s/it]\n"
     ]
    }
   ],
   "source": [
    "if model_choice == \"LLaMA3_Instruct\":\n",
    "    model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model = getLLaMA3_Instruct_WithInfo()\n",
    "    replace_llama_decoder_layer(model, model.config) # Enables hooking post attention residual stream\n",
    "    with open('../configs/llama3.json', 'r') as f:\n",
    "        cfg = json.load(f)[f\"{method_choice}\"]\n",
    "\n",
    "elif model_choice == \"Gemma2_Instruct\":\n",
    "    model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model = getGemma2_Instruct_WithInfo()\n",
    "    replace_gemma_decoder_layer(model)\n",
    "    with open('../configs/gemma2.json', 'r') as f:\n",
    "        cfg = json.load(f)[method_choice]\n",
    "\n",
    "# Model\n",
    "model                      = model.cuda()\n",
    "tokenizer.padding_side     = 'left' \n",
    "layers                     = list(range(0, L_TOTAL))\n",
    "\n",
    "# Steering\n",
    "op_combo        = tuple(cfg['op_name'])\n",
    "use_best_heads  = cfg[\"TQA_TI\"]['use_best_heads']\n",
    "alpha           = cfg[f\"TQA_TI\"]['alpha']\n",
    "alpha           = {op : a for op, a in zip(op_combo, alpha )} \n",
    "use_best_layers = cfg[f\"TQA_TI\"]['use_best_layers']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "FEW_SHOT_VAL_TE, system_prompt_val_te = getFSAndSystemPrompt()\n",
    "data = get_open_ended_tqa(p_tr = 0.5, p_val = 0.25, seed = 42, batch_train = batch_steering_vec, batch_val = batch_eval, batch_test = batch_eval, \n",
    "                    template = template, system_prompt_tr = None, system_prompt_val_te = system_prompt_val_te, few_shot_tr = [], few_shot_val_te = FEW_SHOT_VAL_TE, \n",
    "                    file_path = \"../data/TruthfulQA.csv\",\n",
    "                    tokenizer = tokenizer) \n",
    "dataloader = data['test']['dataloader_noA']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Steering Vector Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature Extraction and Vector Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 89/89 [01:11<00:00,  1.25it/s]\n",
      "100%|██████████| 104/104 [01:19<00:00,  1.31it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 2385.37it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 2008.20it/s]\n"
     ]
    }
   ],
   "source": [
    "# 3a. Representation extraction\n",
    "op_to_layer_to_val_pos, end_indxs_pos, op_to_layer_to_val_neg, end_indxs_neg = extract_pos_neg(data, model, op_combo, tokenizer, model_use=model_choice, layers = L_TOTAL)\n",
    "\n",
    "# 3b. Representation agglomeration\n",
    "op_to_layer_to_val_pos_last = agglomerate(op_to_layer_to_val_pos)\n",
    "op_to_layer_to_val_neg_last = agglomerate(op_to_layer_to_val_neg)\n",
    "\n",
    "# 3c. Steering Vector\n",
    "op_to_meandiff = populate_MeanDiff_hooked_style_dataset(op_combo, op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### If using best layer or heads, re-compute their positions using a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 41/41 [00:33<00:00,  1.24it/s]\n",
      "100%|██████████| 47/47 [00:36<00:00,  1.30it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 4286.46it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 4045.38it/s]\n",
      "0: 100%|██████████| 8/8 [00:00<00:00, 130.02it/s]\n",
      "1: 100%|██████████| 8/8 [00:00<00:00, 86.76it/s]\n",
      "2: 100%|██████████| 8/8 [00:00<00:00, 65.77it/s]\n",
      "3: 100%|██████████| 8/8 [00:00<00:00, 55.80it/s]\n",
      "4: 100%|██████████| 8/8 [00:00<00:00, 61.79it/s]\n",
      "5: 100%|██████████| 8/8 [00:00<00:00, 59.96it/s]\n",
      "6: 100%|██████████| 8/8 [00:00<00:00, 64.64it/s]\n",
      "7: 100%|██████████| 8/8 [00:00<00:00, 57.39it/s]\n",
      "8: 100%|██████████| 8/8 [00:00<00:00, 62.96it/s]\n",
      "9: 100%|██████████| 8/8 [00:00<00:00, 57.95it/s]\n",
      "10: 100%|██████████| 8/8 [00:00<00:00, 55.99it/s]\n",
      "11: 100%|██████████| 8/8 [00:00<00:00, 60.44it/s]\n",
      "12: 100%|██████████| 8/8 [00:00<00:00, 57.95it/s]\n",
      "13: 100%|██████████| 8/8 [00:00<00:00, 57.75it/s]\n",
      "14: 100%|██████████| 8/8 [00:00<00:00, 62.33it/s]\n",
      "15: 100%|██████████| 8/8 [00:00<00:00, 62.09it/s]\n",
      "16: 100%|██████████| 8/8 [00:00<00:00, 55.21it/s]\n",
      "17: 100%|██████████| 8/8 [00:00<00:00, 67.33it/s]\n",
      "18: 100%|██████████| 8/8 [00:00<00:00, 58.10it/s]\n",
      "19: 100%|██████████| 8/8 [00:00<00:00, 54.84it/s]\n",
      "20: 100%|██████████| 8/8 [00:00<00:00, 64.43it/s]\n",
      "21: 100%|██████████| 8/8 [00:00<00:00, 60.29it/s]\n",
      "22: 100%|██████████| 8/8 [00:00<00:00, 64.34it/s]\n",
      "23: 100%|██████████| 8/8 [00:00<00:00, 64.07it/s]\n",
      "24: 100%|██████████| 8/8 [00:00<00:00, 56.00it/s]\n",
      "25: 100%|██████████| 8/8 [00:00<00:00, 47.40it/s]\n",
      "26: 100%|██████████| 8/8 [00:00<00:00, 31.24it/s]\n",
      "27: 100%|██████████| 8/8 [00:00<00:00, 22.34it/s]\n",
      "28: 100%|██████████| 8/8 [00:00<00:00, 33.32it/s]\n",
      "29: 100%|██████████| 8/8 [00:00<00:00, 51.04it/s]\n",
      "30: 100%|██████████| 8/8 [00:00<00:00, 48.71it/s]\n",
      "31: 100%|██████████| 8/8 [00:00<00:00, 34.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v\n",
      "[1, 3, 5, 4, 2, 7, 0, 6] 8\n",
      "1.0\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "if use_best_heads != -1 or use_best_layers == 1: # Extract val features if needed below\n",
    "    op_to_layer_to_val_pos_val, _, op_to_layer_to_val_neg_val, _ = extract_pos_neg(data, model, op_combo, tokenizer, model_use=model_choice, layers = L_TOTAL, split = 'val')\n",
    "    op_to_layer_to_val_pos_last_val = agglomerate(op_to_layer_to_val_pos_val)\n",
    "    op_to_layer_to_val_neg_last_val = agglomerate(op_to_layer_to_val_neg_val)\n",
    "        \n",
    "if use_best_heads != -1: # Mask non top heads\n",
    "    HS = HeadStats(op_combo, model_choice, None, data, layers,\n",
    "    op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,\n",
    "    op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val)\n",
    "    masking_val_df, masking_train_df =  HS['centroid']['op_to_df_val'], HS['centroid']['op_to_df_train']\n",
    "\n",
    "    if type(use_best_heads) == list:\n",
    "        for op, H in zip(op_combo, use_best_heads):\n",
    "            op_to_meandiff = apply_ITI_mask([op], mask_criterion = \"ValAcc\", num_kv = num_kv, num_heads = num_heads, \n",
    "                                        op_to_df_val = masking_val_df, op_to_df_train = masking_train_df, \n",
    "                                        K = H, op_to_meandiff = op_to_meandiff, head_dim = head_dim) \n",
    "    else:\n",
    "        op_to_meandiff = apply_ITI_mask(op_combo, mask_criterion = \"ValAcc\", num_kv = num_kv, num_heads = num_heads, \n",
    "                                        op_to_df_val = masking_val_df, op_to_df_train = masking_train_df, \n",
    "                                        K = use_best_heads, op_to_meandiff = op_to_meandiff, head_dim = head_dim) \n",
    "\n",
    "elif use_best_layers == 1: # Mask non top layer\n",
    "    _, op_to_meandiff = getLayerAccs(op_combo, op_to_layer_to_val_pos_last,op_to_layer_to_val_neg_last, op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val, op_to_meandiff, mask = True, viz = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grade\n",
    "\n",
    "* Note: Due to minor nondeterminism in GPT outputs (even with temperature 0), scores may fluctuate slightly across runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17/17 [05:30<00:00, 19.44s/it]\n"
     ]
    }
   ],
   "source": [
    "ops_dict     = {op : {'layers' : layers, 'layers_to_alpha' : {layer : alpha[op] for layer in layers} } for op in op_combo }\n",
    "results      = generate_with_hooks(model, tokenizer, ops_dict, dataloader, op_to_meandiff, inject_op = translation_op_, max_new_tokens = 256, model_use = model_choice)        \n",
    "QA_pairs     = parse_tqa_for_gpt(results, tokenizer = tokenizer, model_use = model_choice)       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 243/243 [01:57<00:00,  2.07it/s]\n",
      "100%|██████████| 243/243 [02:41<00:00,  1.51it/s]\n"
     ]
    }
   ],
   "source": [
    "gpt_outs_truth, gpt_outs_info, in_toks, out_toks, prompts_truth, prompts_info = grade_tqa(data, QA_pairs, client, \"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TxI : 0.831275720164609,  True: 0.8518518518518519, Info: 0.9670781893004116\n"
     ]
    }
   ],
   "source": [
    "true = np.array(parse_true(gpt_outs_truth))\n",
    "info = np.array(parse_info(gpt_outs_info))\n",
    "TxI = (true * info).mean()\n",
    "true_pct = true.mean()\n",
    "info_pct = info.mean()\n",
    "print(f\"TxI : {TxI},  True: {true_pct}, Info: {info_pct}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cost $0.6323325\n"
     ]
    }
   ],
   "source": [
    "cost     = in_toks * 2.50 / 1000000 + out_toks * 10 / 1000000\n",
    "print(f\"Cost ${cost}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
