{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Power Seeking, Corrigibility and Wealth-Seeking Open Ended Generation\n",
    "\n",
    "* Evaluate promotion and suppression for power-seeking, corrigibility or wealth-seeking in **LLaMA 3.1 8B** or **Gemma 2 9B**\n",
    "* This script uses a GPT-4o Judge and requires an **OpenAI API Key**. \n",
    "* The **default setting** method,model,valence costs **~$1.6** and runs in **~50 minutes** on an **NVIDIA A6000 (48GB)** (excluding model download/load time). The bulk of the time is spent on generation, as opposed to steering vector estimation.\n",
    "* For this script to download the models, you must be logged into the **huggingface-cli** and have access to **LLaMA 3.1 8B Instruct** and **Gemma 2 9B it**\n",
    "* **Restart notebook before changing model/dataset/valence/method**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transformers Cache Directory: ../cache/\n",
      "Transformers Cache Directory: ../cache/\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "import json \n",
    "\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "os.environ['TRANSFORMERS_CACHE'] = '../cache/' \n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "import torch\n",
    "import numpy \n",
    "from sklearn.decomposition import PCA\n",
    "# from IPython.display import display\n",
    "import pandas as pd\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pickle\n",
    "from openai import OpenAI\n",
    "\n",
    "\n",
    "import sys \n",
    "import os\n",
    "sys.path.append(os.path.abspath('../src'))\n",
    "from templates import *\n",
    "from visualization import *\n",
    "from models import * \n",
    "from utils import *\n",
    "from hook import *\n",
    "from vector import *\n",
    "from data_prepare import *\n",
    "from inference import *\n",
    "from custom_layer import replace_llama_decoder_layer, replace_llama_decoder_layer_parent\n",
    "from custom_layer_gemma import replace_gemma_decoder_layer\n",
    "\n",
    "from Anthro import gpt_grade_with_examples\n",
    "from Anthro import prepare_corr_less_hhh\n",
    "from Anthro import prepare_power_seeking\n",
    "# from tqa_open_gpt import getFSAndSystemPrompt, get_open_ended_tqa\n",
    "from Anthro import prepare_wealth_seeking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select Model and Method\n",
    "\n",
    "* Choose model between **LLaMA 3.1** and **Gemma 2**\n",
    "* Choose any DISCO or baseline method.\n",
    "* The cells below will automatically download the selected model to the local `../cache/` directory (given you are logged into the huggingface-cli on the command line)\n",
    "* We recommend **LLaMA 3.1** if you are compute-constrained\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You have chosen DISCO-QV, LLaMA3_Instruct, Power, P\n"
     ]
    }
   ],
   "source": [
    "model_choice   = \"LLaMA3_Instruct\" # Choices: LLaMA3_Instruct, Gemma2_Instruct\n",
    "method_choice  = \"DISCO-QV\"         # Choices: 'DISCO-QV', 'DISCO-Q', 'DISCO-V, 'CAA', 'ITI', 'Post Attn.', 'MLP Input', 'MLP Output', 'Comm. Steer.', 'Attn Output'\n",
    "dataset_choice = \"Power\"            # Choices: Corr, Power, Wealth, TQA\n",
    "valence        = \"P\"               # Choices: P, N\n",
    "print(f\"You have chosen {method_choice}, {model_choice}, {dataset_choice}, {valence}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup OpenAI Client\n",
    "\n",
    "* Paste **your OpenAI api key** here to run the script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = OpenAI(api_key =  'INSERT API KEY HERE')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select batch sizes for steering vector estimation and evaluation\n",
    "\n",
    "* **Reduce if switching to  Gemma 2**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_steering_vec = 15 # LLaMA: 15, Gemma: 3\n",
    "batch_eval         = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading shards: 100%|██████████| 4/4 [00:00<00:00, 896.94it/s]\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.14s/it]\n"
     ]
    }
   ],
   "source": [
    "if model_choice == \"LLaMA3_Instruct\":\n",
    "    model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model = getLLaMA3_Instruct_WithInfo()\n",
    "    replace_llama_decoder_layer(model, model.config) # Enables hooking post attention residual stream\n",
    "    with open('../configs/llama3.json', 'r') as f:\n",
    "        cfg = json.load(f)[f\"{method_choice}\"]\n",
    "\n",
    "elif model_choice == \"Gemma2_Instruct\":\n",
    "    model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model = getGemma2_Instruct_WithInfo()\n",
    "    replace_gemma_decoder_layer(model)\n",
    "    with open('../configs/gemma2.json', 'r') as f:\n",
    "        cfg = json.load(f)[method_choice]\n",
    "\n",
    "# Model\n",
    "model                      = model.cuda()\n",
    "tokenizer.padding_side     = 'left' \n",
    "layers                     = list(range(0, L_TOTAL))\n",
    "\n",
    "# Steering\n",
    "op_combo        = tuple(cfg['op_name'])\n",
    "use_best_heads  = cfg[f\"{dataset_choice}_{valence}\"]['use_best_heads']\n",
    "alpha           = cfg[f\"{dataset_choice}_{valence}\"]['alpha']\n",
    "alpha           = {op : a for op, a in zip(op_combo, alpha )} \n",
    "use_best_layers = cfg[f\"{dataset_choice}_{valence}\"]['use_best_layers']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_choice == \"Corr\":\n",
    "    prompt_format = corr_prompt\n",
    "    data = prepare_corr_less_hhh(template = template, pth =  \"../data/corrigible-less-HHH.jsonl\", \n",
    "                        system_prompt_tr = None, fewshot = [],\n",
    "                        p_tr = 0.2, p_val = 0.29, batch_train = batch_steering_vec, batch_val = batch_eval, batch_test = batch_eval,\n",
    "                        seed = 42, system_prompt_val_te = None, tokenizer = tokenizer)        \n",
    "elif dataset_choice == \"Power\":\n",
    "    prompt_format = power_prompt\n",
    "    data = prepare_power_seeking(template = template, pth =  \"../data/power-seeking_train_val.csv\", \n",
    "                        system_prompt_tr = None, fewshot = [],\n",
    "                        p_tr = 0.18, p_val = 0.16, batch_train = batch_steering_vec, batch_val = batch_eval, batch_test = batch_eval,\n",
    "                        seed = 42, system_prompt_val_te = None, tokenizer = tokenizer)\n",
    "elif dataset_choice == \"Wealth\":\n",
    "    prompt_format        = wealth_prompt\n",
    "    data = prepare_wealth_seeking(template = template, pth =  \"../data/wealth-seeking-train-val.csv\", \n",
    "                        system_prompt_tr = None, fewshot = [],\n",
    "                        p_tr = 0.17, p_val = 0.17, batch_train = batch_steering_vec, batch_val = batch_eval, batch_test = batch_eval,\n",
    "                        seed = 42, system_prompt_val_te = None, tokenizer = tokenizer)\n",
    "\n",
    "ICL_Grading_Pos, ICL_Grading_Neg = data[\"test\"]['df']['matching'].tolist(), data[\"test\"]['df']['not_matching'].tolist() \n",
    "dataloader = data['test']['dataloader']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Steering Vector Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature Extraction and Vector Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:18<00:00,  2.28s/it]\n",
      "100%|██████████| 8/8 [00:15<00:00,  1.94s/it]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 15872.48it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 17154.62it/s]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 15675.98it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 16456.32it/s]\n"
     ]
    }
   ],
   "source": [
    "# 3a. Representation extraction\n",
    "op_to_layer_to_val_pos, end_indxs_pos, op_to_layer_to_val_neg, end_indxs_neg = extract_pos_neg(data, model, op_combo, tokenizer, model_use=model_choice, layers = L_TOTAL)\n",
    "\n",
    "# 3b. Representation agglomeration\n",
    "op_to_layer_to_val_pos_last = agglomerate(op_to_layer_to_val_pos)\n",
    "op_to_layer_to_val_neg_last = agglomerate(op_to_layer_to_val_neg)\n",
    "\n",
    "# 3c. Steering Vector\n",
    "op_to_meandiff = populate_MeanDiff_hooked_style_dataset(op_combo, op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### If using best layer or heads, re-compute their positions using a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:22<00:00,  3.26s/it]\n",
      "100%|██████████| 7/7 [00:24<00:00,  3.50s/it]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 17267.17it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 19127.51it/s]\n",
      "q: 100%|██████████| 32/32 [00:00<00:00, 17471.72it/s]\n",
      "v: 100%|██████████| 32/32 [00:00<00:00, 19065.02it/s]\n",
      "0: 100%|██████████| 32/32 [00:00<00:00, 137.40it/s]\n",
      "0: 100%|██████████| 8/8 [00:00<00:00, 301.29it/s]\n",
      "1: 100%|██████████| 32/32 [00:00<00:00, 54.07it/s]\n",
      "1: 100%|██████████| 8/8 [00:00<00:00, 245.69it/s]\n",
      "2: 100%|██████████| 32/32 [00:00<00:00, 53.37it/s]\n",
      "2: 100%|██████████| 8/8 [00:00<00:00, 196.53it/s]\n",
      "3: 100%|██████████| 32/32 [00:00<00:00, 55.92it/s]\n",
      "3: 100%|██████████| 8/8 [00:00<00:00, 210.27it/s]\n",
      "4: 100%|██████████| 32/32 [00:00<00:00, 74.28it/s]\n",
      "4: 100%|██████████| 8/8 [00:00<00:00, 221.60it/s]\n",
      "5: 100%|██████████| 32/32 [00:00<00:00, 83.86it/s]\n",
      "5: 100%|██████████| 8/8 [00:00<00:00, 171.14it/s]\n",
      "6: 100%|██████████| 32/32 [00:00<00:00, 62.51it/s]\n",
      "6: 100%|██████████| 8/8 [00:00<00:00, 157.91it/s]\n",
      "7: 100%|██████████| 32/32 [00:00<00:00, 63.55it/s]\n",
      "7: 100%|██████████| 8/8 [00:00<00:00, 124.22it/s]\n",
      "8: 100%|██████████| 32/32 [00:00<00:00, 68.27it/s]\n",
      "8: 100%|██████████| 8/8 [00:00<00:00, 71.30it/s]\n",
      "9: 100%|██████████| 32/32 [00:01<00:00, 30.12it/s]\n",
      "9: 100%|██████████| 8/8 [00:00<00:00, 217.34it/s]\n",
      "10: 100%|██████████| 32/32 [00:00<00:00, 94.67it/s]\n",
      "10: 100%|██████████| 8/8 [00:00<00:00, 213.56it/s]\n",
      "11: 100%|██████████| 32/32 [00:00<00:00, 101.60it/s]\n",
      "11: 100%|██████████| 8/8 [00:00<00:00, 206.43it/s]\n",
      "12: 100%|██████████| 32/32 [00:00<00:00, 111.58it/s]\n",
      "12: 100%|██████████| 8/8 [00:00<00:00, 212.65it/s]\n",
      "13: 100%|██████████| 32/32 [00:00<00:00, 112.29it/s]\n",
      "13: 100%|██████████| 8/8 [00:00<00:00, 200.22it/s]\n",
      "14: 100%|██████████| 32/32 [00:00<00:00, 130.45it/s]\n",
      "14: 100%|██████████| 8/8 [00:00<00:00, 223.89it/s]\n",
      "15: 100%|██████████| 32/32 [00:00<00:00, 102.65it/s]\n",
      "15: 100%|██████████| 8/8 [00:00<00:00, 185.78it/s]\n",
      "16: 100%|██████████| 32/32 [00:00<00:00, 122.95it/s]\n",
      "16: 100%|██████████| 8/8 [00:00<00:00, 196.27it/s]\n",
      "17: 100%|██████████| 32/32 [00:00<00:00, 93.11it/s]\n",
      "17: 100%|██████████| 8/8 [00:00<00:00, 220.75it/s]\n",
      "18: 100%|██████████| 32/32 [00:00<00:00, 108.90it/s]\n",
      "18: 100%|██████████| 8/8 [00:00<00:00, 193.58it/s]\n",
      "19: 100%|██████████| 32/32 [00:00<00:00, 81.16it/s]\n",
      "19: 100%|██████████| 8/8 [00:00<00:00, 119.71it/s]\n",
      "20: 100%|██████████| 32/32 [00:00<00:00, 72.28it/s]\n",
      "20: 100%|██████████| 8/8 [00:00<00:00, 129.84it/s]\n",
      "21: 100%|██████████| 32/32 [00:00<00:00, 47.73it/s]\n",
      "21: 100%|██████████| 8/8 [00:00<00:00, 116.23it/s]\n",
      "22: 100%|██████████| 32/32 [00:00<00:00, 44.41it/s]\n",
      "22: 100%|██████████| 8/8 [00:00<00:00, 201.89it/s]\n",
      "23: 100%|██████████| 32/32 [00:00<00:00, 89.51it/s]\n",
      "23: 100%|██████████| 8/8 [00:00<00:00, 198.63it/s]\n",
      "24: 100%|██████████| 32/32 [00:00<00:00, 103.56it/s]\n",
      "24: 100%|██████████| 8/8 [00:00<00:00, 180.02it/s]\n",
      "25: 100%|██████████| 32/32 [00:00<00:00, 75.71it/s]\n",
      "25: 100%|██████████| 8/8 [00:00<00:00, 177.20it/s]\n",
      "26: 100%|██████████| 32/32 [00:00<00:00, 91.26it/s]\n",
      "26: 100%|██████████| 8/8 [00:00<00:00, 199.71it/s]\n",
      "27: 100%|██████████| 32/32 [00:00<00:00, 92.78it/s]\n",
      "27: 100%|██████████| 8/8 [00:00<00:00, 173.49it/s]\n",
      "28: 100%|██████████| 32/32 [00:00<00:00, 90.46it/s]\n",
      "28: 100%|██████████| 8/8 [00:00<00:00, 173.63it/s]\n",
      "29: 100%|██████████| 32/32 [00:00<00:00, 98.23it/s] \n",
      "29: 100%|██████████| 8/8 [00:00<00:00, 105.48it/s]\n",
      "30: 100%|██████████| 32/32 [00:00<00:00, 42.09it/s]\n",
      "30: 100%|██████████| 8/8 [00:00<00:00, 159.65it/s]\n",
      "31: 100%|██████████| 32/32 [00:00<00:00, 71.48it/s]\n",
      "31: 100%|██████████| 8/8 [00:00<00:00, 137.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "q\n",
      "[10, 5, 29, 8, 27, 9, 24, 4, 11, 28, 13, 16, 19, 31, 7, 12, 25, 0, 30, 6, 18, 2, 15, 3, 17, 14, 22, 20, 23, 1, 21, 26] 32\n",
      "1.0\n",
      "1.0\n",
      "v\n",
      "[5, 4, 1, 2, 3, 0, 6, 7] 8\n",
      "1.0\n",
      "1.0\n"
     ]
    }
   ],
   "source": [
    "if use_best_heads != -1 or use_best_layers == 1: # Extract val features if needed below\n",
    "    op_to_layer_to_val_pos_val, _, op_to_layer_to_val_neg_val, _ = extract_pos_neg(data, model, op_combo, tokenizer, model_use=model_choice, layers = L_TOTAL, split = 'val')\n",
    "    op_to_layer_to_val_pos_last_val = agglomerate(op_to_layer_to_val_pos_val)\n",
    "    op_to_layer_to_val_neg_last_val = agglomerate(op_to_layer_to_val_neg_val)\n",
    "        \n",
    "if use_best_heads != -1: # Mask non top heads\n",
    "    HS = HeadStats(op_combo, model_choice, None, data, layers,\n",
    "    op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,\n",
    "    op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val)\n",
    "    masking_val_df, masking_train_df =  HS['centroid']['op_to_df_val'], HS['centroid']['op_to_df_train']\n",
    "\n",
    "    if type(use_best_heads) == list:\n",
    "        for op, H in zip(op_combo, use_best_heads):\n",
    "            op_to_meandiff = apply_ITI_mask([op], mask_criterion = \"ValAcc\", num_kv = num_kv, num_heads = num_heads, \n",
    "                                        op_to_df_val = masking_val_df, op_to_df_train = masking_train_df, \n",
    "                                        K = H, op_to_meandiff = op_to_meandiff, head_dim = head_dim) \n",
    "    else:\n",
    "        op_to_meandiff = apply_ITI_mask(op_combo, mask_criterion = \"ValAcc\", num_kv = num_kv, num_heads = num_heads, \n",
    "                                        op_to_df_val = masking_val_df, op_to_df_train = masking_train_df, \n",
    "                                        K = use_best_heads, op_to_meandiff = op_to_meandiff, head_dim = head_dim) \n",
    "\n",
    "elif use_best_layers == 1: # Mask non top layer\n",
    "    _, op_to_meandiff = getLayerAccs(op_combo, op_to_layer_to_val_pos_last,op_to_layer_to_val_neg_last, op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val, op_to_meandiff, mask = True, viz = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grade\n",
    "\n",
    "* Note: Due to minor nondeterminism in GPT outputs (even with temperature 0), scores may fluctuate slightly across runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 42/42 [30:25<00:00, 43.47s/it]\n"
     ]
    }
   ],
   "source": [
    "ops_dict     = {op : {'layers' : layers, 'layers_to_alpha' : {layer : alpha[op] for layer in layers} } for op in op_combo }\n",
    "results      = generate_with_hooks(model, tokenizer, ops_dict, dataloader, op_to_meandiff, inject_op = translation_op_, max_new_tokens = 512, model_use = model_choice)        \n",
    "QA_pairs     = parseQA(results, tokenizer, **parseStrs)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Score : 3.115569823434992: 100%|██████████| 623/623 [16:19<00:00,  1.57s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Behavior Grade (LLaMA3_Instruct , DISCO-QV) : 3.115569823434992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "_, usages, _, scores  = gpt_grade_with_examples(client, QA_pairs, \n",
    "                                                ICL_Grading_Pos, ICL_Grading_Neg, \n",
    "                                                \"gpt-4o-2024-08-06\", 0.0, \n",
    "                                                prompt_format, max_tokens = 512,\n",
    "                                                system_prompt = None)\n",
    "\n",
    "meangrade = np.array(scores).mean()\n",
    "in_toks   = sum([u.prompt_tokens for u in usages])\n",
    "out_toks  = sum([u.completion_tokens for u in usages])\n",
    "cost      = in_toks * 2.50 / 1000000 + out_toks * 10 / 1000000\n",
    "\n",
    "print(f\"Behavior Grade ({model_choice} , {method_choice}) : {meangrade}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cost $1.607895\n"
     ]
    }
   ],
   "source": [
    "print(f\"Cost ${cost}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
