{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8027bd15-7a46-4690-845a-97bab1519022",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n",
      "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n",
      "`config.hidden_activation` if you want to override this behaviour.\n",
      "See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a55f47dd7bf14c228617049c8c238ddb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gemma-2b-it into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('')\n",
    "from SAE.util.steering import *\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "device = set_up()\n",
    "model_name = \"gemma-2b-it\"\n",
    "sae_name = \"gemma-2b-it-res-jb\"\n",
    "sae_id = \"blocks.12.hook_resid_post\"\n",
    "model, sae = load_model(model_name, sae_name, sae_id, device)\n",
    "tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b-it')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1bf4b433-4e16-4e0a-8903-71151c363294",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from util.option_dict_4 import *\n",
    "from util.prompts import  get_prompt\n",
    "from util.lm_format import apply_format\n",
    "from util.steering import *\n",
    "\n",
    "steering_vectors = [sae.W_dec[5060]]\n",
    "coeff = 100\n",
    "\n",
    "seed = 16\n",
    "layer = 12\n",
    "data=json.load(open(f\"../../data/TRAIT/TRAIT_Dark.json\", encoding='utf-8'))\n",
    "sample = data[1234]\n",
    "instruction=sample[\"situation\"]+\" \"+sample[\"query\"]\n",
    "response_high1=sample[\"response_high1\"]\n",
    "response_high2=sample[\"response_high2\"]\n",
    "response_low1=sample[\"response_low1\"]\n",
    "response_low2=sample[\"response_low2\"]\n",
    "\n",
    "prompt=get_prompt(1, False, instruction, response_high1, response_high2, response_low1, response_low2)\n",
    "encoded=apply_format(prompt, \"chat\", tokenizer)\n",
    "option_tokens=get_option_token(\"ABCD\")\n",
    "steering_on = True\n",
    "likelihoods = get_likelihood_steer(encoded, model, layer, coeff, steering_vectors, True, seed=seed).squeeze().tolist()\n",
    "vocab_probabilities={}\n",
    "for token in option_tokens:\n",
    "    vocab_probabilities[token]=likelihoods[tokenizer.convert_tokens_to_ids(token)]\n",
    "vocab_probabilities = dict(sorted(vocab_probabilities.items(), key=lambda item: item[1], reverse=True))\n",
    "vocab_probabilities = {k: vocab_probabilities[k] for k in list(vocab_probabilities)[:10]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d43f2ec7-c220-4780-a7b0-adb91b35c8b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': 0.003120406763628125,\n",
       " 'A': 0.0015889368951320648,\n",
       " 'D': 8.822359632176813e-06,\n",
       " 'B': 2.2841149984742515e-06}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a525590-cdfe-401a-97b4-2573387cf4e0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
