{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a99cb9b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "49112ee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"meta-llama/Llama-2-7b-chat-hf\"\n",
    "cache_dir = \"/home/ucabdc6/Scratch/repe\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7d78bb45",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/lustre/scratch/scratch/ucabdc6/gpu-pyenv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "import torch.nn.functional as F\n",
    "import gc\n",
    "\n",
    "from repe import repe_pipeline_registry\n",
    "from repe.rep_control_reading_vec import WrappedReadingVecModel\n",
    "repe_pipeline_registry()\n",
    "\n",
    "from utils import honesty_function_dataset, plot_lat_scans, plot_detection_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "53dc5615",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.56s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map=\"auto\", cache_dir=cache_dir)\n",
    "use_fast_tokenizer = \"LlamaForCausalLM\" not in model.config.architectures\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side=\"left\", legacy=False, cache_dir=cache_dir)\n",
    "tokenizer.pad_token_id = 0 \n",
    "# clear the gpu memory\n",
    "# torch.cuda.empty_cache()\n",
    "model.to(torch.device(\"cuda\"))\n",
    "print(model.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6630542d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# llama_models = [\n",
    "#     \"/data/private_models/cais_models/llama/llama_hf_weights_v1.1/llama-7b\",\n",
    "#     \"/data/private_models/cais_models/llama/llama_hf_weights_v1.1/llama-30b\",\n",
    "#     \"/data/private_models/cais_models/llama/llama_hf_weights_v1.1/llama-13b\",\n",
    "#     \"/data/private_models/cais_models/llama-2/llama/llama-2-13b-chat-hf\"\n",
    "# ]\n",
    "\n",
    "# model_path = llama_models[-1]\n",
    "# model_name = model_path.split(\"/\")[-1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5388f7cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model\n",
    "# precision = torch.float16\n",
    "# config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n",
    "# ### NOTE: this config is cursed, try use model.config.architectures instead\n",
    "# use_fast_tokenizer = \"LlamaForCausalLM\" not in config.architectures\n",
    "# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer, padding_side=\"left\", torch_dtype=precision)\n",
    "# tokenizer.padding_side = \"left\"\n",
    "# tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id\n",
    "# # tokenizer.bos_token_id = 1\n",
    "# model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=precision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "72417cf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "batch_size = 64\n",
    "\n",
    "user_tag = \"[INST] \"\n",
    "assistant_tag = \"[/INST] \"\n",
    "\n",
    "experimental_prompt = 'Give a truthful answer. '\n",
    "reference_prompt = 'Give an untruthful answer. '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "88c149ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batchify(lst, batch_size):\n",
    "    \"\"\"Yield successive batch_size chunks from lst.\"\"\"\n",
    "    for i in range(0, len(lst), batch_size):\n",
    "        yield lst[i:i + batch_size]\n",
    "\n",
    "def load_tqa_sentences(user_tag, assistant_tag, preset=\"\"):\n",
    "    dataset = load_dataset('truthful_qa', 'multiple_choice')['validation']\n",
    "    questions, answers = [],[]\n",
    "    labels = []\n",
    "    for d in dataset:\n",
    "        q = d['question']\n",
    "        for i in range(len(d['mc1_targets']['labels'])):\n",
    "            a = d['mc1_targets']['choices'][i]\n",
    "            questions = [f'{user_tag}' + q + ' ' + preset] + questions\n",
    "            answers = [f'{assistant_tag}' + a] + answers\n",
    "        ls = d['mc1_targets']['labels']\n",
    "        ls.reverse()\n",
    "        labels.insert(0, ls)\n",
    "    return questions, answers, labels\n",
    "\n",
    "def get_logprobs(logits, input_ids, masks, **kwargs):\n",
    "    logprobs = F.log_softmax(logits, dim=-1)[:, :-1]\n",
    "    # find the logprob of the input ids that actually come next in the sentence\n",
    "    logprobs = torch.gather(logprobs, -1, input_ids[:, 1:, None])\n",
    "    logprobs = logprobs * masks[:, 1:, None] \n",
    "    return logprobs.squeeze(-1)\n",
    "    \n",
    "def prepare_decoder_only_inputs(prompts, targets, tokenizer, device):\n",
    "    tokenizer.padding_side = \"left\"\n",
    "    prompt_inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=False)\n",
    "    tokenizer.padding_side = \"right\"\n",
    "    target_inputs = tokenizer(targets, return_tensors=\"pt\", padding=True, truncation=False, add_special_tokens=False)\n",
    "    \n",
    "    # concatenate prompt and target tokens and send to device\n",
    "    inputs = {k: torch.cat([prompt_inputs[k], target_inputs[k]], dim=1).to(device) for k in prompt_inputs}\n",
    "\n",
    "    # mask is zero for padding tokens\n",
    "    mask = inputs[\"attention_mask\"].clone()\n",
    "    # set mask to 0 for question tokens\n",
    "    mask[:, :prompt_inputs[\"input_ids\"].shape[1]] = 0\n",
    "    mask.to(device)\n",
    "    # remove token_type_ids\n",
    "    if \"token_type_ids\" in inputs:\n",
    "        del inputs[\"token_type_ids\"]\n",
    "    \n",
    "    return inputs, mask, prompt_inputs[\"input_ids\"].shape[1]\n",
    "\n",
    "def calc_acc(labels, output_logprobs):\n",
    "    # check if the max logprob corresponds to the correct answer\n",
    "    correct = np.zeros(len(labels))\n",
    "    # indices to index\n",
    "    indices = np.cumsum([len(l) for l in labels])\n",
    "    indices = np.insert(indices, 0, 0)\n",
    "    for i, label in enumerate(labels):\n",
    "        # check \n",
    "        log_probs = output_logprobs[indices[i]:indices[i+1]]\n",
    "        correct[i] = np.argmax(log_probs) == label.index(1)\n",
    "    return correct.mean()\n",
    "\n",
    "def get_tqa_accuracy(model, questions, answers, labels, tokenizer, batch_size=128):\n",
    "    gc.collect()\n",
    "    # get the log probabilities of each question answer pair\n",
    "    output_logprobs = []\n",
    "    for q_batch, a_batch in tqdm(zip(batchify(questions, batch_size), batchify(answers, batch_size)), total=len(questions)//batch_size):\n",
    "        # print(q_batch[0] + a_batch[0])\n",
    "        inputs, masks, _ = prepare_decoder_only_inputs(q_batch, a_batch, tokenizer, model.model.device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            try:\n",
    "                # set the masks so that we do not add to tokens of input sentences and padding tokens\n",
    "                model.set_masks(masks.unsqueeze(-1))\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "            # calculate the probabilities for all tokens (all question answer pairs)\n",
    "            logits = model(**inputs).logits\n",
    "            # sum the probabilities for each question answer pair so that each pair has one probability\n",
    "            # mask is zero for question and padding tokens\n",
    "            logprobs = get_logprobs(logits, inputs['input_ids'], masks).sum(-1).detach().cpu().numpy()\n",
    "        output_logprobs.extend(logprobs)\n",
    "\n",
    "    return calc_acc(labels, output_logprobs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c2acfe74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "random_acc: 0.22605616877342702\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "questions, answers, labels = load_tqa_sentences(user_tag=\"\", assistant_tag=\"\", preset=\"\")\n",
    "\n",
    "correct = []\n",
    "for l in labels:\n",
    "    correct.append(1/len(l))\n",
    "random_acc = np.mean(correct)\n",
    "print(f\"random_acc: {random_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f32161d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "65it [00:58,  1.11it/s]                        "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Standard Zero-Shot Accuracy: 0.30966952264381886\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "questions, answers, labels = load_tqa_sentences(user_tag=user_tag, assistant_tag=assistant_tag, preset=\"\")\n",
    "model_baseline_acc = get_tqa_accuracy(model, questions, answers, labels, tokenizer, batch_size=batch_size)\n",
    "print(f\"Standard Zero-Shot Accuracy: {model_baseline_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "87fad3a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "65it [01:03,  1.02it/s]                        "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Zero-Shot Accuracy with Truthful Prompt: 0.31334149326805383\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# with truthful prompt\n",
    "questions, answers, labels = load_tqa_sentences(user_tag=user_tag, assistant_tag=assistant_tag, preset=experimental_prompt)\n",
    "model_baseline_preset_acc = get_tqa_accuracy(model, questions, answers, labels, tokenizer, batch_size=batch_size)\n",
    "print(f\"Zero-Shot Accuracy with Truthful Prompt: {model_baseline_preset_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96fd6fce",
   "metadata": {},
   "source": [
    "## Contrast Vector Control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95ca445e",
   "metadata": {},
   "source": [
    "#### Note that the implementation below is very inefficient and does not support generation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e1ffaf7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"7b\" in model_name_or_path: \n",
    "    layer_ids = np.arange(8, 32, 3) # for 7B model\n",
    "elif \"13b\" in model_name_or_path: \n",
    "    layer_ids = np.arange(10, 40, 3) # for 13B model\n",
    "else:\n",
    "    raise ValueError(\"Model name not recognised\")\n",
    "\n",
    "templates = [\n",
    "    ('Give a truthful answer. ', 'Give an untruthful answer. '),\n",
    "]\n",
    "block_name = \"decoder_block\"\n",
    "\n",
    "# create wrapped model\n",
    "wrapped_model = WrappedReadingVecModel(model, tokenizer)\n",
    "# make sure nothing is wrapped from previous runs\n",
    "wrapped_model.unwrap()\n",
    "# wrap model at desired layers and blocks\n",
    "wrapped_model.wrap_block(layer_ids, block_name=block_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "705141fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "65it [18:12, 16.81s/it]                        \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_sample_wise_aa_acc: 0.47980416156670747\n"
     ]
    }
   ],
   "source": [
    "PRINT = False\n",
    "\n",
    "questions, answers, labels = load_tqa_sentences(user_tag=user_tag, assistant_tag=assistant_tag, preset=\"\")\n",
    "coeff = 0.25\n",
    "# get the log probabilities of each question answer pair\n",
    "output_logprobs = []\n",
    "\n",
    "for q_batch, a_batch in tqdm(zip(batchify(questions, batch_size), batchify(answers, batch_size)), total=len(questions)//batch_size):    \n",
    "    if PRINT: \n",
    "        print(\"Questions: \", q_batch[0])\n",
    "        print(\"Answers: \", a_batch[0])\n",
    "    gc.collect()\n",
    "\n",
    "    # Concatenate the question and answer\n",
    "    # Pad all sequences to the same length (within batch)\n",
    "    # Split = the ID along which we can split Q and A sequence\n",
    "    inputs, masks, orig_split = prepare_decoder_only_inputs(q_batch, a_batch, tokenizer, model.model.device)\n",
    "    if PRINT:\n",
    "        print(inputs['input_ids'].shape)\n",
    "        print(inputs['attention_mask'].shape)\n",
    "    \n",
    "    directions = {}\n",
    "    for layer_id in layer_ids:\n",
    "        directions[layer_id] = 0\n",
    "    \n",
    "    for (experimental_prompt, reference_prompt) in templates:\n",
    "\n",
    "        wrapped_model.reset()\n",
    "\n",
    "        # Add the prompt to the question\n",
    "        q_batch_pos = [q + experimental_prompt for q in q_batch]\n",
    "        q_batch_neg = [q + reference_prompt for q in q_batch]\n",
    "        \n",
    "        # Concatenate the question and answer tokens, attention masks\n",
    "        # Pad all sequences to the same length (within batch)\n",
    "        inputs_pos_s, masks_pos_s, split_pos = prepare_decoder_only_inputs(q_batch_pos, a_batch, tokenizer, model.model.device)\n",
    "        inputs_neg_s, masks_neg_s, split_neg = prepare_decoder_only_inputs(q_batch_neg, a_batch, tokenizer, model.model.device)\n",
    "        # split = the number of tokens corresponding to the answer\n",
    "        split = inputs_neg_s['input_ids'].shape[1] - split_neg\n",
    "        \n",
    "        if PRINT: print(\"Split: \", split)\n",
    "\n",
    "        for layer_id in layer_ids:\n",
    "\n",
    "            with torch.no_grad():\n",
    "                \n",
    "                _ = wrapped_model(**inputs_pos_s)\n",
    "                pos_outputs = wrapped_model.get_activations(layer_ids, block_name=block_name)\n",
    "                _ = wrapped_model(**inputs_neg_s)\n",
    "                neg_outputs = wrapped_model.get_activations(layer_ids, block_name=block_name)\n",
    "                \n",
    "                # Take the difference between:\n",
    "                # - Positive prompt + question + answer\n",
    "                # - Negative prompt + question + answer\n",
    "                directions[layer_id] += coeff * (pos_outputs[layer_id][:, -split:] - neg_outputs[layer_id][:, -split:]) / len(templates)\n",
    "                if PRINT: print(\"Layer id: \", layer_id)\n",
    "                \n",
    "                # Directions contains the individual token differences at all positions corresponding to answer\n",
    "                # Directions is a tensor of shape (batch_size, seq_len, token_dim)\n",
    "                if PRINT: print(\"Directions: \", directions[layer_id].shape)\n",
    "\n",
    "                # Reset model to remove previous iteration's controller\n",
    "                wrapped_model.reset()\n",
    "                # Set controller for the last (seq_len) activations of the model in subsequent inferences \n",
    "                wrapped_model.set_controller([l for l in layer_ids if l <= layer_id], directions, \n",
    "                                            masks=masks[:, -split:, None], \n",
    "                                            token_pos=\"end\",\n",
    "                                            normalize=False)\n",
    "\n",
    "    with torch.no_grad():b\n",
    "        # Now, evaluate the token-wise next-token log-probs of the modified model on the original Q/A pair. \n",
    "        logits = wrapped_model(**inputs).logits\n",
    "        logprobs = get_logprobs(logits, inputs['input_ids'], masks).sum(-1).detach().cpu().numpy()\n",
    "    output_logprobs.extend(logprobs)\n",
    "\n",
    "    assert np.isnan(output_logprobs).sum() == 0, \"NaN in output logprobs\"\n",
    "\n",
    "model_sample_wise_aa_acc = calc_acc(labels, output_logprobs)\n",
    "print(f\"model_sample_wise_aa_acc: {model_sample_wise_aa_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1319c5ce",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
