{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8027bd15-7a46-4690-845a-97bab1519022",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constantSetting center_unembed=False instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a44350de742b4a7d9d1f686659f408d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/452 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "302a9869b48a4eb7a0b0e37c3a189dfd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/1.78k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0cf169c7c13e41968c0e55f8a5837d2c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d59b15ce423a4a2181c5f06d42553ebf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d698e2f2f4349c49b83ebc67121d4bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76250deed5314a20abb1eab6ea0b23dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d063d099fc804e48b8d6a7e3cc97c8d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "05b1566bc93647538d24249f5768bc1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b755b0612f4e400d888d76e0c65b8472",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gemma-2-9b-it into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('')\n",
    "from SAE.util.steering import *\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "device = set_up()\n",
    "model_name = \"gemma-2-9b-it\"\n",
    "sae_name = \"gemma-scope-9b-it-res-canonical\"\n",
    "sae_id = \"layer_31/width_131k/canonical\"\n",
    "model, sae = load_model(model_name, sae_name, sae_id, device)\n",
    "tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b-it')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d454e2a3-229c-4d04-87f9-5284b56ccc0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from util.option_dict_4 import *\n",
    "from util.prompts import  get_prompt\n",
    "from util.lm_format import apply_format\n",
    "from util.steering import *\n",
    "\n",
    "steering_vectors = [sae.W_dec[30812],sae.W_dec[38950]]\n",
    "coeff = 1500\n",
    "\n",
    "seed = 16\n",
    "layer = 31\n",
    "data=json.load(open(f\"../../data/TRAIT/TRAIT_Dark.json\", encoding='utf-8'))\n",
    "sample = data[1234]\n",
    "instruction=sample[\"situation\"]+\" \"+sample[\"query\"]\n",
    "response_high1=sample[\"response_high1\"]\n",
    "response_high2=sample[\"response_high2\"]\n",
    "response_low1=sample[\"response_low1\"]\n",
    "response_low2=sample[\"response_low2\"]\n",
    "\n",
    "prompt=get_prompt(1, False, instruction, response_high1, response_high2, response_low1, response_low2)\n",
    "encoded=apply_format(prompt, \"chat\", tokenizer)\n",
    "option_tokens=get_option_token(\"ABCD\")\n",
    "steering_on = True\n",
    "likelihoods = get_likelihood_steer(encoded, model, layer, coeff, steering_vectors, True, seed=seed).squeeze().tolist()\n",
    "vocab_probabilities={}\n",
    "for token in option_tokens:\n",
    "    vocab_probabilities[token]=likelihoods[tokenizer.convert_tokens_to_ids(token)]\n",
    "vocab_probabilities = dict(sorted(vocab_probabilities.items(), key=lambda item: item[1], reverse=True))\n",
    "vocab_probabilities = {k: vocab_probabilities[k] for k in list(vocab_probabilities)[:10]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d43f2ec7-c220-4780-a7b0-adb91b35c8b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'D': 0.003367143450304866,\n",
       " 'C': 0.0002600759326014668,\n",
       " 'A': 9.611510904505849e-05,\n",
       " 'B': 2.327394759049639e-05}"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33722ba2-f040-42aa-a48c-551a66635b6e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
